GithubHelp home page GithubHelp logo

"RuntimeError: The size of tensor a (0) must match the size of tensor b (4096) at non-singleton dimension 1" (DPO + LoRA) about alignment-handbook HOT 6 OPEN

huggingface avatar huggingface commented on May 14, 2024
"RuntimeError: The size of tensor a (0) must match the size of tensor b (4096) at non-singleton dimension 1" (DPO + LoRA)

from alignment-handbook.

Comments (6)

lewtun avatar lewtun commented on May 14, 2024 2

Hi @ohmeow as discussed here I think indeed the issue is when trying to do the following:

  • Use DeepSpeed's zero.init() to shard the base model weights directly on GPU via this flag in the accelerate config
  • Try to merge the adapter weights on the sharded base model

I don't think we saw this issue in the original release of the code because we made a goof on the device_map for LoRA training that was later fixed in #51

If you have enough vRAM then one should be able to workaround this by setting zero3_init_flag: False in the accelerate config.

I'm discussing this with the peft team and hopefully can find a more stable solution!

from alignment-handbook.

ohmeow avatar ohmeow commented on May 14, 2024

NOTE: This only occurs if I'm using the deepspeed accelerate config and set num_processes > 1

from alignment-handbook.

ohmeow avatar ohmeow commented on May 14, 2024

So I think the solution to add accelerator.wait_for_everyone() you instantiate the DPOTrainer.

If someone can confirm that feel free to close this out. If not, lmk :)

from alignment-handbook.

ohmeow avatar ohmeow commented on May 14, 2024

I think the problem might be related to using deepspeed on my local DL rig with 2x3090s. Just switched to the multi-gpu.yaml file and the script ran no problem.

from alignment-handbook.

ohmeow avatar ohmeow commented on May 14, 2024

The only way I was able to get training to proceed was by adding device_map=get_kbit_device_map() to the model_kwargs when loading an adapter model.

    if is_adapter_model(model, model_args.model_revision):
        # load the model, merge the adapter weights and unload the adapter
        # Note: to run QLora, you will need to merge the based model separately as the merged model in 16bit
        logger.info(f"Merging peft adapters for {model_args.model_name_or_path=}")

        peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)

        model_kwargs = dict(
            revision=model_args.base_model_revision,
            trust_remote_code=model_args.trust_remote_code,
            use_flash_attention_2=model_args.use_flash_attention_2,
            torch_dtype=torch_dtype,
            use_cache=False if training_args.gradient_checkpointing else True,
            device_map=get_kbit_device_map(),
        )

        base_model = AutoModelForCausalLM.from_pretrained(peft_config.base_model_name_or_path, **model_kwargs)
        model = PeftModel.from_pretrained(base_model, model_args.model_name_or_path, revision=model_args.model_revision)
        model.eval()
        model = model.merge_and_unload()
        model_kwargs = None

    if model_args.use_peft is True:
        ref_model = None
        ref_model_kwargs = None
    else:
        ref_model = model
        ref_model_kwargs = model_kwargs

    accelerator.wait_for_everyone()

With this I can get everything running on my 2x3090s using the multi-gpu.yaml. GPU utilization looks even across both cards.

The deepspeed config works as well but for some reason fails when pushing the model to the hub. I imagine this has something to do with my machine and/or with using 3090s.

from alignment-handbook.

Randl avatar Randl commented on May 14, 2024

Can confirm that setting zero3_init_flag: False helps.

from alignment-handbook.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.