Comments (8)
A simple script to pretrain mamba using the Causal Language Modeling (CLM) task, extrapolated from mamba-chat and adapted for CLM:
import torch
import os
import torch
import argparse
from datasets import load_from_disk
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, TrainingArguments, TrainerCallback, Trainer
from transformers import DataCollatorForLanguageModeling
class MambaTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
return lm_loss
def save_model(self, output_dir, _internal_call):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
self.tokenizer.save_pretrained(output_dir)
def train(args):
model = MambaLMHeadModel.from_pretrained(args.model, dtype=torch.bfloat16, device="cuda")
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) #
ds_tokenized = load_from_disk("your_pretokenized_ds_mamba")
trainer = MambaTrainer(
model=model,
train_dataset=ds_tokenized["train"],
tokenizer=tokenizer,
args=TrainingArguments(
learning_rate=args.learning_rate,
num_train_epochs=args.num_epochs,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
dataloader_num_workers=2,
optim=args.optim,
output_dir="out",
logging_steps=50,
weight_decay=1e-2,
evaluation_strategy="epoch",
save_strategy="epoch"
),
data_collator=data_collator
)
trainer.train()
I'm having issues with the eval step, so I've omitted it.
from mamba.
Nevermind. This works for me. Adapted from here
`import torch
from transformers import AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset, load_metric
import numpy as np
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
model_name = "state-spaces/mamba-130m" # Make sure this model is compatible with MambaLMHeadModel
CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template
model = MambaLMHeadModel.from_pretrained(model_name) # Changed to MambaLMHeadModel
model.to(device)
dataset = load_dataset('databricks/databricks-dolly-15k')
def tokenize_function(examples):
concatenated_texts = [instr + " " + ctxt + " " + resp
for instr, ctxt, resp in zip(examples['instruction'],
examples['context'],
examples['response'])]
return tokenizer(concatenated_texts, truncation=True, padding='max_length', max_length=500)
tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=12)
accuracy_metric = load_metric("accuracy")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return accuracy_metric.compute(predictions=predictions, references=labels)
training_args = TrainingArguments(
output_dir="./model_output",
num_train_epochs=1,
per_device_train_batch_size=1,
save_steps=10_000,
save_total_limit=2,
logging_dir="./logs",
logging_steps=500,
do_train=True,
fp16=False
)
class MambaTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
return lm_loss
trainer = MambaTrainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
compute_metrics=compute_metrics
)
trainer.train()
`
from mamba.
We do not have one currently. Hopefully the community will help out!
from mamba.
I haven't tried it but this repo appears to implement a trainer.
https://github.com/havenhq/mamba-chat/tree/main
from mamba.
Sorry I'm trying to figure out where we get trainer.mamba_trainer from.
from mamba.
@Eupham snipped fixed, it was just some old import.
from mamba.
Is there any chance of a notebook showing how to train with this? I'm doing something wrong in my attempts.
`import os
import torch
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, TrainingArguments, Trainer
from transformers import DataCollatorForLanguageModeling
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
Tokenize and save the IMDb dataset
def tokenize_and_save_imdb(tokenizer, save_path):
dataset = load_dataset("imdb")
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.save_to_disk(save_path)
class MambaTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
return (lm_loss, (lm_logits,)) if return_outputs else lm_loss
def save_model(self, output_dir, _internal_call):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
self.tokenizer.save_pretrained(output_dir)
def train(model_name, tokenizer_name, dataset_path, output_dir, learning_rate, num_epochs, batch_size, gradient_accumulation_steps):
model = MambaLMHeadModel.from_pretrained(model_name, dtype=torch.bfloat16, device="cuda")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
ds_tokenized = load_from_disk(dataset_path)
trainer = MambaTrainer(
model=model,
train_dataset=ds_tokenized["train"],
tokenizer=tokenizer,
args=TrainingArguments(
learning_rate=learning_rate,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
dataloader_num_workers=2,
output_dir=output_dir,
logging_steps=50,
weight_decay=1e-2,
evaluation_strategy="epoch",
save_strategy="epoch"
),
data_collator=data_collator
)
trainer.train()
Set your parameters here
model_name = "clibrain/mamba-2.8b-instruct-openhermes"
tokenizer_name = "clibrain/mamba-2.8b-instruct-openhermes"
dataset_path = "./tokenized_imdb"
output_dir = "./output"
learning_rate = 2e-5
num_epochs = 1
batch_size = 1
gradient_accumulation_steps = 2
Tokenize and save IMDb dataset if not already done
if not os.path.exists(dataset_path):
print("Tokenizing and saving IMDb dataset...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenize_and_save_imdb(tokenizer, dataset_path)
Run training
train(model_name, tokenizer_name, dataset_path, output_dir, learning_rate, num_epochs, batch_size, gradient_accumulation_steps)
`
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:72: UserWarning: The secret
HF_TOKENdoes not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn( Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the
callmethod is faster than using a method to encode the text followed by a call to the
padmethod to get a padded encoding. You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the
callmethod is faster than using a method to encode the text followed by a call to the
pad` method to get a padded encoding.
RuntimeError Traceback (most recent call last)
in <cell line: 88>()
86
87 # Run training
---> 88 train(model_name, tokenizer_name, dataset_path, output_dir, learning_rate, num_epochs, batch_size, gradient_accumulation_steps)
25 frames
in _layer_norm_fwd_1pass_kernel(X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd, stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, N, eps, IS_RMS_NORM, BLOCK_N, HAS_RESIDUAL, STORE_RESIDUAL_OUT, HAS_BIAS, grid, num_warps, num_stages, extern_libs, stream, warmup, device, device_type)
/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py in ptx_to_cubin(ptx, arch)
148 '''
149 ptxas, _ = path_to_ptxas()
--> 150 return compile_ptx_to_cubin(ptx, ptxas, arch)
151
152
RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/compile-ptx-src-a9e01c, line 984; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 984; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 986; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 986; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
...
...
ptxas /tmp/compile-ptx-src-a9e01c, line 2805; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 2807; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 2807; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas fatal : Ptx assembly aborted due to errors`
from mamba.
The suggested MambaTrainer code trains the model to learn the padding tokens as well.
Why? Huggingface collators such as DataCollatorForLanguageModeling
sets the padding tokens to the integer value of -100, which is ignored by torch.nn.CrossEntropyLoss()
, but it does so only for the labels
not for the input_ids
.
To prevent the model being trained on predicting padding tokens I found it's sufficient to change the code below
class MambaTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
return lm_loss
To
class MambaTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
labels = inputs.pop("labels")
lm_logits = model(input_ids).logits
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
lm_loss = torch.nn.CrossEntropyLoss()(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return (lm_loss, lm_logits) if return_outputs else lm_loss
from mamba.
Related Issues (20)
- The simple test of the model works fine, but there is an "Aborted (core dumped)" issue during training. HOT 3
- Why mamba2 is much slower than transformer (flash attn)? HOT 1
- MuTransfer for Mamba HOT 1
- How to understand the relationship between mamba2 and GLA?
- Question: Why does mamba use LayerNorm instead of RMSNorm? HOT 3
- Mamba state update function
- seq_idx in mamba2 HOT 3
- mamab2 has the error HOT 1
- Triton Error [CUDA]: device kernel image is invalid HOT 7
- Triton Error: OutOfResources Shared Memory HOT 4
- How does prefix-tuning perform on mamba model?
- mamba2 training speed is very very very slow HOT 5
- RuntimeError: Triton Error [CUDA]: context is destroyed HOT 7
- Speculative Decoding with Mamba 1 HOT 1
- Unable to install mamba-ssh
- Small bug in slow path with PyTorch Conv1d
- Training Flops HOT 1
- Pip install error HOT 4
- Model checkpoint parameters are stored in FP16, why? HOT 1
- Error with Mamba2 example code HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from mamba.