GithubHelp home page GithubHelp logo

Galore finetuning #stopped about transformers HOT 6 OPEN

j-datta avatar j-datta commented on September 27, 2024
Galore finetuning #stopped

from transformers.

Comments (6)

amyeroberts avatar amyeroberts commented on September 27, 2024

cc @younesbelkada @SunMarc

from transformers.

younesbelkada avatar younesbelkada commented on September 27, 2024

Hi @j-datta
Initializing the galore training might take a while but not 2hours IMO .. I suspect your model might be mistakenly initialized on CPU. Can you make sure the model is on GPU ?

from transformers.

j-datta avatar j-datta commented on September 27, 2024

Hi @younesbelkada
I've used these line of codes:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model is on device: {device}")

I've no idea why this is happening.
I am using two GPUs here.

from transformers.

younesbelkada avatar younesbelkada commented on September 27, 2024

Hi @j-datta
can you try to use the default arguments of Galore (i.e. removing optim_args="rank=64, update_proj_gap=100, scale=0.10",) and see if this helps ? Maybe using a high rank and update_proj_gap slows down the initialization step

from transformers.

j-datta avatar j-datta commented on September 27, 2024

Hello @younesbelkada
When I've tried to run the following script:

import torch
import datasets
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import trl
from trl import SFTConfig

train_dataset = datasets.load_dataset('rajpurkar/squad_v2', split='train')
def preprocess_function(examples):
    inputs = [q + " " + c for q, c in zip(examples["question"], examples["context"])]
    targets = [a["text"][0] if len(a["text"]) > 0 else "" for a in examples["answers"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
args = SFTConfig(
    output_dir="/home/IAIS/jdatta/teacher_model/test-galore",
    max_steps=5000,
    per_device_train_batch_size=1,
    fp16=True,
    dataset_text_field='input_ids',
    max_seq_length=128,
    #num_train_epochs=3,
    optim="galore_adamw_8bit",
    optim_target_modules=["c_attn", "c_proj", "q_proj", "k_proj", "v_proj", "down_proj", "up_proj"],
)
model_id = "mistralai/Mistral-7B-v0.1"
config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
train_dataset=train_dataset.map(preprocess_function, batched=True,remove_columns=train_dataset.column_names)
model = AutoModelForCausalLM.from_config(config).half()

trainer = trl.SFTTrainer(
    model=model, 
    args=args,
    train_dataset=train_dataset,
)

trainer.train()

It's showing OOM error now.
I'm using 2 Tesla-V100s GPU here.

from transformers.

amyeroberts avatar amyeroberts commented on September 27, 2024

cc @SunMarc

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.