Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions .smell-baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"path": "protea/core/training_dump_helpers.py",
"name": "TrainRerankerAutoOperation",
"line": 1105,
"metric": 763,
"metric": 766,
"threshold": 500
},
{
Expand All @@ -48,7 +48,7 @@
"path": "protea/core/training_dump_helpers.py",
"name": "",
"line": 0,
"metric": 1867,
"metric": 1870,
"threshold": 800
},
{
Expand Down Expand Up @@ -416,8 +416,8 @@
"kind": "method",
"path": "protea/core/parquet_export.py",
"name": "export_reranker_parquets",
"line": 82,
"metric": 180,
"line": 116,
"metric": 162,
"threshold": 60
},
{
Expand Down Expand Up @@ -461,7 +461,7 @@
"kind": "method",
"path": "protea/core/training_dump_helpers.py",
"name": "TrainRerankerAutoOperation.execute",
"line": 1209,
"line": 1212,
"metric": 659,
"threshold": 60
},
Expand Down Expand Up @@ -636,15 +636,6 @@
"metric": 11,
"threshold": 6
},
{
"key": "params::protea/core/parquet_export.py::export_reranker_parquets",
"kind": "params",
"path": "protea/core/parquet_export.py",
"name": "export_reranker_parquets",
"line": 82,
"metric": 16,
"threshold": 6
},
{
"key": "params::protea/core/retry.py::with_retry",
"kind": "params",
Expand Down
110 changes: 63 additions & 47 deletions protea/core/parquet_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
import logging
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Any

Expand All @@ -37,6 +38,39 @@
_CATEGORIES = ("nk", "lk", "pk")


@dataclass(frozen=True)
class ParquetExportContext:
"""Bundle of inputs ``export_reranker_parquets`` consumes.

Groups the 15 per-call inputs (identity, source shards, publishing
options) so the entry-point signature stays under flake8-bugbear's
parameter ceiling. Keep this dataclass authoritative when adding
new options.
"""

# Source shards
stage_dir: Path
split_files: dict[str, list[Path]]
valid_split_versions: list[tuple[int, int]]
test_files: dict[str, Path | None]
test_old_v: int
test_new_v: int

# Dataset identity
name: str
k: int
embedding_config_id: str
ontology_snapshot_id: str
annotation_source: str

# Publishing
store: ArtifactStore | None = None
key_prefix: str = ""
producer_version: str | None = None
producer_git_sha: str | None = None
validate_with_contracts: bool = True


def resolve_protea_git_sha() -> str | None:
"""Best-effort current HEAD sha of the PROTEA repo. Returns None when
the code is not running inside a git checkout or git is unavailable.
Expand Down Expand Up @@ -79,57 +113,39 @@ def _validate_manifest_with_contracts(manifest: dict[str, Any]) -> None:
ManifestV1.model_validate(manifest)


def export_reranker_parquets(
*,
stage_dir: Path,
split_files: dict[str, list[Path]],
valid_split_versions: list[tuple[int, int]],
test_files: dict[str, Path | None],
test_old_v: int,
test_new_v: int,
name: str,
k: int,
embedding_config_id: str,
ontology_snapshot_id: str,
annotation_source: str,
store: ArtifactStore | None = None,
key_prefix: str = "",
producer_version: str | None = None,
producer_git_sha: str | None = None,
validate_with_contracts: bool = True,
) -> dict[str, Any]:
def export_reranker_parquets(ctx: ParquetExportContext) -> dict[str, Any]:
"""Consolidate per-cat per-split parquet shards into the frozen
dataset layout and optionally publish via an ``ArtifactStore``.

Parameters
----------
stage_dir
Directory used as the local staging area. The three output files
are written here regardless of ``store``; when ``store`` is given
they are additionally uploaded under ``key_prefix``.
split_files
Per-category list of training shard paths, parallel to
``valid_split_versions``.
valid_split_versions
``(v_old, v_new)`` pairs for each training shard position.
test_files
Per-category test shard path (may be ``None`` when the category
has no test rows).
store
When provided, the three consolidated files are uploaded under
``f"{key_prefix}train.parquet"`` etc. The returned dict includes
the resulting URIs.
key_prefix
Prefix for store keys (should typically end with ``/``).
producer_version
PROTEA version string recorded in the manifest (optional).
producer_git_sha
PROTEA git HEAD at export time, recorded in the manifest
(optional).
validate_with_contracts
If True, best-effort validate the manifest dict against the lab's
``ManifestV1`` before writing. Silent if the lab isn't installed.
All inputs live on :class:`ParquetExportContext`. Notable fields:

- ``stage_dir``: local staging area (always written here; uploaded
under ``key_prefix`` if ``store`` is set).
- ``split_files``: per-category training shard paths, parallel to
``valid_split_versions``.
- ``test_files``: per-category test shard path (may be ``None``).
- ``store`` / ``key_prefix``: optional artifact-store upload.
- ``producer_version`` / ``producer_git_sha``: manifest provenance.
- ``validate_with_contracts``: best-effort validate against the
lab's ``ManifestV1`` before writing.
"""
stage_dir = ctx.stage_dir
split_files = ctx.split_files
valid_split_versions = ctx.valid_split_versions
test_files = ctx.test_files
test_old_v = ctx.test_old_v
test_new_v = ctx.test_new_v
name = ctx.name
k = ctx.k
embedding_config_id = ctx.embedding_config_id
ontology_snapshot_id = ctx.ontology_snapshot_id
annotation_source = ctx.annotation_source
store = ctx.store
key_prefix = ctx.key_prefix
producer_version = ctx.producer_version
producer_git_sha = ctx.producer_git_sha
validate_with_contracts = ctx.validate_with_contracts

stage_dir.mkdir(parents=True, exist_ok=True)
aspect_norm = dict(_ASPECT_NAMES)

Expand Down
31 changes: 17 additions & 14 deletions protea/core/training_dump_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,25 +1181,28 @@ def _dump_frozen_dataset(
"""
from protea import __version__ as _protea_version
from protea.core.parquet_export import (
ParquetExportContext,
export_reranker_parquets,
resolve_protea_git_sha,
)

result = export_reranker_parquets(
stage_dir=dump_dir,
split_files=split_files,
valid_split_versions=valid_split_versions,
test_files=test_files,
test_old_v=test_old_v,
test_new_v=test_new_v,
name=name,
k=k,
embedding_config_id=embedding_config_id,
ontology_snapshot_id=ontology_snapshot_id,
annotation_source=annotation_source,
store=None,
producer_version=_protea_version,
producer_git_sha=resolve_protea_git_sha(),
ParquetExportContext(
stage_dir=dump_dir,
split_files=split_files,
valid_split_versions=valid_split_versions,
test_files=test_files,
test_old_v=test_old_v,
test_new_v=test_new_v,
name=name,
k=k,
embedding_config_id=embedding_config_id,
ontology_snapshot_id=ontology_snapshot_id,
annotation_source=annotation_source,
store=None,
producer_version=_protea_version,
producer_git_sha=resolve_protea_git_sha(),
)
)
# Preserve the historical return contract — callers rely on
# ``dump_dir`` instead of ``stage_dir``.
Expand Down
37 changes: 21 additions & 16 deletions tests/test_parquet_export_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,31 @@ def _call_export(
train_rows: list[dict[str, object]],
eval_rows: list[dict[str, object]],
) -> dict[str, object]:
from protea.core.parquet_export import export_reranker_parquets
from protea.core.parquet_export import (
ParquetExportContext,
export_reranker_parquets,
)

train_shard = _write_shard(stage_dir / "_train_nk.parquet", train_rows)
eval_shard = _write_shard(stage_dir / "_eval_nk.parquet", eval_rows)
return export_reranker_parquets(
stage_dir=stage_dir,
split_files={"nk": [train_shard]},
valid_split_versions=[(220, 221)],
test_files={"nk": eval_shard},
test_old_v=221,
test_new_v=222,
name="t18-test",
k=5,
embedding_config_id="00000000-0000-0000-0000-000000000001",
ontology_snapshot_id="00000000-0000-0000-0000-000000000002",
annotation_source="goa",
store=None,
producer_version="t18",
producer_git_sha=None,
validate_with_contracts=False,
ParquetExportContext(
stage_dir=stage_dir,
split_files={"nk": [train_shard]},
valid_split_versions=[(220, 221)],
test_files={"nk": eval_shard},
test_old_v=221,
test_new_v=222,
name="t18-test",
k=5,
embedding_config_id="00000000-0000-0000-0000-000000000001",
ontology_snapshot_id="00000000-0000-0000-0000-000000000002",
annotation_source="goa",
store=None,
producer_version="t18",
producer_git_sha=None,
validate_with_contracts=False,
)
)


Expand Down
Loading