GithubHelp home page GithubHelp logo

Comments (8)

masc-it avatar masc-it commented on July 4, 2024 4

A simple script to pretrain mamba using the Causal Language Modeling (CLM) task, extrapolated from mamba-chat and adapted for CLM:

import torch
import os
import torch
import argparse
from datasets import load_from_disk
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer, TrainingArguments, TrainerCallback, Trainer
from transformers import DataCollatorForLanguageModeling

class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids).logits

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

    def save_model(self, output_dir, _internal_call):
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
        self.tokenizer.save_pretrained(output_dir)

def train(args):
    model = MambaLMHeadModel.from_pretrained(args.model, dtype=torch.bfloat16, device="cuda")

    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
    tokenizer.eos_token = "<|endoftext|>"
    tokenizer.pad_token = tokenizer.eos_token
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # 

    ds_tokenized = load_from_disk("your_pretokenized_ds_mamba")

    trainer = MambaTrainer(
        model=model,
        train_dataset=ds_tokenized["train"],
        tokenizer=tokenizer,
        args=TrainingArguments(
            learning_rate=args.learning_rate,
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            dataloader_num_workers=2,
            optim=args.optim,
            output_dir="out",
            logging_steps=50,
            weight_decay=1e-2,
            evaluation_strategy="epoch",
            save_strategy="epoch"
        ),
        
        data_collator=data_collator
    )

    trainer.train()

I'm having issues with the eval step, so I've omitted it.

from mamba.

Eupham avatar Eupham commented on July 4, 2024 1

Nevermind. This works for me. Adapted from here
`import torch
from transformers import AutoTokenizer, TrainingArguments, Trainer
from datasets import load_dataset, load_metric
import numpy as np
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

model_name = "state-spaces/mamba-130m" # Make sure this model is compatible with MambaLMHeadModel
CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta"
device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template

model = MambaLMHeadModel.from_pretrained(model_name) # Changed to MambaLMHeadModel
model.to(device)

dataset = load_dataset('databricks/databricks-dolly-15k')

def tokenize_function(examples):
concatenated_texts = [instr + " " + ctxt + " " + resp
for instr, ctxt, resp in zip(examples['instruction'],
examples['context'],
examples['response'])]
return tokenizer(concatenated_texts, truncation=True, padding='max_length', max_length=500)

tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=12)

accuracy_metric = load_metric("accuracy")

def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return accuracy_metric.compute(predictions=predictions, references=labels)

training_args = TrainingArguments(
output_dir="./model_output",
num_train_epochs=1,
per_device_train_batch_size=1,
save_steps=10_000,
save_total_limit=2,
logging_dir="./logs",
logging_steps=500,
do_train=True,
fp16=False
)

class MambaTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits

    labels = input_ids.to(lm_logits.device)
    shift_logits = lm_logits[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()

    loss_fct = torch.nn.CrossEntropyLoss()
    lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

    return lm_loss

trainer = MambaTrainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
compute_metrics=compute_metrics
)

trainer.train()
`

from mamba.

albertfgu avatar albertfgu commented on July 4, 2024

We do not have one currently. Hopefully the community will help out!

from mamba.

pjsample avatar pjsample commented on July 4, 2024

I haven't tried it but this repo appears to implement a trainer.
https://github.com/havenhq/mamba-chat/tree/main

from mamba.

Eupham avatar Eupham commented on July 4, 2024

Sorry I'm trying to figure out where we get trainer.mamba_trainer from.

from mamba.

masc-it avatar masc-it commented on July 4, 2024

@Eupham snipped fixed, it was just some old import.

from mamba.

Eupham avatar Eupham commented on July 4, 2024

Is there any chance of a notebook showing how to train with this? I'm doing something wrong in my attempts.
`import os
import torch
from datasets import load_dataset, load_from_disk
from transformers import AutoTokenizer, TrainingArguments, Trainer
from transformers import DataCollatorForLanguageModeling
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

Tokenize and save the IMDb dataset

def tokenize_and_save_imdb(tokenizer, save_path):
dataset = load_dataset("imdb")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset.save_to_disk(save_path)

class MambaTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits

    labels = input_ids.to(lm_logits.device)
    shift_logits = lm_logits[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()

    loss_fct = torch.nn.CrossEntropyLoss()
    lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

    return (lm_loss, (lm_logits,)) if return_outputs else lm_loss

def save_model(self, output_dir, _internal_call):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    torch.save(self.model.state_dict(), f"{output_dir}/pytorch_model.bin")
    self.tokenizer.save_pretrained(output_dir)

def train(model_name, tokenizer_name, dataset_path, output_dir, learning_rate, num_epochs, batch_size, gradient_accumulation_steps):
model = MambaLMHeadModel.from_pretrained(model_name, dtype=torch.bfloat16, device="cuda")

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.eos_token

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

ds_tokenized = load_from_disk(dataset_path)

trainer = MambaTrainer(
    model=model,
    train_dataset=ds_tokenized["train"],
    tokenizer=tokenizer,
    args=TrainingArguments(
        learning_rate=learning_rate,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        dataloader_num_workers=2,
        output_dir=output_dir,
        logging_steps=50,
        weight_decay=1e-2,
        evaluation_strategy="epoch",
        save_strategy="epoch"
    ),
    data_collator=data_collator
)

trainer.train()

Set your parameters here

model_name = "clibrain/mamba-2.8b-instruct-openhermes"
tokenizer_name = "clibrain/mamba-2.8b-instruct-openhermes"
dataset_path = "./tokenized_imdb"
output_dir = "./output"
learning_rate = 2e-5
num_epochs = 1
batch_size = 1
gradient_accumulation_steps = 2

Tokenize and save IMDb dataset if not already done

if not os.path.exists(dataset_path):
print("Tokenizing and saving IMDb dataset...")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenize_and_save_imdb(tokenizer, dataset_path)

Run training

train(model_name, tokenizer_name, dataset_path, output_dir, learning_rate, num_epochs, batch_size, gradient_accumulation_steps)
`

/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:72: UserWarning: The secret HF_TOKENdoes not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn( Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using thecallmethod is faster than using a method to encode the text followed by a call to thepadmethod to get a padded encoding. You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using thecallmethod is faster than using a method to encode the text followed by a call to thepad` method to get a padded encoding.

RuntimeError Traceback (most recent call last)
in <cell line: 88>()
86
87 # Run training
---> 88 train(model_name, tokenizer_name, dataset_path, output_dir, learning_rate, num_epochs, batch_size, gradient_accumulation_steps)

25 frames
in _layer_norm_fwd_1pass_kernel(X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd, stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, N, eps, IS_RMS_NORM, BLOCK_N, HAS_RESIDUAL, STORE_RESIDUAL_OUT, HAS_BIAS, grid, num_warps, num_stages, extern_libs, stream, warmup, device, device_type)

/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py in ptx_to_cubin(ptx, arch)
148 '''
149 ptxas, _ = path_to_ptxas()
--> 150 return compile_ptx_to_cubin(ptx, ptxas, arch)
151
152

RuntimeError: Internal Triton PTX codegen error:
ptxas /tmp/compile-ptx-src-a9e01c, line 984; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 984; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 986; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 986; error : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
...
...
ptxas /tmp/compile-ptx-src-a9e01c, line 2805; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 2807; error : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-a9e01c, line 2807; error : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas fatal : Ptx assembly aborted due to errors`

from mamba.

zabealbe avatar zabealbe commented on July 4, 2024

The suggested MambaTrainer code trains the model to learn the padding tokens as well.

Why? Huggingface collators such as DataCollatorForLanguageModeling sets the padding tokens to the integer value of -100, which is ignored by torch.nn.CrossEntropyLoss(), but it does so only for the labels not for the input_ids.

To prevent the model being trained on predicting padding tokens I found it's sufficient to change the code below

class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids).logits

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

To

class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        labels = inputs.pop("labels")
        lm_logits = model(input_ids).logits

        labels = labels.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        lm_loss = torch.nn.CrossEntropyLoss()(
            shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
        )

        return (lm_loss, lm_logits) if return_outputs else lm_loss

from mamba.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.