Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions .smell-baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
"kind": "class",
"path": "protea/core/training_dump_helpers.py",
"name": "TrainRerankerAutoOperation",
"line": 1105,
"metric": 766,
"line": 1133,
"metric": 770,
"threshold": 500
},
{
Expand All @@ -48,7 +48,7 @@
"path": "protea/core/training_dump_helpers.py",
"name": "",
"line": 0,
"metric": 1870,
"metric": 1902,
"threshold": 800
},
{
Expand Down Expand Up @@ -434,7 +434,7 @@
"kind": "method",
"path": "protea/core/training_dump_helpers.py",
"name": "_preload_all_embeddings",
"line": 153,
"line": 179,
"metric": 73,
"threshold": 60
},
Expand All @@ -443,7 +443,7 @@
"kind": "method",
"path": "protea/core/training_dump_helpers.py",
"name": "_build_reference_from_cache",
"line": 228,
"line": 254,
"metric": 67,
"threshold": 60
},
Expand All @@ -452,17 +452,17 @@
"kind": "method",
"path": "protea/core/training_dump_helpers.py",
"name": "_knn_transfer_and_label",
"line": 346,
"metric": 622,
"line": 372,
"metric": 624,
"threshold": 60
},
{
"key": "method::protea/core/training_dump_helpers.py::TrainRerankerAutoOperation.execute",
"kind": "method",
"path": "protea/core/training_dump_helpers.py",
"name": "TrainRerankerAutoOperation.execute",
"line": 1212,
"metric": 659,
"line": 1240,
"metric": 663,
"threshold": 60
},
{
Expand Down Expand Up @@ -645,21 +645,12 @@
"metric": 7,
"threshold": 6
},
{
"key": "params::protea/core/training_dump_helpers.py::_knn_transfer_and_label",
"kind": "params",
"path": "protea/core/training_dump_helpers.py",
"name": "_knn_transfer_and_label",
"line": 346,
"metric": 16,
"threshold": 6
},
{
"key": "params::protea/core/training_dump_helpers.py::TrainRerankerAutoOperation._dump_frozen_dataset",
"kind": "params",
"path": "protea/core/training_dump_helpers.py",
"name": "TrainRerankerAutoOperation._dump_frozen_dataset",
"line": 1162,
"line": 1190,
"metric": 11,
"threshold": 6
},
Expand Down
124 changes: 78 additions & 46 deletions protea/core/training_dump_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,32 @@ class StreamOutput:
chunk_rows: int = 100_000


@dataclass(frozen=True)
class KnnTransferContext:
"""Bundle of KNN inputs + enrichment maps for ``_knn_transfer_and_label``.

Groups the 12 per-call data arguments (queries, references, ontology
maps, optional enrichment helpers) so the entry-point signature
stays under flake8-bugbear's parameter ceiling. ``session``,
payload ``p``, ``sequence_context``, and ``stream_output`` remain
standalone arguments because they are configuration / IO concerns,
not data.
"""

valid_queries: list[str]
query_emb: np.ndarray
ref_by_aspect: dict[str, dict[str, Any]]
go_id_map: dict[int, str]
aspect_map: dict[int, str]
gt_pairs: set[tuple[str, str]]
query_known_gos: dict[str, set[str]] | None = None
parent_map_str: dict[str, set[str]] | None = None
ia_weights: dict[str, float] | None = None
pca_state: tuple[np.ndarray, np.ndarray] | None = None
pivot_go_ids: set[str] | frozenset[str] | None = None
embedding_pool: np.ndarray | None = None


# ---------------------------------------------------------------------------
# Payload
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -345,46 +371,48 @@ def _load_taxonomy_ids(

def _knn_transfer_and_label(
session: Session,
valid_queries: list[str],
query_emb: np.ndarray,
ref_by_aspect: dict[str, dict[str, Any]],
go_id_map: dict[int, str],
aspect_map: dict[int, str],
gt_pairs: set[tuple[str, str]],
p: TrainRerankerAutoPayload,
ctx: KnnTransferContext,
*,
sequence_context: SequenceContext | None = None,
query_known_gos: dict[str, set[str]] | None = None,
parent_map_str: dict[str, set[str]] | None = None,
ia_weights: dict[str, float] | None = None,
pca_state: tuple[np.ndarray, np.ndarray] | None = None,
pivot_go_ids: set[str] | frozenset[str] | None = None,
stream_output: StreamOutput | None = None,
embedding_pool: np.ndarray | None = None,
) -> list[dict[str, Any]] | dict[str, Any]:
"""Run per-aspect KNN, transfer GO terms, label, compute features.

``query_known_gos`` is ``{protein_accession: {go_id}}`` of annotations
the query already carries before the prediction cutoff (from
``EvaluationData.known``). Used to compute query-side Anc2Vec coherence
features — the PK-killer signal: how close is each candidate GO to the
query's existing annotation profile?
``ctx.query_known_gos`` is ``{protein_accession: {go_id}}`` of
annotations the query already carries before the prediction cutoff
(from ``EvaluationData.known``). Used to compute query-side Anc2Vec
coherence features — the PK-killer signal: how close is each candidate
GO to the query's existing annotation profile?

Streaming mode: when ``stream_output`` is given, records are
written to disk in ``stream_output.chunk_rows`` chunks as they
are generated (re-ordered to iterate per ``(q_acc, aspect)``
group so the ancestor-expansion stays local). In this mode the
function returns ``{"parquet_path": str, "n_rows": int}`` instead
of the full list. ``pivot_go_ids`` (orthogonal to streaming)
of the full list. ``ctx.pivot_go_ids`` (orthogonal to streaming)
filters records by go_id; useful in either mode.
"""
# Unpack the parameter objects so the body keeps using the original
# local names — body untouched, only the call surface shrinks.
ctx = sequence_context or SequenceContext()
query_sequences = ctx.query_sequences
ref_sequences = ctx.ref_sequences
query_tax_ids = ctx.query_tax_ids
ref_tax_ids = ctx.ref_tax_ids
valid_queries = ctx.valid_queries
query_emb = ctx.query_emb
ref_by_aspect = ctx.ref_by_aspect
go_id_map = ctx.go_id_map
aspect_map = ctx.aspect_map
gt_pairs = ctx.gt_pairs
query_known_gos = ctx.query_known_gos
parent_map_str = ctx.parent_map_str
ia_weights = ctx.ia_weights
pca_state = ctx.pca_state
pivot_go_ids = ctx.pivot_go_ids
embedding_pool = ctx.embedding_pool

seq_ctx = sequence_context or SequenceContext()
query_sequences = seq_ctx.query_sequences
ref_sequences = seq_ctx.ref_sequences
query_tax_ids = seq_ctx.query_tax_ids
ref_tax_ids = seq_ctx.ref_tax_ids

if stream_output is not None:
output_parquet: Path | None = stream_output.output_parquet
Expand Down Expand Up @@ -1489,24 +1517,26 @@ def execute(
session.expire_all()
unlabeled_preds = _knn_transfer_and_label(
session,
valid_queries,
query_emb,
ref_by_aspect,
go_id_map,
aspect_map,
set(), # empty gt → all label=0
p,
KnnTransferContext(
valid_queries=valid_queries,
query_emb=query_emb,
ref_by_aspect=ref_by_aspect,
go_id_map=go_id_map,
aspect_map=aspect_map,
gt_pairs=set(), # empty gt → all label=0
query_known_gos=eval_data.known,
parent_map_str=parent_map if p.expand_votes_to_ancestors else None,
ia_weights=ia_weights,
pca_state=pca_state,
embedding_pool=all_embeddings,
),
sequence_context=SequenceContext(
query_sequences=qs,
ref_sequences=rs,
query_tax_ids=qt,
ref_tax_ids=rt,
),
query_known_gos=eval_data.known,
parent_map_str=parent_map if p.expand_votes_to_ancestors else None,
ia_weights=ia_weights,
pca_state=pca_state,
embedding_pool=all_embeddings,
)

# Restrict predictions to terms present in the pivot universe —
Expand Down Expand Up @@ -1720,26 +1750,28 @@ def execute(
test_unlabeled_path = tmp_dir / "test_unlabeled.parquet"
test_stream_info = _knn_transfer_and_label(
session,
test_valid,
test_emb,
test_ref,
go_id_map,
aspect_map,
set(),
p,
KnnTransferContext(
valid_queries=test_valid,
query_emb=test_emb,
ref_by_aspect=test_ref,
go_id_map=go_id_map,
aspect_map=aspect_map,
gt_pairs=set(),
query_known_gos=test_eval_data.known,
parent_map_str=parent_map if p.expand_votes_to_ancestors else None,
ia_weights=ia_weights,
pca_state=pca_state,
pivot_go_ids=pivot_go_ids,
embedding_pool=all_embeddings,
),
sequence_context=SequenceContext(
query_sequences=test_qs,
ref_sequences=test_rs,
query_tax_ids=test_qt,
ref_tax_ids=test_rt,
),
query_known_gos=test_eval_data.known,
parent_map_str=parent_map if p.expand_votes_to_ancestors else None,
ia_weights=ia_weights,
pca_state=pca_state,
pivot_go_ids=pivot_go_ids,
stream_output=StreamOutput(output_parquet=test_unlabeled_path),
embedding_pool=all_embeddings,
)
del test_ref, test_emb, test_valid, test_qs, test_rs, test_qt, test_rt
gc.collect()
Expand Down
45 changes: 25 additions & 20 deletions tests/test_knn_streaming_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import pyarrow.parquet as pq
import pytest

from protea.core.training_dump_helpers import StreamOutput, _knn_transfer_and_label
from protea.core.training_dump_helpers import (
KnnTransferContext,
StreamOutput,
_knn_transfer_and_label,
)


class _StubAnc2Vec:
Expand Down Expand Up @@ -123,33 +127,34 @@ def _run(mode: str, tmp_path: Path | None = None, *, expand: bool, pivot=None):
session = MagicMock()
p = _mk_payload(expand=expand)

kwargs: dict = {
"query_known_gos": None,
"parent_map_str": parent_map_str if expand else None,
"ia_weights": None,
"pca_state": None,
}
kwargs["pivot_go_ids"] = pivot
if mode == "stream":
kwargs["stream_output"] = StreamOutput(
output_parquet=tmp_path / "out.parquet",
chunk_rows=3, # tiny to force multiple flushes
)
ctx = KnnTransferContext(
valid_queries=valid_queries,
query_emb=query_emb,
ref_by_aspect=ref_by_aspect,
go_id_map=go_id_map,
aspect_map=aspect_map,
gt_pairs=gt_pairs,
query_known_gos=None,
parent_map_str=parent_map_str if expand else None,
ia_weights=None,
pca_state=None,
pivot_go_ids=pivot,
)
stream_output = (
StreamOutput(output_parquet=tmp_path / "out.parquet", chunk_rows=3)
if mode == "stream"
else None
)

with patch(
"protea.core.training_dump_helpers.get_anc2vec_index",
return_value=_StubAnc2Vec(),
):
return _knn_transfer_and_label(
session,
valid_queries,
query_emb,
ref_by_aspect,
go_id_map,
aspect_map,
gt_pairs,
p,
**kwargs,
ctx,
stream_output=stream_output,
)


Expand Down
Loading