Accelerating Inference in Retrieval-Augmented Generation Models for Long-Form Question Answering via Dynamic Token Pruning
This repository contains the implementation of the method presented in the paper "Accelerating Inference in Retrieval-Augmented Generation Models for Long-Form Question Answering via Dynamic Token Pruning".
The repository is based on the repository facebookresearch/FiD.
Follow these steps to prepare the necessary data. These instructions are based on the original FiD setup.
Download the required datasets and Wikipedia passages.
bash data_download.sh
Download a pre-trained retriever model, such as the NQ Retriever from the FiD project.
Refer to the official FiD instructions: FiD Model Download Script
Build the retrieval index for your knowledge source (e.g., the downloaded Wikipedia passages).
Detailed instructions can be found at: FiD Knowledge Source Indexing
Retrieve relevant passages for your question-answering dataset using the indexed knowledge source and the pre-trained retriever.
Detailed instructions can be found at: FiD Passage Retrieval
The following command launches a distributed training job for the reader model with dynamic token pruning.
PROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l)
NGPU=${PROC_PER_NODE} CUDA_LAUNCH_BLOCKING=1 python -m torch.distributed.launch --nproc_per_node=${PROC_PER_NODE} train_reader_tp_clapnq.py \
--seed <random_seed> \
--use_checkpoint \
--lr 1e-4 \
--optim adamw \
--scheduler linear \
--weight_decay 0.01 \
--question_maxlength 100 \
--per_gpu_batch_size 1 \
--n_context 50 \
--text_maxlength 250 \
--total_steps 20000 \
--eval_freq 2000 \
--eval_print_freq 1000 \
--save_freq 2000 \
--warmup_steps 1000 \
--min_ratio 0.8 \
--train_data <train_data_path> \
--eval_data <eval_data_path> \
--answer_maxlength 128 \
--theta1 0.9 \
--theta2 0.3 \
--last_theta 0.05 \
--gumbel_temperature 1.0 \
--pruning_scale 2.0 \
--kl_loss_scale 1.0 \
--temp_retain_steps 1000 \
--temp_reducing_steps 2000 \
--model_size base \
--checkpoint_dir <run_folder_name> \
--name <log_folder_name> \
--accumulation_steps 32 \
Use the following command to evaluate a trained model on a evaluation set.
PROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l)
NGPU=${PROC_PER_NODE} python -m torch.distributed.launch --nproc_per_node=${PROC_PER_NODE} test_reader.py \
--seed <random_seed> \
--model_path <model_path> \
--eval_data <eval_data_path> \
--per_gpu_batch_size 1 \
--eval_print_freq 100 \
--n_context 50 \
--question_maxlength 100 \
--text_maxlength 250 \
--answer_maxlength 128 \
--checkpoint_dir <run_folder_name> \
--name <log_folder_name> \
--write_results \