GithubHelp home page GithubHelp logo

Model trained with Flash Attention 2.0 raises "RuntimeError: query and key must have the same dtype" when generating about transformers HOT 13 OPEN

antonioalegria avatar antonioalegria commented on June 9, 2024
Model trained with Flash Attention 2.0 raises "RuntimeError: query and key must have the same dtype" when generating

from transformers.

Comments (13)

zucchini-nlp avatar zucchini-nlp commented on June 9, 2024 2

@edchengg

I could localize the error. When using trainer it has a line that prepares model with accelerate, which in turn adds something like model.forward = convert_outputs_to_fp32(new_forward) casting all model outputs to fp32. That is why the past _key_values cache was casted to fp32 causing errors when generating with kv cache.

@gante, I think the issue will be resolved when we get rid of the legacy cache format, so I will leave to it you here :)
UPD: Or maybe not, just found this issue which causes error if cache class is used in prediction loop

from transformers.

ArthurZucker avatar ArthurZucker commented on June 9, 2024 1

cc @gante and @zucchini-nlp that's something we should pay attention to!

from transformers.

edchengg avatar edchengg commented on June 9, 2024 1

I got the same problem when running model inference with flashattention2.
Below is the code to reproduce the error @zucchini-nlp @gante @ArthurZucker :

from datasets import load_dataset, Dataset
import torch

checkpoint = "codellama/CodeLlama-7b-hf"
model_kwargs = {
    "attn_implementation": "flash_attention_2",
    "torch_dtype": torch.bfloat16, 
}
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")

# Create a small dataset with 5 sentences
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "I love eating pizza and watching movies.",
    "The sun is shining brightly today.",
    "She plays the piano beautifully.",
    "He enjoys reading books in his free time."
]

encodings = tokenizer(texts, padding=True, return_tensors="pt")
from datasets import Dataset
# Create the dataset
dataset = Dataset.from_dict({
    "input_ids": encodings.input_ids,
    "attention_mask": encodings.attention_mask,
    "labels": encodings.input_ids
})

# Define the training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="output",
    bf16=True, # add this leads to error
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    generation_max_length=50,
    generation_num_beams=1,
)
from transformers import Seq2SeqTrainer
# Create the Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
)

# Run prediction on the dataset

test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out)) # Okay

predictions = trainer.predict(dataset) # Error
print(tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True))

If I set bf16=True in the training arguments, the trainer.predict will crash but model.generate is okay!

    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: query and key must have the same dtype

It works fine if I remove bf16 in the argument...
The problem is that I need to use bf16, otherwise the inference speed is very slow..

from transformers.

zucchini-nlp avatar zucchini-nlp commented on June 9, 2024 1

@ArthurZucker sorry, totally forgot to reply here. While I was thinking that we might need to pass in "compute_dtype" in all cache classes, I found that #30536 actually should solve the problem.

In the linked PR we never go back and forth between legacy-new cache formats, thus no problems with incorrect casting to fp32. Will verify when the PR is merged, and close this issue :)

from transformers.

ArthurZucker avatar ArthurZucker commented on June 9, 2024

Hey, seems related to this issue,also make sure the model is in fp16 or bfloat16.

from transformers.

antonioalegria avatar antonioalegria commented on June 9, 2024

Training Args has fp16=True. But I have tried after the training to convert model to torch.float16 as well as every module inside it, doesn't make a difference, I get the same error. Haven't found any workaround that works.

from transformers.

imirzadeh avatar imirzadeh commented on June 9, 2024

I have the same issue (generating with FA2). I checked the input to FA2 module and they are all fp16. In my case using phi-2, I tracked down the issue to the flash attention code in modeling:

# before >>  q:torch.bfloat16, k:torch.bfloat16, v:torch.bfloat16
if past_key_value is not None:
   cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
   key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# after q:torch.bfloat16, k:torch.float32, v:torch.float32

This past_key_value.update somehow upcasts the KVs to fp32.

from transformers.

ylacombe avatar ylacombe commented on June 9, 2024

Hey @ArthurZucker !
I'm facing the same issue! Seems like it's coming from using DynamicCache and autocast, somewhere along the way the past key values stored in the Cache get upcasted. It's probably happening where @imirzadeh identified above.

I don't see clear way of dealing with this, as I'm not an autocast expert, but an ugly fix could be add to the condition key_values.dtype == torch.float32 here :

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

Let me know what you think and how if I can help furthermore

from transformers.

zucchini-nlp avatar zucchini-nlp commented on June 9, 2024

@antonioalegria In the case of this code snippet the model is loaded in float32, and I believe running a trainer.train() does not change the model itself to fp16 dtype. Can confirm that running your script fails for me, but the below script which loads model in fp16 works fine

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

checkpoint = "trainer_ckpt_path"
model_kwargs = {
    "attn_implementation": "flash_attention_2",
    "torch_dtype": torch.float16,  # need to indicate fp16 here to work with FA2
}

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")

test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out))

from transformers.

gante avatar gante commented on June 9, 2024

@antonioalegria as @zucchini-nlp wrote :) The fp16=True flag turns on autocast for training, so it doesn't crash at train time. However, at inference time, there is no autocast, so the model is still seen as FP32. FA2 does not support FP32, and thus it will crash.

You can also train in FP32 (with autocast) and then manually cast to FP16 / reload the trained model in FP16.

from transformers.

ManuelFay avatar ManuelFay commented on June 9, 2024

Same here, got it running with fp16 but bf16 raises the error on a A100 cluster and the Idefics2 Model !
Good luck with the fix !

from transformers.

ambroser53 avatar ambroser53 commented on June 9, 2024

I'm also running into this error but only when quantising the model. I made an edit much like the one referenced here by @ylacombe

Hey @ArthurZucker ! I'm facing the same issue! Seems like it's coming from using DynamicCache and autocast, somewhere along the way the past key values stored in the Cache get upcasted. It's probably happening where @imirzadeh identified above.

I don't see clear way of dealing with this, as I'm not an autocast expert, but an ugly fix could be add to the condition key_values.dtype == torch.float32 here :

input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

Let me know what you think and how if I can help furthermore

but that simply ensures each of the three states have the same dtype:

input_dtype = query_states.dtype
if input_dtype == torch.float32 or not (query_states.dtype == key_states.dtype == value_states.dtype):
    if torch.is_autocast_enabled():
        target_dtype = torch.get_autocast_gpu_dtype()
    # Handle the case where the model is quantized
    elif hasattr(self.config, "_pre_quantization_dtype"):
        target_dtype = self.config._pre_quantization_dtype
    else:
        target_dtype = self.q_proj.weight.dtype

    logger.warning_once(
        f"The input hidden states seems to be silently casted in float32, this might be related to"
        f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
        f" {target_dtype}."
    )

    query_states = query_states.to(target_dtype)
    key_states = key_states.to(target_dtype)
    value_states = value_states.to(target_dtype)

from transformers.

ArthurZucker avatar ArthurZucker commented on June 9, 2024

@zucchini-nlp are we juste waiting for the deprecation of the to legacy cache func?

from transformers.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.