GithubHelp home page GithubHelp logo

arazd / progressiveprompts Goto Github PK

View Code? Open in Web Editor NEW
87.0 2.0 10.0 23.71 MB

Progressive Prompts: Continual Learning for Language Models

Home Page: https://arxiv.org/pdf/2301.12314.pdf

License: Apache License 2.0

Python 58.24% Shell 1.57% Jupyter Notebook 40.20%
nlp continual-learning prompt-tuning llms

progressiveprompts's Introduction

Progressive Prompts

Our work on Progressive Prompts is accepted to ICLR 2023! ๐ŸŽ‰

This repo includes an original implementation of Anastasia Razdaibiedina, Yuning Mao, Rui Hou, Madian Khabsa, Mike Lewis and Amjad Almahairi. "Progressive Prompts: Continual Learning for Language Models", ICLR 2023.

Table of contents

๐ŸŒŸ Introduction

We introduce Progressive Prompts โ€“ a novel Continual Learning (CL) approach for language models. Our method is inspired by progressive networks (A. Rusu et al., NeurIPS 2017), but is significantly more memory-efficient. In Progressive Prompts, we learn a separate set of virtual tokens, or soft prompt (B. Lester et al., EMNLP 2021), for each incoming task and sequentially concatenate it with previously learned prompts.

Our method can:

  1. alleviate catastrophic forgetting; since it preserves the knowledge acquired by previous prompts, and
  2. transfer knowledge to future tasks; since new prompts are sequentially concatenated with all prior prompts.

Progressive Prompts schematics Figure: Illustrating our proposed method Progressive Prompts and contrasting it with a simple adaptation of progressive networks using prompt tuning. In the simple adaptation of progressive networks we learn a separate prompt and repeat the frozen input embeddings for each new task. This setup requires repeating input tokens for each task. In Progressive Prompts we use the same input and progressively append new prompt for each new task. Prior task prompts are not modified by the addition of new prompts.

โ“ What's in this repository

This is our code structure:

|_T5_codebase/
      |_t5_dataset.py --> T5 Dataset class for reading and processing datasets
      |_t5_continual.py --> Model class for T5 with prompt tuning and continual learning functions
      |_train_t5_cl.py --> Code to run continual learning experiments with T5
      
|_BERT_codebase/
      |_dataset_utils.py --> BERT Dataset class for reading and processing datasets
      |_model_utils.py --> Model class for BERT with prompt tuning and fine-tuning functions
      |_continual_learning_utils.py --> Continual Learner class for Progressive Prompts (with BERT)
      |_continual_learning_one_head.py --> Continual Learner class for regularization-based CL approaches for BERT 
      |_train_cl2.py --> Code to run continual learning experiments with BERT
      
|_datasets/src/data/ --> CL datasets from Zhang et. al., 2015
      |_amazon --> Amazon reviews (zip archive, since dataset is not available through HuggingFace datasets)
      (the rest of datasets can be either accessed through HuggingFace or downloaded by instructions below)

Note: we access most of the datasets for our experiments through HuggingFace datasets, including CL datasets from Zhang et. al., 2015. Since only one CL datasets from Zhang et. al. is not available on HuggingFace - Amazon Reviews, we uploaded its archived train / test data to datasets/src/data/amazon/. To access the rest of CL datasets (Yelp, Yahoo, AG, DbPedia), you can either use their HuggingFace names in our training script or download them from http://goo.gl/JyCnZq to datasets/src/data/.

๐Ÿ”ง Installation

Our implementation is based on PyTorch and HuggingFace (transformers + datasets).

Requirements:

  • Python 3.8.5
  • Pytorch 1.10.0
  • transformers 4.20.0
  • datasets 2.3.2
  • tqdm, sklearn, numpy, pandas

Step-by-step instructions to get you running Progressive Prompts:

1) Clone this repository to your local machine:

git clone https://github.com/arazd/ProgressivePrompts    

A folder called ProgressivePrompts with all the codebase should appear.

2) Install the required packages:

Make sure that you have Anaconda installed. If not - follow this miniconda installation.

To run Progressive Prompts code on GPU, make sure that you have a CUDA capable GPU and the drivers for your GPU are up to date. In our implementation, we used and CUDA 11.0.

You can re-create our conda enviroment from environment.yaml file:

cd ProgressivePrompts
conda env create -f environment.yaml

Your conda should start downloading and extracting packages. This can take ~15-20 minutes.

3) Activate the environment:

Your environment should be called nlp, and you can activate it now to run the scripts:

conda activate nlp

โšก How to run

For example, to run Progressive Prompts with T5-large on four tasks (IMDb, CB, SST-2 and DbPedia):

cd T5_codebase

python train_t5_cl.py --task_list imdb cb sst2 dbpedia_14 --select_k_per_class 1000 \
--lr 0.3 --num_epochs 10 --freeze_weights 1 --prefix_len 10 \
--model_name t5-large --early_stopping 1 \
--save_name T5_experiment --save_dir my_path_to_save_directory

In the example above, we froze weights and trained a prompt of size 10 (per task) for 10 epochs. We also limited data to 1000 samples per class. For other arguments and their descriptions, please check T5_codebase/train_t5_cl.py file.

To train Progressive Prompts on the same four tasks with BERT-base:

cd BERT_codebase

python train_cl2.py --task_list imdb cb sst2 dbpedia_14  --select_k_per_class 1000 \
--lr 3e-5 --num_epochs 50 --freeze_weights 1 --freeze_except word_embeddings \
--prompt_tuning 1 --prefix_len 10 --seq_len 450 --one_head 0 \
--model_name bert-base-uncased --early_stopping 1 \
--save_name BERT_experiment --save_dir my_path_to_save_directory

Note how soft prompts for BERT need to be trained with smaller learning rate and higher number of epochs. We also have some other BERT-specific arguments, one_head controls whether to use a separate head for each task, freeze_except allows to freeze all weights except word embeddings (since we include prompt tokens into vocabulary for BERT implementation), seq_len controls max input length (without prompt), prompt_tuning flag signals if we are doing prompt tuning. For other arguments and their descriptions, please check BERT_codebase/train_cl2.py file.

๐Ÿ™‹ Questions

If you have any questions about the paper or code, please contact Anastasia Razdaibiedina (anastasia.razdaibiedina[at]mail.utoronto.ca) or open an issue.

๐Ÿ“š Citation

If you use our code in your research, please cite our work:

@inproceedings{razdaibiedina2023progressive,
   title={Progressive Prompts: Continual Learning for Language Models},
   author={Razdaibiedina, Anastasia and Mao, Yuning and Hou, Rui and Khabsa, Madian and Lewis, Mike and Almahairi, Amjad},
   booktitle={International Conference on Learning Representations},
   year={2023}
}

progressiveprompts's People

Contributors

arazd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

progressiveprompts's Issues

Differences between Progressive Prompt and LFPT5

Dear authors,
Congratulations for your great effort to release source code. I really enjoyed your paper when I read it. However, I realized that the idea of progressive prompt is quite similar to the idea in Sec. 3.2.1 in LFPT 5 paper in which authors proposed to add a new set of learnable prompt tokens when a new task come, learned tokens from previous task are frozen, and all prior prompt tokens is concatenate with new ones in training a new task. The main different between these two is the use of embedding re-parameterization indicated in Sec. 3 in your paper. Am I correct? Feel free to correct me if I misunderstood.
P/s: And one more question about the ablation studies of the paper: Did you compare LFPT5 with Progressive Prompt w/o MLP and Progressive Prompt w MLP in the continual setting?
Thanks.

trouble with bert-based model

Dear Dr.Arazd @arazd ,
Thanks for your great work. I'm trying to replicate your result in Table.1 with the order 4 (5 tasks - bert base-uncased model) - CL setting (full data) in the main paper with the following cmd:
python train_cl2.py --task_list ag yelp_review_full amazon yahoo dbpedia --prefix_MLP residual_MLP2 --lr 1e-4 --num_epochs 40 --freeze_weights 1 --freeze_except word_embeddings \ --prompt_tuning 1 --prefix_len 20 --seq_len 450 --one_head 0 \ --model_name bert-base-uncased --early_stopping 1 \ --save_name BERT_order_4_run1 --save_dir ./results

However, when the progressive prompt model evaluate the accuracy on all dataset, it thrown an error like this when it started evaluating on yahoo dataset:
/opt/conda/conda-bld/pytorch_1639180487213/work/aten/src/ATen/native/cuda/Indexing.cu:699: indexSelectLargeIndex: block: [60,0,0], thread: [96,0,0] Assertion srcIndex < srcSelectDimSize failed.

This issue only appears in the evaluation of the 4th task, I tried other settings such as shorter sequence (2, 3 tasks), removing ResMLP, ... but it works as normal. I also tried to print the pattern of the input ids, token_type_ids, and position_ids but the 4th task's pattern is similar to previous tasks.

the error stems from this line in your repo: https://github.com/arazd/ProgressivePrompts/blob/01572d6a73c0576b070ceee00dbe4f5bc278423f/BERT_codebase/model_utils.py#L576

Could you give me some insight about this problem?
I really appreciate if you can help me to fix that.

P/s: After troubleshooting the source of those problem, I found that it only occur when the sequence length longer than 4, and evaluating with full validation set.

Questions about the data preprocessing method

Hi,

thanks for your excellent paper! I have a questions about the data preprocessing in the implementation code.

A small observation about data processing, for the 'yahoo', 'amazon', 'agnews' and 'dbpedia' dataset, besides the 'content' column, they also include the 'title' (and 'answer' for 'yahoo') column. In the implementation in IDBR and MbPA++, they only use 'content' for training and discard other information such as 'title' and 'answer'. In Progressive Prompts all texts are taken into account, I'm wondering if this will leads to some performance difference and did you compare them?

Many thanks in advance for your clarification!

Question about ProgPrompt vs. per-task prompt and Backward Transfer

Dear authors,

Thank you for your effort to release the code. The idea of the paper is simple yet effective, I really like it!

I have 2 concerns after reading the paper and the code.

  1. Why does the Progressive Prompt achieve higher performance than the per-task prompt?

    • I buy the explanation of forward transferring. Intuitively, however, if given adequate training samples and epochs for incoming tasks, the per-task prompts should have better performance as it has more parameters to overfit each task. Based on your experiments, I can see a smaller and smaller gap between these two as the number of training samples rises.
    • On the contrary talking of the few-shot setting, could it be that the Progressive Prompt has access to M*N training samples compared to each per-task prompt and only sees N samples for each task in M tasks? Do you have any experimental or intuitive explanation for this?
  2. How do you test the previous tasks when new prompts are already concatenated?

    • during test time for the old tasks, do you use all the concatenated prompts for evaluating? I suppose this should have either positive or negative effects on old tasks. For example, when learning the 2nd task, you have only 2 prompts concatenated. After learning all 15 tasks, you have 15 prompts. The performance of 15 prompts and 2 prompts on the 2nd task should be different, which results in a non-zero Backward Transfer. However, in your result, the BWT is all zero for Progressive Prompt. (Fig. 8-10 in the Appendix)
    • In the Openreview I saw some reviews pointing out the usage of the task identifier. I found the identifier is used for different outputs in T5 and different classifier heads in Bert, which seems inevitable. You also used that for RepMLP for different tasks. Why not store the MLP(new_prompt) after training on one task, so at least the MLP part is task-agnostic during inference or new task learning? The code shows that you only MLP(task_prompt) when inference or validation but keep the other prompts as original. This makes me confused.

Feel free to correct me if I misunderstood.

Many Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.