Comments (13)
I could localize the error. When using trainer it has a line that prepares model with accelerate
, which in turn adds something like model.forward = convert_outputs_to_fp32(new_forward)
casting all model outputs to fp32. That is why the past _key_values
cache was casted to fp32 causing errors when generating with kv cache.
@gante, I think the issue will be resolved when we get rid of the legacy cache format, so I will leave to it you here :)
UPD: Or maybe not, just found this issue which causes error if cache class is used in prediction loop
from transformers.
cc @gante and @zucchini-nlp that's something we should pay attention to!
from transformers.
I got the same problem when running model inference with flashattention2.
Below is the code to reproduce the error @zucchini-nlp @gante @ArthurZucker :
from datasets import load_dataset, Dataset
import torch
checkpoint = "codellama/CodeLlama-7b-hf"
model_kwargs = {
"attn_implementation": "flash_attention_2",
"torch_dtype": torch.bfloat16,
}
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")
# Create a small dataset with 5 sentences
texts = [
"The quick brown fox jumps over the lazy dog.",
"I love eating pizza and watching movies.",
"The sun is shining brightly today.",
"She plays the piano beautifully.",
"He enjoys reading books in his free time."
]
encodings = tokenizer(texts, padding=True, return_tensors="pt")
from datasets import Dataset
# Create the dataset
dataset = Dataset.from_dict({
"input_ids": encodings.input_ids,
"attention_mask": encodings.attention_mask,
"labels": encodings.input_ids
})
# Define the training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="output",
bf16=True, # add this leads to error
per_device_eval_batch_size=1,
predict_with_generate=True,
generation_max_length=50,
generation_num_beams=1,
)
from transformers import Seq2SeqTrainer
# Create the Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
)
# Run prediction on the dataset
test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out)) # Okay
predictions = trainer.predict(dataset) # Error
print(tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True))
If I set bf16=True in the training arguments, the trainer.predict will crash but model.generate is okay!
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: query and key must have the same dtype
It works fine if I remove bf16 in the argument...
The problem is that I need to use bf16, otherwise the inference speed is very slow..
from transformers.
@ArthurZucker sorry, totally forgot to reply here. While I was thinking that we might need to pass in "compute_dtype" in all cache classes, I found that #30536 actually should solve the problem.
In the linked PR we never go back and forth between legacy-new cache formats, thus no problems with incorrect casting to fp32. Will verify when the PR is merged, and close this issue :)
from transformers.
Hey, seems related to this issue,also make sure the model is in fp16 or bfloat16.
from transformers.
Training Args has fp16=True. But I have tried after the training to convert model to torch.float16 as well as every module inside it, doesn't make a difference, I get the same error. Haven't found any workaround that works.
from transformers.
I have the same issue (generating with FA2). I checked the input to FA2 module and they are all fp16. In my case using phi-2
, I tracked down the issue to the flash attention code in modeling:
# before >> q:torch.bfloat16, k:torch.bfloat16, v:torch.bfloat16
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# after q:torch.bfloat16, k:torch.float32, v:torch.float32
This past_key_value.update
somehow upcasts the KVs to fp32.
from transformers.
Hey @ArthurZucker !
I'm facing the same issue! Seems like it's coming from using DynamicCache
and autocast
, somewhere along the way the past key values stored in the Cache
get upcasted. It's probably happening where @imirzadeh identified above.
I don't see clear way of dealing with this, as I'm not an autocast expert, but an ugly fix could be add to the condition key_values.dtype == torch.float32
here :
transformers/src/transformers/models/mistral/modeling_mistral.py
Lines 431 to 449 in 60dea59
Let me know what you think and how if I can help furthermore
from transformers.
@antonioalegria In the case of this code snippet the model is loaded in float32
, and I believe running a trainer.train() does not change the model itself to fp16 dtype. Can confirm that running your script fails for me, but the below script which loads model in fp16 works fine
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
checkpoint = "trainer_ckpt_path"
model_kwargs = {
"attn_implementation": "flash_attention_2",
"torch_dtype": torch.float16, # need to indicate fp16 here to work with FA2
}
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")
test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out))
from transformers.
@antonioalegria as @zucchini-nlp wrote :) The fp16=True
flag turns on autocast for training, so it doesn't crash at train time. However, at inference time, there is no autocast, so the model is still seen as FP32. FA2 does not support FP32, and thus it will crash.
You can also train in FP32 (with autocast) and then manually cast to FP16 / reload the trained model in FP16.
from transformers.
Same here, got it running with fp16
but bf16
raises the error on a A100 cluster and the Idefics2 Model !
Good luck with the fix !
from transformers.
I'm also running into this error but only when quantising the model. I made an edit much like the one referenced here by @ylacombe
Hey @ArthurZucker ! I'm facing the same issue! Seems like it's coming from using
DynamicCache
andautocast
, somewhere along the way the past key values stored in theCache
get upcasted. It's probably happening where @imirzadeh identified above.I don't see clear way of dealing with this, as I'm not an autocast expert, but an ugly fix could be add to the condition
key_values.dtype == torch.float32
here :transformers/src/transformers/models/mistral/modeling_mistral.py
Lines 431 to 449 in 60dea59
Let me know what you think and how if I can help furthermore
but that simply ensures each of the three states have the same dtype:
input_dtype = query_states.dtype
if input_dtype == torch.float32 or not (query_states.dtype == key_states.dtype == value_states.dtype):
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
from transformers.
@zucchini-nlp are we juste waiting for the deprecation of the to legacy cache func?
from transformers.
Related Issues (20)
- Trainer memory leak for evaluation with `compute_metrics`
- Llama Model throwing "RuntimeError: expected scalar type BFloat16 but found Float" when using torch.compile and AMP together HOT 8
- [LLaMA3] 'add_bos_token=True, add_eos_token=True' seems not taking effect HOT 4
- google/siglip-so400m-patch14-384 inference output mismatch with pipeline output HOT 4
- Why using empty tensor to initialize? HOT 3
- Allow `ConversationalPipeline` to receive string input HOT 3
- Weird behaviour running AWQ code on RTX 4000 Ada that worked on Tesla T4 HOT 5
- AttributeError: 'BertModel' object has no attribute 'attn_implementation' HOT 16
- Training GPT2 with run_clm.py exceeds the described memory amount . HOT 2
- LayoutLMv3 Significant Training Slowdown from 4.33.3 -> 4.34.0 and beyond versions HOT 13
- Off-by-one error in strided perplexity calculation
- RuntimeError: unique_by_key: failed to synchronize: cudaErrorIllegalAddress: an illegal memory access was encountered HOT 2
- Autotokenizer."from_pretrained" read wrong config file. not "tokenizer_config.json", but "config.json" HOT 3
- ViTLayer.forward() needs to be in "eager" mode when `output_attentions=True` HOT 2
- Fix for hardcoded `final_labels` to enable loss calculation in PaliGemma HOT 9
- Sentence Transformers Gets Stuck loading HOT 3
- Paligemma causal attention still not causal ? HOT 5
- Add Nomic Embed Code to Transformers HOT 2
- loss calculation for PaliGemmaForConditionalGeneration potentially not cast to correct device HOT 2
- Trainer should throw a warning if max_sequence_length < number of tokens in dataset sample record. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from transformers.