alxndrtl / mamba.py Goto Github PK
View Code? Open in Web Editor NEWA simple and efficient Mamba implementation in pure PyTorch and MLX.
License: MIT License
A simple and efficient Mamba implementation in pure PyTorch and MLX.
License: MIT License
Hi,
Thanks for the repo! This is really useful!
Hi,
Firstly, thanks for making this repo. I found it very useful in understanding the scan algorithm. However, discretization in this repo seems to be different from the Eq 4 of the paper. Do you have any comments on this? Also, I wonder why the original paper needs the discretization step in the first place since it is possible to make the discrete versions of A and B conditioned on the input directly. I imagine that it must have something to do with the initialization, but I am not sure.
Line 248 in b9f315d
from commit hash 6a49341:
What I have done was executing generate.py for a mamba fine-tuned model - kuotient/mamba-ko-2.8b
, and below error was happened. How can I deal with this error?
(venv_mamba_py) ******@Mac-Studio-2022-01 scripts % python generate.py --prompt="Mamba is a type of" --hf_model_name="kuotient/mamba-ko-2.8b" --n_tokens=100
Traceback (most recent call last):
File "/Users/******/test/mamba.py/mlx/scripts/generate.py", line 31, in <module>
model = MambaLM.from_pretrained(args.hf_model_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/******/test/mamba.py/mlx/mamba_lm_mlx.py", line 150, in from_pretrained
mlx_state_dict = map_mambassm_torch_to_mlx(state_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/******/test/mamba.py/mlx/utils.py", line 53, in map_mambassm_torch_to_mlx
return map_mambapy_torch_to_mlx(new_state_dict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/******/test/mamba.py/mlx/utils.py", line 37, in map_mambapy_torch_to_mlx
new_state_dict[key] = value.numpy()
^^^^^^^^^^^^^
TypeError: Got unsupported ScalarType BFloat16
My environments are:
Hello. In the function selective_scan_seq
, there are two points that I am confused:
BX = deltaB * (x.unsqueeze(-1)) #ย (B, L, ED, N)
h = deltaA[:, t] * h + BX[:, t]
However, in the paper, the equation is
Both terms in the right side of the euation is performed in matrix multiplication.
I am curious that do the two lines of code use some tricks to convert matrix multiplication into the elementwise one?
Hello, does this project have similar training functions to llama2.c?
i find that the pscan method used in this Mamba implementation use huge amount of memory! any idea how to reduce memory consumption? or replace the pscan method with other implementation??
great thanks!
Would it be possible to put an explicit OSS license on the codebase just to remove approval burden when experimenting based on this code? I want to prototype a bit on laptop to get the plumbing right before moving to cloud GPUs to actually train and this looks like the best CPU friendly implementation I have found for that purpose
The 1.4B model takes 10-11GB of RAM at inference. (own test, M2 Pro 16GB)
The 2.8B model takes around 50GB at inference. (https://twitter.com/awnihannun/status/1749515431336112275)
This is not due to loading the model from HF (same memory footprint if model initialized with random weights).
This is neither due to the ssm_step
.
However, turning off the convolution at inference reduces the memory footprint (by 3GB for the 1.4B model : from 10GB to around 7GB). It also greatly speeds up the inference. (buf of course, the forward is not correct).
Files concerned :
mamba_mlx.py
(step
functions)misc.py
The depthwise conv implemented in misc.py
seems to be part of the problem.
As said the file, the PyTorch versions uses groups=channels (true depthwise), while the MLX depthwise conv in misc.py
uses groups=1 but with some weights set at 0. (only workaround found).
This result in a (d_model, 4, d_model) filter size, against (d_model, 4) for the "true" depthwise conv.
Either :
-wait for MLX to implement groups=channels for conv1d
-find another workaround (one possibility is to create d_model
conv object, each with 1 input and 1 output channel. but this result in a big for loop which is around 45x slower than the workaround found. but ofc, memory usage is greatly reduces (by d_model
=2560)
Hey! Awesome work on this project! I know it's not technically vanilla Mamba but I've been trying to convert the new SSM-Transformers Jamba into MLX for more efficient training and usability but am having a difficult time. My specialty is in the training/datasets world and not the strongest in the core math behind the model architectures beyond the basic implementations.
Would somebody know of an easier way to get Jamba converted into MLX? I truly think Jamba has A LOT to offer and could do some awesome stuff in the MLX format and for local model training with Mac
I've provided the modeling script released by AI21 for quick reference. Is this feasible or just way too complicated at the moment?
Hey! I ported mamba
to transformers
and think your approach to replace the naive scan would be great there!
Would you like to open a PR? ๐ค (to https://github.com/huggingface/transformers)
Thank you for sharing your fantastic work.
We have noticed the image that with rising the dimension of d_state
, the mamba's time occupation doesn't rise.
However, we found in code that writes (selective_scan_fwd_kernel.cuh#163):
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
...
if constexpr (kIsVariableB) {
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
}
}
which shows a for loop with related to state_idx that reads from HBM to shared memory.
Then I tested the speed again and finds that with the d_state rises, the time occupation of mamba rises linearly, which is aligned with the code.
device = torch.device("cuda")
dtype = torch.float32
B, L, G, D, N, R = 3, 4096, 4, 192, 16, 192 // 16
xi = torch.randn((B, G * D, L), device=device, dtype=dtype)
Ai = torch.randn((G * D, N), device=device, dtype=dtype)
Di = torch.randn((G * D), device=device, dtype=dtype)
dti = torch.randn((B, G * D, L), device=device, dtype=dtype)
Bi = torch.randn((B, G, N, L), device=device, dtype=dtype)
Ci = torch.randn((B, G, N, L), device=device, dtype=dtype)
tpb = torch.randn((G * D), device=device, dtype=dtype)
Ai2 = torch.randn((G * D, 4*N), device=device, dtype=dtype)
Bi2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
Ci2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
import time
tim0 = time.time()
for _ in range(1000):
y = selective_scan_fn(xi, dti, Ai, Bi, Ci, Di, tpb, True)
torch.cuda.synchronize()
torch.cuda.empty_cache()
tim1 = time.time()
for _ in range(1000):
y = selective_scan_fn(xi, dti, Ai2, Bi2, Ci2, Di, tpb, True)
torch.cuda.synchronize()
torch.cuda.empty_cache()
tim2 = time.time()
print(tim1-tim0, tim2-tim1, torch.cuda.max_memory_allocated()) # 0.7172577381134033 2.400775194168091 185063424
time.sleep(100000)
So what did I miss?
Hi Alex,
I have tried to generate the onnx file for the inference,
as follows in generate function.
I called onnx export inside generate function here: https://github.com/alxndrTL/mamba.py/blob/main/mamba_lm.py#L134
as follows:
torch.onnx.export(model, inputids, "mamba.onnx", opset_version=12)
It is throwing me an error
TypeError: MambaLM.forward() missing 1 required positional argument: 'tokens'.
Any idea how can I generate onnx file? Is there a better way of generating onnx file for inference?
Support for batch size > 1, how is the progress of the work?
Segmentation fault: 11 while inferencing mamba with mlx
https://github.com/alxndrTL/mamba.py/tree/main/mlx
python3 generate.py --prompt="Mamba is a type of" --hf_model_name="state-spaces/mamba-130m" --n_tokens=100
on an Apple M1 Pro
I found that the line that cause the error is
mlx_weights = torch.zeros(channels, kernel_size, channels)
in functiontorch_to_mlx_depthwise_weights
but I don't know how to fix it
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.