GithubHelp home page GithubHelp logo

bert-kpe's Introduction

BERT for Keyphrase Extraction (PyTorch)

This repository provides the code of the paper Capturing Global Informativeness in Open Domain Keyphrase Extraction.

In this paper, we conduct an empirical study of 5 keyphrase extraction models with 3 BERT variants, and then propose a multi-task model BERT-JointKPE. Experiments on two KPE benchmarks, OpenKP with Bing web pages and KP20K demonstrate JointKPE’s state-of-the-art and robust effectiveness. Our further analyses also show that JointKPE has advantages in predicting long keyphrases and non-entity keyphrases, which were challenging for previous KPE techniques.

Please cite our paper if our experimental results, analysis conclusions or the code are helpful to you ~ 😊

@article{sun2020joint,
    title={Joint Keyphrase Chunking and Salience Ranking with BERT},
    author={Si Sun, Zhenghao Liu, Chenyan Xiong, Zhiyuan Liu and Jie Bao},
    year={2020},
    eprint={2004.13639},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

CONTACT

For any question, feel free to create an issue, and we will try our best to solve.
If the problem is more urgent, you can send an email to me at the same time (I check email almost everyday 😉).

NAME: Si Sun
EMAIL: [email protected]

🤠 What's New ?

  • 2020/9/5

    Compared with the OpenKP dataset we downloaded from MS MARCO in October of 2019 (all our experiments are based on this version of the dataset), we found that the dataset has been updated. We remind you to download the latest data from the official website. For comparison, we also provide the data version we use. (The dataset version issue was raised by Yansen Wang et al from CMU, thank them ! )

    ~DownLoad from Here or ~Email [email protected] for Data

  • 2021/12/7

    Our repo now adds Multilingual-KPE and FP16 Training Mode. Thanks, Amit Chaulwar! Amit also shared their zero-shot results on the Wikinews (French), Cacic (Spanish), Pak2018 (Polish), wicc (spanish), 110-PT-BN-KP (Portugese).

Spported Model Classes

Index Model Descriptions
1 BERT-JointKPE (Bert2Joint) A multi-task model is trained jointly on the chunking task and the ranking task, balancing the estimation of keyphrase quality and salience.
2 BERT-RankKPE (Bert2Rank) Learn the salience phrases in the documents using a ranking network.
3 BERT-ChunkKPE (Bert2Chunk) Classify high quality keyphrases using a chunking network.
4 BERT-TagKPE (Bert2Tag) We modified the sequence tagging model to generate enough candidate keyphrases for a document.
5 BERT-SpanKPE (Bert2Span) We modified the span extraction model to extract multiple keyphrases from a document.
6 DistilBERT-JointKPE (DistilBert2Joint) A multi-task model is trained jointly on the chunking task and the ranking task, balancing the estimation of keyphrase quality and salience.

BERT Variants Tested

Requirements

python 3.8
pytorch 1.9.0
pip install -r pip-requirements.txt

QUICKSTART

1/ Download

  • First download and decompress our data folder to this repo, the folder includes benchmark datasets and pre-trained BERT variants.

  • We also provide 15 checkpoints (5 KPE models * 3 BERT variants) trained on OpenKP training dataset.

2/ Preprocess

  • To preprocess the source datasets using preprocess.sh in the preprocess folder:

    source preprocess.sh
    
  • Optional arguments:

    --dataset_class         choices=['openkp', 'kp20k', 'multidata]
    --source_dataset_dir    The path to the source dataset
    --output_path           The dir to save preprocess data; default: ../data/prepro_dataset
    
  • To preprocess the multilingual dataset, download respective datasets from https://github.com/LIAAD/KeywordExtractor-Datasets and use scripts jsonify_multidata.py to preprocess the datasets. The dataset can be split into train, dev, and test sets using split_json.py.

3/ Train Models

  • To train a new model from scratch using train.sh in the scripts folder:

    source train.sh
    

    PS. Running the training script for the first time will take some time to perform preprocess such as tokenization, and by default, the processed features will be saved under ../data/cached_features, which can be directly loaded next time.

  • Optional arguments:

    --dataset_class         choices=['openkp', 'kp20k', 'multidata']
    --model_class           choices=['bert2span', 'bert2tag', 'bert2chunk', 'bert2rank', 'bert2joint']
    --pretrain_model_type   choices=['bert-base-cased', 'spanbert-base-cased', 'roberta-base', 'distilbert-base-cased']
    

    Complete optional arguments can be seen in config.py in the scripts folder.

  • Training Parameters:

    We always keep the following settings in all our experiments:

    args.warmup_proportion = 0.1
    args.max_train_steps = 20810 (openkp) , 73430 (kp20k)
    args.per_gpu_train_batch_size * max(1, args.n_gpu) * args.gradient_accumulation_steps = 64
    
  • Distributed Training

    We recommend using DistributedDataParallel to train models on multiple GPUs (It's faster than DataParallel, but it will take up more memory)

    CUDA_VISIBLE_DEVICES=0,1 OMP_NUM_THREADS=2 python -m torch.distributed.launch --nproc_per_node=2 --master_port=1234 train.py
    # if you use DataParallel rather than DistributedDataParallel, remember to set --local_rank=-1
    

4/ Inference

  • To evaluate models using trained checkpoints using test.sh in the scripts folder:

    source test.sh
    
  • Optional arguments:

    --dataset_class         choices=['openkp', 'kp20k', 'multidata']
    --model_class           choices=['bert2span', 'bert2tag', 'bert2chunk', 'bert2rank', 'bert2joint']
    --pretrain_model_type   choices=['bert-base-cased', 'spanbert-base-cased', 'roberta-base', 'distilbert-base-cased']
    --eval_checkpoint       The checkpoint file to be evaluated
    

5/ Re-produce evaluation results using our checkpoints

  • Run test.sh, and change the eval_checkpoint to the checkpoint files we provided to reproduce the following results.

    --dataset_class         openkp
    --eval_checkpoint       The filepath of our provided checkpoint
    

* RESULTS

The following results are ranked by F1@3 on OpenKP Dev dataset, the eval results can be seen in the OpenKP Leaderboard.

* BERT (Base)

Rank Method F1 @1,@3,@5 Precision @1,@3,@5 Recall @1,@3,@5
1 Bert2Joint 0.371, 0.384, 0.326 0.504, 0.313, 0.227 0.315, 0.555, 0.657
2 Bert2Rank 0.369, 0.381, 0.325 0.502, 0.311, 0.227 0.315, 0.551, 0.655
3 Bert2Tag 0.370, 0.374, 0.318 0.502, 0.305, 0.222 0.315, 0.541, 0.642
4 Bert2Chunk 0.370, 0.370, 0.311 0.504, 0.302, 0.217 0.314, 0.533, 0.627
5 Bert2Span 0.341, 0.340, 0.293 0.466, 0.277, 0.203 0.289, 0.492, 0.593

* SpanBERT (Base)

Rank Method F1 @1,@3,@5 Precision @1,@3,@5 Recall @1,@3,@5
1 Bert2Joint 0.388, 0.393, 0.333 0.527, 0.321, 0.232 0.331, 0.567, 0.671
2 Bert2Rank 0.385, 0.390, 0.332 0.521, 0.319, 0.232 0.328, 0.564, 0.666
3 Bert2Tag 0.384, 0.385, 0.327 0.520, 0.315, 0.228 0.327, 0.555, 0.657
4 Bert2Chunk 0.378, 0.385, 0.326 0.514, 0.314, 0.228 0.322, 0.555, 0.656
5 Bert2Span 0.347, 0.359, 0.304 0.477, 0.294, 0.212 0.293, 0.518, 0.613

* RoBERTa (Base)

Rank Method F1 @1,@3,@5 Precision @1,@3,@5 Recall @1,@3,@5
1 Bert2Joint 0.391, 0.398, 0.338 0.532, 0.325, 0.235 0.334, 0.577, 0.681
2 Bert2Rank 0.388, 0.395, 0.335 0.526, 0.322, 0.233 0.330, 0.570, 0.677
3 Bert2Tag 0.387, 0.389, 0.330 0.525, 0.318, 0.230 0.329, 0.562, 0.666
4 Bert2Chunk 0.380, 0.382, 0.327 0.518, 0.312, 0.228 0.324, 0.551, 0.660
5 Bert2Span 0.358, 0.355, 0.306 0.487, 0.289, 0.213 0.304, 0.513, 0.619

MODEL OVERVIEW

* BERT-JointKPE, RankKPE, ChunkKPE (See Paper)

* BERT-TagKPE (See Code)

  • Word-Level Representations : We encode an input document into a sequence of WordPiece tokens' vectors with a pretrained BERT (or its variants), and then we pick up the first sub-token vector of each word to represent the input in word-level.

  • Phrase-Level Representations : We perform a soft-select method to decode phrase from word-level vector instead of hard-select used in the standard sequence tagging task .

    The word-level representation is feed into an classification layer to obtain the tag probabilities of each word on 5 classes (O, B, I, E, U) , and then we employ different tag patterns for extracting different n-grams ( 1 ≤ n ≤ 5 ) over the whole sequence.

    Last there are a collect of n-gram candidates, each word of the n-gram just has one score.

    Soft-select Example : considering all 3-grams (B I E) on the L-length document, we can extract (L-3+1) 3-grams sequentially like sliding window. In each 3-gram, we only keep B score for the first word, I score for the middle word, and E score for the last word, etc.

    O : Non Keyphrase ; B : Begin word of the keyprase ; I : Middle word of the keyphrase ; E : End word of keyprhase ; U : Uni-word keyphrase

  • Document-Level Keyphrase : At the Last stage, the recovering from phrase-level n-grams to document-level keyphrases can be naturally formulated as a ranking task.

    Incorporating with term frequency, we employ Min Pooling to get the final score of each n-gram (we called it Buckets Effect: No matter how high a bucket, it depends on the height of the water in which the lowest piece of wood) . Based on the final scores, we extract 5 top ranked keyprhase candidates for each document.

* BERT-SpanKPE (See Code)

  • Word-Level Representations : Same as BERT-TagKPE

  • Phrase-Level Representations : Traditional span extraction model could not extract multiple important keyphrase spans for the same document. Therefore, we propose an self-attention span extraction model.

    Given the token representations {t1, t2, ..., tn}, we first calculate the probability that the token is the starting word Ps(ti), and then apply the single-head self-attention layer to calculate the ending word probability of all j>=i tokens Pe(tj).

  • Document-Level Keyphrase : We select the spans with the highest probability P = Ps(ti) * Pe(tj) as the keyphrase spans.

bert-kpe's People

Contributors

amitchaulwar avatar edwardzh avatar sunsishining avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

bert-kpe's Issues

loss function is wrong

File "G:\BERT-KPE\scripts\train.py", line 52, in train
loss = model.update(step, inputs, scaler)
File "G:\BERT-KPE\scripts\model.py", line 86, in update
loss = self.network(**inputs)
File "G:\BERT-KPE\venv\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "G:\BERT-KPE\scripts..\bertkpe\networks\Roberta2Joint.py", line 356, in forward
Rank_Loss_Fct(
File "G:\BERT-KPE\venv\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "G:\BERT-KPE\venv\lib\site-packages\torch\nn\modules\loss.py", line 1333, in forward
return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction)
File "G:\BERT-KPE\venv\lib\site-packages\torch\nn\functional.py", line 3328, in margin_ranking_loss
raise RuntimeError(
RuntimeError: margin_ranking_loss : All input tensors should have same dimension but got sizes: input1: torch.Size([2, 1]), input2: torch.Size([1, 307]), target: torch.Size([1])

变量 max_diff_gram_num 的含义

hey,孙博士好。近期我在研究KPE模型源码时,对 batchify_XXX_XXX 函数中的 变量max_diff_gram_num 有些疑惑,以 /bertkpe/dataloader)/bert2joint_dataloader.py 为例:

line 482: max_diff_gram_num = (1 + max([max(_mention_mask[-1]) for _mention_mask in mention_mask]))

这边的 max_diff_gram_num 是不是表示在 一个batch的文章中,候选关键词集合的元素数的最大值?

如果是的话,这行代码是不是可以等价于 max(phrase_list_lens) ? 因为 phrase_list_lens 即是每篇文章的候选关键词集合中包含的元素数。

micro-f1 or macro-f1 ?

hi , in

# Micro-Averaged Method

I found that you annotated with " # Micro-Averaged Method ". But your approach seems more in line with macro calculation method(Calculate the f1 of each sample first, and then calculate the average of those as the total f1.). So, which indicator do you use in your final result?

logger.info("F1:{}".format(np.mean(f1_scores[i])))


hi~ 发现您在kp20k_evaluator.py中注释了 " # Micro-Averaged Method ",但您的计算方法看起来更符合macro-f1的定义(先计算每个样本的f1,再求平均得到总体的f1),请问您最终的结果用得是macro-f1还是micro-f1呢?

Installing the requirements.txt file does not work

Hi, thanks for providing the code!
Unfortunately, I cannot install all the requirements with your provided command

python 3.5
pip install -r requirements.txt

Do you really use this outdated Python version?

The GitHub repository -e git+https://github.com/xaynetwork/xayn_ai_research.git@23d366ff8a05eca164718a6857eb31d439d52448#egg=xain_ai_research does not exist.

There is a version conflict of allennlp and transformers

The conflict is caused by:
    The user requested transformers==4.12.3
    allennlp 2.5.0 depends on transformers<4.7 and >=4.1

Unable to use checkpoints for inference

I was trying to use the checkpoints provided for inference on openkp dataset, but I am getting this error for bert-base-cased:

RuntimeError: Error(s) in loading state_dict for BertForChunkTFRanking:                                                         
Missing key(s) in state_dict: "bert.embeddings.position_ids".

For roberta-base:

RuntimeError: Error(s) in loading state_dict for RobertaForChunkTFRanking:                                                      
Missing key(s) in state_dict: "roberta.embeddings.position_ids".

I am using transformers==4.12.3 and pytorch==1.8 as mentioned in the requirements.txt file.

When I ran the code, some errors occurred.

image
image
true_score.unsqueezez(-1) and neg_score.unsqueeze(0) have different sizes. But margin_ranking_loss requires that all input tensors have same dimension. I think there is the reason why errors occurred.

Using kp20k dataset

dataloaders code of your own are fitted with OpenKP.
How about kp20k for tutorials?

Error while running test.sh

Hello, thank you for this great work and for sharing your code. I am trying to run the script test.sh using your provided dataset and pre-trained model checkpoints, for bert2joint. I get the following error:

Traceback (most recent call last):
  File "test.py", line 276, in <module>
    dev_candidate = candidate_decoder(args, dev_data_loader, dev_dataset, model, test_input_refactor, pred_arranger, 'dev')
  File "test.py", line 152, in bert2rank_decoder
    for step, batch in enumerate(tqdm(data_loader)):
  File "/home/ec2-user/anaconda3/envs/bert-kpe/lib/python3.5/site-packages/tqdm/std.py", line 1165, in __iter__
    for obj in iterable:
  File "/home/ec2-user/anaconda3/envs/bert-kpe/lib/python3.5/site-packages/torch/utils/data/dataloader.py", line 346, in __next__
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/ec2-user/anaconda3/envs/bert-kpe/lib/python3.5/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/ec2-user/anaconda3/envs/bert-kpe/lib/python3.5/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "../bertkpe/dataloader/loader_utils.py", line 104, in __getitem__
    self.tokenizer, self.mode, self.max_phrase_words)
  File "../bertkpe/dataloader/bert2joint_dataloader.py", line 229, in bert2joint_converter
    src_tensor = torch.LongTensor(tokenizer.convert_tokens_to_ids(src_tokens))
TypeError: an integer is required (got type NoneType)

Please let me know if you have any suggestions on how to fix this. Thank you.

I am using Python 3.5.5 and huggingface transformers-2.5.1.

Model loading failed

We used test.sh to load model in checkpoints/bert2joint/bert2joint.openkp.bert.checkpoint and encountered the following error. According to error info, we guess that the model you provided may be inappropriate. I hope you can check and provide a model that can be used directly. Thank you.

Some weights of the model checkpoint at ../data/pretrain_model/bert-base-cased were not used when initializing BertForTFRanking: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForTFRanking from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTFRanking from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTFRanking were not initialized from the model checkpoint at ../data/pretrain_model/bert-base-cased and are newly initialized: ['cnn2gram.cnn_list.4.weight', 'cnn2gram.cnn_list.1.bias', 'classifier.bias', 'cnn2gram.cnn_list.3.weight', 'cnn2gram.cnn_list.4.bias', 'cnn2gram.cnn_list.0.weight', 'classifier.weight', 'cnn2gram.cnn_list.2.weight', 'cnn2gram.cnn_list.2.bias', 'cnn2gram.cnn_list.3.bias', 'cnn2gram.cnn_list.1.weight', 'cnn2gram.cnn_list.0.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Traceback (most recent call last):
  File "test.py", line 327, in <module>
    args.eval_checkpoint, args
  File "/home/smji/BERT-KPE-master/scripts/model.py", line 215, in load_checkpoint
    model = KeyphraseSpanExtraction(args, state_dict)
  File "/home/smji/BERT-KPE-master/scripts/model.py", line 35, in __init__
    self.network.load_state_dict(state_dict)
  File "/home/smji/anaconda3/envs/bert_kpe_up/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BertForTFRanking:
	Missing key(s) in state_dict: "bert.embeddings.position_ids".

Increasing max_phrase_words runs into memory issues

Hi, first of all, thank you very much for the repository! :)

I want to retrain the model using a larger number of keyphrases and output longer keyphrases in general.

To achieve this I:

  • increase the number of max_phrase_words from 5 to 10 in scripts/config.py
  • increase max_gram from 5 to 10 parameter in my model (in bertkpe/networks/)

However, I see that every number bigger than 5 makes me run out of memory during the step
"start preparing (train) features for bert2joint (bert) ..."
I can also see that if I increase the numbers to 6, I run out of memory at a much later stage in the preparation step than if I increase it to something higher like 10, even though the operation is performed on batches. I suspect that there is a memory leak in one of the data loader functions.

API for key phrase extraction inference

Currently the provided code (test.py) evaluates a pretrained checkpoint on datasets such as OpenKP. Is there an API where the input is some arbitrary text and the outputs are the extracted keywords, or tutorials on how to adapt the current code to do that? Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.