diff --git a/scripts/01_download_wiki.py b/scripts/01_download_wiki.py index 54a616d..94b753e 100644 --- a/scripts/01_download_wiki.py +++ b/scripts/01_download_wiki.py @@ -6,6 +6,7 @@ python scripts/01_download_wiki.py --lang de python scripts/01_download_wiki.py --lang de --lang en """ + import hashlib import logging import re @@ -27,9 +28,7 @@ def _validate_lang(lang: str) -> str: if not LANG_RE.match(lang): - raise ValueError( - f"Invalid language code {lang!r}. Expected 2-5 lowercase ASCII letters." - ) + raise ValueError(f"Invalid language code {lang!r}. Expected 2-5 lowercase ASCII letters.") return lang @@ -67,7 +66,6 @@ def _md5_file(path: Path) -> str: return h.hexdigest() - def download_dump(lang: str, out_dir: Path) -> None: """Download and verify the Wikipedia dump for *lang*.""" _validate_lang(lang) @@ -113,8 +111,8 @@ def download_dump(lang: str, out_dir: Path) -> None: if total_size > 0: pct = min(downloaded * 100 / total_size, 100) print( - f"\r {pct:5.1f}% {downloaded/1_048_576:,.0f}" - f" / {total_size/1_048_576:,.0f} MB", + f"\r {pct:5.1f}% {downloaded / 1_048_576:,.0f}" + f" / {total_size / 1_048_576:,.0f} MB", end="", flush=True, ) diff --git a/scripts/02_extract_wiki.py b/scripts/02_extract_wiki.py index 1915ae2..a310938 100644 --- a/scripts/02_extract_wiki.py +++ b/scripts/02_extract_wiki.py @@ -8,6 +8,7 @@ Usage: python scripts/02_extract_wiki.py --dump data/raw/dewiki-latest-pages-articles.xml.bz2 --lang de """ + import argparse import json import logging diff --git a/scripts/03_extract_pdfs.py b/scripts/03_extract_pdfs.py index 23104fb..87edd56 100644 --- a/scripts/03_extract_pdfs.py +++ b/scripts/03_extract_pdfs.py @@ -9,6 +9,7 @@ Usage: python scripts/03_extract_pdfs.py --input-dir /path/to/pdfs """ + import argparse import json import logging diff --git a/scripts/04_extract_markdown.py b/scripts/04_extract_markdown.py index 6c7a38d..a1bbbb5 100644 --- a/scripts/04_extract_markdown.py +++ b/scripts/04_extract_markdown.py @@ -10,6 +10,7 @@ Usage: python scripts/04_extract_markdown.py --input-dir /path/to/docs """ + import argparse import json import logging diff --git a/scripts/05_clean_deduplicate.py b/scripts/05_clean_deduplicate.py index d43e2a8..d2e3bed 100644 --- a/scripts/05_clean_deduplicate.py +++ b/scripts/05_clean_deduplicate.py @@ -15,6 +15,7 @@ data/processed/markdown.jsonl \ --output-file data/processed/corpus.jsonl """ + import argparse import hashlib import json @@ -126,9 +127,7 @@ def main() -> None: kept += 1 if total % 10_000 == 0: - log.info( - "%10d read | %10d kept | %8d dupes", total, kept, dupes - ) + log.info("%10d read | %10d kept | %8d dupes", total, kept, dupes) except Exception: tmp.unlink(missing_ok=True) raise diff --git a/scripts/06_tokenize.py b/scripts/06_tokenize.py index 002f701..8d4b8fc 100644 --- a/scripts/06_tokenize.py +++ b/scripts/06_tokenize.py @@ -11,6 +11,7 @@ --model-id mistralai/Ministral-3-14B-Base-2512 \ --seq-len 8192 """ + import argparse import json import logging @@ -112,7 +113,9 @@ def chunk_generator( log.info( "Tokenization complete: %d docs -> %d chunks (seq_len=%d)", - doc_count, chunk_count, seq_len, + doc_count, + chunk_count, + seq_len, ) @@ -158,7 +161,9 @@ def main() -> None: log.info( "Tokenizing: %s (seq_len=%d, batch_size=%d)", - jsonl_path, args.seq_len, args.batch_size, + jsonl_path, + args.seq_len, + args.batch_size, ) def _gen() -> Generator[dict[str, list[int]], None, None]: diff --git a/scripts/07_create_sft_data.py b/scripts/07_create_sft_data.py index 8cee4dc..058b67a 100644 --- a/scripts/07_create_sft_data.py +++ b/scripts/07_create_sft_data.py @@ -16,6 +16,7 @@ --output data/processed/sft_data.jsonl \ --max-docs 200000 """ + import argparse import json import logging diff --git a/scripts/patch_wikiextractor.py b/scripts/patch_wikiextractor.py index f8dd4ba..57dcc1d 100644 --- a/scripts/patch_wikiextractor.py +++ b/scripts/patch_wikiextractor.py @@ -6,6 +6,7 @@ This script moves them to the front of the affected patterns. Idempotent: re-running after the patch is already applied is a no-op. """ + import sys from pathlib import Path diff --git a/scripts/smoke_test.py b/scripts/smoke_test.py index 106ec75..42cf9bb 100644 --- a/scripts/smoke_test.py +++ b/scripts/smoke_test.py @@ -20,24 +20,27 @@ from transformers import AutoModelForCausalLM, AutoTokenizer DEVICE = "cuda" -MODEL = "gpt2" +MODEL = "gpt2" SEQ_LEN = 128 -STEPS = 5 +STEPS = 5 + def check(label): print(f" [OK] {label}") + def fail(label, exc): print(f" [FAIL] {label}: {exc}", file=sys.stderr) sys.exit(1) + print("=== knowledge-lora smoke test ===\n") # ── 1. CUDA ─────────────────────────────────────────────────────────────────── try: assert torch.cuda.is_available(), "CUDA not available" name = torch.cuda.get_device_name(0) - mem = torch.cuda.get_device_properties(0).total_memory / 1024**3 + mem = torch.cuda.get_device_properties(0).total_memory / 1024**3 check(f"torch {torch.__version__} | {name} | {mem:.0f} GB") except Exception as e: fail("CUDA", e) @@ -45,6 +48,7 @@ def fail(label, exc): # ── 2. flash-attn ───────────────────────────────────────────────────────────── try: import flash_attn + check(f"flash-attn {flash_attn.__version__}") except Exception as e: fail("flash-attn import", e) @@ -67,12 +71,12 @@ def fail(label, exc): r=4, lora_alpha=8, lora_dropout=0.05, - target_modules=["c_attn"], # GPT-2 attention projection + target_modules=["c_attn"], # GPT-2 attention projection bias="none", ) model = get_peft_model(model, lora_cfg) trainable, total = model.get_nb_trainable_parameters() - check(f"LoRA applied | trainable {trainable/1e3:.1f}K / {total/1e6:.1f}M params") + check(f"LoRA applied | trainable {trainable / 1e3:.1f}K / {total / 1e6:.1f}M params") except Exception as e: fail("LoRA", e)