diff --git a/.smell-baseline.json b/.smell-baseline.json index a91a42c..eda68cf 100644 --- a/.smell-baseline.json +++ b/.smell-baseline.json @@ -65,8 +65,8 @@ "kind": "method", "path": "protea/api/routers/benchmark.py", "name": "get_benchmark_matrix", - "line": 143, - "metric": 219, + "line": 185, + "metric": 209, "threshold": 60 }, { @@ -92,8 +92,8 @@ "kind": "method", "path": "protea/api/routers/reranker_models.py", "name": "_register_model", - "line": 99, - "metric": 90, + "line": 120, + "metric": 81, "threshold": 60 }, { @@ -510,21 +510,12 @@ "metric": 7, "threshold": 6 }, - { - "key": "params::protea/api/routers/reranker_models.py::_register_model", - "kind": "params", - "path": "protea/api/routers/reranker_models.py", - "name": "_register_model", - "line": 99, - "metric": 10, - "threshold": 6 - }, { "key": "params::protea/api/routers/reranker_models.py::import_reranker_model_multipart", "kind": "params", "path": "protea/api/routers/reranker_models.py", "name": "import_reranker_model_multipart", - "line": 192, + "line": 204, "metric": 10, "threshold": 6 }, diff --git a/protea/api/routers/reranker_models.py b/protea/api/routers/reranker_models.py index 2bbee08..3ba1278 100644 --- a/protea/api/routers/reranker_models.py +++ b/protea/api/routers/reranker_models.py @@ -22,6 +22,7 @@ import json import uuid +from dataclasses import dataclass from typing import Any from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile @@ -43,6 +44,26 @@ compute_feature_schema_sha = None # type: ignore[assignment] +@dataclass(frozen=True) +class _RerankerRegistration: + """Bundle of inputs ``_register_model`` accepts. + + Groups the nine non-session fields the multipart and by-reference + endpoints both fill so the registration helper signature stays + under flake8-bugbear's parameter ceiling. + """ + + name: str + artifact_uri: str + run: dict[str, Any] + spec_yaml_text: str + dataset_id_override: str | None + external_source: str | None + prediction_set_id: str | None + evaluation_set_id: str | None + force: bool + + router = APIRouter(prefix="/reranker-models", tags=["reranker-models"]) @@ -97,34 +118,25 @@ def _compute_feature_schema_sha(run: dict[str, Any]) -> str | None: def _register_model( - *, session: Session, - name: str, - artifact_uri: str, - run: dict[str, Any], - spec_yaml_text: str, - dataset_id_override: str | None, - external_source: str | None, - prediction_set_id: str | None, - evaluation_set_id: str | None, - force: bool, + reg: _RerankerRegistration, ) -> uuid.UUID: - existing = session.query(RerankerModel).filter(RerankerModel.name == name).first() + existing = session.query(RerankerModel).filter(RerankerModel.name == reg.name).first() if existing is not None: - if not force: + if not reg.force: raise HTTPException( status_code=409, - detail=f"RerankerModel name={name!r} already exists (id={existing.id})", + detail=f"RerankerModel name={reg.name!r} already exists (id={existing.id})", ) session.delete(existing) session.flush() - category, aspect = _parse_cell(_extract_cell_from_spec(spec_yaml_text)) + category, aspect = _parse_cell(_extract_cell_from_spec(reg.spec_yaml_text)) - dataset = run.get("dataset", {}) or {} + dataset = reg.run.get("dataset", {}) or {} dataset_uuid: uuid.UUID | None = None - if dataset_id_override: - dataset_uuid = uuid.UUID(dataset_id_override) + if reg.dataset_id_override: + dataset_uuid = uuid.UUID(reg.dataset_id_override) else: dataset_name = dataset.get("name") if dataset_name: @@ -132,7 +144,7 @@ def _register_model( if row is not None: dataset_uuid = row.id - feature_schema_sha = _compute_feature_schema_sha(run) + feature_schema_sha = _compute_feature_schema_sha(reg.run) embedding_config_id_raw = dataset.get("embedding_config_id") ontology_snapshot_id_raw = dataset.get("ontology_snapshot_id") @@ -153,32 +165,32 @@ def _register_model( ontology_snapshot_id = candidate model = RerankerModel( - name=name, - prediction_set_id=uuid.UUID(prediction_set_id) if prediction_set_id else None, - evaluation_set_id=uuid.UUID(evaluation_set_id) if evaluation_set_id else None, + name=reg.name, + prediction_set_id=uuid.UUID(reg.prediction_set_id) if reg.prediction_set_id else None, + evaluation_set_id=uuid.UUID(reg.evaluation_set_id) if reg.evaluation_set_id else None, category=category, aspect=aspect, model_data=None, - artifact_uri=artifact_uri, + artifact_uri=reg.artifact_uri, feature_schema_sha=feature_schema_sha, embedding_config_id=embedding_config_id, ontology_snapshot_id=ontology_snapshot_id, producer_version=dataset.get("producer_version"), producer_git_sha=dataset.get("producer_git_sha"), - spec_yaml=spec_yaml_text, - metrics=run.get("metrics", {}) or {}, - feature_importance=run.get("feature_importance", {}) or {}, + spec_yaml=reg.spec_yaml_text, + metrics=reg.run.get("metrics", {}) or {}, + feature_importance=reg.run.get("feature_importance", {}) or {}, # Categorical code maps live in metrics under a reserved key so the # predict path can replicate the lab's sorted-unique encoding instead # of falling back to ``pd.factorize`` (first-seen order, which gives # different codes than training and silently corrupts LK/PK scores). dataset_id=dataset_uuid, - external_source=external_source, + external_source=reg.external_source, ) # Stash categorical_codes in metrics if the lab supplied them. Stored as # ``metrics["__categorical_codes__"]`` to keep the column scalar-shaped # without bloating spec_yaml. - cat_codes = run.get("categorical_codes") + cat_codes = reg.run.get("categorical_codes") if cat_codes: m = dict(model.metrics or {}) m["__categorical_codes__"] = cat_codes @@ -223,16 +235,18 @@ async def import_reranker_model_multipart( with session_scope(factory) as session: model_id = _register_model( - session=session, - name=resolved_name, - artifact_uri=artifact_uri, - run=run, - spec_yaml_text=spec_text, - dataset_id_override=dataset_id, - external_source=external_source, - prediction_set_id=prediction_set_id, - evaluation_set_id=evaluation_set_id, - force=force, + session, + _RerankerRegistration( + name=resolved_name, + artifact_uri=artifact_uri, + run=run, + spec_yaml_text=spec_text, + dataset_id_override=dataset_id, + external_source=external_source, + prediction_set_id=prediction_set_id, + evaluation_set_id=evaluation_set_id, + force=force, + ), ) return { @@ -281,16 +295,18 @@ def import_reranker_model_by_reference( with session_scope(factory) as session: model_id = _register_model( - session=session, - name=resolved_name, - artifact_uri=body.artifact_uri, - run=body.run, - spec_yaml_text=body.spec_yaml, - dataset_id_override=body.dataset_id, - external_source=body.external_source, - prediction_set_id=body.prediction_set_id, - evaluation_set_id=body.evaluation_set_id, - force=body.force, + session, + _RerankerRegistration( + name=resolved_name, + artifact_uri=body.artifact_uri, + run=body.run, + spec_yaml_text=body.spec_yaml, + dataset_id_override=body.dataset_id, + external_source=body.external_source, + prediction_set_id=body.prediction_set_id, + evaluation_set_id=body.evaluation_set_id, + force=body.force, + ), ) return {"id": str(model_id), "name": resolved_name, "artifact_uri": body.artifact_uri}