GithubHelp home page GithubHelp logo

Comments (11)

RUFFY-369 avatar RUFFY-369 commented on June 9, 2024

@mostafamdy I also think the same, **input_kwargs can't have decoder_input_ids without self.is_encoder_decoder being True. That being said, self.is_encoder_decoder gets its value from model's config. And there is no attribute called is_encoder_decoder in Gemma's configuration file :

class GemmaConfig(PretrainedConfig):

Can you debug by getting ppo_trainer.is_encoder_decoder? Or maybe just set it to false and check if the error is gone.

from transformers.

mostafamdy avatar mostafamdy commented on June 9, 2024

Do you know how we can set it to False without changing the source code?

from transformers.

amyeroberts avatar amyeroberts commented on June 9, 2024

cc @ArthurZucker

from transformers.

RUFFY-369 avatar RUFFY-369 commented on June 9, 2024

@mostafamdy here is the code to get a PPOTrainer instance which you may have used as I don't know about your script:

access_token = 'to_fill'
model = AutoModelForCausalLMWithValueHead.from_pretrained('google/gemma-2b',token = access_token
)
#the config dict doesn't have 'is_encoder_decoder' attribute
print("config", model.pretrained_model.config)


model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", token = access_token)
tokenizer.pad_token = tokenizer.eos_token

# initialize trainer
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

So, AutoModelForCausalLMWithValueHead takes google/gemma-2b or google/gemma-7b as pretrained model name for PPOTrainer. AutoModelForCausalLMWithValueHead class itself doesn't have is_encoder_decoder attribute when the google/gemma-2b config is accessed here for ValueHead class with gemma-2b config.

The TypeError: GemmaForCausalLM.forward() got an unexpected keyword argument 'decoder_input_ids' which you get can be verified [here in this code line] (

). So, clearly we need to set is_encoder_decoder in PPOTrainer asFalse as somehow it is set True which leads to sending decoder_input_ids in input_ids to GemmaForCausalLM.forward().

So, try this simple line of code for changing the value of self.is_encoder_decoder in PPOTrainer:

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
ppo_trainer.is_encoder_decoder = False

from transformers.

mostafamdy avatar mostafamdy commented on June 9, 2024

Thanks @RUFFY-369 I tried this but not working with me

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
ppo_trainer.is_encoder_decoder = False

it worked after adding this code.

# this line is very important
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

here is the full code

from transformers import AutoModelForCausalLM, GemmaTokenizer


model = AutoModelForCausalLM.from_pretrained(    
    config.model_name,
    device_map="auto",
    torch_dtype="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,  # Loading weights in 4-bit format
        bnb_4bit_quant_type="nf4",  # Using non-linear quantization with 4 bits
        bnb_4bit_compute_dtype=torch.bfloat16,  # Using bfloat16 for computation
        bnb_4bit_use_double_quant=True  # Using double quantization
    ),
    trust_remote_code=True
)


# this line is very important
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)


tokenizer = GemmaTokenizer.from_pretrained(config.model_name)

# tokens = tokenizer("Hi How are", return_tensors='pt')
# outputs = model(**tokens)

peft_config = LoraConfig(
    task_type="CAUSAL_LM",
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1
)


model = get_peft_model(model, peft_config)

print("Using lora")
print(model.print_trainable_parameters())
# print(tokens.keys())

# outputs = model.generate(**tokens)
# outputs = model(**tokens)
model = AutoModelForCausalLMWithValueHead.from_pretrained(model,                                                               
                                                           torch_dtype=torch.bfloat16,
                                                           is_trainable=True)
# outputs = model(**tokens)
print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(model)}\n')
print(model.v_head)

Have a nice day 😄

from transformers.

mostafamdy avatar mostafamdy commented on June 9, 2024

I don't know is it correct or not

I found this in test ppo trainer

from transformers.

RUFFY-369 avatar RUFFY-369 commented on June 9, 2024

@mostafamdy Yeah I checked that test script out while finding the value changes in is_encoder_decoder and decoder_input_ids as I didn't have your script. Also, apologies I couldn't test out the code I told you as cuda was running out of memory on my system maybe 'cause another model was in training phase.

So, the above code that you mentioned, are you using all the code from test file or just bits of it to make your script work?

from transformers.

mostafamdy avatar mostafamdy commented on June 9, 2024

Thank you so much for your help
no I used only this part of code

# this line is very important
def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

from transformers.

RUFFY-369 avatar RUFFY-369 commented on June 9, 2024

Your welcome ! Glad I could be of help. 😄
Oh! okay if that worked for you as one of the uses of test files are that they have at most of the times solutions to different general errors too.

Have a nice day 👍 😄

from transformers.

ArthurZucker avatar ArthurZucker commented on June 9, 2024

Hey both, is the issue that the newly resized embedding don't require grad even if the rest does?

from transformers.

RUFFY-369 avatar RUFFY-369 commented on June 9, 2024

Hi @ArthurZucker , what i found out was that there was a PR with DPO+Gradient checkpoiting issue where if "one uses gradient_checkpointing we need to attach hooks to enable inputs to have requires grad to true, otherwise the training will either silently fail or completely fail".
And the fix was as such:

elif getattr(args, "gradient_checkpointing", False):
                # For backward compatibility with older versions of transformers
                if hasattr(model, "enable_input_require_grads"):
                    model.enable_input_require_grads()
                else:

                    def make_inputs_require_grad(module, input, output):
                        output.requires_grad_(True)

                    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

It's the same as what @mostafamdy found in test_ppo_trainer.
But the main file of ppo_trainer doesn't have this fix regarding gradient_checkpointing as compared to other files such as dpo_trainer. That's why the ppo_trainer() instance leads to this error and get fixed by the same block of code in the test file and dpo_trainer.py file

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.