GithubHelp home page GithubHelp logo

akariasai / self-rag Goto Github PK

View Code? Open in Web Editor NEW
1.6K 16.0 139.0 2.9 MB

This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection by Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi.

Home Page: https://selfrag.github.io/

License: MIT License

Python 99.12% Shell 0.88%

self-rag's Introduction

SELF-RAG: Learning to Retrieve, Generate and Critique through Self-reflection

This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection (ICLR 2024, Oral top 1%) by Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi.

Website | 7B Model | 13B Model | Paper | Training data | Twitter summary | Updates

Self-RAG (Figure right) is a new framework to train an arbitrary LM to learn to retrieve, generate, and critique to enhance the factuality and quality of generations, without hurting the versatility of LLMs.

Unlike a widely-adopted Retrieval-Augmented Generation (RAG; Figure left) approach, Self-RAG retrieves on demand (e.g., can retrieve multiple times or completely skip retrieval) given diverse queries, and criticize its own generation from multiple fine-grained aspects by predicting reflection tokens as an integral part of generation. We conduct a segment-wise beam search to select the output that maximizes the utility for diverse preferences.

If you find our code, data, models, or the paper useful, please cite the paper:

@inproceedings{
asai2024selfrag,
author={Asai, Akari and Wu, Zeqiu and Wang, Yizhong and Sil, Avirup and Hajishirzi, Hannaneh},
title={Self-{RAG}: Learning to Retrieve, Generate, and Critique through Self-Reflection},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
url={https://openreview.net/forum?id=hSyW5go0v8}
}

Updates

  • 2023.10: Initial release of codes, models, and the paper.

Content

  1. Installation
  2. Quick Start
  3. Retriever setup
  4. Training
  5. Inference
  6. Baselines
  7. FAQ
  8. Contact

Installation

Install dependent Python libraries by running the command below.

pip install -r requirements.txt

Please use the latest version of vllm, as the older version may not enable you to set skip_special_tokens via SamplingParam, which is added by (this PR).

You can also create a conda environment by running the command below.

conda env create -f environment.yml

Quick start

You can download Self-RAG from HuggingFace Hub. For inference, we recommend using vllm as it significantly speeds up inferences.

from vllm import LLM, SamplingParams

model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

def format_prompt(input, paragraph=None):
  prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
  if paragraph is not None:
    prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
  return prompt

query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]

# for a query that doesn't require retrieval
preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
  print("Model prediction: {0}".format(pred.outputs[0].text))

Output:

Model prediction: Twitter, Instagram, and WhatsApp are all social media platforms. [No Retrieval]WhatsApp is the odd one out because it is a messaging app, while Twitter and # Instagram are primarily used for sharing photos and videos.[Utility:5]</s>
Model prediction: Sure![Retrieval]<paragraph><paragraph>

As you can see, Self-RAG starts generating responses without retrieval in the first query when it does not require retrieval. On the other hand, Self-RAG output [Retrieve] tokens for the second, as this question requires more fine-grained factual grounding.

For queries that require factual grounding, you can insert a paragraph. Self-RAG can retrieve and insert paragraphs anytime while generating, and recognizes them as long as they are surrounded by context markup special tokens <paragraph>, </paragraph>.

# for a query that needs factual grounding
prompt = format_prompt("Can you tell me the difference between llamas and alpacas?", "The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds])
# ['[Relevant]Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.[Fully supported][Utility:5]</s>']

Self-RAG finds the relevant inserted document and generates answers that are fully supported by the evidence.

Run your evaluation using the online retrieval model

You can also run retrieval on-demand and use it with Self-RAG. As running retrieval over full English Wikipedia requires large RAM and multiple GPUs, we created a subset of Wikipedia, including intro paragraphs of Wikipedia articles only for demo purposes.

First, please download the corpus and embeddings (9GB in total).

git clone [email protected]:AkariAsai/self-rag.git
cd retrieval_lm
bash download_demo_corpus.sh

If the script does not work, you can download the data from google drive or HF dataset. Then, you can run the script under retrieval_lm. We tested the script using on 1 RTX 6000 with 24GB and 100G RAM (but should be runnable with much smaller RAM).

from passage_retrieval import Retriever
retriever = Retriever({})
retriever.setup_retriever_demo("facebook/contriever-msmarco", "enwiki_2020_intro_only/enwiki_2020_dec_intro_only.jsonl", "enwiki_2020_intro_only/enwiki_dec_2020_contriever_intro/*",  n_docs=5, save_or_load_index=False)
retrieved_documents = retriever.search_document_demo(query_3, 5)
prompts = [format_prompt(query_3, doc["title"] +"\n"+ doc["text"]) for doc in retrieved_documents]
preds = model.generate(prompts, sampling_params)
top_doc = retriever.search_document_demo(query_3, 1)[0]
print("Reference: {0}\nModel prediction: {1}".format(top_doc["title"] + "\n" + top_doc["text"], preds[0].outputs[0].text))

Output:

Reference: Overfitting
  In statistics, overfitting is "the production of an analysis that corresponds too closely or exactly to a particular set of data, and may therefore fail to fit additional data or predict future observations reliably". An overfitted model is a statistical model that contains more parameters than can be justified by the data. The essence of overfitting is to have unknowingly extracted some of the residual variation (i.e., the noise) as if that variation represented underlying model structure. Underfitting occurs when a statistical model cannot adequately capture the underlying structure of the data. An under-fitted model is a model where some parameters or terms that would appear in a correctly specified model are
Model prediction: [Relevant]Overfitting occurs when a model has too many parameters relative to the amount of data it has been trained on, leading it to memorize the training data too closely and perform poorly on new, unseen data.[Fully supported][Utility:5]</s>

The retriever system properly retrieves necessary document and generate fully grounded output.

Note that this demo uses a smaller corpus and Self-RAG with the full inference algorithm. For a full evaluation, you either need to set up a retriever or download our retrieved results. Please follow instructions at Inference.

Retriever Setup

By default, we use Contriever as our retrieval component.

Download data

Download preprocessed passage data used in DPR.

cd retrieval_lm
wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz

Then, download the generated passages. We use Contriever-MSMARCO

wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar

Run retriever

You can run passage retrieval by running the command below.

cd retrieval_lm
python passage_retrieval.py \
    --model_name_or_path facebook/contriever-msmarco --passages psgs_w100.tsv \
    --passages_embeddings "wikipedia_embeddings/*" \
    --data YOUR_INPUT_FILE  \
    --output_dir YOUR_OUTPUT_FILE \
    --n_docs 20

Your input file should be either a json or jsonl. Each instance must contain either question or instruction, which will be used as a query during retrieval.

Generate embeddings for your own data

You can generate embeddings for your own data by running the following command. (The script is adapted from the Contriever repository.) Note that generating embeddings from a large-scale corpus (>10M docs) can take time, and we recommend running it on multiple GPUs.

cd retrieval_lm
for i in {0..3}; do
  export CUDA_VISIBLE_DEVICES=${i}
  python generate_passage_embeddings.py  --model_name_or_path facebook/contriever-msmarco \
  --output_dir YOUR_OUTPUT_DIR \
  --passages YOUR_PASSAGE_DATA --shard_id ${i}  --num_shards 4 > ./log/nohup.my_embeddings.${i} 2>&1 &

Training

Self-RAG trains two models, Critic and Generator, both of which expand token vocabularies with reflection tokens and are trained with the standard next token prediction objective.

Alternatively, you can download our training data consisting of 150K instances here.

Collect reflection tokens

We collect training data from GPT-4. The scripts to call GPT-4 for each special token type are available at data_creation/critic.

Alternatively, you can download our training data at here.

Critic training

Once you create or download training data, run the command below to fine-tune Llama2-7B on critic training.

cd data_creation
torchrun --nproc_per_node=2 \
  --master_port=2568 train_special_tokens.py \
  --model_name_or_path meta-llama/Llama-2-7b-hf \
  --data_path PATH_TO_TRAIN_DATA_FILE \
  --bf16  True \
  --output_dir PATH_TO_CRITIC_MODEL \
  --num_train_epochs 3  \
  --per_device_train_batch_size 1 --per_device_eval_batch_size 1 \
  --gradient_accumulation_steps 8 \
  --evaluation_strategy "no" \
  --save_strategy "steps" \
  --save_steps 300 \
  --save_total_limit 1 \
  --learning_rate 2e-5 \
  --weight_decay 0. \
  --warmup_ratio 0.01 \
  --lr_scheduler_type "cosine" \
  --logging_steps 10 \
  --fsdp "full_shard auto_wrap"

Generator Data Creation

The code to create Generator training data is under generator_data_creation. See the instructions at README.md.

Alternatively, you can download our training data at HuggingFace dataset or here

Generator Training

For generator training, we use DeepSpeed to make training more efficient. You can run training by running the script below, after setting the training data path.

cd retrieval_lm
bash script_finetune_7b.sh

For 13B model training, use training_13b. We use 8 A100 with 40 GRAM for 7B model training and 4 a100 with 80 GB GRAM for 13B training. 7B should fit 1-2 A100 although training can be slow.

Inference

For the task evaluation conducted in the paper, please download the data here.

Each file already comes with retrieved documents, so if you don't want to run a retriever as a part of inference, you can simply load the retrieved docs at contexts.

Below, we describe Self-RAG and baselines.

  • Short-form: run evaluation for short-form generation.
  • Long-form: run evaluations for long-form generations.

Short-form (PubHealth, ARC-Challenge, TriviaQA, PopQA)

As we typically retrieve only once for a short-form generation task, we provide an easy-to-run evaluation script that leverages pre-given documents retrieved by Contriever offline. See the individual command below.

Question Answering

python run_short_form.py \
--model_name selfrag/selfrag_llama2_7b \
--input_file eval_data/popqa_longtail_w_gs.jsonl \
--mode MODE --max_new_tokens 100 \
--threshold 0.2 \
--output_file YOUR_OUTPUT_FILE \
--metric match --ndocs 10 --use_groundness --use_utility --use_seqscore \
--dtype half

mode specifies the inference time mode among ['adaptive_retrieval', 'no_retrieval', 'always_retrieve'].

  • adaptive_retrieval retrieves given the threshold or Self-RAG prediction
  • no_retrieval disables retrieval at inference time
  • always_retrieve always retrieves.

For 13B, you may have an OOM issue if you use a single GPU with 24 GRAM. You can run inference on multiple GPUs by setting --world_size.

ARC Challenge

python run_short_form.py \
  --model_name selfrag/selfrag_llama2_7b \
  --input_file eval_data/arc_challenge_processed.jsonl \
  --max_new_tokens 50 --threshold 0.2 \
  --output_file OUTPUT_FILE_NAME \
  --metric match --ndocs 5 --use_groundness --use_utility --use_seqscore \
  --task arc_c

PubHealth

python run_short_form.py \
  --model_name selfrag/selfrag_llama2_7b \
  --input_file eval_data/health_claims_processed.jsonl \
  --max_new_tokens 50 \
  --threshold 0.2 --output_file OUTPUT_FILE_NAME \
  --metric match --ndocs 5 \
  --use_groundness --use_utility --use_seqscore \
  --task fever

Long-form (ASQA, FactScore)

For long-form QA, you can either run evaluations with a retrieval model or with pre-given passages. Currently, we are working on reducing run-time memory requirements (DPR / Contriever with the whole English Wikipedia Embeddings requires 100 GB RAM) speeding up for long-form generations, and releasing the inference code using a small set of initial retrieved documents first (~20).

Note: Our current implementation is specifically designed for evaluations of target task datasets. We are planning to update our code base to make the interface more simple and easier to use. We will announce it when we release another version.

Run inference using pre-retrieved passages

For ASQA, please run the following command,

python run_long_form_static.py \
  --model_name selfrag/selfrag_llama2_7b \
  --ndocs 5 --max_new_tokens 300 --threshold 0.2 \
  --use_grounding --use_utility --use_seqscore \
  --task asqa --input_file eval_data/asqa_eval_gtr_top100.json \
  --output_file YOUR_OUTPUT_FILE_NAME --max_depth 7 --mode always_retrieve \

For FactScore,

python run_long_form_static.py \
  --model_name selfrag/selfrag_llama2_7b \
  --ndocs 5 --max_new_tokens 300 --threshold 0.2 \
  --use_grounding --use_utility --use_seqscore \
  --task factscore --input_file eval_data/factscore_unlabeled_alpaca_13b_retrieval.jsonl \
  --output_file YOUR_OUTPUT_FILE_NAME --max_depth 7 \
Key parameters for long-form generations

There are several key parameters related to the inference of Self-RAG.

  • w_rel (default 1.0): w_rel controls the emphasis on the isRel (a critique token on whether retrieved passages are relevant or not) token probability during beam search.
  • w_sup (default 1.0): w_sup controls the emphasis on the isSup (a critique token on whether the generation is supported by the document or not) token probability during beam search.
  • w_use (default 0.5): w_use controls the emphasis on the isUse (a critique token on overall quality) token probability during beam search.
  • threshold (default 0.2): this threshold controls the frequency of adaptive retrieval.
  • max_depth (default 6): this corresponds to T in the paper, which defines the maximum depth of search.
  • beam_width (default 2): this controls the size of the beam in the segment-level beam search.

For more details, please refer to the details (Section 3.3) and analysis (Section 5) in our paper.

Run evaluation

For long-form evaluations, set up external libraries or repositories to run evaluations.

  • factscore==v0.1.5 (bio) Please follow the instructions at the FactScore official repository to set up your environment.
python -m factscore.factscorer --data_path YOUR_OUTPUT_FILE  --model_name retrieval+ChatGPT --cache_dir YOUR_CACHE_DIR --openai_key YOUR_OPEN_AI_KEY --verbose

ALCE provides a comprehensive evaluation using multiple different metrics for long-form QA. For your first evaluation, install the ALCE repo and download the data.

git clone https://github.com/princeton-nlp/ALCE.git
python3 -m alce_env
cd ALCE
bash download_data.sh

For ASQA, you can run evaluations as follows. Note that ASQA evaluations require T5-XXL (11B)-based NLI module.

python eval.py --f YOUR_OUTPUT_FILE --citations --qa --mauve

Baselines

Code to rerun the baselines is available at run_baseline_lm.py. To run the retrieval-augmented baselines, make sure to download the task input files with retrieved passages.

Vanilla LM baselines

  • Huggingface models
python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
 --max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa --prompt_name "prompt_no_input"

e.g., PubHealth

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/health_claims_processed.jsonl \
--max_new_tokens 20 \
--metric accuracy \
--result_fp llama2_7b_pubhealth_results.json \
--task fever

Note: for PubHealth and ARC, please pass the task names (ARC = arc_c and PubHealth = fever) to properly set the instruction.

  • OpenAI APIs

For OpenAI API models, you also need to set the organization key here. You also need to have a txt file including your OpenAI API key.

python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
 --task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--prompt_name "prompt_no_input"

Retrieval-augmented baselines

  • Huggingface models
python run_baseline_refactor.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
 --max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH --task qa \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"
  • OpenAI APIs
python run_baseline_lm.py \
--model_name gpt-3.5-turbo-0301 \
--input_file INPUT_FILE_SAME_AS_SELF_RAG \
--max_new_tokens 100 --metric match \
--result_fp RESULT_FILE_PATH \
 --task qa \
--api_key YOUR_OPEN_AI_API_KEY_FILE \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval"

FAQ

Q1: How can I train a new pre-trained LM using Self-RAG scheme? -- If you are using hugging face transformers, you can simply change the model_name_or_path and tokenizer_name in our training script, script_finetune_7b.sh. If you want to use your own fine-tuning script, please make sure to add the special tokens and mask out the paragraph context, as discussed in this issue

Q2: Are you planning to release Mistral-7B-based Self-RAG? -- Right now I have limited bandwidth to do so, but there is a community-trained version of Self-RAG SciPhi-Self-RAG-Mistral-7B-32k on top of Mistral-7B. We will announce if we can train Self-RAG on Mistral-7B and release the checkpoint.

Contact

If you have questions, please open an issue mentioning @AkariAsai or send an email to akari[at]cs.washington.edu.

self-rag's People

Contributors

akariasai avatar d223302 avatar emrgnt-cmplxty avatar ianthereal avatar notoookay avatar shruti222patel avatar zlwang-cs avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

self-rag's Issues

Why is eval logic so complicated?

Hi,

Two major questions cross my mind when looking over this repository

1.) Would you like access to an optimized compressed FAISS index? I ran a simple optimization this morning and shrunk my index from ~60gb to 5.5gb without any obvious degradation. This might help others in replicating the retrieval pipeline.
2.) Will I be able to reproduce your results in production if I elect to just retrieve on '[Retrieval]' tokens, rather than performing this calculation -

        if threshold is not None:
            score_dict = {}
            for tok, id in ret_tokens.items():
                if id not in pred_log_probs[0]:
                    score_dict[tok] = -100
                prob = pred_log_probs[0][id]
                score_dict[tok] = float(prob)
            do_retrieve = score_dict["[Retrieval]"] / (
                score_dict["[Retrieval]"] + score_dict["[No Retrieval]"]) > threshold

E.g. how different is this in practice retrieving only when the retrieval token is the sampled result?

Question about the pre-given passages

Hi @AkariAsai, great work and thanks for this repo!

I'm confused about the pre-given passages in the evaluation datasets. Are they retrieved using the initial question as query or you took the union of all the retrieved passages at multiple stages?

Will you release the non-static inference code without pre-given passages or the code for generating the pre-given passages for the evaluation datasets?

Is the FT script correct?

Hi,

I ran the fine-tune script on Mistral base model but found rather poor results on ARC Challenge (<50% with retrieval). Any ideas why? I will repeat with Mistral Instruct to see if it makes a beneficial difference, but I am not optimistic as I have seen similar poor results when fine-tuning this model with the self-rag dataset and script.

MODEL_SIZE=7B
NUM_GPUS=8
BATCH_SIZE_PER_GPU=1
TOTAL_BATCH_SIZE=128
GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE_PER_GPU))
echo "Training llama model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps"

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
    --mixed_precision bf16 \
    --num_machines 1 \
    --num_processes $NUM_GPUS \
    --use_deepspeed \
    --deepspeed_config_file stage3_no_offloading_accelerate.conf \
    finetune.py \
    --model_name_or_path mistralai/Mistral-7B-v0.1 \
    --use_flash_attn \
    --tokenizer_name mistralai/Mistral-7B-v0.1 \
    --use_slow_tokenizer \
    --train_file full_output_1005.jsonl \
    --max_seq_length 2048 \
    --preprocessing_num_workers 16 \
    --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
    --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
    --learning_rate 2e-5 \
    --lr_scheduler_type linear \
    --warmup_ratio 0.03 \
    --weight_decay 0. \
    --num_train_epochs 5 \
    --output_dir output/mistral_root_${MODEL_SIZE}/ \
    --with_tracking \
    --report_to tensorboard \
    --logging_steps 1 \
    --use_special_tokens

EDIT: I had a chance to look into this today, I am fairly confident the issue is that this script will NOT work for a model that has not had the tokenizer independently prepared. Will confirm and close the issue - it might be nice to add some information on how to independently replicate the result fine-tuning from scratch.

Use Self-RAG with vector databases

The ability to retrieve from a vector database is often how RAG is used in applications. Could you please look into adding langchain or llamaindex support on top to leverage e.g. a Pinecone, Milvus, or Chroma vector store?

The saved embed_tokens is empty

Hello, I try to run this code with llama1-7B, while I find the saved embed_tokens is empty and fail to load after training. Have you met this problem?

(Pdb) param_name
'model.embed_tokens.weight'
(Pdb) param
tensor([], dtype=torch.bfloat16)

Mistral 7B

Will there be a release of the Mistral 7b model?

max_depth argument in retrieval_lm/run_short_form.py

Hello,

I'm trying to reproduce paper numbers on PopQA by running the following command :

Question Answering
python run_short_form.py
--model_name selfrag/selfrag_llama2_7b
--input_file eval_data/popqa_longtail_w_gs.jsonl
--mode MODE --max_new_tokens 100
--threshold 0.2
--output_file YOUR_OUTPUT_FILE
--metric match --ndocs 10 --use_groundness --use_utility --use_seqscore
--dtype half

But, I'm getting an error :
TypeError: call_model_rerank_w_scores_batch() got an unexpected keyword argument 'max_depth'

So, should I modify the code by removing ’max_depth=args.max_depth,‘?

requirement confilicts for vllm and fast-attn

Hi, I encountered some issues while trying to run it. I attempted to install the dependency package in requirements.txt, but found a version conflict between two packages:

fast-attn requires torch>=1.13.0 and torch<2.0.0
vllm requires torch>=2.0.0.

besides, I also encountered a problem same to #26,
may I ask if you have encountered this problem? thank you!

Requirement Confilicts

Hi,

Thanks for the great work! I am trying to install the requirements at you showed in the repo. However, I encountered the following error: factscore 0.1.5 depends on openai<0.28.0 and >=0.27.7.

I understand it could be difficult to maintain the compatible environment given that the dependencies are updating so quickly. It would be great if you can specify a bit in the requirement.txt to make the installation more smoothly.

Thanks!

Are these two code files the same thing?

Hi!
Thanks for your great work.

In Baselines's description, the code is run_baseline_lm.py, but in the Python running scripts, the code is run_baseline_refactor.py.
So are these two code files the same thing?

Thanks for your answer.

RAG Baselines

Hey @AkariAsai, great work and thank you very much for putting together a nice hugging face model, dataset, and repo for reproducing and extending the results :)

In terms of the baselines for RAG, I see the paper describe:

Baselines with retrievals. We evaluate models augmented with retrieval at test time or during training.
The first category includes standard RAG baselines, where an LM (Llama2, Alpaca) generates output
given the query prepended with the top retrieved documents using the same retriever as in our system.

I was wondering if you had the scripts available for running the evaluations of Llama2 and Llama-2 chat with the described setup? It seemed to me that run_short_form.py should only work for models with the special tokens.

Thanks in advance!

Cannot approach the performance of the uploaded self-rag ckpt when finetuning meta/Llama-2 myself

Thanks for your inspiring work @AkariAsai .

I tried to run the script_finetune_7b.sh script myself (using meta-llama/Llama-2-7b-hf and your provided generator data), which is expected to produce a ckpt that aligns with the uploaded ckpt in performance now that both the base ckpt and data align.

However the resulting model shows a significant performance gap wrt the uploaded self-rag checkpoint. For example, on triviaQA, my ckpt has only 0.503 acc compared with 0.679 the uploaded ckpt.

I do notice that there is difference between the final checkpoint dir and the uploaded checkpoint dir:

  1. my reproduction ckpt are saved in *.savetensors, however the uploaded ckpt in *.bin.
  2. I encounter the same issue as mentioned in #21. The checkpointing stores both the single checkpoint (but without embedding parameters) and sharded checkpoints (see the figure below). So #21 suggested model.safetensors be removed. I guess you did not encounter such issue.
image

I was wondering whether the underlying cause of such difference might result in this performance gap. Do you have any idea regarding this matter?

Cannot reproduce baseline tasks?

Hi! Thanks for your great work.

I tried to reproduce the baseline tasks, but the results were low compared to the paper. So I am not sure whether I used the correct script. Please help me.

For PopQA

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/popqa_longtail_w_gs.jsonl  \
 --max_new_tokens 100 --metric match \
--result_fp output/test_out_popqa_run_short_form_Llama-2-7b-hf_100 --task qa --prompt_name "prompt_no_input" --world_size 8

overall result: 0.09578270192994996, which is low.

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/popqa_longtail_w_gs.jsonl \
 --max_new_tokens 100 --metric match \
--result_fp output/test_out_popqa_run_short_form_Llama-2-7b-hf_100_Retrieval-augmented --task qa --mode retrieval --prompt_name "prompt_no_input_retrieval" --world_size 8

overall result: 0.3566833452466047, which is close to the paper. So this result may be correct.

For ARC Challenge

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/arc_challenge_processed.jsonl \
 --max_new_tokens 50 --metric match \
--result_fp output/test_out_arc_run_short_form_Llama-2-7b-hf_50 --task qa --prompt_name "prompt_no_input" --world_size 8

overall result: 0.11433447098976109, which is low.

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/arc_challenge_processed.jsonl \
 --max_new_tokens 50 --metric match \
--result_fp output/test_out_arc_run_short_form_Llama-2-7b-hf_50_Retrieval-augmented --task qa --mode retrieval --prompt_name "prompt_no_input_retrieval" --world_size 8

overall result: 0.09044368600682594, which is low.

For PubHealth,

python run_baseline_lm.py \ --model_name meta-llama/Llama-2-7b-hf \ --input_file eval_data/health_claims_processed.jsonl \ --max_new_tokens 50 --metric match \ --result_fp output/test_out_pubhealth_run_short_form_Llama-2-7b-hf_50 --task qa --prompt_name "prompt_no_input" --world_size 8

overall result: 0.0060790273556231, which is low.

python run_baseline_lm.py \
--model_name meta-llama/Llama-2-7b-hf \
--input_file eval_data/health_claims_processed.jsonl \
 --max_new_tokens 100 --metric match \
--result_fp output/test_out_pubhealth_run_short_form_Llama-2-7b-hf_Retrieval-augmented --task qa \
--mode retrieval \
--prompt_name "prompt_no_input_retrieval" --world_size 8

overall result: 0.008105369807497468, which is low.

Question about Algorithm 1 table in self rag paper download from ariv

Your work is impressive, but I am confused about the 5-th line in SELF-RAG Inference table in your paper.
What is the meaning of "M predicts ISREL given x, d and yt given x, d, y<t for each d ∈ D". In my opinion, given x, d and yt is no need. "given x, d, y<t " is right.
Here is my question, I believe I have not fully understand the inference process of self rag
waitting for your explaination ~ ths a lot
1

The result of direct inference without using VLLM is wrong, is it a problem with the model?

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GenerationConfig
# from vllm import LLM, SamplingParams
import torch
device = torch.device(0)


def load_tokenizer_and_model():
  tokenizer = AutoTokenizer.from_pretrained('/root/autodl-tmp/selfrag_llama2_7b')
  config = AutoConfig.from_pretrained('/root/autodl-tmp/selfrag_llama2_7b')
  model = AutoModelForCausalLM.from_pretrained(
    '/root/autodl-tmp/selfrag_llama2_7b',
    torch_dtype=torch.float16,
    config=config
  )

  model.to(device)
  model.eval()
  return tokenizer, model

def format_prompt(input, paragraph=None):
  prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
  if paragraph is not None:
    prompt += "[Retrieval]<paragraph>{0}</paragraph>".format(paragraph)
  return prompt

 if  __name__ == "__main__":
  query_1 = "Leave odd one out: twitter, instagram, whatsapp."
  query_2 = "Can you tell me the difference between llamas and alpacas?"
  queries = [query_1, query_2]
  tokenizer, model = load_tokenizer_and_model()

  for q in queries:
    # inputs = tokenizer([format_prompt(query) for query in queries], return_tensors='pt')
    inputs = tokenizer(format_prompt(q), return_tensors='pt')
    input_ids = inputs['input_ids'].to(device)

    generation_config = GenerationConfig(
      temperature=0.0,
      top_p=1.0,
      max_tokens=100
    )
    with torch.no_grad():
      generation_output = model.generate(
        input_ids=input_ids,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=True,
        repetition_penalty=1.2,
      )
    output = generation_output.sequences[0]
    output = tokenizer.decode(output, skip_special_tokens=True)
    print(output)

"""
'### Instruction:
Leave odd one out: twitter, instagram, whatsapp.

### Response:
Tw'


'### Instruction:
Can you tell me the difference between llamas and alpacas?

### Response:
S'
"""

OOM issue when running the quick start code under 80GB gpu

Hello, I appreciate the effort you’ve put into your work!

I’ve been trying to execute your quick start code, but I’ve run into an Out Of Memory (OOM) error, despite having an 80GB GPU at my disposal. I was under the impression that a 7B model would fit comfortably within an 80GB GPU memory, so I’m unsure why I’m still facing this OOM error. Could you possibly shed some light on this issue? Thanks!

from vllm import LLM, SamplingParams
model = LLM("selfrag/selfrag_llama2_7b", download_dir=MY_DIR, dtype="half")

and by the way, can you tell me the typical memory usage when executing this code snippet?

The selfrag_llama2_7b model does not come out as in the example

I experimented using the settings provided in the example at https://huggingface.co/selfrag/selfrag_llama2_7b, but the prediction result I got was just a series of 'Model prediction: blank result'. However, when using the model at https://huggingface.co/selfrag/self_rag_critic, the results come out as expected.


from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM, SamplingParams

model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

def format_prompt(input, paragraph=None):
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
if paragraph is not None:
prompt += "[Retrieval]{0}".format(paragraph)
return prompt

query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]

preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
print("Model prediction: {0}".format(pred.outputs[0].text))

Expected results are below
Model prediction: Twitter, Instagram, and WhatsApp are all social media platforms.[No Retrieval]WhatsApp is the odd one out because it is a messaging app, while Twitter and # Instagram are primarily used for sharing photos and videos.[Utility:5] (this query doesn't require factual grounding; just skip retrieval and do normal instruction-following generation)
=>But I got the blank result

Expected results are below
Model prediction: Sure![Retrieval] ... (this query requires factual grounding, call a retriever)
=>But I got the blank result

generate with retrieved passage

prompt = format_prompt("Can you tell me the difference between llamas and alpacas?", paragraph="The alpaca (Lama pacos) is a species of South American camelid mammal. It is similar to, and often confused with, the llama. Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.")
preds = model.generate([prompt], sampling_params)
print([pred.outputs[0].text for pred in preds])

Expected results are below
['[Relevant]Alpacas are considerably smaller than llamas, and unlike llamas, they were not bred to be working animals, but were bred specifically for their fiber.[Fully supported][Utility:5]']
=>But I got the blank result


Out of memory at inference in free tier Google Colab

Tried the quick start code in free tier Google Colab. Got out of memory error. Is this expected ?

from vllm import LLM, SamplingParams

model = LLM("selfrag/selfrag_llama2_7b", download_dir="/gscratch/h2lab/akari/model_cache", dtype="half")
sampling_params = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=100, skip_special_tokens=False)

def format_prompt(input, paragraph=None):
prompt = "### Instruction:\n{0}\n\n### Response:\n".format(input)
if paragraph is not None:
prompt += "[Retrieval]{0}".format(paragraph)
return prompt

query_1 = "Leave odd one out: twitter, instagram, whatsapp."
query_2 = "Can you tell me the difference between llamas and alpacas?"
queries = [query_1, query_2]

for a query that doesn't require retrieval

preds = model.generate([format_prompt(query) for query in queries], sampling_params)
for pred in preds:
print("Model prediction: {0}".format(pred.outputs[0].text))

How to fix the bug about 'local variable 'pred' referenced before assignment'?

Hi! Thanks for the great work.

I find an error in run_short_form.py when testing with my own data about multiple choices

 if len(results) == 1:
        postprocessed_pred = postprocess_answer_option_conditioned(pred)
        return postprocessed_pred, results, do_retrieve

Traceback (most recent call last):
File "run_short_form.py", line 371, in
main()
File "run_short_form.py", line 329, in main
pred, results, do_retrieve = generate(
File "run_short_form.py", line 313, in generate
return call_model_rerank_w_scores_batch(prompt, evidences=evidences, model=model, max_new_tokens=max_new_tokens,
File "run_short_form.py", line 179, in call_model_rerank_w_scores_batch
postprocessed_pred = postprocess_answer_option_conditioned(pred)
UnboundLocalError: local variable 'pred' referenced before assignment

Could you please help solve it?

Reproducing the TriviaQA numbers

Hi @AkariAsai , great work and thank you very much for this repo!

I try to reproduce the baseline results on TriviaQA task, using Llama-2-7b downloaded from Huggingface.
But the results I get are quite different from those reported in the paper:

  • Paper - vanilla: 30.5 retrieval-augmented: 42.5
  • My - vanilla: 4.95 retrieval-augmented: 2.61

The command I used for retrieval-augmented experiment:

python run_baseline_lm.py \ 
 --model_name PATH_TO_LLAMA2-7B-HF \
 --input_file eval_data/triviaqa_test_w_gs.jsonl \
 --max_new_tokens 100 \
 --metric accuracy \
 --task qa \
 --dtype half \
 --mode retrieval \
 --batch_size 5 \
 --prompt_name "prompt_no_input_retrieval"

The gap is so huge that I'm wondering if I got some details wrong.

Thanks in advance :)

For ASQA, how to reproduce the baseline?

Hi! Thanks for your great work.
I try to reproduce the baseline for ASQA using Llama-2-7b-hf, like this:

python run_baseline_lm.py
--model_name meta-llama/Llama-2-7b-hf
--input_file eval_data/asqa_eval_gtr_top100.json
--max_new_tokens 300 --metric match
--result_fp output/test_out_ASQA_Llama-2-7b-hf_300_Retrieval-augmented --task qa --mode retrieval --prompt_name "prompt_no_input_retrieval" --world_size 8

However, the error was returned :
Traceback (most recent call last):
File "run_baseline_lm.py", line 230, in
main()
File "run_baseline_lm.py", line 128, in main
retrieval_result = item["ctxs"][:args.top_n]
KeyError: 'ctxs'

Could you please help to fix it?

Reproducing the ASQA numbers

Hi, I was unable to reproduce the ASQA numbers for long-form generation. After evaluating the output with ALCE, I see the below numbers which are very different from those reported in the paper:

  • 'str_em': 30.05098452883263
  • 'rougeLsum': 34.10838297032821
  • 'mauve': 68.43516667345226
  • 'citation_rec': 50.0210970464135
  • 'citation_prec': 63.60759493670886

The command I used:

python run_long_form_static.py 
--model_name selfrag/selfrag_llama2_7b --ndocs 5 --max_new_tokens 300 
--threshold 0.2 --use_grounding --use_utility --use_seqscore  --task asqa 
--input_file eval_data/asqa_eval_gtr_top100.json 
--output_file asqa/selfrag_llama2_7b.json --max_depth 7 --mode always_retrieve

I have also uploaded the model output file here for your reference. Just wanted to know whether I am doing anything wrong for ASQA.

Btw, I did a sanity check by evaluating on short-form generation with PopQA and I see 55.0 for accuracy, which matches the number reported in the paper.

Retrieval over my own documents

Hi there @AkariAsai ,

I was wondering how to do RAG over my own documents rather than using wikipedia, do I need to train a new critic and generator model? Or can I use the pretrained model that you have released somehow, and only update the retriever?

Thank you

Where are the "preceding sentences" from?

I want to build my own dataset to fine-tune a Critic model, but I found that the prompt in the code has a preceding presence, and I don't know where it comes from and how it was generated?

PROMPT_DICT = { "context": ( "Given an instruction, please make a judgment on whether finding some external documents from the web (e.g., Wikipedia) helps to generate a better response. Please answer [Yes] or [No] and write an explanation.\n\n" "##\nInstruction: Give three tips for staying healthy.\n" "Need retrieval?: [Yes]\n" "Explanation: There might be some online sources listing three tips for staying healthy or some reliable sources to explain the effects of different behaviors on health. So retrieving documents is helpful to improve the response to this query.\n\n" "##\nInstruction: Describe a time when you had to make a difficult decision.\n" "Need retrieval?: [No]\n" "Explanation: This instruction is asking about some personal experience and thus it does not require one to find some external documents.\n\n" "##\nInstruction: Write a short story in third person narration about a protagonist who has to make an important career decision.\n" "Need retrieval?: [No]\n" "Explanation: This instruction asks us to write a short story, which does not require external evidence to verify.\n\n" "##\nInstruction: What is the capital of France?\n" "Need retrieval?: [Yes]\n" "Explanation: While the instruction simply asks us to answer the capital of France, which is a widely known fact, retrieving web documents for this question can still help.\n\n" "##\n Instruction: Find the area of a circle given its radius. Radius = 4\n" "Need retrieval?: [No]\n" "Explanation: This is a math question and although we may be able to find some documents describing a formula, it is unlikely to find a document exactly mentioning the answer.\n\n" "##\nInstruction: Arrange the words in the given sentence to form a grammatically correct sentence. quickly the brown fox jumped\n" "Need retrieval?: [No]\n" "Explanation: This task doesn't require any external evidence, as it is a simple grammatical question.\n\n" "##\nInstruction: Explain the process of cellular respiration in plants." "Need retrieval?: [Yes]\n" "Explanation: This instruction asks for a detailed description of a scientific concept, and is highly likely that we can find a reliable and useful document to support the response.\n\n" "##\nInstruction:{instruction}\n" "Need retrieval?: " ), "multi_retrieval": ( "You will be provided with an instruction, evidence, output sentence, and preceding sentences (optional). If the preceding sentence is given, the output should be the sentence that follows those preceding sentences. Your task is to determine whether the information in the output sentence can be fully verified by the evidence or if it requires further external verification. If the output sentence can be verified solely with the evidence or doesn’t require any verification, respond with [No Retrieval]. If additional information is needed to verify the output sentence, respond with [Retrieval]. Please provide explanations for your judgments.\n\n" "##\nInstruction: Explain the use of word embeddings in Natural Language Processing.\n" "Preceding sentences: Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured.\n" "Evidence: Word embedding\nWord embedding is the collective name for a set of language modeling and feature learning techniques in natural language processing (NLP) where words or phrases from the vocabulary are mapped to vectors of real numbers. Conceptually it involves a mathematical embedding from a space with one dimension per word to a continuous vector space with a much lower dimension.\n" "Output: Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies.\n" "Rating: [Retrieval]\n" "Explanation: The output discusses the applications of word embeddings, while the evidence only discusses the definitions of word embeddings and how it works. Therefore, we need to retrieve other evidence to verify whether the output is actually correct or not.\n" "###\nInstruction: {instruction}\n" "Preceding sentences: {preceding_sentences}\n" "Evidence: {evidence}\n" "Output: {target_output}\n" "Rating: "), "multi_retrieval_no_preceding": ( "You will be provided with an instruction, evidence, output sentence, and preceding sentences (optional). If the preceding sentence is given, the output should be the sentence that follows those preceding sentences. Your task is to determine whether the information in the output sentence can be fully verified by the evidence or if it requires further external verification. If the output sentence can be verified solely with the evidence or doesn’t require any verification, respond with [No Retrieval]. If additional information is needed to verify the output sentence, respond with [Retrieval]. Please provide explanations for your judgments.\n\n" "##\nInstruction: Explain the use of word embeddings in Natural Language Processing.\n" "Preceding sentences: Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured.\n" "Evidence: Word embedding\nWord embedding is the collective name for a set of language modeling and feature learning techniques in natural language processing (NLP) where words or phrases from the vocabulary are mapped to vectors of real numbers. Conceptually it involves a mathematical embedding from a space with one dimension per word to a continuous vector space with a much lower dimension.\n" "Output: Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies.\n" "Rating: [Retrieval]\n" "Explanation: The output discusses the applications of word embeddings, while the evidence only discusses the definitions of word embeddings and how it works. Therefore, we need to retrieve other evidence to verify whether the output is actually correct or not.\n" "###\nInstruction: {instruction}\n" "Evidence: {evidence}\n" "Output: {target_output}\n" "Rating: " ), "multi_retrieval_three_way": ( "You will be provided with an instruction, evidence, output sentence, and preceding sentences (optional). If the preceding sentence is given, the output should be the sentence that follows those preceding sentences. Your task is to determine whether the information in the output sentence can be fully verified by the evidence or if it requires further external verification. There are three cases:\n" "- If the output sentence can be verified solely with the evidence, then respond with [Continue to Use Evidence]. \n" "- If the sentence doesn't require any factual verification (e.g., a subjective sentence or a sentence about common sense), then respond with [No Retrieval]. \n" "If additional information is needed to verify the output sentence, respond with [Retrieval]. Please provide explanations for your judgments. \n\n" "##\nInstruction: Explain the use of word embeddings in Natural Language Processing.\n" "Preceding sentences: Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured. \n" "Evidence:\nWord embedding\nWord embedding is the collective name for a set of language modeling and feature learning techniques in natural language processing (NLP) where words or phrases from the vocabulary are mapped to vectors of real numbers. Conceptually it involves a mathematical embedding from a space with one dimension per word to a continuous vector space with a much lower dimension. \n" "Output: Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies.\n" "Rating: [Retrieval]\n" "Explanation: The output discusses the applications of word embeddings, while the evidence only discusses the definitions of word embeddings and how it works. Therefore, we need to retrieve other evidence to verify whether the output is correct or not.\n" "###\nInstruction: {instruction}\n" "Preceding sentences: {preceding_sentences}\n" "Evidence: {evidence}\n" "Output: {target_output}\n" "Rating: " ), "multi_retrieval_three_way_no_preceding": ( "You will be provided with an instruction, evidence, output sentence, and preceding sentences (optional). If the preceding sentence is given, the output should be the sentence that follows those preceding sentences. Your task is to determine whether the information in the output sentence can be fully verified by the evidence or if it requires further external verification. There are three cases:\n" "- If the output sentence can be verified solely with the evidence, then respond with [Continue to Use Evidence]. \n" "- If the sentence doesn't require any factual verification (e.g., a subjective sentence or a sentence about common sense), then respond with [No Retrieval]. \n" "- If additional information is needed to verify the output sentence, respond with [Retrieval]. Please provide explanations for your judgments. \n\n" "##\nInstruction: Explain the use of word embeddings in Natural Language Processing.\n" "Preceding sentences: Word embeddings are one of the most powerful tools available for Natural Language Processing (NLP). They are mathematical representations of words or phrases in a vector space, allowing similarities between words and the context in which they are used to be measured. \n" "Evidence:\nWord embedding\nWord embedding is the collective name for a set of language modeling and feature learning techniques in natural language processing (NLP) where words or phrases from the vocabulary are mapped to vectors of real numbers. Conceptually it involves a mathematical embedding from a space with one dimension per word to a continuous vector space with a much lower dimension. \n" "Output: Word embeddings are useful for tasks such as sentiment analysis, text classification, predicting the next word in a sequence, and understanding synonyms and analogies.\n" "Rating: [Retrieval]\n" "Explanation: The output discusses the applications of word embeddings, while the evidence only discusses the definitions of word embeddings and how it works. Therefore, we need to retrieve other evidence to verify whether the output is correct or not.\n" "###\nInstruction: {instruction}\n" "Evidence: {evidence}\n" "Output: {target_output}\n" "Rating: " ) }

About baseline's parameter 'task'

Hi!
Thanks for your great work.
I find the task parameter in all example codes(run_baseline_lm.py) is 'qa', but in self-RAG, the task could be different values, such as arc_c, fever, asqa, factscore.
So if I'd like to run run_baseline_lm.py, should I set different values for the task parameter according to different tasks?
Or just set the parameter to 'qa' is fine for all kinds of tasks?

Thanks.

the logic of NO retrieval in long form inference

excellent work ! I have a question about the logic of no retrieval. In multiple rounds of retrieval, if [No retrieval] appears for the first time, does it mean that all retrievals have ended, or does it mean that the second and third rounds of retrieval can still be carried out?
My understanding based on your paper and code is that if [No retrieval] appears for the first time, it means that all answers are generated without multiple retrievals.
image

waitting for your detailed explaination~

If it's possible to use LoRA to train critic and generator

Hi there @AkariAsai , I'm a beginner and wondering if it is possible to use LoRA to train the critic and generator? Because i have a try to train the generator and it needs 200+ hours(i only have 2*a100 40g)
If it's possible, i'm also wondering how to expand token vocabularies with reflection tokens

Thank you very much!

Training Critic Model

Hey @AkariAsai awesome work your team has done :)

I am trying to get access to the 7B critic model mentioned in the paper and I noticed it is not released. If you have a trained model I am happy to test it as well.

At the same time I am trying to train this critic model with your provided "gpt4_reward_all_0813_train.json". But it seems like it is not compatible by directly running your "/data_creation/train_special_tokens.py". Do you happen to have an preprocessing script or could you provide me the training data you have processed ?

EDIT:
I just realised Line 239 is the culprit, I have changed to:
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
and it runs ok.

For training critic model is it ok to remove prompt_no_input_paragraph and prompt_no_input_separated ?

About PopQA

Hi! Thanks for the great work.
When reproducing the inference for PopQA using Self-RAG, I got the same score for adaptive_retrieval and always_retrieve.
In theory, the adaptive_retrieval result should be better than always_retrieve?I don't know why...

Obtaining my own critic data

First of all congratulations on the impressive work!
I am looking to extend SELF_RAG for long context tasks and planning to create my own training data. I had 3 questions about obtaining my own critic data.

  1. What is the input file? Are the inputs obtained from the tasks themselves?
  2. How do we create the jsonl for the input files? Where do we get those parameters from? (There is no README in process_data as mentioned here data_creation/critic/gpt4_reward/README.md). Does each critic data need a separate input file?
  3. Once I have obtained the critic data for each of the four tokens, how do I combine them into one file for training the critic model?

beam_width argument in retrieval_lm/run_short_form.py

Hello,

I'm trying to reproduce paper numbers on arc_challenge by running the following command :

python run_short_form.py
--model_name selfrag/selfrag_llama2_7b
--input_file eval_data/arc_challenge_processed.jsonl
--max_new_tokens 50 --threshold 0.2
--output_file OUTPUT_FILE_NAME
--metric match --ndocs 5 --use_groundness --use_utility --use_seqscore
--task arc_c

But, I'm getting an error :
return call_model_rerank_w_scores_batch(prompt, evidences=evidences, model=model, max_new_tokens=max_new_tokens,
TypeError: call_model_rerank_w_scores_batch() got an unexpected keyword argument 'beam_width'

Opening : retrieval_lm/run_short_form.py

def call_model_rerank_w_scores_batch(prompt, evidences, model, max_new_tokens=15,
ret_tokens=None, rel_tokens=None, grd_tokens=None, ut_tokens=None,
use_seqscore=False, threshold=0.5,
w_rel=1.0, w_sup=1.0, w_use=0.5, mode="adaptive_retrieval", closed=False):

def generate(prompt, evidences, max_new_tokens):
return call_model_rerank_w_scores_batch(prompt, evidences=evidences, model=model, max_new_tokens=max_new_tokens,
rel_tokens=rel_tokens, ret_tokens=ret_tokens, grd_tokens=grd_tokens, ut_tokens=ut_tokens,
threshold=args.threshold, beam_width=args.beam_width, max_depth=args.max_depth, use_seqscore=args.use_seqscore,
w_rel=args.w_rel, w_sup=args.w_sup, w_use=args.w_use, mode=args.mode, closed=args.task in ["fever", "arc_c"])

Maybe I'm missing something, any help would be appreciated !

custom datset help

Sorry if we are using a custom dataset does that have to be an instruction format or does the script convert. I have json files with just "text:"

About save_merged_lora_model

Hi! Thanks for the great tool.
I find the parameter 'save_merged_lora_model' was not implemented in the finetune.py. Could you please fix it?

An Error when FT Llama2

Hi, @AkariAsai thx for opening source.
I ran the ft script based on Llama-2-7b-chat-hf and 8*A800 GPUs, I only modified the training params and did not change the training code,but i've got an unexpected error.

File "/mnt/data/anaconda3/envs/baichuan2/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 848, in forward
    shift_logits = shift_logits.view(-1, self.config.vocab_size)
RuntimeError: shape '[-1, 0]' is invalid for input of size 262018944
    shift_logits = shift_logits.view(-1, self.config.vocab_size)    
shift_logits = shift_logits.view(-1, self.config.vocab_size)
RuntimeError: RuntimeErrorshape '[-1, 0]' is invalid for input of size 262018944: 
shape '[-1, 0]' is invalid for input of size 262018944

Here is the FT script:

MODEL_SIZE=7B
NUM_GPUS=8
BATCH_SIZE_PER_GPU=8
TOTAL_BATCH_SIZE=128
GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE_PER_GPU))
echo "Training llama model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps"

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
    --mixed_precision bf16 \
    --num_machines 1 \
    --num_processes $NUM_GPUS \
    --use_deepspeed \
    --deepspeed_config_file stage3_no_offloading_accelerate.conf \
    finetune.py \
    --model_name_or_path /mnt/model/Llama-2-7b-chat-hf \
    --use_flash_attn \
    --tokenizer_name /mnt/model/Llama-2-7b-chat-hf \
    --use_slow_tokenizer \
    --train_file train.jsonl \
    --max_seq_length 1024 \
    --preprocessing_num_workers 16 \
    --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
    --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
    --learning_rate 2e-5 \
    --lr_scheduler_type linear \
    --warmup_ratio 0.03 \
    --weight_decay 0. \
    --num_train_epochs 5 \
    --output_dir output/adaptive_${MODEL_SIZE}/ \
    --with_tracking \
    --report_to tensorboard \
    --logging_steps 1 \
    --use_special_tokens

training.jsonl is download from https://huggingface.co/datasets/selfrag/selfrag_train_data/tree/main

Question regarding "Critique outputs and select best segment".

Dear author, I think this project is great and has done some very interesting work, but I have a point of confusion.

I am confused about how to implement "Critique outputs and select best segment". In the code:

prompts = [format_prompt(query_3, doc["title"] +"\n"+ doc["text"]) for doc in retrieved_documents]
preds = model.generate(prompts, sampling_params)
top_doc = retriever.search_document_demo(query_3, 1)[0]
print("Reference: {0}\nModel prediction: {1}".format(top_doc["title"] + "\n" + top_doc["text"], preds[0].outputs[0].text))

it seems that the highest relevance document is selected by using the [0] element of the return value from the retriever and directly generating an answer based on that document using LLM, instead of generating segments in parallel with LLM and then critiquing and selecting the best one all by LLM.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.