GithubHelp home page GithubHelp logo

salesforce / overture Goto Github PK

View Code? Open in Web Editor NEW
20.0 7.0 1.0 977 KB

Library for soft prompt tuning

License: BSD 3-Clause "New" or "Revised" License

Python 100.00%
prompt-tuning nlp deep-learning python pytorch soft-prompt-tuning

overture's Introduction

Project Overture - A Prompt-Tuning Library for Researchers

Why name it Overture? An overture in music is an orchestral piece at the beginning which sets the mood and tone for what's about to come. We think of prompt tuning as analogous to that; in the types of prompt tuning methods we consider, a prompt is prepended to the input that sets the tone for the downstream task.

Introduction

Prompt Tuning has recently become an important research direction in Natural Language Processing. In contrast to classical fine-tuning, which involves optimizing the weights of the entire network, (one style of) prompt tuning keeps the large language model (a.k.a. the "backbone") frozen and instead prepends a few learnable vectors to each input which are learnt in order to accomplish a task. This brings the number of parameters to train from O(millions) down to a few thousand while still achieving similar levels of performance. There are other benefits that have been found in the research community for prompt tuned models when compared to classically trained models.

Methods Supported

The repository leverages the HuggingFace Transformers repository and currently, we support WARP-like prompt-tuning for masked language modeling(MLM), text classification models, and extractive question answering (e.g., SQuAD). We plan on adding support for Seq2Seq prompt-tuning soon. If there is any other algorithm/method that you would like for us to prioritize, please write to us or file a feature request. Finally, we refer an interested reader to the excellent survey on the topic for the various types of prompt tuning methods and their history.

Some Potential Extensions

Here are some research ideas one could experiment with our codebase. Since the community is evolving rapidly, it is entirely possible that some of these ideas have already been studied. Please file an issue if that is the case, or if you want to contribute more ideas.

  1. Does prompt tuning on a multilingual backbone (e.g., mBERT or XLM) lead to models that can perform cross-lingual zero-shot transfer?
  2. How can we make the prompts more interpretable? Could adding a loss to make the prompt vectors be close to existing word embeddings help?
  3. Can prompts learned for BERT-Large help learn prompts for RoBERTa-Large?

Design Choices & Other Similar Libraries

Fundamentally, we designed the repository for researchers to easily experiment with ideas within the realm of prompt-tuning. As such, we intentionally do not abstract away the sub-components. The repository is intended to be a fork-and-edit library and is designed to be easily extensible for the kinds of projects we anticipated people to use the library for.

A recently released library, OpenPrompt, is also intended to be a library for prompt tuning and we refer an interested practitioner to their repository for further exploration and comparisons. OpenPrompt may be a better fit for those who seek greater abstraction.

How to Use

Inside the examples folder, we provide training code for RoBERTa-Large model on the MNLI dataset (in the style of WARP). To start training:

CUDA_VISIBLE_DEVICES=0 python train_warp_mnli.py --save_prompts_path dir_to_save_prompts --save_classifier_path dir_to_save_classifier 

After training, user should expect the model performance (accuracy) to be 87-89, which matches the original WARP paper results. The curve of training loss and evaluation of validation set from one run is shown below.

Training Loss Evaluation Accuracy on Validation Set
train_loss_curve eval_validation_value

Dev environment

  • Python 3.8.5
  • A-100 GPU, CUDA Version: 11.0
  • Other dependencies: requirements.txt

API

# importing RoBERTa based API
from models.modeling_roberta import WARPPromptedRobertaForMaskedLM, WARPPromptedRobertaForSequenceClassification, WARPPromptedRobertaForQuestionAnswering
# importing Bert based API
from models.modeling_bert import WARPPromptedBertForMaskedLM, WARPPromptedBertForSequenceClassification, WARPPromptedBertForQuestionAnswering
# importing XLM-RoBERTa based API
from models.modeling_roberta import WARPPromptedXLMRobertaForMaskedLM, WARPPromptedXLMRobertaForSequenceClassification, WARPPromptedXLMRobertaForQuestionAnswering
# importing function for randomly masking inputs
from utils import random_mask_input_ids

# initialize model for MNLI task
model = WARPPromptedRobertaForSequenceClassification(
                                                     pretrained_backbone_path = "roberta-large",                 
                                                     n_prompts = 8, 
                                                     seed_token_id_for_prompts_embeddings = 50264, # token id for "<mask>"
                                                     mask_token_id = 50264,
                                                     token_ids_for_classification_head = [1342, 12516, 10800], # 'ent', 'neutral', 'cont'
                                                     pretrained_prompts_path = None,
                                                     pretrained_classifier_path = None
                                                     )
                                                     
# initialize model for masked language modeling (MLM)
model = WARPPromptedRobertaForMaskedLM(
                                         pretrained_backbone_path = "roberta-large",                 
                                         n_prompts = 8, 
                                         seed_token_id_for_prompts_embeddings = 50264,
                                         pretrained_prompts_path = None
                                        )
                                        
# prepad input ids before feeding into model
features = tokenizer([str_1, str_2, ..., str_n], return_tensors='pt', truncation=True, padding=True)
features["input_ids"] = torch.cat([torch.full((features["input_ids"].shape[0], n_prompts), 0), features['input_ids']], 1)

# randomly mask input ids for MLM task
features['input_ids'] = random_mask_input_ids(features['input_ids'], mask_token_id, prob = .15)

# initialize model for question answering (QA)
model = WARPPromptedRobertaForQuestionAnswering(
            pretrained_backbone_path = "roberta-large",
            n_prompts = 4,
            seed_token_id_for_prompts_embeddings = 50264,
            pretrained_prompts_path = None,
            freeze_qa_outputs_layer = False,
        )

Reference

Contact

Please contact Jin Qu if you are interested in collaboration, internship opportunities, or discussions. Feel free to create issues if you discover a bug or want to request new features for future release.

overture's People

Contributors

qjin2016 avatar svc-scm avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Forkers

isabella232

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.