Train Vision-Language Models to generate better search queries through trajectory learning.
This pipeline improves sparse search by finetuning a vision-language model (Qwen-VL) on successful search trajectories. The key insight is that we can use the normalized rank of ground truth documents as a reward signal to train the model to generate more effective search queries.
┌─────────────────────────────────────────────────────────────────────┐
│ Training Pipeline │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ 1. Collect 2. Process 3. Format 4. Finetune │
│ Trajectories Trajectories Training Data Model │
│ │
│ ┌─────────┐ ┌─────────────┐ ┌────────────┐ ┌────────────┐ │
│ │ Run VLM │ ──► │ Filter GT │ ─►│ SFT / DPO │ ─►│ LoRA │ │
│ │ Agent │ │ in top-k │ │ Formats │ │ Training │ │
│ └─────────┘ └─────────────┘ └────────────┘ └────────────┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ trajectories. processed. training_data/ checkpoints/ │
│ jsonl jsonl *.jsonl final/ │
│ │
└─────────────────────────────────────────────────────────────────────┘
# Clone the repository
git clone https://github.com/your-username/DRAG.git
cd DRAG
# Install dependencies with uv (recommended)
uv sync
# Or with pip
pip install -e .
# Navigate to the pipeline directory
cd agentic-retrieval-finetuningNote: All pipeline scripts are in the
agentic-retrieval-finetuning/directory.
# For trajectory collection and evaluation
vllm serve Qwen/Qwen3-VL-8B-Thinking \
--port 8000 \
--enable-auto-tool-choice \
--tool-call-parser hermescd agentic-retrieval-finetuning
python trajectory_collector.py \
--output trajectories.jsonl \
--ocr-file /path/to/ocr_output.jsonl \
--model Qwen/Qwen3-VL-8B-Thinking \
--sampling-config default \
--concurrency 8 \
--limit 500python process_trajectories.py \
--input trajectories.jsonl \
--output processed.jsonl \
--stats stats.jsonpython format_training_data.py \
--input processed.jsonl \
--output-dir training_data/ \
--formats all# SFT on best queries
python finetune.py sft \
--train-data training_data/sft_best_train.jsonl \
--val-data training_data/sft_best_val.jsonl \
--output-dir checkpoints/sft \
--model Qwen/Qwen3-VL-8B-Thinking
# Or DPO on query pairs
python finetune.py dpo \
--train-data training_data/dpo_train.jsonl \
--output-dir checkpoints/dpo
# Merge adapter with base model
python finetune.py merge \
--adapter-path checkpoints/sft/final \
--output-path merged_model/# Serve model with LoRA adapter
vllm serve Qwen/Qwen3-VL-8B-Thinking \
--port 8000 \
--enable-lora \
--lora-modules my-adapter=checkpoints/sft/final
# Evaluate base model
python evaluate.py \
--model Qwen/Qwen3-VL-8B-Thinking \
--ocr-file /path/to/ocr_output.jsonl \
--output results/eval_base.json
# Evaluate finetuned adapter
python evaluate.py \
--model my-adapter \
--ocr-file /path/to/ocr_output.jsonl \
--output results/eval_adapter.json
# Compare results
python evaluate.py --compare results/eval_base.json results/eval_adapter.jsonThe reward signal for training is computed as:
rank_score = (top_k - rank + 1) / top_k
| Rank | Score (k=3) | Interpretation |
|---|---|---|
| 1 | 1.00 | Perfect hit |
| 2 | 0.67 | Good |
| 3 | 0.33 | Found but low |
| >3 | 0.00 | Not found |
| Format | Description | Use Case |
|---|---|---|
sft_best |
question → best_query | Direct query generation |
sft_trajectory |
question → full_trace | Learn reasoning patterns |
sft_context |
question + prev_attempts → better_query | Iterative improvement |
dpo |
(prompt, chosen_query, rejected_query) | Preference learning |
reward |
(query, score) | Reward model training |
Experiment with different sampling parameters to generate diverse trajectories:
| Config | Temperature | Top-p | Use Case |
|---|---|---|---|
greedy |
0.0 | 1.0 | Deterministic baseline |
low_temp |
0.3 | 0.9 | Focused, less random |
default |
0.7 | 0.95 | Balanced |
high_temp |
1.0 | 0.95 | More diverse |
creative |
1.2 | 0.95 | Maximum diversity |
- Success Rate: % of questions where GT document was found
- Iterations to Success: Average searches needed to find GT document
- First-Hit Rank: GT document rank on first search
- MRR (Mean Reciprocal Rank): Average of 1/rank across queries
- Hit@K: % of queries where GT is in top-K results
- ANLS*: Answer accuracy metric
- Citation F1: Precision/recall of cited documents
After finetuning on ~200 successful trajectories:
| Metric | Base | Finetuned | Δ |
|---|---|---|---|
| Hit@1 | 40% | 46% | +15% |
| First Search MRR | 0.893 | 0.936 | +4.8% |
| Iterations to Success | 1.97 | 1.85 | -6.1% |
DRAG/
├── pyproject.toml # Project dependencies
├── README.md
│
├── agentic-retrieval-finetuning/
│ ├── trajectory_collector.py # Step 1: Collect agent trajectories
│ ├── process_trajectories.py # Step 2: Filter and score
│ ├── format_training_data.py # Step 3: Create training data
│ ├── finetune.py # Step 4: LoRA finetuning (SFT/DPO)
│ ├── evaluate.py # Step 5: Evaluation
│ ├── inference_test.py # Quick inference testing
│ ├── search_engine.py # Whoosh sparse search
│ └── utils.py # PDF/image utilities
│
├── training_data/ # Generated training files (gitignored)
├── results/ # Evaluation results (gitignored)
└── checkpoints/ # Model checkpoints (gitignored)
- Python 3.10+
- CUDA-capable GPU (8×A100 recommended for training)
- vLLM for model serving
- Access to document dataset with OCR (coming with release of Document AI
MIT