Comments (6)
Hi @ohmeow as discussed here I think indeed the issue is when trying to do the following:
- Use DeepSpeed's
zero.init()
to shard the base model weights directly on GPU via this flag in theaccelerate
config - Try to merge the adapter weights on the sharded base model
I don't think we saw this issue in the original release of the code because we made a goof on the device_map
for LoRA training that was later fixed in #51
If you have enough vRAM then one should be able to workaround this by setting zero3_init_flag: False
in the accelerate
config.
I'm discussing this with the peft
team and hopefully can find a more stable solution!
from alignment-handbook.
NOTE: This only occurs if I'm using the deepspeed accelerate config and set num_processes
> 1
from alignment-handbook.
So I think the solution to add accelerator.wait_for_everyone()
you instantiate the DPOTrainer
.
If someone can confirm that feel free to close this out. If not, lmk :)
from alignment-handbook.
I think the problem might be related to using deepspeed on my local DL rig with 2x3090s. Just switched to the multi-gpu.yaml file and the script ran no problem.
from alignment-handbook.
The only way I was able to get training to proceed was by adding device_map=get_kbit_device_map()
to the model_kwargs
when loading an adapter model.
if is_adapter_model(model, model_args.model_revision):
# load the model, merge the adapter weights and unload the adapter
# Note: to run QLora, you will need to merge the based model separately as the merged model in 16bit
logger.info(f"Merging peft adapters for {model_args.model_name_or_path=}")
peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)
model_kwargs = dict(
revision=model_args.base_model_revision,
trust_remote_code=model_args.trust_remote_code,
use_flash_attention_2=model_args.use_flash_attention_2,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map(),
)
base_model = AutoModelForCausalLM.from_pretrained(peft_config.base_model_name_or_path, **model_kwargs)
model = PeftModel.from_pretrained(base_model, model_args.model_name_or_path, revision=model_args.model_revision)
model.eval()
model = model.merge_and_unload()
model_kwargs = None
if model_args.use_peft is True:
ref_model = None
ref_model_kwargs = None
else:
ref_model = model
ref_model_kwargs = model_kwargs
accelerator.wait_for_everyone()
With this I can get everything running on my 2x3090s using the multi-gpu.yaml
. GPU utilization looks even across both cards.
The deepspeed config works as well but for some reason fails when pushing the model to the hub. I imagine this has something to do with my machine and/or with using 3090s.
from alignment-handbook.
Can confirm that setting zero3_init_flag: False
helps.
from alignment-handbook.
Related Issues (20)
- Does QLora DPO Training support reference model?
- Cannot apply "run_dpo.py" on a trained Axolotl model
- Reward Modeling Support
- DPO loss on different datasets
- Using MT-Bench to evaluate zephyr HOT 2
- About DPO formatting before fine-tuning HOT 4
- system message being included in chosen & rejected when chat_template inserts system message HOT 2
- Cost of Generating a Dataset for Constitutional AI
- ImportError: Flash Attention 2 is not available
- (QLoRA) DPO without previous SFT HOT 1
- DPO recipe saves a float32 model
- Zephyr-dpo-full Checkpoints perform poorly on TruthfulQA. HOT 1
- cannot replicate DPO results of zephyr HOT 5
- Major bug: Chat template is not actually applied in run_sft.py and run_dpo.py HOT 7
- Estimated Time for SFT Fine-Tuning of Mistral-7B Model HOT 1
- Minor question about PAD token and EOS token. HOT 2
- Downloading latest CUDA version (11.6 or above) for MacOS to use FlashAttention
- Not able to run Zephyr 7B Gemma with 4 80GB A100s HOT 1
- Early Stopping Issue when used with ConstantLengthDataset
- Is there a way to freeze some layers of a model ?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from alignment-handbook.