Skip to content

kws9208/dynamic_token_pruning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Accelerating Inference in Retrieval-Augmented Generation Models for Long-Form Question Answering via Dynamic Token Pruning

This repository contains the implementation of the method presented in the paper "Accelerating Inference in Retrieval-Augmented Generation Models for Long-Form Question Answering via Dynamic Token Pruning".

The repository is based on the repository facebookresearch/FiD.

Data Preparation

Follow these steps to prepare the necessary data. These instructions are based on the original FiD setup.

Dataset and Wikipedia Passage Download

Download the required datasets and Wikipedia passages.

bash data_download.sh

Pre-trained Retrieval Model Downlaod

Download a pre-trained retriever model, such as the NQ Retriever from the FiD project.

Refer to the official FiD instructions: FiD Model Download Script

Passage Retrieval Index

Build the retrieval index for your knowledge source (e.g., the downloaded Wikipedia passages).

Detailed instructions can be found at: FiD Knowledge Source Indexing

Passage Retrieval

Retrieve relevant passages for your question-answering dataset using the indexed knowledge source and the pre-trained retriever.

Detailed instructions can be found at: FiD Passage Retrieval

Training

The following command launches a distributed training job for the reader model with dynamic token pruning.

PROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l)

NGPU=${PROC_PER_NODE} CUDA_LAUNCH_BLOCKING=1 python -m torch.distributed.launch --nproc_per_node=${PROC_PER_NODE} train_reader_tp_clapnq.py \
        --seed <random_seed> \
        --use_checkpoint \
        --lr 1e-4 \
        --optim adamw \
        --scheduler linear \
        --weight_decay 0.01 \
        --question_maxlength 100 \
        --per_gpu_batch_size 1 \
        --n_context 50 \
        --text_maxlength 250 \
        --total_steps 20000 \
        --eval_freq 2000 \
        --eval_print_freq 1000 \
        --save_freq 2000 \
        --warmup_steps 1000 \
        --min_ratio 0.8 \
        --train_data <train_data_path> \
        --eval_data <eval_data_path> \
        --answer_maxlength 128 \
        --theta1 0.9 \
        --theta2 0.3 \
        --last_theta 0.05 \
        --gumbel_temperature 1.0 \
        --pruning_scale 2.0 \
        --kl_loss_scale 1.0 \
        --temp_retain_steps 1000 \
        --temp_reducing_steps 2000 \
        --model_size base \
        --checkpoint_dir <run_folder_name> \
        --name <log_folder_name> \
        --accumulation_steps 32 \

Test

Use the following command to evaluate a trained model on a evaluation set.

PROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l)

NGPU=${PROC_PER_NODE} python -m torch.distributed.launch --nproc_per_node=${PROC_PER_NODE} test_reader.py \
        --seed <random_seed> \
        --model_path <model_path> \
        --eval_data <eval_data_path> \
        --per_gpu_batch_size 1 \
        --eval_print_freq 100 \
        --n_context 50 \
        --question_maxlength 100 \
        --text_maxlength 250 \
        --answer_maxlength 128 \
        --checkpoint_dir <run_folder_name> \
        --name <log_folder_name> \
        --write_results \

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors