From 51e481dca49105218ea64b5724df230811f305f6 Mon Sep 17 00:00:00 2001 From: frapercan Date: Fri, 8 May 2026 18:05:00 +0200 Subject: [PATCH] =?UTF-8?q?refactor(training):=20T-CONTEXTS=20partial=20?= =?UTF-8?q?=E2=80=94=20KnnTransferContext?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces ``KnnTransferContext`` frozen dataclass in ``training_dump_helpers.py`` bundling the 12 per-call data inputs (queries, references, ontology maps, optional enrichment helpers) that ``_knn_transfer_and_label`` consumes. Signature collapses 16 args → 5: ``session``, ``p``, ``ctx``, ``sequence_context`` (existing), ``stream_output`` (existing). ``session`` and payload ``p`` stay separate as configuration; the existing ``SequenceContext`` and ``StreamOutput`` helpers continue to handle their orthogonal concerns. Three call sites updated: - ``training_dump_helpers.py`` train branch (line ~1515) - ``training_dump_helpers.py`` test stream branch (line ~1747) - ``tests/test_knn_streaming_smoke.py`` shared test runner Body destructures the context once at the top so the rest of the implementation stays diff-minimal. Sizes: - training_dump_helpers.py: +35 LOC (dataclass + destructuring) - Smell baseline: 77 → 75 (params>6: 22 → 20: ``_knn_transfer_and_label`` retired plus 1 knock-on improvement) Local-first 5 verde (ruff + flake8 + pytest 1163 + check_smells). --- .smell-baseline.json | 29 +++---- protea/core/training_dump_helpers.py | 124 +++++++++++++++++---------- tests/test_knn_streaming_smoke.py | 45 +++++----- 3 files changed, 113 insertions(+), 85 deletions(-) diff --git a/.smell-baseline.json b/.smell-baseline.json index a830d9a..4baa081 100644 --- a/.smell-baseline.json +++ b/.smell-baseline.json @@ -20,8 +20,8 @@ "kind": "class", "path": "protea/core/training_dump_helpers.py", "name": "TrainRerankerAutoOperation", - "line": 1105, - "metric": 766, + "line": 1133, + "metric": 770, "threshold": 500 }, { @@ -48,7 +48,7 @@ "path": "protea/core/training_dump_helpers.py", "name": "", "line": 0, - "metric": 1870, + "metric": 1902, "threshold": 800 }, { @@ -434,7 +434,7 @@ "kind": "method", "path": "protea/core/training_dump_helpers.py", "name": "_preload_all_embeddings", - "line": 153, + "line": 179, "metric": 73, "threshold": 60 }, @@ -443,7 +443,7 @@ "kind": "method", "path": "protea/core/training_dump_helpers.py", "name": "_build_reference_from_cache", - "line": 228, + "line": 254, "metric": 67, "threshold": 60 }, @@ -452,8 +452,8 @@ "kind": "method", "path": "protea/core/training_dump_helpers.py", "name": "_knn_transfer_and_label", - "line": 346, - "metric": 622, + "line": 372, + "metric": 624, "threshold": 60 }, { @@ -461,8 +461,8 @@ "kind": "method", "path": "protea/core/training_dump_helpers.py", "name": "TrainRerankerAutoOperation.execute", - "line": 1212, - "metric": 659, + "line": 1240, + "metric": 663, "threshold": 60 }, { @@ -645,21 +645,12 @@ "metric": 7, "threshold": 6 }, - { - "key": "params::protea/core/training_dump_helpers.py::_knn_transfer_and_label", - "kind": "params", - "path": "protea/core/training_dump_helpers.py", - "name": "_knn_transfer_and_label", - "line": 346, - "metric": 16, - "threshold": 6 - }, { "key": "params::protea/core/training_dump_helpers.py::TrainRerankerAutoOperation._dump_frozen_dataset", "kind": "params", "path": "protea/core/training_dump_helpers.py", "name": "TrainRerankerAutoOperation._dump_frozen_dataset", - "line": 1162, + "line": 1190, "metric": 11, "threshold": 6 }, diff --git a/protea/core/training_dump_helpers.py b/protea/core/training_dump_helpers.py index 6019930..717b0ac 100644 --- a/protea/core/training_dump_helpers.py +++ b/protea/core/training_dump_helpers.py @@ -100,6 +100,32 @@ class StreamOutput: chunk_rows: int = 100_000 +@dataclass(frozen=True) +class KnnTransferContext: + """Bundle of KNN inputs + enrichment maps for ``_knn_transfer_and_label``. + + Groups the 12 per-call data arguments (queries, references, ontology + maps, optional enrichment helpers) so the entry-point signature + stays under flake8-bugbear's parameter ceiling. ``session``, + payload ``p``, ``sequence_context``, and ``stream_output`` remain + standalone arguments because they are configuration / IO concerns, + not data. + """ + + valid_queries: list[str] + query_emb: np.ndarray + ref_by_aspect: dict[str, dict[str, Any]] + go_id_map: dict[int, str] + aspect_map: dict[int, str] + gt_pairs: set[tuple[str, str]] + query_known_gos: dict[str, set[str]] | None = None + parent_map_str: dict[str, set[str]] | None = None + ia_weights: dict[str, float] | None = None + pca_state: tuple[np.ndarray, np.ndarray] | None = None + pivot_go_ids: set[str] | frozenset[str] | None = None + embedding_pool: np.ndarray | None = None + + # --------------------------------------------------------------------------- # Payload # --------------------------------------------------------------------------- @@ -345,46 +371,48 @@ def _load_taxonomy_ids( def _knn_transfer_and_label( session: Session, - valid_queries: list[str], - query_emb: np.ndarray, - ref_by_aspect: dict[str, dict[str, Any]], - go_id_map: dict[int, str], - aspect_map: dict[int, str], - gt_pairs: set[tuple[str, str]], p: TrainRerankerAutoPayload, + ctx: KnnTransferContext, *, sequence_context: SequenceContext | None = None, - query_known_gos: dict[str, set[str]] | None = None, - parent_map_str: dict[str, set[str]] | None = None, - ia_weights: dict[str, float] | None = None, - pca_state: tuple[np.ndarray, np.ndarray] | None = None, - pivot_go_ids: set[str] | frozenset[str] | None = None, stream_output: StreamOutput | None = None, - embedding_pool: np.ndarray | None = None, ) -> list[dict[str, Any]] | dict[str, Any]: """Run per-aspect KNN, transfer GO terms, label, compute features. - ``query_known_gos`` is ``{protein_accession: {go_id}}`` of annotations - the query already carries before the prediction cutoff (from - ``EvaluationData.known``). Used to compute query-side Anc2Vec coherence - features — the PK-killer signal: how close is each candidate GO to the - query's existing annotation profile? + ``ctx.query_known_gos`` is ``{protein_accession: {go_id}}`` of + annotations the query already carries before the prediction cutoff + (from ``EvaluationData.known``). Used to compute query-side Anc2Vec + coherence features — the PK-killer signal: how close is each candidate + GO to the query's existing annotation profile? Streaming mode: when ``stream_output`` is given, records are written to disk in ``stream_output.chunk_rows`` chunks as they are generated (re-ordered to iterate per ``(q_acc, aspect)`` group so the ancestor-expansion stays local). In this mode the function returns ``{"parquet_path": str, "n_rows": int}`` instead - of the full list. ``pivot_go_ids`` (orthogonal to streaming) + of the full list. ``ctx.pivot_go_ids`` (orthogonal to streaming) filters records by go_id; useful in either mode. """ # Unpack the parameter objects so the body keeps using the original # local names — body untouched, only the call surface shrinks. - ctx = sequence_context or SequenceContext() - query_sequences = ctx.query_sequences - ref_sequences = ctx.ref_sequences - query_tax_ids = ctx.query_tax_ids - ref_tax_ids = ctx.ref_tax_ids + valid_queries = ctx.valid_queries + query_emb = ctx.query_emb + ref_by_aspect = ctx.ref_by_aspect + go_id_map = ctx.go_id_map + aspect_map = ctx.aspect_map + gt_pairs = ctx.gt_pairs + query_known_gos = ctx.query_known_gos + parent_map_str = ctx.parent_map_str + ia_weights = ctx.ia_weights + pca_state = ctx.pca_state + pivot_go_ids = ctx.pivot_go_ids + embedding_pool = ctx.embedding_pool + + seq_ctx = sequence_context or SequenceContext() + query_sequences = seq_ctx.query_sequences + ref_sequences = seq_ctx.ref_sequences + query_tax_ids = seq_ctx.query_tax_ids + ref_tax_ids = seq_ctx.ref_tax_ids if stream_output is not None: output_parquet: Path | None = stream_output.output_parquet @@ -1489,24 +1517,26 @@ def execute( session.expire_all() unlabeled_preds = _knn_transfer_and_label( session, - valid_queries, - query_emb, - ref_by_aspect, - go_id_map, - aspect_map, - set(), # empty gt → all label=0 p, + KnnTransferContext( + valid_queries=valid_queries, + query_emb=query_emb, + ref_by_aspect=ref_by_aspect, + go_id_map=go_id_map, + aspect_map=aspect_map, + gt_pairs=set(), # empty gt → all label=0 + query_known_gos=eval_data.known, + parent_map_str=parent_map if p.expand_votes_to_ancestors else None, + ia_weights=ia_weights, + pca_state=pca_state, + embedding_pool=all_embeddings, + ), sequence_context=SequenceContext( query_sequences=qs, ref_sequences=rs, query_tax_ids=qt, ref_tax_ids=rt, ), - query_known_gos=eval_data.known, - parent_map_str=parent_map if p.expand_votes_to_ancestors else None, - ia_weights=ia_weights, - pca_state=pca_state, - embedding_pool=all_embeddings, ) # Restrict predictions to terms present in the pivot universe — @@ -1720,26 +1750,28 @@ def execute( test_unlabeled_path = tmp_dir / "test_unlabeled.parquet" test_stream_info = _knn_transfer_and_label( session, - test_valid, - test_emb, - test_ref, - go_id_map, - aspect_map, - set(), p, + KnnTransferContext( + valid_queries=test_valid, + query_emb=test_emb, + ref_by_aspect=test_ref, + go_id_map=go_id_map, + aspect_map=aspect_map, + gt_pairs=set(), + query_known_gos=test_eval_data.known, + parent_map_str=parent_map if p.expand_votes_to_ancestors else None, + ia_weights=ia_weights, + pca_state=pca_state, + pivot_go_ids=pivot_go_ids, + embedding_pool=all_embeddings, + ), sequence_context=SequenceContext( query_sequences=test_qs, ref_sequences=test_rs, query_tax_ids=test_qt, ref_tax_ids=test_rt, ), - query_known_gos=test_eval_data.known, - parent_map_str=parent_map if p.expand_votes_to_ancestors else None, - ia_weights=ia_weights, - pca_state=pca_state, - pivot_go_ids=pivot_go_ids, stream_output=StreamOutput(output_parquet=test_unlabeled_path), - embedding_pool=all_embeddings, ) del test_ref, test_emb, test_valid, test_qs, test_rs, test_qt, test_rt gc.collect() diff --git a/tests/test_knn_streaming_smoke.py b/tests/test_knn_streaming_smoke.py index c9af7ea..c0776e1 100644 --- a/tests/test_knn_streaming_smoke.py +++ b/tests/test_knn_streaming_smoke.py @@ -17,7 +17,11 @@ import pyarrow.parquet as pq import pytest -from protea.core.training_dump_helpers import StreamOutput, _knn_transfer_and_label +from protea.core.training_dump_helpers import ( + KnnTransferContext, + StreamOutput, + _knn_transfer_and_label, +) class _StubAnc2Vec: @@ -123,18 +127,24 @@ def _run(mode: str, tmp_path: Path | None = None, *, expand: bool, pivot=None): session = MagicMock() p = _mk_payload(expand=expand) - kwargs: dict = { - "query_known_gos": None, - "parent_map_str": parent_map_str if expand else None, - "ia_weights": None, - "pca_state": None, - } - kwargs["pivot_go_ids"] = pivot - if mode == "stream": - kwargs["stream_output"] = StreamOutput( - output_parquet=tmp_path / "out.parquet", - chunk_rows=3, # tiny to force multiple flushes - ) + ctx = KnnTransferContext( + valid_queries=valid_queries, + query_emb=query_emb, + ref_by_aspect=ref_by_aspect, + go_id_map=go_id_map, + aspect_map=aspect_map, + gt_pairs=gt_pairs, + query_known_gos=None, + parent_map_str=parent_map_str if expand else None, + ia_weights=None, + pca_state=None, + pivot_go_ids=pivot, + ) + stream_output = ( + StreamOutput(output_parquet=tmp_path / "out.parquet", chunk_rows=3) + if mode == "stream" + else None + ) with patch( "protea.core.training_dump_helpers.get_anc2vec_index", @@ -142,14 +152,9 @@ def _run(mode: str, tmp_path: Path | None = None, *, expand: bool, pivot=None): ): return _knn_transfer_and_label( session, - valid_queries, - query_emb, - ref_by_aspect, - go_id_map, - aspect_map, - gt_pairs, p, - **kwargs, + ctx, + stream_output=stream_output, )