diff --git a/train/sft/download_tiles.py b/train/sft/download_tiles.py index d27761e..47fa0c6 100644 --- a/train/sft/download_tiles.py +++ b/train/sft/download_tiles.py @@ -34,15 +34,15 @@ def shard_suffix(p: str) -> str: return p -def collect_paths(retrieval_dir: Path, splits: list[str]) -> set[str]: +def collect_paths(retrieval_dir: Path, splits: list[str]) -> dict[str, str | None]: """Collect every unique absolute path across hit lists + gold suffixes. Gold paths use the dataset-relative form ('images/shard_.../chunk.png'); hits are absolute '/opt/dlami/nvme/kiwix_tiles/shard_.../chunk.png'. We normalize everything to shard-suffix for dedup across sources. - Returns set of (shard_suffix_key, preferred_abs_path_for_fetch). + Returns dict mapping shard_suffix_key to preferred_abs_path_for_fetch. """ - by_suffix = {} + by_suffix: dict[str, str | None] = {} for split in splits: p = retrieval_dir / f"{split}.jsonl" if not p.exists(): @@ -111,7 +111,7 @@ def fetch_tile( return False, f"{last_err}" -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument( "--retrieval-dir", @@ -151,7 +151,7 @@ def main(): print(f" Unique tile suffixes: {len(by_suffix):,}") # Split into linkable-from-local vs must-fetch - need_fetch = [] # (suffix, abs_path_for_api) + need_fetch: list[tuple[str, str]] = [] linked = 0 already = 0 for suffix, abs_fetch in by_suffix.items(): @@ -183,7 +183,7 @@ def main(): fail_paths = [] with open(failed_log, "w") as f_fail: - def _work(item): + def _work(item: tuple[str, str]) -> tuple[str, str, bool, str]: suffix, abs_path = item dst = mirror / suffix success, err = fetch_tile(args.api_url, abs_path, dst) diff --git a/train/sft/eval_baseline.py b/train/sft/eval_baseline.py index da09ab6..e717ed2 100644 --- a/train/sft/eval_baseline.py +++ b/train/sft/eval_baseline.py @@ -20,6 +20,7 @@ import sys from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path +from typing import TypedDict import torch from tqdm import tqdm @@ -27,6 +28,29 @@ from qwen_vl_utils import process_vision_info +class EvalResult(TypedDict, total=False): + query: str + golden: str + predicted: str + chunk_path: str + image_missing: bool + n_images: int + judge_grade: str + judge_correct: bool + + +class EMCharMetrics(TypedDict): + exact_match: float + char_accuracy: float + scored: int + + +class JudgeMetrics(TypedDict): + llm_judge_accuracy: float + llm_judge_correct: int + llm_judge_total: int + + # Reused from train_contrastors.py — SimpleQA-style grader, returns A/B/C. _GRADER_TEMPLATE = """Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"]. Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter. @@ -49,7 +73,7 @@ Just return the letters "A", "B", or "C", with no text around it.""" -def _resolve_image_path(ex: dict, images_root: str) -> str: +def _resolve_image_path(ex: dict[str, str], images_root: str) -> str: """chunk_path is relative to the dataset root (e.g. images/shard_000/...). images_root is the directory that contains the `images/` subtree (compressed or original).""" rel = ex["chunk_path"] @@ -59,15 +83,15 @@ def _resolve_image_path(ex: dict, images_root: str) -> str: def run_inference( - model, - processor, - examples, + model: Qwen3VLForConditionalGeneration, + processor: AutoProcessor, + examples: list[dict[str, str]], images_root: str, device: str, desc: str, max_new_tokens: int = 128, enable_thinking: bool = False, -): +) -> list[EvalResult]: """Run VQA inference on a list of examples; returns list of (golden, predicted) pairs.""" results = [] for ex in tqdm(examples, desc=desc): @@ -126,7 +150,7 @@ def run_inference( return results -def compute_em_char(results): +def compute_em_char(results: list[EvalResult]) -> EMCharMetrics: correct_em = 0 char_correct = 0 char_total = 0 @@ -150,7 +174,9 @@ def compute_em_char(results): } -def grade_with_gpt(results, model: str, concurrency: int = 16): +def grade_with_gpt( + results: list[EvalResult], model: str, concurrency: int = 16 +) -> JudgeMetrics: """Grade predictions with GPT-4.1. Returns list of (bool correct, raw grade).""" from openai import OpenAI @@ -202,7 +228,7 @@ def _grade(idx, r): } -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument("--model", default="Qwen/Qwen3-VL-4B-Instruct") p.add_argument( diff --git a/train/sft/eval_multiimage.py b/train/sft/eval_multiimage.py index a43028e..398586c 100644 --- a/train/sft/eval_multiimage.py +++ b/train/sft/eval_multiimage.py @@ -16,6 +16,7 @@ import sys from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path +from typing import TypedDict import torch from tqdm import tqdm @@ -23,6 +24,31 @@ from qwen_vl_utils import process_vision_info +class EvalResult(TypedDict, total=False): + query: str + golden: str + predicted: str + image_missing: bool + missing_paths: list[str] + n_images: int + gold_pos: int | None + gold_in_top6_pos: int | None + judge_grade: str + judge_correct: bool + + +class EMCharMetrics(TypedDict): + exact_match: float + char_accuracy: float + scored: int + + +class JudgeMetrics(TypedDict): + llm_judge_accuracy: float + llm_judge_correct: int + llm_judge_total: int + + _GRADER_TEMPLATE = """Your job is to look at a question, a gold target, and a predicted answer, and then assign a grade of either ["CORRECT", "INCORRECT", "NOT_ATTEMPTED"]. Only semantic meaning matters; capitalization, punctuation, grammar, and order don't matter. Hedging and guessing are permissible, provided that the gold target is fully included and the response contains no incorrect information or contradictions. @@ -49,8 +75,14 @@ def strip_image_tokens(s: str) -> str: def run_inference( - model, processor, examples, device, desc, max_new_tokens=128, enable_thinking=False -): + model: Qwen3VLForConditionalGeneration, + processor: AutoProcessor, + examples: list[dict[str, object]], + device: str, + desc: str, + max_new_tokens: int = 128, + enable_thinking: bool = False, +) -> list[EvalResult]: results = [] for ex in tqdm(examples, desc=desc): images = ex.get("images", []) @@ -112,7 +144,7 @@ def run_inference( return results -def compute_em_char(results): +def compute_em_char(results: list[EvalResult]) -> EMCharMetrics: correct_em = 0 char_correct = 0 char_total = 0 @@ -136,7 +168,9 @@ def compute_em_char(results): } -def grade_with_gpt(results, model: str, concurrency: int = 16): +def grade_with_gpt( + results: list[EvalResult], model: str, concurrency: int = 16 +) -> JudgeMetrics: from openai import OpenAI client = OpenAI() @@ -186,7 +220,7 @@ def _grade(idx, r): } -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument("--model", default="Qwen/Qwen3-VL-4B-Instruct") p.add_argument("--adapter", default=None) diff --git a/train/sft/fetch_top6_retrieval.py b/train/sft/fetch_top6_retrieval.py index aeea228..d0e1da5 100644 --- a/train/sft/fetch_top6_retrieval.py +++ b/train/sft/fetch_top6_retrieval.py @@ -28,6 +28,16 @@ import urllib.error import urllib.request from pathlib import Path +from typing import TypedDict + + +class SplitStats(TypedDict): + split: str + total: int + gold_in_top1: int + gold_in_top3: int + gold_in_top6: int + gold_miss: int def shard_suffix(p: str) -> str: @@ -40,7 +50,7 @@ def shard_suffix(p: str) -> str: def search_batch( api_url: str, queries: list[str], n_docs: int, timeout: int = 300, retries: int = 5 -) -> list[dict]: +) -> list[dict[str, object]]: payload = {"queries": [{"text": q} for q in queries], "n_docs": n_docs} body = json.dumps(payload).encode() last_err = None @@ -72,7 +82,7 @@ def process_split( api_url: str, batch_size: int, n_docs: int, -) -> dict: +) -> SplitStats: # Resume: count existing lines existing = 0 if out_path.exists(): @@ -185,7 +195,7 @@ def process_split( } -def _collect_stats(path: Path) -> dict: +def _collect_stats(path: Path) -> SplitStats: gold_in_topk = {1: 0, 3: 0, 6: 0} gold_miss = 0 n = 0 @@ -211,7 +221,7 @@ def _collect_stats(path: Path) -> dict: } -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument( "--dataset-dir", diff --git a/train/sft/generate_think_traces.py b/train/sft/generate_think_traces.py index 4901603..22cbdc8 100644 --- a/train/sft/generate_think_traces.py +++ b/train/sft/generate_think_traces.py @@ -29,7 +29,9 @@ Write a brief reasoning trace (2-3 sentences) showing how someone would find this answer by examining the screenshot. Mention what specific text/detail they would look for. Be natural and concise. No preamble. Output ONLY the reasoning, nothing else.""" -def process_one(client, model, ex): +def process_one( + client: OpenAI, model: str, ex: dict[str, str] +) -> dict[str, str | None]: try: resp = client.chat.completions.create( model=model, @@ -48,7 +50,7 @@ def process_one(client, model, ex): return {**ex, "reasoning": None, "_error": str(e)[:200]} -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument("--input", required=True) p.add_argument("--output", required=True) diff --git a/train/sft/generate_think_traces_v2.py b/train/sft/generate_think_traces_v2.py index f211fe6..e4cccb4 100644 --- a/train/sft/generate_think_traces_v2.py +++ b/train/sft/generate_think_traces_v2.py @@ -58,7 +58,9 @@ def encode_image(path: str, max_bytes: int = 4_000_000) -> str | None: return None -def process_one(client, model, ex, image_root): +def process_one( + client: OpenAI, model: str, ex: dict[str, str], image_root: str +) -> dict[str, str | None]: img_path = os.path.join(image_root, ex["chunk_path"]) img_url = encode_image(img_path) if os.path.exists(img_path) else None if img_url is None: @@ -92,7 +94,7 @@ def process_one(client, model, ex, image_root): return {**ex, "reasoning": None, "_error": str(e)[:200]} -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument("--input", required=True) p.add_argument("--output", required=True) diff --git a/train/sft/generate_think_traces_v3_highdetail.py b/train/sft/generate_think_traces_v3_highdetail.py index f211fe6..e4cccb4 100644 --- a/train/sft/generate_think_traces_v3_highdetail.py +++ b/train/sft/generate_think_traces_v3_highdetail.py @@ -58,7 +58,9 @@ def encode_image(path: str, max_bytes: int = 4_000_000) -> str | None: return None -def process_one(client, model, ex, image_root): +def process_one( + client: OpenAI, model: str, ex: dict[str, str], image_root: str +) -> dict[str, str | None]: img_path = os.path.join(image_root, ex["chunk_path"]) img_url = encode_image(img_path) if os.path.exists(img_path) else None if img_url is None: @@ -92,7 +94,7 @@ def process_one(client, model, ex, image_root): return {**ex, "reasoning": None, "_error": str(e)[:200]} -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument("--input", required=True) p.add_argument("--output", required=True) diff --git a/train/sft/prepare_mixed_data.py b/train/sft/prepare_mixed_data.py index 30201bf..087139f 100644 --- a/train/sft/prepare_mixed_data.py +++ b/train/sft/prepare_mixed_data.py @@ -20,7 +20,7 @@ BASE = "/scratch/users/zwcolin/cxr_embeds/sft_data" -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument("--output-dir", default=f"{BASE}/compressed_mixed") p.add_argument("--seed", type=int, default=42) diff --git a/train/sft/prepare_sft_data.py b/train/sft/prepare_sft_data.py index 011d9ba..84e10f4 100644 --- a/train/sft/prepare_sft_data.py +++ b/train/sft/prepare_sft_data.py @@ -23,11 +23,47 @@ import sys from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path +from typing import TypedDict from PIL import Image from tqdm import tqdm +class ShareGPTMessage(TypedDict): + role: str + content: str + + +class ShareGPTExample(TypedDict): + messages: list[ShareGPTMessage] + images: list[str] + + +class ImageInfo(TypedDict): + src: str + dst: str + dst_rel: str + + +class DatasetColumnMap(TypedDict): + messages: str + images: str + + +class DatasetTagMap(TypedDict): + role_tag: str + content_tag: str + user_tag: str + assistant_tag: str + + +class DatasetInfoEntry(TypedDict): + file_name: str + formatting: str + columns: DatasetColumnMap + tags: DatasetTagMap + + def compress_image(src: str, dst: str, scale_factor: float) -> bool: """Compress image by scale_factor per dimension. Returns True on success.""" try: @@ -45,7 +81,7 @@ def compress_image(src: str, dst: str, scale_factor: float) -> bool: return False -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( "--dataset-dir", @@ -107,7 +143,7 @@ def main(): print(f" Loaded {len(examples)} examples") # Collect unique positive image paths - unique_images = {} + unique_images: dict[str, ImageInfo] = {} for ex in examples: src_rel = ex["chunk_path"] # e.g. images/shard_760/... src_abs = str(dataset_dir / src_rel) @@ -155,7 +191,7 @@ def main(): print(" All images already cached") # Build ShareGPT format - sharegpt_data = [] + sharegpt_data: list[ShareGPTExample] = [] skipped = 0 for ex in examples: src_rel = ex["chunk_path"] @@ -184,7 +220,7 @@ def main(): ) # Write dataset_info.json for LlamaFactory - dataset_info = {} + dataset_info: dict[str, DatasetInfoEntry] = {} for split_name in splits: out_json = output_dir / f"{split_name}.json" if out_json.exists(): diff --git a/train/sft/prepare_sft_data_multiimage.py b/train/sft/prepare_sft_data_multiimage.py index 00c2727..110e558 100644 --- a/train/sft/prepare_sft_data_multiimage.py +++ b/train/sft/prepare_sft_data_multiimage.py @@ -29,11 +29,28 @@ import sys from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path +from typing import TypedDict from PIL import Image from tqdm import tqdm +class RetrievalHit(TypedDict): + path: str + score: float | None + article_id: int | None + url: str | None + + +class RetrievalRow(TypedDict): + query: str + answer: str + gold_path_rel: str + gold_suffix: str + hits: list[RetrievalHit] + gold_in_top6_pos: int + + def shard_suffix(p: str) -> str: parts = p.split("/") for i, x in enumerate(parts): @@ -58,7 +75,9 @@ def compress_image(src: str, dst: str, scale_factor: float) -> bool: return False -def build_image_set(row: dict, seed_base: int, n_images: int) -> tuple[list[str], int]: +def build_image_set( + row: RetrievalRow, seed_base: int, n_images: int +) -> tuple[list[str], int]: """Return (list of n_images shard-suffixes, gold index) with gold position shuffled. Composition: gold + (n_images-1) non-gold hits from top-6.""" gold = row["gold_suffix"] @@ -75,7 +94,7 @@ def build_image_set(row: dict, seed_base: int, n_images: int) -> tuple[list[str] return shuffled, gold_pos -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument( "--retrieval-dir", @@ -128,8 +147,8 @@ def main(): print(f"Splits: {args.splits}") # Pass 1: gather unique suffixes that need compression across all splits - all_split_rows = {} - needed = set() + all_split_rows: dict[str, list[RetrievalRow]] = {} + needed: set[str] = set() for split in args.splits: p_in = retrieval_dir / f"{split}.jsonl" if not p_in.exists(): diff --git a/train/sft/prepare_sft_data_upscale.py b/train/sft/prepare_sft_data_upscale.py index 3d769e5..633859c 100644 --- a/train/sft/prepare_sft_data_upscale.py +++ b/train/sft/prepare_sft_data_upscale.py @@ -14,11 +14,17 @@ import sys from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path +from typing import TypedDict from PIL import Image from tqdm import tqdm +class ImageInfo(TypedDict): + src: str + dst: str + + def compress_then_upscale(src: str, dst: str, scale_factor: float) -> bool: """Downscale by scale_factor/dim, then upscale back to original size.""" try: @@ -39,7 +45,7 @@ def compress_then_upscale(src: str, dst: str, scale_factor: float) -> bool: return False -def main(): +def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--dataset-dir", required=True) parser.add_argument("--output-dir", required=True) @@ -80,7 +86,7 @@ def main(): examples = examples[: args.max_examples] print(f" Loaded {len(examples)} examples") - unique_images = {} + unique_images: dict[str, ImageInfo] = {} for ex in examples: src_rel = ex["chunk_path"] src_abs = str(dataset_dir / src_rel) diff --git a/train/sft/prepare_sft_data_variable.py b/train/sft/prepare_sft_data_variable.py index 6a2cb09..3dc677f 100644 --- a/train/sft/prepare_sft_data_variable.py +++ b/train/sft/prepare_sft_data_variable.py @@ -19,6 +19,23 @@ import os import random from pathlib import Path +from typing import TypedDict + + +class RetrievalHit(TypedDict): + path: str + score: float | None + article_id: int | None + url: str | None + + +class RetrievalRow(TypedDict): + query: str + answer: str + gold_path_rel: str + gold_suffix: str + hits: list[RetrievalHit] + gold_in_top6_pos: int def shard_suffix(p: str) -> str: @@ -29,7 +46,9 @@ def shard_suffix(p: str) -> str: return p -def build_variable_image_set(row: dict, rng: random.Random, k_min: int, k_max: int): +def build_variable_image_set( + row: RetrievalRow, rng: random.Random, k_min: int, k_max: int +) -> tuple[list[str], int, int]: """Sample k ~ Uniform[k_min, k_max], return (shard_suffixes, gold_pos, k).""" gold = row["gold_suffix"] hit_sufs = [shard_suffix(h["path"]) for h in row["hits"]] @@ -45,7 +64,7 @@ def build_variable_image_set(row: dict, rng: random.Random, k_min: int, k_max: i return shuffled, gold_pos, k -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument( "--retrieval-dir", @@ -97,8 +116,6 @@ def main(): else: dataset_info = {} - {k: 0 for k in range(args.k_min, args.k_max + 1)} - for split in args.splits: p_in = retrieval_dir / f"{split}.jsonl" if not p_in.exists(): diff --git a/train/sft/prepare_think_data.py b/train/sft/prepare_think_data.py index 4055773..90641ec 100644 --- a/train/sft/prepare_think_data.py +++ b/train/sft/prepare_think_data.py @@ -26,7 +26,7 @@ def format_assistant(reasoning: str, answer: str) -> str: return f"\n{reasoning.strip()}\n\n\n{answer.strip()}" -def main(): +def main() -> None: p = argparse.ArgumentParser() p.add_argument( "--traces", diff --git a/train/sft/push_3x_v5_snapshot.py b/train/sft/push_3x_v5_snapshot.py index dae9a30..f927c1c 100644 --- a/train/sft/push_3x_v5_snapshot.py +++ b/train/sft/push_3x_v5_snapshot.py @@ -176,7 +176,7 @@ """ -def main(): +def main() -> None: assert SRC.exists(), f"missing {SRC}" print(f"=== {REPO_ID} ===") diff --git a/train/sft/push_multi3_to_hf.py b/train/sft/push_multi3_to_hf.py index e011d3f..df7223a 100644 --- a/train/sft/push_multi3_to_hf.py +++ b/train/sft/push_multi3_to_hf.py @@ -69,7 +69,7 @@ } -def build_readme(comp, judge, config, step_note): +def build_readme(comp: str, judge: float, config: str, step_note: str) -> str: n_ratio = int(comp.rstrip("x")) return f"""--- license: apache-2.0 @@ -155,7 +155,9 @@ def build_readme(comp, judge, config, step_note): """ -def push_one(comp, adapter_dir, judge, config, step_note): +def push_one( + comp: str, adapter_dir: str, judge: float, config: str, step_note: str +) -> None: src = Path(adapter_dir) assert src.exists(), f"missing {src}" @@ -187,7 +189,7 @@ def push_one(comp, adapter_dir, judge, config, step_note): print(f" uploaded → https://huggingface.co/{repo_id}") -def main(): +def main() -> None: for args in BEST: push_one(*args) print("\nAll done.") diff --git a/train/sft/push_multik_to_hf.py b/train/sft/push_multik_to_hf.py index 6043e2e..5566749 100644 --- a/train/sft/push_multik_to_hf.py +++ b/train/sft/push_multik_to_hf.py @@ -71,7 +71,12 @@ } -def build_readme(comp, adapter_dir, cutoff, scores): +def build_readme( + comp: str, + adapter_dir: str, + cutoff: int, + scores: tuple[float, float, float, float], +) -> str: n_ratio = int(comp.rstrip("x")) sk1, sk2, sk3, sk4 = scores @@ -181,7 +186,12 @@ def build_readme(comp, adapter_dir, cutoff, scores): """ -def push_one(comp, adapter_dir, cutoff, scores): +def push_one( + comp: str, + adapter_dir: str, + cutoff: int, + scores: tuple[float, float, float, float], +) -> None: src = Path(adapter_dir) assert src.exists(), f"missing {src}" @@ -213,7 +223,7 @@ def push_one(comp, adapter_dir, cutoff, scores): print(f" uploaded → https://huggingface.co/{repo_id}") -def main(): +def main() -> None: for args in BEST: push_one(*args) print("\nAll done.") diff --git a/train/sft/push_to_hf.py b/train/sft/push_to_hf.py index 24acdc0..4f411bd 100644 --- a/train/sft/push_to_hf.py +++ b/train/sft/push_to_hf.py @@ -56,7 +56,9 @@ } -def build_readme(comp, step, judge, config, base_judge): +def build_readme( + comp: str, step: int, judge: float, config: str, base_judge: float +) -> str: return f"""--- license: apache-2.0 library_name: peft @@ -128,7 +130,7 @@ def build_readme(comp, step, judge, config, base_judge): """ -def push_one(comp, run_dir, step, judge, config): +def push_one(comp: str, run_dir: str, step: int, judge: float, config: str) -> None: src = Path(BASE_CKPT_DIR) / run_dir / f"checkpoint-{step}" assert src.exists(), f"missing {src}" @@ -166,7 +168,7 @@ def push_one(comp, run_dir, step, judge, config): print(f" uploaded → https://huggingface.co/{repo_id}") -def main(): +def main() -> None: for args in BEST: push_one(*args) print("\nAll done.") diff --git a/train/sft/push_universal_to_hf.py b/train/sft/push_universal_to_hf.py index cea7e7a..0e36ee4 100644 --- a/train/sft/push_universal_to_hf.py +++ b/train/sft/push_universal_to_hf.py @@ -117,7 +117,7 @@ """ -def main(): +def main() -> None: print(f"=== {REPO_ID} ===") print(f" src: {SRC}") with tempfile.TemporaryDirectory() as tmp: