diff --git a/.smell-baseline.json b/.smell-baseline.json index 507c8fb..a830d9a 100644 --- a/.smell-baseline.json +++ b/.smell-baseline.json @@ -21,7 +21,7 @@ "path": "protea/core/training_dump_helpers.py", "name": "TrainRerankerAutoOperation", "line": 1105, - "metric": 763, + "metric": 766, "threshold": 500 }, { @@ -48,7 +48,7 @@ "path": "protea/core/training_dump_helpers.py", "name": "", "line": 0, - "metric": 1867, + "metric": 1870, "threshold": 800 }, { @@ -416,8 +416,8 @@ "kind": "method", "path": "protea/core/parquet_export.py", "name": "export_reranker_parquets", - "line": 82, - "metric": 180, + "line": 116, + "metric": 162, "threshold": 60 }, { @@ -461,7 +461,7 @@ "kind": "method", "path": "protea/core/training_dump_helpers.py", "name": "TrainRerankerAutoOperation.execute", - "line": 1209, + "line": 1212, "metric": 659, "threshold": 60 }, @@ -636,15 +636,6 @@ "metric": 11, "threshold": 6 }, - { - "key": "params::protea/core/parquet_export.py::export_reranker_parquets", - "kind": "params", - "path": "protea/core/parquet_export.py", - "name": "export_reranker_parquets", - "line": 82, - "metric": 16, - "threshold": 6 - }, { "key": "params::protea/core/retry.py::with_retry", "kind": "params", diff --git a/protea/core/parquet_export.py b/protea/core/parquet_export.py index 81c5434..738c05b 100644 --- a/protea/core/parquet_export.py +++ b/protea/core/parquet_export.py @@ -22,6 +22,7 @@ import json import logging import subprocess +from dataclasses import dataclass from pathlib import Path from typing import Any @@ -37,6 +38,39 @@ _CATEGORIES = ("nk", "lk", "pk") +@dataclass(frozen=True) +class ParquetExportContext: + """Bundle of inputs ``export_reranker_parquets`` consumes. + + Groups the 15 per-call inputs (identity, source shards, publishing + options) so the entry-point signature stays under flake8-bugbear's + parameter ceiling. Keep this dataclass authoritative when adding + new options. + """ + + # Source shards + stage_dir: Path + split_files: dict[str, list[Path]] + valid_split_versions: list[tuple[int, int]] + test_files: dict[str, Path | None] + test_old_v: int + test_new_v: int + + # Dataset identity + name: str + k: int + embedding_config_id: str + ontology_snapshot_id: str + annotation_source: str + + # Publishing + store: ArtifactStore | None = None + key_prefix: str = "" + producer_version: str | None = None + producer_git_sha: str | None = None + validate_with_contracts: bool = True + + def resolve_protea_git_sha() -> str | None: """Best-effort current HEAD sha of the PROTEA repo. Returns None when the code is not running inside a git checkout or git is unavailable. @@ -79,57 +113,39 @@ def _validate_manifest_with_contracts(manifest: dict[str, Any]) -> None: ManifestV1.model_validate(manifest) -def export_reranker_parquets( - *, - stage_dir: Path, - split_files: dict[str, list[Path]], - valid_split_versions: list[tuple[int, int]], - test_files: dict[str, Path | None], - test_old_v: int, - test_new_v: int, - name: str, - k: int, - embedding_config_id: str, - ontology_snapshot_id: str, - annotation_source: str, - store: ArtifactStore | None = None, - key_prefix: str = "", - producer_version: str | None = None, - producer_git_sha: str | None = None, - validate_with_contracts: bool = True, -) -> dict[str, Any]: +def export_reranker_parquets(ctx: ParquetExportContext) -> dict[str, Any]: """Consolidate per-cat per-split parquet shards into the frozen dataset layout and optionally publish via an ``ArtifactStore``. - Parameters - ---------- - stage_dir - Directory used as the local staging area. The three output files - are written here regardless of ``store``; when ``store`` is given - they are additionally uploaded under ``key_prefix``. - split_files - Per-category list of training shard paths, parallel to - ``valid_split_versions``. - valid_split_versions - ``(v_old, v_new)`` pairs for each training shard position. - test_files - Per-category test shard path (may be ``None`` when the category - has no test rows). - store - When provided, the three consolidated files are uploaded under - ``f"{key_prefix}train.parquet"`` etc. The returned dict includes - the resulting URIs. - key_prefix - Prefix for store keys (should typically end with ``/``). - producer_version - PROTEA version string recorded in the manifest (optional). - producer_git_sha - PROTEA git HEAD at export time, recorded in the manifest - (optional). - validate_with_contracts - If True, best-effort validate the manifest dict against the lab's - ``ManifestV1`` before writing. Silent if the lab isn't installed. + All inputs live on :class:`ParquetExportContext`. Notable fields: + + - ``stage_dir``: local staging area (always written here; uploaded + under ``key_prefix`` if ``store`` is set). + - ``split_files``: per-category training shard paths, parallel to + ``valid_split_versions``. + - ``test_files``: per-category test shard path (may be ``None``). + - ``store`` / ``key_prefix``: optional artifact-store upload. + - ``producer_version`` / ``producer_git_sha``: manifest provenance. + - ``validate_with_contracts``: best-effort validate against the + lab's ``ManifestV1`` before writing. """ + stage_dir = ctx.stage_dir + split_files = ctx.split_files + valid_split_versions = ctx.valid_split_versions + test_files = ctx.test_files + test_old_v = ctx.test_old_v + test_new_v = ctx.test_new_v + name = ctx.name + k = ctx.k + embedding_config_id = ctx.embedding_config_id + ontology_snapshot_id = ctx.ontology_snapshot_id + annotation_source = ctx.annotation_source + store = ctx.store + key_prefix = ctx.key_prefix + producer_version = ctx.producer_version + producer_git_sha = ctx.producer_git_sha + validate_with_contracts = ctx.validate_with_contracts + stage_dir.mkdir(parents=True, exist_ok=True) aspect_norm = dict(_ASPECT_NAMES) diff --git a/protea/core/training_dump_helpers.py b/protea/core/training_dump_helpers.py index a15f2e2..6019930 100644 --- a/protea/core/training_dump_helpers.py +++ b/protea/core/training_dump_helpers.py @@ -1181,25 +1181,28 @@ def _dump_frozen_dataset( """ from protea import __version__ as _protea_version from protea.core.parquet_export import ( + ParquetExportContext, export_reranker_parquets, resolve_protea_git_sha, ) result = export_reranker_parquets( - stage_dir=dump_dir, - split_files=split_files, - valid_split_versions=valid_split_versions, - test_files=test_files, - test_old_v=test_old_v, - test_new_v=test_new_v, - name=name, - k=k, - embedding_config_id=embedding_config_id, - ontology_snapshot_id=ontology_snapshot_id, - annotation_source=annotation_source, - store=None, - producer_version=_protea_version, - producer_git_sha=resolve_protea_git_sha(), + ParquetExportContext( + stage_dir=dump_dir, + split_files=split_files, + valid_split_versions=valid_split_versions, + test_files=test_files, + test_old_v=test_old_v, + test_new_v=test_new_v, + name=name, + k=k, + embedding_config_id=embedding_config_id, + ontology_snapshot_id=ontology_snapshot_id, + annotation_source=annotation_source, + store=None, + producer_version=_protea_version, + producer_git_sha=resolve_protea_git_sha(), + ) ) # Preserve the historical return contract — callers rely on # ``dump_dir`` instead of ``stage_dir``. diff --git a/tests/test_parquet_export_boundary.py b/tests/test_parquet_export_boundary.py index 30e35f2..ed27b9b 100644 --- a/tests/test_parquet_export_boundary.py +++ b/tests/test_parquet_export_boundary.py @@ -51,26 +51,31 @@ def _call_export( train_rows: list[dict[str, object]], eval_rows: list[dict[str, object]], ) -> dict[str, object]: - from protea.core.parquet_export import export_reranker_parquets + from protea.core.parquet_export import ( + ParquetExportContext, + export_reranker_parquets, + ) train_shard = _write_shard(stage_dir / "_train_nk.parquet", train_rows) eval_shard = _write_shard(stage_dir / "_eval_nk.parquet", eval_rows) return export_reranker_parquets( - stage_dir=stage_dir, - split_files={"nk": [train_shard]}, - valid_split_versions=[(220, 221)], - test_files={"nk": eval_shard}, - test_old_v=221, - test_new_v=222, - name="t18-test", - k=5, - embedding_config_id="00000000-0000-0000-0000-000000000001", - ontology_snapshot_id="00000000-0000-0000-0000-000000000002", - annotation_source="goa", - store=None, - producer_version="t18", - producer_git_sha=None, - validate_with_contracts=False, + ParquetExportContext( + stage_dir=stage_dir, + split_files={"nk": [train_shard]}, + valid_split_versions=[(220, 221)], + test_files={"nk": eval_shard}, + test_old_v=221, + test_new_v=222, + name="t18-test", + k=5, + embedding_config_id="00000000-0000-0000-0000-000000000001", + ontology_snapshot_id="00000000-0000-0000-0000-000000000002", + annotation_source="goa", + store=None, + producer_version="t18", + producer_git_sha=None, + validate_with_contracts=False, + ) )