GithubHelp home page GithubHelp logo

nathangodey / headless-lm Goto Github PK

View Code? Open in Web Editor NEW
20.0 1.0 4.0 134 KB

Training and evaluation code for the paper "Headless Language Models: Learning without Predicting with Contrastive Weight Tying" (https://arxiv.org/abs/2309.08351)

Python 100.00%

headless-lm's Introduction

headless-lm: Better and Faster LM pretraining

Hugging Face Hugging Face arXiv

This repository contains training and evaluation code for the paper "Headless Language Models: Learning without Predicting with Contrastive Weight Tying".

Paper abstract:

Self-supervised pre-training of language models usually consists in predicting probability distributions over extensive token vocabularies. In this study, we propose an innovative method that shifts away from probability prediction and instead focuses on reconstructing input embeddings in a contrastive fashion via Constrastive Weight Tying (CWT). We apply this approach to pretrain Headless Language Models in both monolingual and multilingual contexts. Our method offers practical advantages, substantially reducing training computational requirements by up to 20 times, while simultaneously enhancing downstream performance and data efficiency. We observe a significant +1.6 GLUE score increase and a notable +2.7 LAMBADA accuracy improvement compared to classical LMs within similar compute budgets.


Install environment

Make sure you have Python>=3.9 and Cuda>=11.2 installed. Then run:

pip install -r requirements.txt

Preprocess data

Adapt the config file in configs/preprocess_owt2.json to your specific case, and then run:

python preprocess.py --config=configs/your_config_file.json

Training

Encoder

To train an encoder model:

  1. Write/edit model-related parameters in a config file similar to configs/mlm_headless.json
  2. Run the following command with your specific arguments:
python mlm_headless.py \
    --config configs/your_config_file.json \
    --num_nodes your-gpu-node-count \
    --global_bs your-accumulated-batch_size \
    --gpu_bs your-per-device-batch-size \
    --dataset your-preprocessed-output.hf \
    --hf_tokenizer your-tokenizer \
    --hf_path path-to-your-model-arch-on-HF \
    --model_max_seq_len models-max-pos-embeddings \
    --run_name run-name-for-logging-and-ckpts \
    --saved_ckpt_path where-to-save-ckpts

Other args include --accelerator (hf, xformers or flash_attention), --ckpt_every to pick checkpoint frequency, among others.

  1. Pick your checkpoint and publish it to HuggingFace:
python hf_publisher.py \
    --hf_name your_hf_id/your_model \
    --model_ckpt your_model.ckpt \
    --mode mlm

Decoder

To train a decoder model:

  1. Write/edit model-related parameters in a config file similar to configs/gpt_headless_70m.json
  2. Run the following command with your specific arguments:
python gpt_headless.py \
    --config configs/your_config_file.json \
    --num_nodes your-gpu-node-count \
    --global_bs your-accumulated-batch_size \
    --gpu_bs your-per-device-batch-size \
    --dataset your-preprocessed-output.hf \
    --hf_tokenizer your-tokenizer \
    --hf_path path-to-your-model-arch-on-HF \
    --model_max_seq_len models-max-pos-embeddings \
    --run_name run-name-for-logging-and-ckpts \
    --saved_ckpt_path where-to-save-ckpts

Other args include --accelerator (hf, xformers or flash_attention), --ckpt_every to pick checkpoint frequency, among others.

  1. (optional) Pick your checkpoint and publish it to HuggingFace. You'll need to use the add_head option to make it able to output tokens:
python hf_publisher.py \
    --hf_name your_hf_id/your_model \
    --model_ckpt your_model.ckpt \
    --mode add_head
  1. The resulting model will probably perform poorly for language generation. Why? Because it was not trained to do it! To turn your contrastive model into a good LM, you'll need add a head and fine-tune it. Setup a config file in the style of config/gpt_vanilla_ft.json and run:
python ft_gpt_headless.py \
    --ckpt_path your_headless_model.ckpt' \
    --config configs/your_ft_config.json \
    --num_nodes your-gpu-nodes \
    --global_bs your-accumulated-bs \
    --gpu_bs your-device-bs \
    --dataset your-preprocessed-output.hf \
    --run_name run-name-for-logging-and-ckpts \
    --saved_ckpt_path where-to-save-finetuned-ckpts
  1. Pick your fine-tuned checkpoint and publish it to HuggingFace. You don't need to use the add_head option anymore as you just trained one:
python hf_publisher.py \
    --hf_name your_hf_id/your_model \
    --model_ckpt your_model.ckpt \
    --mode lm

Evaluation

You can now use any zero-shot or fine-tuning code to evaluate your models. We provide our GLUE fine-tuning script in glue_finetuning.py, and we used the LM Eval Harness for zero-shot evaluation.

Citation

This repo contains the code that was used for the experiments of the paper "Headless Language Models: Learning without Predicting with Contrastive Weight Tying".

@misc{godey2023headless,
      title={Headless Language Models: Learning without Predicting with Contrastive Weight Tying}, 
      author={Nathan Godey and Éric de la Clergerie and Benoît Sagot},
      year={2023},
      eprint={2309.08351},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

headless-lm's People

Stargazers

Joseph Cheng avatar Ricardo Herrmann avatar felix-wang avatar Kristian Klemon avatar 爱可可-爱生活 avatar Alex avatar Shahrukh Khan avatar Wissam Antoun avatar  avatar Laxman Singh Tomar avatar Sandalots avatar Sulthan Abiyyu Hakim avatar Kashif Rasul avatar Omar Sanseviero avatar david avatar Manuel Romero avatar Aflah avatar Thibault Clérice avatar Stefan Schweter avatar Hiroto Kurita avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.