GithubHelp home page GithubHelp logo

Comments (15)

huseinzol05 avatar huseinzol05 commented on July 24, 2024 2

After properly wrapped the model with transformers PreTrainedModel and use 1.4B, surprisingly no more overflow, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-1.4b-trainer-deepspeed3-bf16.ipynb

  1. tested to save using safetensors.
  2. load existing checkpoints to continue pretraining.
  3. with 80GB VRAM, maximum batch size is 8 with 4k context length, 1 step took ~300ms.

from mamba.

lwang2070 avatar lwang2070 commented on July 24, 2024 2

Hi, I am using pytorch_lighting trainer with bf16-mixed precision, i.e., model params in f32 and bf16 to train the model (pytorch_lightning also use Pytorch AMP under the hood), to train Mamba on DNA data, however, the training is still
quite unstable

Loss on hg38

Weirdly, I trained using the same code and same model (different size) with f16-mixed precision on NLP data (wiki103), and such instability did not occur. Any suggestion?

from mamba.

gpantaz avatar gpantaz commented on July 24, 2024 2

Hi, I am using pytorch_lighting trainer with bf16-mixed precision, i.e., model params in f32 and bf16 to train the model (pytorch_lightning also use Pytorch AMP under the hood), to train Mamba on DNA data, however, the training is still quite unstable

Loss on hg38

Weirdly, I trained using the same code and same model (different size) with f16-mixed precision on NLP data (wiki103), and such instability did not occur. Any suggestion?

I also have observed similar instabilities with using PL trainer. In my case I am creating custom embeddings which I concantenate along with the text embeddings from mamba before feeding them to the mixer model. I have tried lowering the learning rate, disabling AMP but my loss goes to nan values.

from mamba.

tridao avatar tridao commented on July 24, 2024 1

The 2.8B uses total batch size of 1M tokens (following GPT3 paper), and seqlen=2k. Activation memory should be around the same as an optimized transformer (e.g. with FlashAttention), so it should fit in a single A100 80GB (you might need to adjust gradient accumulation to fit).
For models around 1B-3B on 8xA100s, sharding the optimizer states (e.g. with Pytorch distributed optimizer, equivalent to ZeRO-stage1) will help reduce the amount of memory needed.

from mamba.

geronimi73 avatar geronimi73 commented on July 24, 2024 1

I am using pytorch's FSDP

what is your auto_wrap_policy? if you don't mind sharing it
@binxuan

from mamba.

btrude avatar btrude commented on July 24, 2024 1

I am using pytorch's FSDP

what is your auto_wrap_policy? if you don't mind sharing it @binxuan

This worked for me:

from mamba_ssm.modules.mamba_simple import Block
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

auto_wrap_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={Block,},
)

I was also able to pretrain a 200m mamba lm on 12b tokens yesterday in bf16 without issue. Testing fsdp + bf16 right now and everything works as expected as well.

from mamba.

tridao avatar tridao commented on July 24, 2024

We don't use deepspeed, just Pytorch AMP (bf16) to train models. Can you try that?

from mamba.

tridao avatar tridao commented on July 24, 2024

Model parameters should be in fp32, just like in Pytorch AMP docs.

from mamba.

huseinzol05 avatar huseinzol05 commented on July 24, 2024

Looks good on amp bf16, https://github.com/mesolitica/malaya/blob/5.1/pretrained-model/mamba/causallm-130m-trainer-bf16.ipynb, and I saw there is 2.8B checkpoint, what is the batch size and it use 2k context length based on the paper? Is it 2.8B model fit in a single A100 80GB?

from mamba.

huseinzol05 avatar huseinzol05 commented on July 24, 2024

Thanks!

from mamba.

huseinzol05 avatar huseinzol05 commented on July 24, 2024

I tried with Zero2 FP32, sometime not overflow, sometime overflow, after that tried 1e-6 lr and 0.1 gradient norm, no longer overflow but 0.1 gradient norm is too low for pretraining.

from mamba.

albertfgu avatar albertfgu commented on July 24, 2024

Thanks for your explorations! I added a little warning to the README about the dtype, but it's quite useful for people to post their observations here. We've actually never seen these types of instabilities during training; as Tri said, we just use native PyTorch AMP.

from mamba.

binxuan avatar binxuan commented on July 24, 2024

I am using pytorch's FSDP with bf16 for training. Looks like I encountered similar issue with NaN loss.

from mamba.

apoorv2904 avatar apoorv2904 commented on July 24, 2024

@gpantaz and @lwang2070 were you able to fix the issue of bf16 or fp16 training with PL? I also see training issues for my model.

from mamba.

gpantaz avatar gpantaz commented on July 24, 2024

No sadly :/

from mamba.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.