Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions train/sft/download_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def shard_suffix(p: str) -> str:
return p


def collect_paths(retrieval_dir: Path, splits: list[str]) -> set[str]:
def collect_paths(retrieval_dir: Path, splits: list[str]) -> dict[str, str | None]:
"""Collect every unique absolute path across hit lists + gold suffixes.

Gold paths use the dataset-relative form ('images/shard_.../chunk.png');
hits are absolute '/opt/dlami/nvme/kiwix_tiles/shard_.../chunk.png'.
We normalize everything to shard-suffix for dedup across sources.
Returns set of (shard_suffix_key, preferred_abs_path_for_fetch).
Returns dict mapping shard_suffix_key to preferred_abs_path_for_fetch.
"""
by_suffix = {}
by_suffix: dict[str, str | None] = {}
for split in splits:
p = retrieval_dir / f"{split}.jsonl"
if not p.exists():
Expand Down Expand Up @@ -111,7 +111,7 @@ def fetch_tile(
return False, f"{last_err}"


def main():
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument(
"--retrieval-dir",
Expand Down Expand Up @@ -151,7 +151,7 @@ def main():
print(f" Unique tile suffixes: {len(by_suffix):,}")

# Split into linkable-from-local vs must-fetch
need_fetch = [] # (suffix, abs_path_for_api)
need_fetch: list[tuple[str, str]] = []
linked = 0
already = 0
for suffix, abs_fetch in by_suffix.items():
Expand Down Expand Up @@ -183,7 +183,7 @@ def main():
fail_paths = []
with open(failed_log, "w") as f_fail:

def _work(item):
def _work(item: tuple[str, str]) -> tuple[str, str, bool, str]:
suffix, abs_path = item
dst = mirror / suffix
success, err = fetch_tile(args.api_url, abs_path, dst)
Expand Down
42 changes: 34 additions & 8 deletions train/sft/eval_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,37 @@
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import TypedDict

import torch
from tqdm import tqdm
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info


class EvalResult(TypedDict, total=False):
query: str
golden: str
predicted: str
chunk_path: str
image_missing: bool
n_images: int
judge_grade: str
judge_correct: bool


class EMCharMetrics(TypedDict):
exact_match: float
char_accuracy: float
scored: int


class JudgeMetrics(TypedDict):
llm_judge_accuracy: float
llm_judge_correct: int
llm_judge_total: int


# Reused from train_contrastors.py — SimpleQA-style grader, returns A/B/C.
_GRADER_TEMPLATE = """Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"].
Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter.
Expand All @@ -49,7 +73,7 @@
Just return the letters "A", "B", or "C", with no text around it."""


def _resolve_image_path(ex: dict, images_root: str) -> str:
def _resolve_image_path(ex: dict[str, str], images_root: str) -> str:
"""chunk_path is relative to the dataset root (e.g. images/shard_000/...).
images_root is the directory that contains the `images/` subtree (compressed or original)."""
rel = ex["chunk_path"]
Expand All @@ -59,15 +83,15 @@ def _resolve_image_path(ex: dict, images_root: str) -> str:


def run_inference(
model,
processor,
examples,
model: Qwen3VLForConditionalGeneration,
processor: AutoProcessor,
examples: list[dict[str, str]],
images_root: str,
device: str,
desc: str,
max_new_tokens: int = 128,
enable_thinking: bool = False,
):
) -> list[EvalResult]:
"""Run VQA inference on a list of examples; returns list of (golden, predicted) pairs."""
results = []
for ex in tqdm(examples, desc=desc):
Expand Down Expand Up @@ -126,7 +150,7 @@ def run_inference(
return results


def compute_em_char(results):
def compute_em_char(results: list[EvalResult]) -> EMCharMetrics:
correct_em = 0
char_correct = 0
char_total = 0
Expand All @@ -150,7 +174,9 @@ def compute_em_char(results):
}


def grade_with_gpt(results, model: str, concurrency: int = 16):
def grade_with_gpt(
results: list[EvalResult], model: str, concurrency: int = 16
) -> JudgeMetrics:
"""Grade predictions with GPT-4.1. Returns list of (bool correct, raw grade)."""
from openai import OpenAI

Expand Down Expand Up @@ -202,7 +228,7 @@ def _grade(idx, r):
}


def main():
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--model", default="Qwen/Qwen3-VL-4B-Instruct")
p.add_argument(
Expand Down
44 changes: 39 additions & 5 deletions train/sft/eval_multiimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,39 @@
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import TypedDict

import torch
from tqdm import tqdm
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info


class EvalResult(TypedDict, total=False):
query: str
golden: str
predicted: str
image_missing: bool
missing_paths: list[str]
n_images: int
gold_pos: int | None
gold_in_top6_pos: int | None
judge_grade: str
judge_correct: bool


class EMCharMetrics(TypedDict):
exact_match: float
char_accuracy: float
scored: int


class JudgeMetrics(TypedDict):
llm_judge_accuracy: float
llm_judge_correct: int
llm_judge_total: int


_GRADER_TEMPLATE = """Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"].
Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter.
Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions.
Expand All @@ -49,8 +75,14 @@ def strip_image_tokens(s: str) -> str:


def run_inference(
model, processor, examples, device, desc, max_new_tokens=128, enable_thinking=False
):
model: Qwen3VLForConditionalGeneration,
processor: AutoProcessor,
examples: list[dict[str, object]],
device: str,
desc: str,
max_new_tokens: int = 128,
enable_thinking: bool = False,
) -> list[EvalResult]:
results = []
for ex in tqdm(examples, desc=desc):
images = ex.get("images", [])
Expand Down Expand Up @@ -112,7 +144,7 @@ def run_inference(
return results


def compute_em_char(results):
def compute_em_char(results: list[EvalResult]) -> EMCharMetrics:
correct_em = 0
char_correct = 0
char_total = 0
Expand All @@ -136,7 +168,9 @@ def compute_em_char(results):
}


def grade_with_gpt(results, model: str, concurrency: int = 16):
def grade_with_gpt(
results: list[EvalResult], model: str, concurrency: int = 16
) -> JudgeMetrics:
from openai import OpenAI

client = OpenAI()
Expand Down Expand Up @@ -186,7 +220,7 @@ def _grade(idx, r):
}


def main():
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--model", default="Qwen/Qwen3-VL-4B-Instruct")
p.add_argument("--adapter", default=None)
Expand Down
18 changes: 14 additions & 4 deletions train/sft/fetch_top6_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@
import urllib.error
import urllib.request
from pathlib import Path
from typing import TypedDict


class SplitStats(TypedDict):
split: str
total: int
gold_in_top1: int
gold_in_top3: int
gold_in_top6: int
gold_miss: int


def shard_suffix(p: str) -> str:
Expand All @@ -40,7 +50,7 @@ def shard_suffix(p: str) -> str:

def search_batch(
api_url: str, queries: list[str], n_docs: int, timeout: int = 300, retries: int = 5
) -> list[dict]:
) -> list[dict[str, object]]:
payload = {"queries": [{"text": q} for q in queries], "n_docs": n_docs}
body = json.dumps(payload).encode()
last_err = None
Expand Down Expand Up @@ -72,7 +82,7 @@ def process_split(
api_url: str,
batch_size: int,
n_docs: int,
) -> dict:
) -> SplitStats:
# Resume: count existing lines
existing = 0
if out_path.exists():
Expand Down Expand Up @@ -185,7 +195,7 @@ def process_split(
}


def _collect_stats(path: Path) -> dict:
def _collect_stats(path: Path) -> SplitStats:
gold_in_topk = {1: 0, 3: 0, 6: 0}
gold_miss = 0
n = 0
Expand All @@ -211,7 +221,7 @@ def _collect_stats(path: Path) -> dict:
}


def main():
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument(
"--dataset-dir",
Expand Down
6 changes: 4 additions & 2 deletions train/sft/generate_think_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
Write a brief reasoning trace (2-3 sentences) showing how someone would find this answer by examining the screenshot. Mention what specific text/detail they would look for. Be natural and concise. No preamble. Output ONLY the reasoning, nothing else."""


def process_one(client, model, ex):
def process_one(
client: OpenAI, model: str, ex: dict[str, str]
) -> dict[str, str | None]:
try:
resp = client.chat.completions.create(
model=model,
Expand All @@ -48,7 +50,7 @@ def process_one(client, model, ex):
return {**ex, "reasoning": None, "_error": str(e)[:200]}


def main():
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--input", required=True)
p.add_argument("--output", required=True)
Expand Down
6 changes: 4 additions & 2 deletions train/sft/generate_think_traces_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def encode_image(path: str, max_bytes: int = 4_000_000) -> str | None:
return None


def process_one(client, model, ex, image_root):
def process_one(
client: OpenAI, model: str, ex: dict[str, str], image_root: str
) -> dict[str, str | None]:
img_path = os.path.join(image_root, ex["chunk_path"])
img_url = encode_image(img_path) if os.path.exists(img_path) else None
if img_url is None:
Expand Down Expand Up @@ -92,7 +94,7 @@ def process_one(client, model, ex, image_root):
return {**ex, "reasoning": None, "_error": str(e)[:200]}


def main():
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--input", required=True)
p.add_argument("--output", required=True)
Expand Down
6 changes: 4 additions & 2 deletions train/sft/generate_think_traces_v3_highdetail.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def encode_image(path: str, max_bytes: int = 4_000_000) -> str | None:
return None


def process_one(client, model, ex, image_root):
def process_one(
client: OpenAI, model: str, ex: dict[str, str], image_root: str
) -> dict[str, str | None]:
img_path = os.path.join(image_root, ex["chunk_path"])
img_url = encode_image(img_path) if os.path.exists(img_path) else None
if img_url is None:
Expand Down Expand Up @@ -92,7 +94,7 @@ def process_one(client, model, ex, image_root):
return {**ex, "reasoning": None, "_error": str(e)[:200]}


def main():
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--input", required=True)
p.add_argument("--output", required=True)
Expand Down
2 changes: 1 addition & 1 deletion train/sft/prepare_mixed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
BASE = "/scratch/users/zwcolin/cxr_embeds/sft_data"


def main():
def main() -> None:
p = argparse.ArgumentParser()
p.add_argument("--output-dir", default=f"{BASE}/compressed_mixed")
p.add_argument("--seed", type=int, default=42)
Expand Down
Loading