GithubHelp home page GithubHelp logo

contextualai / halos Goto Github PK

View Code? Open in Web Editor NEW
541.0 6.0 27.0 6 MB

A library with extensible implementations of DPO, KTO, PPO, ORPO, and other human-aware loss functions (HALOs).

Home Page: https://arxiv.org/abs/2402.01306

License: Apache License 2.0

Python 94.30% Shell 5.70%
alignment dpo halos kto ppo rlhf

halos's People

Contributors

kawin-contextual-ai avatar kawine avatar xwinxu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

halos's Issues

How to sample from HF models?

Hi, this is great great project!

I'm thinking about reproducing your scores in the paper first. Specifically from the model ContextualAI/archangel_sft_llama7b. How to sample from this HF model? Right now the eval.py takes in a path from /data/..., which exists only if I trained it myself.

Can you provide a clear description of the dataset structure we can use for our custom dataset.

Something like this?

From the Huggingface DPOTrainer docs:

dpo_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Java",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"C++",
],
}

a few queries

Hello Kawin and authors,

Thanks for sharing the excellent work. I enjoyed reading the technical report it was crisp and concise, I appreciate the efforts put in to achieve this clarity.

I have a couple of questions:

  1. In the KTO loss function, the second term of expected KL divergence can be interpreted as the average reward score obtained over 'rejected' responses over all input prompts (in desired response case). In this sense, it is similar to DPO where the second term is 'reward' for the rejected response of that particular x but here we are replacing it with the average reward for 'rejected' response from all the xs.
    Is this understanding correct?
  2. I saw in the code that there is KTOZero implementation where the expected KL divergence is replaced with 0. I am curious what was its finding.
  3. On a lighter note, I would like to know what is the font used in the technical report.

thanks,
Onkar

Question about the KL term in the loss function

For y ~ p_chosen, the KL term for the loss is derived as KL(p_policy(y_rejected|x') || p_reference(y_rejected|x'))), however according to technical report, the KL divergence should be calculated over the entire dataset. Is there a reason x' for above only includes only rejected inputs and not input x's from the entire batch (i.e include y_chosen too) ? I can guess maybe numerical stability could be the reason to make sure that two terms of the loss aren't correlated but want to make sure I am not missing something here.

Error fix for evaluation script `eval.py`

Thank you for maintaining such an important repository. While using your repository, I was faced with some issues using the eval.py, and wanted to check if the following changes required are indeed valid.

  1. In line 86 of eval.py, we need to modify reference_model.resize_token_embeddings(len(tokenizer)) into the following:
if config.loss.name == 'ppo':
    reference_model.resize_token_embeddings(len(tokenizer))

this is because we have resized the token embeddings for the reference model only if the loss was ppo (refer to line 170 of train.py)

  1. When we use the eval mode of eval.py, we need to change the trainer type from BasicTrainer to DPOTrainer, otherwise we are faced with the NotImplementedError while calling the get_batch_metrics method.

Once again, thank you for maintaining an amazing repository which is very easy to use. Thank you so much!

ERROR:None of the inputs have requires_grad=True. Gradients will be None

Computing eval metrics: 0%| | 0/86 [00:00<?, ?it/s]/opt/conda/envs/python3.10.13/lib/python3.10/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

Gradient Clipping for FSDP

Hi! Thank you for maintaining such a valuable repository.
I would like to suggest a minor fix regarding gradient clipping. For FSDP, we should not use torch.nn.utils.clip_grad_norm_ (relevant issue), but instead directly call the clip_grad_norm_ method of the FSDP module. Thus, I would like to suggest modifying the following:

HALOs/trainers.py

Lines 453 to 455 in f9f7826

def clip_gradient(self):
"""Clip the gradient norm of the parameters of a non-FSDP policy."""
return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.model.max_grad_norm).item()

into the following:

        def clip_gradient(self):
            """Clip the gradient norm of the parameters."""
            if self.fsdp:
                return self.policy.clip_grad_norm_(self.config.model.max_grad_norm).item()
            return torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.model.max_grad_norm).item()

Thank you so much!

Is there a problem with training?

Initially, I used KTO for training, and the loss did not converge at all, as shown in the following training result graph.
725bd748-1d85-4d80-90ff-bd2275031467

Later, based entirely on llama7b and hh data, I used the script you provided exactly: python train.py loss=sft model=llama7b datasets=[hh] exp_name=llama7b_sft mode=train ++cache_dir=/data/models, and the training result graph is as follows:
image (1)
image

The only difference in the SFT training from yours is that I set use_flash_attention to false.

Losses list appears to be empty for loss=DPO

Hi,

I got this to run using fsdp. When I print the metrics they are all strangely empty. I follow the PairedPreferenceTrainer class, which calls self.loss. When following self.loss, loss is defined in DPOTrainer. I add a print right after line :699 -> this is oddly an empty tensor. Any ideas?

Thanks!

image

In dataloder

pairs: List[Tuple[int, int]] = field(default_factory=list) # indices in responses, where i > j in pair (i,j) is a preference

In this line, does "i > j" mean "score(i) > score(j)"? I'm a bit confused, thank you for your clarification!
image

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.