GithubHelp home page GithubHelp logo

aredden / torch-cublas-hgemm Goto Github PK

View Code? Open in Web Editor NEW
25.0 3.0 0.0 44 KB

PyTorch half precision gemm lib w/ fused optional bias + optional relu/gelu

Python 26.70% C++ 12.53% Cuda 60.77%
cuda float16 gemm pytorch

torch-cublas-hgemm's Introduction

CublasOps: High-Performance Linear Layers with cuBLAS and cuBLASLt

CublasOps is a PyTorch extension library that provides high-performance linear layers for half-precision (FP16) matrix multiplications using NVIDIA's cuBLAS and cuBLASLt libraries. It offers fast and efficient execution of A x B^T matrix multiplications with optional bias addition and activation functions (ReLU or GELU).

Features

  • Fast half-precision (FP16) matrix multiplications using cuBLAS and cuBLASLt with half precision accumulation (2x speedup on 4090).
  • Support for fused operations: matmul + bias + activation (ReLU or GELU)
  • Easy-to-use linear layers: CublasLinear, CublasLinearRelu, and CublasLinearGelu
  • Seamless integration with PyTorch models
  • Batched and non-batched operations

For example: using the cublas linear with 4096, 4096 in/out features, and a (2, 4096, 4096) input tensor on my RTX 4090:

CUBLAS INFERENCE: 
FLOPS: 274877906944
TFLOP/s: 305.801

TORCH INFERENCE: 
FLOPS: 274877906944
TFLOP/s: 166.989

Installation

To install CublasOps, follow these steps:

  1. Make sure you have PyTorch installed with CUDA support.
  2. Clone the repository:
    git clone https://github.com/aredden/torch-cublas-hgemm.git
    
  3. Navigate to the cloned repository:
    cd torch-cublas-hgemm
    
  4. Build and install the extension:
    python -m pip install -U -v .
    

Usage

Here's a simple example of how to use CublasOps in your PyTorch code:

import torch
from cublas_ops import CublasLinear, CublasLinearGelu, CublasLinearRelu

in_features = 64
out_features = 64
bias = True  # or False

# (A x B^T + bias)
linear = CublasLinear(in_features, out_features, bias=bias, device='cuda', dtype=torch.float16)
input_tensor = torch.randn((2, 8, 64)).cuda().half()
# or...
input_tensor = torch.randn((8, 64)).cuda().half()

output_tensor = linear(input_tensor)

# For fused GELU: gelu(A x B^T + bias)
linear_gelu = CublasLinearGelu(in_features, out_features, bias=bias, device='cuda', dtype=torch.float16)

# For fused ReLU: relu(A x B^T + bias)
linear_relu = CublasLinearRelu(in_features, out_features, bias=bias, device='cuda', dtype=torch.float16)

API Reference

Linear Layers

  • CublasLinear(in_features, out_features, bias=True, device=None, dtype=torch.float16, epilogue_str="NONE")
    • A linear layer that performs A x B^T + bias matrix multiplication with optional bias addition.
  • CublasLinearGelu(in_features, out_features, bias=True, device=None, dtype=torch.float16)
    • A linear layer with fused GELU activation: gelu(A x B^T + bias).
  • CublasLinearRelu(in_features, out_features, bias=True, device=None, dtype=torch.float16)
    • A linear layer with fused ReLU activation: relu(A x B^T + bias).

Low-Level Functions

  • cublas_half_matmul_simple(a: torch.Tensor, b: torch.Tensor)
    • Performs a simple A x B^T matrix multiplication using cuBLAS.
  • cublas_half_matmul_batched_simple(a: torch.Tensor, b: torch.Tensor)
    • Performs a batched A x B^T batched matrix multiplication using cuBLAS. At least one of A/B should have 3 dimensions, with the other having 2 or 3.
  • cublaslt_fused_half_matmul_simple(a: torch.Tensor, b: torch.Tensor, bias: Optional[torch.Tensor] = None, epilogue_str: Optional[Literal["NONE", "RELU", "GELU"]] = "NONE")
    • Performs a fused optional_activation(A x B^T + optional(bias)) matrix multiplication with optional bias addition and activation using cuBLASLt.
  • def cublaslt_fused_half_matmul_batched_simple(a: torch.Tensor, b: torch.Tensor, bias: Optional[torch.Tensor] = None, epilogue_str: Optional[Literal["NONE", "RELU", "GELU"]] = "NONE")
    • Performs a fused optional_activation(A x B^T + optional(bias)) batched matrix multiplication with optional bias addition and activation using cuBLASLt. At least one of A/B should have 3 dimensions, with the other having 2 or 3.

Contributing

Contributions to CublasOps are welcome! If you encounter any issues, have suggestions for improvements, or want to add new features, please open an issue or submit a pull request on the GitHub repository.

Acknowledgments

CublasOps is built upon the powerful cuBLAS and cuBLASLt libraries provided by NVIDIA. We would like to thank the NVIDIA team for their excellent work on these libraries.

torch-cublas-hgemm's People

Contributors

aredden avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

torch-cublas-hgemm's Issues

Matmul errors out when one tensor is batched and another isn't

Cool idea! Proud to submit a first bug report :)

This PyTorch code (Ubuntu, CUDA 12.1, Torch 2.2.2, Nvidia 4090):

>>> import cublas_ops
>>> import torch
>>> x = torch.ones([1, 2560, 8192], dtype=torch.float16, device="cuda:0")
>>> x
tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]], device='cuda:0',
       dtype=torch.float16)
>>> y = torch.ones([8192, 28672], dtype=torch.float16, device="cuda:0")
>>> z = cublas_ops.cublas_half_matmul_batched_simple(x, y)

fails with this stack trace:

 ** On entry to HgemmStridedBatched parameter number 10 had an illegal value
cuBLAS API failed with status 7
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/cublas_ops/__init__.py", line 39, in cublas_half_matmul_batched_simple
    return _cublas_hgemm_batched_simple(a, b)
RuntimeError: cuBLAS API failed

but this code works:

>>> y = torch.ones([1, 8192, 28672], dtype=torch.float16, device="cuda:0")
>>> z = cublas_ops.cublas_half_matmul_batched_simple(x, y)

Error installing

Prior to seeing your note in the README on https://github.com/aredden/flux-fp8-api, I went ahead and tried to install this because it was in the requirements.txt for flux-fp8-api, but got an error on Ubuntu 22.04. I ended up having to change line 31 in cublas_ops/init.py as follows:

def get(cls, __name: str, device: torch.device) -> torch.Any:

to this:

def get(cls, __name: str, device: torch.device) -> Any:

and I had to include:

from typing import Any

I'm way out of my league here, but wanted to report it all the same. Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.