GithubHelp home page GithubHelp logo

summareranker's Introduction

SummaReranker

Source code for the paper SummaReranker: A Multi-Task Mixture-of-Experts Re-ranking Framework for Abstractive Summarization.

Mathieu Ravaut, Shafiq Joty, Nancy F. Chen.

Accepted for publication at ACL 2022.

Setup

1 - Download the code

git clone https://github.com/Ravoxsg/SummaReranker.git
cd SummaReranker

2 - Install the dependencies

conda create --name summa_reranker python=3.8.8
conda activate summa_reranker
pip install -r requirements.txt

Dataset

We use HuggingFace datasets library to access and save each dataset. We save it as .txt file for the sources, and another one for the summaries, with 1 data point per line. For CNN/DM, we save one .txt file for every single data point.

For instance to download and save Reddit:

cd src/candidate_generation/
bash dataset.sh

Note that for Reddit TIFU, we make a custom 80/10/10 train/val/test split.
To match our results on Reddit TIFU, first double check that you have the following:
For training set, size is 33,704 and the first data point summary is:
got a toy train from first grade. used an old hot wheels ramp to fling it into the air and smash my ceiling fan globe.
For the validation set, size is 4,213 and the first data point summary is:
married a redditor. created a reddit account. lost many hours to reddit.
For the test set, size is 4,222 and the first data point summary is:
laughed at baby boner...it turned into a super soaker.

DEMO

If you just want a demo (in a single file) of SummaReranker on a single data point (default: CNN/DM), run:

cd src/summareranker/
CUDA_VISIBLE_DEVICES=0 python demo.py

EVALUATION pipeline (assumes an already trained SummaReranker checkpoint)

1 - Generate summary candidates

SummaReranker takes as input a set of summary candidates from a given sequence-to-sequence model (PEGASUS, BART) and a given decoding method (beam search, diverse beam search, top-p sampling, top-k sampling).

For instance with PEGASUS on Reddit validation set, and with diverse beam search:

CUDA_VISIBLE_DEVICES=0 bash candidate_generation.sh

Generating summary candidates should take a few hours on the validation or test sets of CNN/DM, XSum or Reddit.

Note that for Reddit, you need to fine-tune the model on your training split prior to generating candidates.

2 - Score the candidates

To evaluate SummaReranker, we need to score each summary candidate with all the metrics of interest (ROUGE-1/2/L, BERTScore, BARTScore, etc).

For instance to score PEGASUS diverse beam search candidates on Reddit validation set with ROUGE-1:

CUDA_VISIBLE_DEVICES=0 bash scores.sh

Scoring all candidates should take a few minutes with ROUGE metrics on the validation or test sets of CNN/DM, XSum or Reddit.

3 - Download the model checkpoint

CNN/DM checkpoint (trained on beam search + diverse beam search candidates, for ROUGE-1/2/L metrics): here
XSum checkpoint (trained on beam search + diverse beam search candidates, for ROUGE-1/2/L metrics): here
Reddit checkpoint (trained on beam search + diverse beam search candidates, for ROUGE-1/2/L metrics): here

4 - Run SummaReranker

For instance, to run SummaReranker trained for ROUGE-1/2/L on PEGASUS (beam search + diverse beam search candidates) on Reddit validation set:

cd ../summareranker/
CUDA_VISIBLE_DEVICES=0 bash evaluate.sh

Make sure that the argument --load_model_path points to where you placed the SummaReranker checkpoint.

TRAINING pipeline

1 - Fine-tune base models

For training, SummaReranker follows a cross-validation approach: the training set is split in two, and we train one model on each half, to then infer it and use its predictions on the other half. We also need a third model trained on the entire training set (for the transfer setup at inference time), which we re-train ourselves for Reddit.

For instance with PEGASUS on Reddit:

cd ../base_model_finetuning/
CUDA_VISIBLE_DEVICES=0 bash train_base_models.sh

Note that this single script performs all the tasks of splitting the training set, then training models and on each half, and training a model on the entire set.

For models trained on the entire training set of CNN/DM and XSum, I used the public HuggingFace checkpoints:
CNN/DM: https://huggingface.co/google/pegasus-cnn_dailymail
XSum: https://huggingface.co/google/pegasus-xsum
Here's a link to download the PEGASUS-large fine-tuned on 100% of my training set of Reddit: here

2 - Generate summary candidates

Then, we need to get summary candidates on the training, validation and test sets.

For instance with PEGASUS on Reddit with diverse beam search:

cd ../candidate_generation/
CUDA_VISIBLE_DEVICES=0 bash candidate_generation_train.sh

Generating summary candidates on the entire datasets should take up to a few days.

3 - Score the candidates

Next, we need to score the summary candidates on the training, validation and test sets for each of the metrics.

For instance to score PEGASUS diverse beam search candidates on Reddit with ROUGE-1/2/L:

CUDA_VISIBLE_DEVICES=0 bash scores_train.sh

Scoring all candidates should take a few minutes with ROUGE metrics.

4 - Train SummaReranker

For instance, to train SummaReranker trained for ROUGE-1/2/L on PEGASUS diverse beam search candidates on Reddit:

cd ../summareranker/
CUDA_VISIBLE_DEVICES=0 bash train.sh

Citation

If you find our paper or this project helps your research, please kindly consider citing our paper in your publication.

@article{ravaut2022summareranker,
  title={SummaReranker: A Multi-Task Mixture-of-Experts Re-ranking Framework for Abstractive Summarization},
  author={Ravaut, Mathieu and Joty, Shafiq and Chen, Nancy F},
  journal={arXiv preprint arXiv:2203.06569},
  year={2022}
}

summareranker's People

Contributors

ravoxsg avatar

Stargazers

 avatar Nikita Kuzmin avatar Ke Zhang avatar Jon Chun avatar  avatar Xin Tie avatar Pawel Dziemiach avatar Frank Qi avatar Xi avatar  avatar  avatar  avatar 乐、 avatar  avatar Guang Yang avatar JEEWOO SUL avatar  avatar Lisa Wang avatar Xinnian Liang avatar jinpeng avatar Yotam avatar

Watchers

 avatar

Forkers

ntunlp ardallie

summareranker's Issues

Error when running file candidate_generation.sh

Dear Mr @Ravoxsg ,
I'm trying to reproduce the evaluation result as your suggested steps.
I modified two lines in the file main_candidate_generation.py
Line 7: sys.path.append("/content/SummaReranker/src/") # todo: change to your folder path
Line 49: default = "/content/SummaReranker/models/summareranker_reddit_bs_dbs_rouge_1_2_l/checkpoint-1000/pytorch_model.bin") # todo: change to where you saved the finetuned checkpoint

The command !bash candidate_generation.sh run for a while and then throw the error
Traceback (most recent call last): File "main_candidate_generation.py", line 182, in <module> main(args) File "main_candidate_generation.py", line 155, in main model.load_state_dict(torch.load(args.load_model_path)) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1407, in load_state_dict self.__class__.__name__, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for FTModel: Missing key(s) in state_dict: "pretrained_model.final_logits_bias", "pretrained_model.model.shared.weight", "pretrained_model.model.encoder.embed_tokens.weight", "pretrained_model.model.encoder.embed_positions.weight", "pretrained_model.model.encoder.layers.0.self_attn.k_proj.weight", "pretrained_model.model.encoder.layers.0.self_attn.k_proj.bias", "pretrained_model.model.encoder.layers.0.self_attn.v_proj.weight", "pretrained_model.model.encoder.layers.0.self_attn.v_proj.bias", "pretrained_model.model.encoder.layers.0.self_attn.q_proj.weight", "pretrained_model.model.encoder.layers.0.self_attn.q_proj.bias", "pretrained_model.model.encoder.layers.0.self_attn.out_proj.weight", "pretrained_model.model.encoder.layers.0.self_attn.out_proj.bias", "pretrained_model.model.encoder.layers.0.self_attn_layer_norm.weight", "pretrained_model.model.encoder.layers.0.self_attn_layer_norm.bias", "pretrained_model.model.encoder.layers.0.fc1.weight", "pretrained_model.model.encoder.layers.0.fc1.bias", "pretrained_model.model.encoder.layers.0.fc2.weight", "pretrained_model.model.encoder.layers.0.fc2.bias", "pretrained_model.model.encoder.layers.0.final_layer_norm.weight", "pretrained_model.model.encoder.layers.0.final_layer_norm.bias", "pretrained_model.model.encoder.layers.1.self_attn.k_proj.weight", "pretrained_model.model.encoder.layers.1.self_attn.k_proj.bias", "pretrained_model.model.encoder.layers.1.self_attn.v_proj.weight", "pretrained_model.model.encoder.layers.1.self_attn.v_proj.bias", "pretrained_model.model.encoder.layers.1.self_attn.q_proj.weight", "pretrained_model.model.encoder.layers.1.self_attn.q_proj.bias", "pretrained_model.model.encoder.layers.1.self_attn.out_proj.weight", "pretrained_model.model.encoder.layers.1.self_attn.out_proj.bias", "pretrained_model.model.encoder.layers.1.self_attn_layer_norm.weight", "pretrained_model.model.encoder.layers.1.self_attn_layer_norm.bias", "pretrained_model.model.encoder.layers.1.fc1.weight", "pretrained_model.model.encoder.layers.1.fc1.bias", "pretrained_model.model.encoder.layers.1.fc2.weight", "pretrained_model.model.encoder.layers.1.fc2.bias", "pretrained_model.model.encoder.layers.1.final_layer_norm.weight", ... "pretrained_model.encoder.layer.23.intermediate.dense.weight", "pretrained_model.encoder.layer.23.intermediate.dense.bias", "pretrained_model.encoder.layer.23.output.dense.weight", "pretrained_model.encoder.layer.23.output.dense.bias", "pretrained_model.encoder.layer.23.output.LayerNorm.weight", "pretrained_model.encoder.layer.23.output.LayerNorm.bias", "pretrained_model.pooler.dense.weight", "pretrained_model.pooler.dense.bias".

image

Can you please check this issue soon.
Thank you.

Could you provide a detailed instruction

Hello, congratulations on your work being accepted for ACL 2022!

I want to follow your work and reproduce the results. However, the script you provide didn't run successfully in my environment. There seems to be some local paths, and the location and format of the dataset are not particularly clear.

Could you offer some help and wish you a happy life!

How to train the re-ranker model?

Dear Mr. @Ravoxsg ,
In the readme, you only provide the pre-trained model checkpoints to reproduce the paper result. I couldn't find the section about how to train the re-ranker model to get that checkpoints. Could you please show me how to train the re-ranker model and save the pre-trained checkpoint file for next evaluation?

Thank you so much!

Experiment set up

Hi, Thanks for your great work.
I am curious about the 3.3 Tackling Training and Inference Gap part, you split the training data 2-fold and cross generate the data in the other half. So, in theory, if you split your training data into more parts (i.e. N-fold with a large N), the distribution of training set for ranking is more close to that of test set. Have you ever tried such experiments ? Why just choosing 2-fold split ?

About the generate model

First of all, thank you very much for your help. I admire your work very much. When I replicated your model, I found that the generative model for fine-tuning on the dataset was missing when generating the candidate set. However, you have only published the reranker model, would you mind sharing your fine-tuned generative model?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.