GithubHelp home page GithubHelp logo

mevi's Introduction

[NeurIPS 2023] Model-enhanced Vector Index (Paper)

Environment

[Option 1] Create conda environment:

conda env create -f environment.yml
conda activate mevi

[Option 2] Use docker:

docker pull hugozhl/nci:latest

MSMARCO Passage

Data Process

[1] Download and preprocess:

bash dataprocess/msmarco_passage/download_data.sh
python dataprocess/msmarco_passage/prepare_origin.py \
	--data_dir data/marco --origin

[2] Tokenize documents:

# tokenize for T5-ANCE and AR2

# T5-ANCE
python dataprocess/msmarco_passage/prepare_passage_tokenized.py \
	--output_dir data/marco/ance \
	--document_path data/marco/raw/corpus.tsv \
	--dataset marco --model ance
rm data/marco/ance/all_document_indices_*.pkl
rm data/marco/ance/all_document_tokens_*.bin
rm data/marco/ance/all_document_masks_*.bin

# AR2
python dataprocess/msmarco_passage/prepare_passage_tokenized.py \
	--output_dir data/marco/ar2 \
	--document_path data/marco/raw/corpus.tsv \
	--dataset marco --model ar2
rm data/marco/ar2/all_document_indices_*.pkl
rm data/marco/ar2/all_document_tokens_*.bin
rm data/marco/ar2/all_document_masks_*.bin

[3] Query generation for augmentation:

We used the docT5query checkpoint as in NCI. The QG data is only for training.

Please download the finetuned docT5query ckpt to data/marco/ckpts/doc2query-t5-base-msmarco

# MUST download the finetuned docT5query ckpt before running the scripts
python dataprocess/msmarco_passage/doc2query.py --data_dir data/marco
# if the qg data has bad quality, e.g. empty query or many duplicate queries, add another script below
python dataprocess/msmarco_passage/complement_qg10.py --data_dir data/marco # Optional

[4] Generate document embeddings and construct RQ

For T5-ANCE, please download T5-ANCE checkpoint to data/marco/ckpts/t5-ance.

For AR2, please download AR2 checkpoint to data/marco/ckpts/ar2g_marco_finetune.pkl and coCondenser checkpoint to data/marco/ckpts/co-condenser-marco-retriever

# MUST download the checkpoints before running the scripts
export DOCUMENT_ENCODER=ance
# export DOCUMENT_ENCODER=ar2 # use this line for ar2
bash MEVI/marco_generate_embedding_n_rq.sh

Training

Train the RQ-based NCI.

export DOCUMENT_ENCODER=ance
# export DOCUMENT_ENCODER=ar2 # use this line for ar2
export WANDB_TOKEN="your wandb token"
bash MEVI/marco_train_nci_rq.sh

Twin-tower Model Evaluation

First generate query embeddings.

# for T5-ANCE
python MEVI/generate.py \
	--query_file data/marco/origin/dev_mevi_dedup.tsv \
	--model_path data/marco/ckpts/t5-ance \
	--tokenizer_path data/marco/ckpts/t5-ance \
	--query_embedding_path data/marco/ance/query_emb.bin \
	--gpus 0,1,2,3,4,5,6,7 --gen_query

# for AR2
python MEVI/generate.py \
	--query_file data/marco/origin/dev_mevi_dedup.tsv \
	--model_path data/marco/ckpts/ar2g_marco_finetune.pkl \
	--tokenizer_path bert-base-uncased \
	--query_embedding_path data/marco/ar2/query_emb.bin \
	--gpus 0,1,2,3,4,5,6,7 --gen_query

Then use faiss for ANN search.

# for T5-ANCE; if for AR2, change the ance directory to ar2 directory
python MEVI/faiss_search.py \
	--query_path data/marco/ance/query_emb.bin \
	--doc_path data/marco/ance/docemb.bin \
	--output_path data/marco/ance/hnsw256.txt \
	--raw_query_path data/marco/origin/dev_mevi_dedup.tsv \
	--param HNSW256

Sequence-to-sequence Model Evaluation

Please download our checkpoint for MSMARCO Passage or train from scratch before evaluation, and put the checkpoint in data/marco/ckpts. If using the downloaded checkpoint, please also download the corresponding RQ files.

# MUST download or train a ckpt before running the scripts
export DOCUMENT_ENCODER=ance
# export DOCUMENT_ENCODER=ar2 # use this line for ar2
bash MEVI/marco_eval_nci_rq.sh

Ensemble

Ensemble the results from the twin-tower model and the sequence-to-sequence model.

export DOCUMENT_ENCODER=ance
# export DOCUMENT_ENCODER=ar2 # use this line for ar2
bash MEVI/marco_ensemble.sh

Natural Questions (DPR version)

Data Process

[1] Download and preprocess:

bash dataprocess/NQ_dpr/download_data.sh
python dataprocess/NQ_dpr/preprocess.py --data_dir data/nq_dpr

[2] Tokenize documents:

# use AR2
python dataprocess/NQ_dpr/tokenize_passage_ar2.py \
	--output_dir data/nq_dpr \
	--document_path data/nq_dpr/corpus.tsv
rm data/nq_dpr/all_document_indices_*.pkl
rm data/nq_dpr/all_document_tokens_*.bin
rm data/nq_dpr/all_document_masks_*.bin

[3] Query generation for augmentation:

We used the docT5query checkpoint as in NCI. The QG data is only for training. Please refer to the QG section for MSMARCO Passage.

# download finetuned docT5query ckpt to data/marco/ckpts/doc2query-t5-base-msmarco
python dataprocess/NQ_dpr/doc2query.py \
	--data_dir data/nq_dpr --n_gen_query 1 \
	--ckpt_path data/marco/ckpts/doc2query-t5-base-msmarco

[4] Generate document embeddings and construct RQ

Please download AR2 checkpoint to data/marco/ckpts/ar2g_nq_finetune.pkl and ERNIE checkpoint to data/marco/ckpts/ernie-2.0-base-en

# MUST download the checkpoints before running the scripts
bash MEVI/nqdpr_generate_embedding_n_rq.sh

[5] Tokenize query

Since NQ has too many augmented queries, to eliminate runtime memory usage, we tokenize query to enable memmap.

python dataprocess/NQ_dpr/tokenize_query.py \
	--output_dir data/nq_dpr \
	--tok_train 1 --tok_corpus 1 --tok_qg 1

[6] Get answers

We sort the answers for fast evaluation. (Time-consuming! Please download the processed binary files if necessary.)

python dataprocess/NQ_dpr/get_answers.py \
	--data_dir data/nq_dpr \
	--dev 1 --test 1
python dataprocess/NQ_dpr/get_inverse_answers.py \
	--data_dir data/nq_dpr \
	--dev 1 --test 1

Training

Train the RQ-based NCI.

export WANDB_TOKEN="your wandb token"
bash MEVI/nqdpr_train_nci_rq.sh

Twin-tower Model Evaluation

First generate query embeddings.

python MEVI/generate.py \
	--query_file data/nq_dpr/nq-test.qa.csv \
	--model_path data/marco/ckpts/ar2g_nq_finetune.pkl \
	--tokenizer_path bert-base-uncased \
	--query_embedding_path data/nq_dpr/query_emb.bin \
	--gpus 0,1,2,3,4,5,6,7 --gen_query

Then use faiss for ANN search.

python MEVI/faiss_search.py \
	--query_path data/nq_dpr/query_emb.bin \
	--doc_path data/nq_dpr/docemb.bin \
	--output_path data/nq_dpr/hnsw256.txt \
	--raw_query_path data/nq_dpr/nq-test.qa.csv \
	--param HNSW256

Sequence-to-sequence Model Evaluation

Please download our checkpoint for NQ or train from scratch before evaluation, and put the checkpoint in data/marco/ckpts. If using the downloaded checkpoint, please also download the corresponding RQ files.

# MUST download or train a ckpt before running the scripts
bash MEVI/nqdpr_eval_nci_rq.sh

Ensemble

Ensemble the results from the twin-tower model and the sequence-to-sequence model.

bash MEVI/nqdpr_ensemble.sh

Citation

If you find this work useful, please cite our paper.

Acknowledgement

We learned a lot and borrowed some codes from the following projects when building MEVI.

mevi's People

Contributors

hugozhl avatar

Stargazers

flr avatar hahaha avatar  avatar Kilim Choi avatar Jiahui Du avatar  avatar Albert Yan avatar jbkim avatar Jeff Carpenter avatar Tongyao Zhu avatar Yan avatar Sandalots avatar Feng Chen avatar 爱可可-爱生活 avatar  avatar  avatar Marc Romeyn avatar Amund Tveit avatar Sunkyung Lee avatar Ruiheng Chang avatar

Watchers

 avatar

mevi's Issues

Question about K means number of cluster?

Hello, thank you very much for this work, it's fantastic.
There is one problem I would like to ask you.
At the time of reproduction, when the document is subjected to multiple levels of K Means clustering, the final number of clusters ranges from 5K-20W (the ckpt you provided can cluster up to 70W or more) .
The K means clustering results seem to be very unstable(number of culster), what do you think is the reason?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.