GithubHelp home page GithubHelp logo

reatt's Introduction

Retrieval as Attention (ReAtt)

This repository contains the code, models, and data for the paper Retrieval as Attention: End-to-end Learning of Retrieval and Reading within a Single Transformer by Zhengbao Jiang*, Luyu Gao*, Jun Araki, Haibo Ding, Zhiruo Wang, Jamie Callan, Graham Neubig.

Overview

ReAtt is a T5-based retrieval-augmented model for knowledge-intensive tasks (e.g., QA) that performs retrieval using attention and learns retrieval and reading in a fully end-to-end way only relying on end-task annotations.

ReAtt

Install environment with Conda

Create a conda env with the name reatt using ./setup.sh.

Quick start

Models

Download pre-built embeddings

Download the pre-built embeddings from Google Drive. You can download it programmatically with gdrive using gdrive download -r 1NXmWudqhHaS32Ebr0ch-f8Ioymcy0xyM.

Retrieval and generation on FiQA dataset

from transformers import AutoTokenizer
from beir.datasets.data_loader import GenericDataLoader
from reatt.model import ReAttConfig, ReAttForConditionalGeneration
from reatt.data import Dataset

model_name = 'neulab/reatt-large-nq-fiqa'
retrieval_corpus = 'reatt_download/reatt-large-nq-fiqa/fiqa'
fiqa_data = 'reatt_download/fiqa'
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = ReAttConfig.from_pretrained(model_name, retrieval_corpus=retrieval_corpus)
model = ReAttForConditionalGeneration.from_pretrained(model_name, config=config).cuda()

question = 'What is considered a business expense on a business trip?'
encoded = tokenizer.batch_encode_plus(
  [Dataset.get_question(question)],
  max_length=128,
  padding=True,
  truncation=True,
  return_tensors='pt')
encoded = {k: v.cuda() for k, v in encoded.items()}

# === only retrieve ===
rank = model.encoder.retriever.retrieve(**encoded)[0]  # top-100 doc tuples <doc_id, score>
corpus, queries, qrels = GenericDataLoader(data_folder=fiqa_data).load(split='test')
print(corpus[rank[0][0]]['text'], rank[0][1], sep='\n')  # content of the top-1 doc and its score
# Food is almost never a valid expense. Reason for it is simple - if you were not conducting business you would have to eat too. ...
# 224.07415771484375

# === retrieve then generate ===
prediction = model.generate(**encoded, search_kwargs={'doc_topk': 10, 'max_length': 512}, max_length=512)[0]
print(tokenizer.decode(prediction, skip_special_tokens=True))
# "It depends on what the ""true"" reason for the trip is. If you decide to deduct the trip as a business expense ...

Experiments

Inference: retrieval

Save embeddings of all document tokens from dataset using model.

python inference.py \
  --task index \
  --model neulab/reatt-large-nq-fiqa \
  --dataset reatt_download/fiqa \
  --output output/fiqa \
  --max_context_len 512

Load model and embeddings from retireval_corpus, retrieve top-100 documents for queries from the dataset, and compute retrieval metrics (nDCG, MAP, recall precision).

python inference.py \
  --task retrieve \
  --model neulab/reatt-large-nq-fiqa \
  --retireval_corpus reatt_download/reatt-large-nq-fiqa/fiqa \
  --dataset reatt_download/fiqa \
  --doc_topk 100 \
  --max_query_len 128

Inference: retrieval-augmented generation

Load model and embeddings from retireval_corpus, retrieve top-10 documents for queries from the dataset, and generate answers.

python inference.py \
  --task generate \
  --model neulab/reatt-large-nq-fiqa \
  --retireval_corpus reatt_download/reatt-large-nq-fiqa/fiqa \
  --dataset reatt_download/fiqa \
  --doc_topk 10 \
  --max_query_len 128 \
  --max_context_len 512 \
  --max_generation_len 512

Reference

@inproceedings{jiang-etal-2022-reatt,
    title = {Retrieval as Attention: End-to-end Learning of Retrieval and Reading within a Single Transformer},
    author = {Zhengbao Jiang and Luyu Gao and Jun Araki and Haibo Ding and Zhiruo Wang and Jamie Callan and Graham Neubig},
    booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
    address = {Abu Dhabi, UAE},
    month = {December},
    year = {2022}
}

reatt's People

Contributors

jzbjyb avatar

Stargazers

Jose Cohenca avatar 허찬회 avatar John JongYoon Kim avatar Jintai Chen avatar Yeong-Joon Ju (주영준) avatar Martin Toman avatar Jerry Yin avatar Jiaming Tang avatar Dacheng Li avatar Zexuan Qiu avatar Lau Van Kiet avatar cipo lee avatar Jack Morris avatar Austin Ray avatar Albert Sun avatar  avatar Nguyễn Khôi Nguyên avatar Valentin Liévin avatar Jeff Hammerbacher avatar Daniel Buades Marcos avatar maginahuang avatar jjacob avatar Puxuan Yu avatar  avatar David Lee avatar  avatar Nikolaus Schlemm avatar Trevor Miller avatar Xin Zhang avatar Wei Liu avatar Shida Wang avatar Jialong Wu avatar typoverflow avatar Songlin Yang avatar Armando Fortes avatar  avatar Jeff Carpenter avatar Xubo Liu avatar Huiying Yan avatar Simon Breslav avatar JSIV avatar  avatar FlyingCat avatar Iris88 avatar Gyungin avatar  avatar Subham Kumar avatar Taco avatar Hieu Duong avatar Zhiqing Sun avatar Jekelmls89 avatar Ikuya Yamada avatar Diwank Singh Tomer avatar WangHeng avatar harapan avatar Michael Tu avatar  avatar Junki Ohmura avatar Paulius Danėnas avatar Tommy Chien avatar Sofian Mejjoute avatar 艾梦 avatar 爱可可-爱生活 avatar Martin Salo avatar  avatar Bowen Jin avatar  avatar Zihan Wang avatar Kunat Pipatanakul avatar Ankur Singh avatar  avatar Mike avatar ashish hazara  avatar dinos avatar Kaiyu Yang avatar  avatar Dongfang Li avatar  avatar

Watchers

Mike avatar  avatar  avatar

reatt's Issues

How to produce your own embeddings?

Hi @jzbjyb, I would like to create embeddings on my own dataset with ReAtt. I couldn't find any code in the repository on how to do this. How were the embeddings in your Google Drive folder computed? Could you help me with this? Thanks!

Unable to run provided example code

Copy and pasting the code in the readme in a fresh venv after running the install script returns:

Traceback (most recent call last):
File "/home/mad326/reu_code/reatt/ReAtt/testing.py", line 11, in
model = ReAttForConditionalGeneration.from_pretrained(model_name, config=config).cuda()
File "/home/mad326/anaconda3/envs/reatt/lib/python3.7/site-packages/transformers/modeling_utils.py", line 1418, in from_pretrained
model = cls(config, *model_args, **model_kwargs)
File "/home/mad326/reu_code/reatt/ReAtt/reatt/model.py", line 761, in init
self.encoder = ReAttStack(encoder_config, self.shared)
File "/home/mad326/reu_code/reatt/ReAtt/reatt/model.py", line 433, in init
self.retriever = ReAttRetriever(self) if not self.is_decoder else None
File "/home/mad326/reu_code/reatt/ReAtt/reatt/model.py", line 1054, in init
self.load_index()
File "/home/mad326/reu_code/reatt/ReAtt/reatt/model.py", line 1066, in load_index
logging.warning(f'did not find embedding in {self.retrieval_corpus}')
AttributeError: module 'transformers.utils.logging' has no attribute 'warning'

Running Ubuntu 18, Python 3.7 if that helps.

Plans to release the training code

Hi, I'm interested on your work and following up your released code.
It seems that only the inference code has been released so far, but I wonder if there are any plans to add the learning code to GitHub.

Thank you for your interesting work!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.