GithubHelp home page GithubHelp logo

mamba.py's Introduction

My favorite projects :

ReadMe Card

ReadMe Card

ReadMe Card

mamba.py's People

Contributors

alxndrtl avatar beebopkim avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

mamba.py's Issues

Discretization step seems to be different from the paper

Hi,

Firstly, thanks for making this repo. I found it very useful in understanding the scan algorithm. However, discretization in this repo seems to be different from the Eq 4 of the paper. Do you have any comments on this? Also, I wonder why the original paper needs the discretization step in the first place since it is possible to make the discrete versions of A and B conditioned on the input directly. I imagine that it must have something to do with the initialization, but I am not sure.

deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) #ย (B, L, ED, N)

image

MLX inference error with BFloat16

from commit hash 6a49341:

What I have done was executing generate.py for a mamba fine-tuned model - kuotient/mamba-ko-2.8b, and below error was happened. How can I deal with this error?

(venv_mamba_py) ******@Mac-Studio-2022-01 scripts % python generate.py --prompt="Mamba is a type of" --hf_model_name="kuotient/mamba-ko-2.8b" --n_tokens=100

Traceback (most recent call last):
  File "/Users/******/test/mamba.py/mlx/scripts/generate.py", line 31, in <module>
    model = MambaLM.from_pretrained(args.hf_model_name)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/******/test/mamba.py/mlx/mamba_lm_mlx.py", line 150, in from_pretrained
    mlx_state_dict = map_mambassm_torch_to_mlx(state_dict)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/******/test/mamba.py/mlx/utils.py", line 53, in map_mambassm_torch_to_mlx
    return map_mambapy_torch_to_mlx(new_state_dict)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/******/test/mamba.py/mlx/utils.py", line 37, in map_mambapy_torch_to_mlx
    new_state_dict[key] = value.numpy()
                          ^^^^^^^^^^^^^
TypeError: Got unsupported ScalarType BFloat16

My environments are:

Why use element-wise multiplication rather than matrix multiplication in the function `selective_scan_seq`

Hello. In the function selective_scan_seq, there are two points that I am confused:

  • BX = deltaB * (x.unsqueeze(-1)) #ย (B, L, ED, N)
  • h = deltaA[:, t] * h + BX[:, t]
    These two lines of code seem to be element-wise multiplication.

However, in the paper, the equation is
$$h_t = \bar{A} h_{t-1} + \bar{B} x_{t}$$
Both terms in the right side of the euation is performed in matrix multiplication.

I am curious that do the two lines of code use some tricks to convert matrix multiplication into the elementwise one?

huge huge memory usage!!

i find that the pscan method used in this Mamba implementation use huge amount of memory! any idea how to reduce memory consumption? or replace the pscan method with other implementation??

great thanks!

Can we get an explicit license?

Would it be possible to put an explicit OSS license on the codebase just to remove approval burden when experimenting based on this code? I want to prototype a bit on laptop to get the plumbing right before moving to cloud GPUs to actually train and this looks like the best CPU friendly implementation I have found for that purpose

MLX memory usage at inference

The 1.4B model takes 10-11GB of RAM at inference. (own test, M2 Pro 16GB)
The 2.8B model takes around 50GB at inference. (https://twitter.com/awnihannun/status/1749515431336112275)

This is not due to loading the model from HF (same memory footprint if model initialized with random weights).
This is neither due to the ssm_step.

However, turning off the convolution at inference reduces the memory footprint (by 3GB for the 1.4B model : from 10GB to around 7GB). It also greatly speeds up the inference. (buf of course, the forward is not correct).

Files concerned :

  • mamba_mlx.py (step functions)
  • misc.py

The depthwise conv implemented in misc.py seems to be part of the problem.
As said the file, the PyTorch versions uses groups=channels (true depthwise), while the MLX depthwise conv in misc.py uses groups=1 but with some weights set at 0. (only workaround found).
This result in a (d_model, 4, d_model) filter size, against (d_model, 4) for the "true" depthwise conv.

Either :
-wait for MLX to implement groups=channels for conv1d
-find another workaround (one possibility is to create d_model conv object, each with 1 input and 1 output channel. but this result in a big for loop which is around 45x slower than the workaround found. but ofc, memory usage is greatly reduces (by d_model=2560)

Possible SSM-Transformers implementation?

Hey! Awesome work on this project! I know it's not technically vanilla Mamba but I've been trying to convert the new SSM-Transformers Jamba into MLX for more efficient training and usability but am having a difficult time. My specialty is in the training/datasets world and not the strongest in the core math behind the model architectures beyond the basic implementations.

Would somebody know of an easier way to get Jamba converted into MLX? I truly think Jamba has A LOT to offer and could do some awesome stuff in the MLX format and for local model training with Mac

I've provided the modeling script released by AI21 for quick reference. Is this feasible or just way too complicated at the moment?

modeling_jamba.txt

About the speed test

Thank you for sharing your fantastic work.
We have noticed the image that with rising the dimension of d_state, the mamba's time occupation doesn't rise.
However, we found in code that writes (selective_scan_fwd_kernel.cuh#163):

for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
    ...
   if constexpr (kIsVariableB) {
                load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
                    smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
   }
}

which shows a for loop with related to state_idx that reads from HBM to shared memory.

Then I tested the speed again and finds that with the d_state rises, the time occupation of mamba rises linearly, which is aligned with the code.

    device = torch.device("cuda")
    dtype = torch.float32
    B, L, G, D, N, R = 3, 4096, 4, 192, 16, 192 // 16
    xi = torch.randn((B, G * D, L), device=device, dtype=dtype)
    Ai = torch.randn((G * D, N), device=device, dtype=dtype)
    Di = torch.randn((G * D), device=device, dtype=dtype)
    dti = torch.randn((B, G * D, L), device=device, dtype=dtype)
    Bi = torch.randn((B, G, N, L), device=device, dtype=dtype)
    Ci = torch.randn((B, G, N, L), device=device, dtype=dtype)
    tpb = torch.randn((G * D), device=device, dtype=dtype)

    Ai2 = torch.randn((G * D, 4*N), device=device, dtype=dtype)
    Bi2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
    Ci2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)

    import time
    tim0 = time.time()
    for _ in range(1000):
        y = selective_scan_fn(xi, dti, Ai, Bi, Ci, Di, tpb, True)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    tim1 = time.time()
    for _ in range(1000):
        y = selective_scan_fn(xi, dti, Ai2, Bi2, Ci2, Di, tpb, True)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    tim2 = time.time()
    print(tim1-tim0, tim2-tim1, torch.cuda.max_memory_allocated()) # 0.7172577381134033 2.400775194168091 185063424
    time.sleep(100000)

So what did I miss?

Onnx export for the inference

Hi Alex,

I have tried to generate the onnx file for the inference,
as follows in generate function.

I called onnx export inside generate function here: https://github.com/alxndrTL/mamba.py/blob/main/mamba_lm.py#L134

as follows:

torch.onnx.export(model, inputids, "mamba.onnx", opset_version=12)

It is throwing me an error

TypeError: MambaLM.forward() missing 1 required positional argument: 'tokens'.

Any idea how can I generate onnx file? Is there a better way of generating onnx file for inference?

Segmentation fault with MLX

Segmentation fault: 11 while inferencing mamba with mlx

https://github.com/alxndrTL/mamba.py/tree/main/mlx

python3 generate.py --prompt="Mamba is a type of" --hf_model_name="state-spaces/mamba-130m" --n_tokens=100

on an Apple M1 Pro

I found that the line that cause the error is

mlx_weights = torch.zeros(channels, kernel_size, channels)

in functiontorch_to_mlx_depthwise_weights

but I don't know how to fix it

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.