GithubHelp home page GithubHelp logo

zhziszz / weak-to-strong-search Goto Github PK

View Code? Open in Web Editor NEW
9.0 2.0 1.0 29 KB

Weak-to-Strong Search: Align Large Language Models via Searching over Small Language Models

Home Page: https://arxiv.org/abs/2405.19262

Python 100.00%

weak-to-strong-search's Introduction

Weak-to-Strong Search

Code release for Weak-to-Strong Search: Align Large Language Models via Searching over Small Language Models.

  • The scripts/instruction_following directory contains code and instructions for using off-the-shelf small/weak models to guide the decoding of large/strong models to better follow human instructions.

  • The scripts/controlled_sentiment_generation directory contains code and instructions for using tuned and untuned gpt2s (124M) to control larger models to write positive movie reviews.

Installation

conda create -n weak-to-strong-search python=3.10
conda activate weak-to-strong-search
pip install torch=2.1.0 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
# (optional) pip install flash-attn==2.3.2 --no-build-isolation
# (optional) pip install bitsandbytes==0.42.0

Quick Start

(Click to expand) To use HuggingFaceH4/zephyr-7b-beta and its untuned verision HuggingFaceH4/mistral-7b-sft-beta to guide the decoding of meta-llama/Meta-Llama-3-8B-Instruct for better alignment.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.inference_time_alignment.decoders.cbs import CBSPosthocGenerationMixin
from src.inference_time_alignment.scorers import ImplicitValueScorer


def get_zephyr_scorer() -> ImplicitValueScorer:
    """
    Use `zephyr-7b-beta` and its untuned verision `mistral-7b-sft-beta` as scorer to guide other models
    """
    tuned_model = AutoModelForCausalLM.from_pretrained(
        "HuggingFaceH4/zephyr-7b-beta", torch_dtype=torch.bfloat16, device_map="auto")
    untuned_model = AutoModelForCausalLM.from_pretrained(
        "HuggingFaceH4/mistral-7b-sft-beta", torch_dtype=torch.bfloat16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    prompt_template = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": ""},
            {"role": "user",   "content": "{raw_prompt}"},
        ],
        tokenize=False, 
        add_generation_prompt=True,
    )
    implicit_value_scorer = ImplicitValueScorer(
        model=tuned_model,
        ref_model=untuned_model,
        tokenizer=tokenizer,
        model_prompt_template=prompt_template,
        ref_model_prompt_template=prompt_template,
    )
    return implicit_value_scorer


# the (stonger/larger) model to be guided
base = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
prompt_template = tokenizer.apply_chat_template(
    [
        {"role": "system", "content": ""},
        {"role": "user",   "content": "{raw_prompt}"},
    ],
    tokenize=False, 
    add_generation_prompt=True,
)

# chunk-level beam search wrapper
cbs_model = CBSPosthocGenerationMixin(base, tokenizer)
# implicit value scorer
scorer = get_zephyr_scorer()

# prepare prompts
raw_prompt = "Who are you?"
prompt = prompt_template.format(raw_prompt=raw_prompt)
prompt_tokenized = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
prompt_len = prompt_tokenized["input_ids"].size(1)

# search for the highest scoring response
outputs = cbs_model.search(
    input_ids=prompt_tokenized["input_ids"].cuda(),
    attention_mask=prompt_tokenized["attention_mask"].cuda(),
    scorer=scorer.set_raw_prompt(raw_prompt),
    split_by_prompt_text=False,
    w=2, k=2, l=30, # CBS related args 
    max_new_tokens=128,
)

print(tokenizer.decode(outputs[0][prompt_len:], skip_special_tokens=True))

See scripts/instruction_following for more examples.

Reference

@article{zhou2024weak,
  title={Weak-to-Strong Search: Align Large Language Models via Searching over Small Language Models},
  author={Zhou, Zhanhui and Liu, Zhixuan and Liu, Jie and Dong, Zhichen and Yang, Chao and Qiao, Yu},
  journal={arXiv preprint arXiv:2405.19262},
  year={2024}
}

weak-to-strong-search's People

Contributors

zhziszz avatar

Stargazers

Shamy Ji avatar Jiaxin Zhang avatar  avatar Xiaoyu Zhang avatar Jiaheng Liu avatar root avatar  avatar  avatar Jie Liu avatar

Watchers

Kostas Georgiou avatar  avatar

Forkers

jxzhangjhu

weak-to-strong-search's Issues

stop cheating

If you are not ready to open source your code, please do not create it .

arXiv is facing many researchers. I understand you may need claiming your work or whatever.

Posting it on arXiv should be sufficient for claiming your work, and you can always update the link afterwards. The paper claims to have code but says it will be available later (TBA), which makes readers feel deceived and will leave a bad impression on the corresponding and affiliated institutions, AKA shanghai ailab.

Perhaps many people in academia would think: 'Isn't arXiv just like that, where you can just post anything casually?' It's not only people in academia who use arXiv; those from second-tier universities also use it.

As a middle-aged person, I feel quite inappropriate to judge other people's work, and at the same time, I find it quite meaningless to face these academic games.

Now that the author stick to keep old title, then keep it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.