GithubHelp home page GithubHelp logo

isabella232 / trans-encoder Goto Github PK

View Code? Open in Web Editor NEW

This project forked from amzn/trans-encoder

0.0 0.0 0.0 57 KB

Trans-Encoder: Unsupervised sentence-pair modelling through self- and mutual-distillations

License: Apache License 2.0

Shell 0.83% Python 99.17%

trans-encoder's Introduction


Trans-Encoder

[arxiv] · [amazon.science blog] · [5min-video] · [talk@RIKEN] · [openreview]

Code repo for ICLR 2022 paper Trans-Encoder: Unsupervised sentence-pair modelling through self- and mutual-distillations
by Fangyu Liu, Yunlong Jiao, Jordan Massiah, Emine Yilmaz, Serhii Havrylov.

Trans-Encoder is a state-of-the-art unsupervised sentence similarity model. It conducts self-knowledge-distillation on top of pretrained language models by alternating between their bi- and cross-encoder forms.

Huggingface pretrained models for STS

base models large models
model STS avg.
baseline: unsup-simcse-bert-base 76.21
trans-encoder-bi-simcse-bert-base 80.41
trans-encoder-cross-simcse-bert-base 79.90
baseline: unsup-simcse-roberta-base 76.10
trans-encoder-bi-simcse-roberta-base 80.47
trans-encoder-cross-simcse-roberta-base 81.15
model STS avg.
baseline: unsup-simcse-bert-large 78.42
trans-encoder-bi-simcse-bert-large 82.65
trans-encoder-cross-simcse-bert-large 82.52
baseline: unsup-simcse-roberta-large 78.92
trans-encoder-bi-simcse-roberta-large 82.93
trans-encoder-cross-simcse-roberta-large 82.93

Dependencies

torch==1.8.1
transformers==4.9.0
sentence-transformers==2.0.0

Please view requirements.txt for more details.

Data

All training and evaluation data will be automatically downloaded when running the scripts. See src/data.py for details.

Train

--task options: sts (STS2012-2016 and STS-b), sickr, sts_sickr (STS2012-2016, STS-b, and SICK-R), qqp, qnli, mrpc, snli, custom. See src/data.py for task data details. By default using all STS data (sts_sickr).

Self-distillation

>> bash train_self_distill.sh 0

0 denotes GPU device index.

Mutual-distillation

>> bash train_mutual_distill.sh 0,1

Two GPUs needed; by default using SimCSE BERT & RoBERTa base models for ensembling. Add --use_large for switching to large models.

Train with your custom corpus

>> CUDA_VISIBLE_DEVICES=0,1 python src/mutual_distill_parallel.py \
         --batch_size_bi_encoder 128 \
         --batch_size_cross_encoder 64 \
         --num_epochs_bi_encoder 10 \
         --num_epochs_cross_encoder 1 \
         --cycle 3 \
         --bi_encoder1_pooling_mode cls \
         --bi_encoder2_pooling_mode cls \
         --init_with_new_models \
         --task custom \
         --random_seed 2021 \
         --custom_corpus_path CORPUS_PATH

CORPUS_PATH should point to your custom corpus in which every line should be a sentence pair in the form of sent1||sent2.

Evaluate

Evaluate a single model

Bi-encoder:

>> python src/eval.py \
--model_name_or_path "cambridgeltl/trans-encoder-bi-simcse-roberta-large"  \
--mode bi \
--task sts_sickr

Cross-encoder:

>> python src/eval.py \
--model_name_or_path "cambridgeltl/trans-encoder-cross-simcse-roberta-large"  \
--mode cross \
--task sts_sickr

Evaluate ensemble

Bi-encoder:

>> python src/eval.py \
--model_name_or_path1 "cambridgeltl/trans-encoder-bi-simcse-bert-large"  \
--model_name_or_path2 "cambridgeltl/trans-encoder-bi-simcse-roberta-large"  \
--mode bi \
--ensemble \
--task sts_sickr

Cross-encoder:

>> python src/eval.py \
--model_name_or_path1 "cambridgeltl/trans-encoder-cross-simcse-bert-large"  \
--model_name_or_path2 "cambridgeltl/trans-encoder-cross-simcse-roberta-large"  \
--mode cross \
--ensemble \
--task sts_sickr

Authors

Security

See CONTRIBUTING for more information.

License

This project is licensed under the Apache-2.0 License.

trans-encoder's People

Contributors

hardyqr avatar yunlongjiao avatar jnkm avatar amazon-auto avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.