Comments (15)
After properly wrapped the model with transformers PreTrainedModel
and use 1.4B, surprisingly no more overflow, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-1.4b-trainer-deepspeed3-bf16.ipynb
- tested to save using safetensors.
- load existing checkpoints to continue pretraining.
- with 80GB VRAM, maximum batch size is 8 with 4k context length, 1 step took ~300ms.
from mamba.
Hi, I am using pytorch_lighting trainer with bf16-mixed precision, i.e., model params in f32 and bf16 to train the model (pytorch_lightning also use Pytorch AMP under the hood), to train Mamba on DNA data, however, the training is still
quite unstable
Weirdly, I trained using the same code and same model (different size) with f16-mixed precision on NLP data (wiki103), and such instability did not occur. Any suggestion?
from mamba.
Hi, I am using pytorch_lighting trainer with bf16-mixed precision, i.e., model params in f32 and bf16 to train the model (pytorch_lightning also use Pytorch AMP under the hood), to train Mamba on DNA data, however, the training is still quite unstable
Weirdly, I trained using the same code and same model (different size) with f16-mixed precision on NLP data (wiki103), and such instability did not occur. Any suggestion?
I also have observed similar instabilities with using PL trainer. In my case I am creating custom embeddings which I concantenate along with the text embeddings from mamba before feeding them to the mixer model. I have tried lowering the learning rate, disabling AMP but my loss goes to nan values.
from mamba.
The 2.8B uses total batch size of 1M tokens (following GPT3 paper), and seqlen=2k. Activation memory should be around the same as an optimized transformer (e.g. with FlashAttention), so it should fit in a single A100 80GB (you might need to adjust gradient accumulation to fit).
For models around 1B-3B on 8xA100s, sharding the optimizer states (e.g. with Pytorch distributed optimizer, equivalent to ZeRO-stage1) will help reduce the amount of memory needed.
from mamba.
I am using pytorch's FSDP
what is your auto_wrap_policy
? if you don't mind sharing it
@binxuan
from mamba.
I am using pytorch's FSDP
what is your
auto_wrap_policy
? if you don't mind sharing it @binxuan
This worked for me:
from mamba_ssm.modules.mamba_simple import Block
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={Block,},
)
I was also able to pretrain a 200m mamba lm on 12b tokens yesterday in bf16 without issue. Testing fsdp + bf16 right now and everything works as expected as well.
from mamba.
We don't use deepspeed, just Pytorch AMP (bf16) to train models. Can you try that?
from mamba.
Model parameters should be in fp32, just like in Pytorch AMP docs.
from mamba.
Looks good on amp bf16, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-trainer-bf16.ipynb, and I saw there is 2.8B checkpoint, what is the batch size and it use 2k context length based on the paper? Is it 2.8B model fit in a single A100 80GB?
from mamba.
Thanks!
from mamba.
I tried with Zero2 FP32, sometime not overflow, sometime overflow, after that tried 1e-6 lr and 0.1 gradient norm, no longer overflow but 0.1 gradient norm is too low for pretraining.
from mamba.
Thanks for your explorations! I added a little warning to the README about the dtype, but it's quite useful for people to post their observations here. We've actually never seen these types of instabilities during training; as Tri said, we just use native PyTorch AMP.
from mamba.
I am using pytorch's FSDP with bf16 for training. Looks like I encountered similar issue with NaN loss.
from mamba.
@gpantaz and @lwang2070 were you able to fix the issue of bf16 or fp16 training with PL? I also see training issues for my model.
from mamba.
No sadly :/
from mamba.
Related Issues (20)
- Question for 'self.use_mem_eff_path and inference_params' HOT 4
- triton.runtime.autotuner.OutOfResources: out of resource: shared memory, Required: 254208, Hardware limit: 101376. HOT 5
- I want to ask does anyone know how to solve this problem
- /anaconda3/lib/python3.11/site-packages/causal_conv1d_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c107WarningC1ENS_7variantIJNS0_11UserWarningENS0_18DeprecationWarningEEEERKNS_14SourceLocationENSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEb HOT 1
- Mamba-2 Error: `'NoneType' object has no attribute 'causal_conv1d_fwd'` HOT 8
- Used selective_scan_cuda and causal_conv1d_cuda, but still very slow to train HOT 1
- mamba / self-attention hybrid generation
- Inference multiple tokens HOT 2
- Error when using FP16 or Mixed precision HOT 3
- How to use Mamba2?
- How to extract whole sentence embeddings HOT 1
- Does mamba support data packing?
- Slow Mamba 2 training speeds with higher d_state values HOT 1
- Where is ‘Block’ class in the new version mamba? HOT 1
- mamba_ssm Install Failure HOT 9
- Sequence parallelism in the mixer (Context Parallelism)
- Support Mamba-codestral
- Why does it take so long to build HOT 1
- Is mamba suitable for time-series classification task? HOT 1
- Question on Comparison between Mamba and S4 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mamba.