Skip to content

jvlinsta/DRAG

Repository files navigation

DRAG: Document Retrieval with Agentic Grounding

Train Vision-Language Models to generate better search queries through trajectory learning.

📖 Read the full tutorial →

Overview

This pipeline improves sparse search by finetuning a vision-language model (Qwen-VL) on successful search trajectories. The key insight is that we can use the normalized rank of ground truth documents as a reward signal to train the model to generate more effective search queries.

┌─────────────────────────────────────────────────────────────────────┐
│                         Training Pipeline                           │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  1. Collect       2. Process        3. Format         4. Finetune   │
│  Trajectories     Trajectories      Training Data     Model         │
│                                                                     │
│  ┌─────────┐     ┌─────────────┐   ┌────────────┐   ┌────────────┐ │
│  │ Run VLM │ ──► │ Filter GT   │ ─►│ SFT / DPO  │ ─►│ LoRA       │ │
│  │ Agent   │     │ in top-k    │   │ Formats    │   │ Training   │ │
│  └─────────┘     └─────────────┘   └────────────┘   └────────────┘ │
│       │                │                 │                │        │
│       ▼                ▼                 ▼                ▼        │
│  trajectories.   processed.         training_data/   checkpoints/  │
│  jsonl           jsonl              *.jsonl          final/        │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

Installation

# Clone the repository
git clone https://github.com/your-username/DRAG.git
cd DRAG

# Install dependencies with uv (recommended)
uv sync

# Or with pip
pip install -e .

# Navigate to the pipeline directory
cd agentic-retrieval-finetuning

Quick Start

Note: All pipeline scripts are in the agentic-retrieval-finetuning/ directory.

1. Start vLLM Server

# For trajectory collection and evaluation
vllm serve Qwen/Qwen3-VL-8B-Thinking \
  --port 8000 \
  --enable-auto-tool-choice \
  --tool-call-parser hermes

2. Collect Trajectories

cd agentic-retrieval-finetuning

python trajectory_collector.py \
  --output trajectories.jsonl \
  --ocr-file /path/to/ocr_output.jsonl \
  --model Qwen/Qwen3-VL-8B-Thinking \
  --sampling-config default \
  --concurrency 8 \
  --limit 500

3. Process Trajectories

python process_trajectories.py \
  --input trajectories.jsonl \
  --output processed.jsonl \
  --stats stats.json

4. Format Training Data

python format_training_data.py \
  --input processed.jsonl \
  --output-dir training_data/ \
  --formats all

5. Finetune

# SFT on best queries
python finetune.py sft \
  --train-data training_data/sft_best_train.jsonl \
  --val-data training_data/sft_best_val.jsonl \
  --output-dir checkpoints/sft \
  --model Qwen/Qwen3-VL-8B-Thinking

# Or DPO on query pairs
python finetune.py dpo \
  --train-data training_data/dpo_train.jsonl \
  --output-dir checkpoints/dpo

# Merge adapter with base model
python finetune.py merge \
  --adapter-path checkpoints/sft/final \
  --output-path merged_model/

6. Evaluate

# Serve model with LoRA adapter
vllm serve Qwen/Qwen3-VL-8B-Thinking \
  --port 8000 \
  --enable-lora \
  --lora-modules my-adapter=checkpoints/sft/final

# Evaluate base model
python evaluate.py \
  --model Qwen/Qwen3-VL-8B-Thinking \
  --ocr-file /path/to/ocr_output.jsonl \
  --output results/eval_base.json

# Evaluate finetuned adapter
python evaluate.py \
  --model my-adapter \
  --ocr-file /path/to/ocr_output.jsonl \
  --output results/eval_adapter.json

# Compare results
python evaluate.py --compare results/eval_base.json results/eval_adapter.json

Key Concepts

Normalized Rank Score

The reward signal for training is computed as:

rank_score = (top_k - rank + 1) / top_k
Rank Score (k=3) Interpretation
1 1.00 Perfect hit
2 0.67 Good
3 0.33 Found but low
>3 0.00 Not found

Training Data Formats

Format Description Use Case
sft_best question → best_query Direct query generation
sft_trajectory question → full_trace Learn reasoning patterns
sft_context question + prev_attempts → better_query Iterative improvement
dpo (prompt, chosen_query, rejected_query) Preference learning
reward (query, score) Reward model training

Sampling Configurations

Experiment with different sampling parameters to generate diverse trajectories:

Config Temperature Top-p Use Case
greedy 0.0 1.0 Deterministic baseline
low_temp 0.3 0.9 Focused, less random
default 0.7 0.95 Balanced
high_temp 1.0 0.95 More diverse
creative 1.2 0.95 Maximum diversity

Evaluation Metrics

  • Success Rate: % of questions where GT document was found
  • Iterations to Success: Average searches needed to find GT document
  • First-Hit Rank: GT document rank on first search
  • MRR (Mean Reciprocal Rank): Average of 1/rank across queries
  • Hit@K: % of queries where GT is in top-K results
  • ANLS*: Answer accuracy metric
  • Citation F1: Precision/recall of cited documents

Results

After finetuning on ~200 successful trajectories:

Metric Base Finetuned Δ
Hit@1 40% 46% +15%
First Search MRR 0.893 0.936 +4.8%
Iterations to Success 1.97 1.85 -6.1%

Project Structure

DRAG/
├── pyproject.toml              # Project dependencies
├── README.md
│
├── agentic-retrieval-finetuning/
│   ├── trajectory_collector.py # Step 1: Collect agent trajectories
│   ├── process_trajectories.py # Step 2: Filter and score
│   ├── format_training_data.py # Step 3: Create training data
│   ├── finetune.py             # Step 4: LoRA finetuning (SFT/DPO)
│   ├── evaluate.py             # Step 5: Evaluation
│   ├── inference_test.py       # Quick inference testing
│   ├── search_engine.py        # Whoosh sparse search
│   └── utils.py                # PDF/image utilities
│
├── training_data/              # Generated training files (gitignored)
├── results/                    # Evaluation results (gitignored)
└── checkpoints/                # Model checkpoints (gitignored)

Requirements

  • Python 3.10+
  • CUDA-capable GPU (8×A100 recommended for training)
  • vLLM for model serving
  • Access to document dataset with OCR (coming with release of Document AI

References

License

MIT

About

Teaching Multimodal Agents to Search Documents

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages