GithubHelp home page GithubHelp logo

Training Script about mamba HOT 7 OPEN

state-spaces avatar state-spaces commented on August 28, 2024 19
Training Script

from mamba.

Comments (7)

tridao avatar tridao commented on August 28, 2024 15

You can use whichever training script / library you'd like, e.g. Megatron, DeepSpeed, lightning, hf accelerate etc. Just have to replace the model definition.

Examples:
Lightning has lit-gpt: https://github.com/Lightning-AI/lit-gpt
FlashAttention has training code, you can swap the model: https://github.com/Dao-AILab/flash-attention/tree/main/training

from mamba.

justusmattern27 avatar justusmattern27 commented on August 28, 2024 7

We've managed to train mamba by modifying the Huggingface Trainer class. Here is our implementation, we were actually able to train a chat model that seems to perform quite well.

from mamba.

geronimi73 avatar geronimi73 commented on August 28, 2024 2

does not seem to be so straightforward with HF trainer, quite literally:

  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1555, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 1860, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2725, in training_step
    loss = self.compute_loss(model, inputs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2748, in compute_loss
    outputs = model(**inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 680, in forward
    return model_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 668, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
    return func(*args, **kwargs)
TypeError: MambaLMHeadModel.forward() got an unexpected keyword argument 'labels'

no labels in forward() ?

would be very nice if you could provide a simple, minimal example of how to use the models with HF trainer. thank you!

from mamba.

binxuan avatar binxuan commented on August 28, 2024

We've managed to train mamba by modifying the Huggingface Trainer class. Here is our implementation, we were actually able to train a chat model that seems to perform quite well.

Cool, nice work! Are you using fp32 for this finetuning work?

from mamba.

Calvinnncy97 avatar Calvinnncy97 commented on August 28, 2024

Hmm... Doesn't seem to work out of the box with lit-gpt.

Minimal example:

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.utils.hf import load_config_hf
from mamba_ssm.models.config_mamba import MambaConfig

with fabric.init_module(
        empty_init=isinstance(fabric.strategy, DeepSpeedStrategy)
    ):
  config = load_config_hf('state-spaces/mamba-2.8b')
  model = MambaLMHeadModel(MambaConfig(**config))

This will give the following error

Traceback (most recent call last):
  File "/home/me/lit-gpt/pretrain/mamba.py", line 782, in <module>
    setup(run_config)
  File "/home/me/lit-gpt/pretrain/mamba.py", line 306, in setup
    main(fabric, run_config)
  File "/home/me/lit-gpt/pretrain/mamba.py", line 342, in main
    model = MambaLMHeadModel(MambaConfig(**model_config))
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 199, in __init__
    self.backbone = MixerModel(
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 118, in __init__
    [
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 119, in <listcomp>
    create_block(
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 42, in create_block
    block = Block(
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 316, in __init__
    self.mixer = mixer_cls(dim)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/home/me/lit-gpt/mamba-venv/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py", line 99, in __init__
    self.dt_proj.bias.copy_(inv_dt)
RuntimeError: The size of tensor a (0) must match the size of tensor b (5120) at non-singleton dimension 0

from mamba.

thistleknot avatar thistleknot commented on August 28, 2024

if that's true, how the heck am I to pass attention?

from mamba.

Calvinnncy97 avatar Calvinnncy97 commented on August 28, 2024

Stage 2 works, but not stage 3. I don't have a fix at the moment. Problem is this line

with torch.no_grad():

from mamba.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.