diff --git a/protea/api/routers/annotations.py b/protea/api/routers/annotations.py index 4ddcf51..487da7e 100644 --- a/protea/api/routers/annotations.py +++ b/protea/api/routers/annotations.py @@ -22,7 +22,6 @@ from protea.core.operations.load_quickgo_annotations import LoadQuickGOAnnotationsPayload from protea.core.operations.run_cafa_evaluation import RunCafaEvaluationPayload from protea.infrastructure.benchmark_config import BenchmarkConfig -from protea.infrastructure.orm.models.annotation.evaluation_result import EvaluationResult from protea.infrastructure.orm.models.annotation.evaluation_set import EvaluationSet from protea.infrastructure.orm.models.annotation.go_term import GOTerm from protea.infrastructure.orm.models.annotation.go_term_relationship import GOTermRelationship @@ -34,15 +33,19 @@ AnnotationSetReferencedError, EntityNotFoundError, delete_annotation_set_data, + delete_eval_result_collect_keys, delete_evaluation_set_collect_keys, get_annotation_set_data, + get_eval_result_with_keys, get_evaluation_set_data, get_snapshot_data, iter_delta_proteins_fasta, iter_groundtruth_tsv, list_annotation_sets_data, + list_evaluation_results_data, list_evaluation_sets_data, list_snapshots_data, + render_evaluation_metrics_tsv, ) from protea.services.annotations_service import ( set_snapshot_ia_url as _set_snapshot_ia_url_service, @@ -504,30 +507,18 @@ def download_evaluation_metrics( result_id: UUID, factory: sessionmaker[Session] = Depends(get_session_factory), ) -> StreamingResponse: - with session_scope(factory) as session: - result = session.get(EvaluationResult, result_id) - if result is None or result.evaluation_set_id != eval_id: - raise HTTPException(status_code=404, detail="EvaluationResult not found") - - def _rows() -> Iterator[str]: - yield "setting\tnamespace\tfmax\tprecision\trecall\ttau\tcoverage\tn_proteins\n" - for setting in ("NK", "LK", "PK"): - ns_data = result.results.get(setting, {}) - for ns in ASPECT_CAFA_CODES: - m = ns_data.get(ns) - if m is None: - continue - yield ( - f"{setting}\t{ns}\t{m.get('fmax', '')}\t{m.get('precision', '')}\t" - f"{m.get('recall', '')}\t{m.get('tau', '')}\t{m.get('coverage', '')}\t" - f"{m.get('n_proteins', '')}\n" - ) - - return StreamingResponse( - _rows(), - media_type="text/tab-separated-values", - headers={"Content-Disposition": f'attachment; filename="metrics_{result_id}.tsv"'}, - ) + try: + with session_scope(factory) as session: + result, _ = get_eval_result_with_keys(session, eval_id, result_id) + return StreamingResponse( + render_evaluation_metrics_tsv(result, ASPECT_CAFA_CODES), + media_type="text/tab-separated-values", + headers={ + "Content-Disposition": f'attachment; filename="metrics_{result_id}.tsv"' + }, + ) + except EntityNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc @router.get( @@ -539,11 +530,11 @@ def download_evaluation_artifacts( result_id: UUID, factory: sessionmaker[Session] = Depends(get_session_factory), ) -> StreamingResponse: - with session_scope(factory) as session: - result = session.get(EvaluationResult, result_id) - if result is None or result.evaluation_set_id != eval_id: - raise HTTPException(status_code=404, detail="EvaluationResult not found") - keys = (result.results or {}).get("artifacts", {}).get("keys") or [] + try: + with session_scope(factory) as session: + _, keys = get_eval_result_with_keys(session, eval_id, result_id) + except EntityNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc if not keys: raise HTTPException(status_code=404, detail="No artifacts found for this result") @@ -578,29 +569,11 @@ def list_evaluation_results( eval_id: UUID, factory: sessionmaker[Session] = Depends(get_session_factory), ) -> list[dict[str, Any]]: - with session_scope(factory) as session: - if session.get(EvaluationSet, eval_id) is None: - raise HTTPException(status_code=404, detail="EvaluationSet not found") - rows = ( - session.query(EvaluationResult) - .filter(EvaluationResult.evaluation_set_id == eval_id) - .order_by(EvaluationResult.created_at.desc()) - .all() - ) - return [ - { - "id": str(r.id), - "evaluation_set_id": str(r.evaluation_set_id), - "prediction_set_id": str(r.prediction_set_id), - "scoring_config_id": str(r.scoring_config_id) if r.scoring_config_id else None, - "reranker_model_id": str(r.reranker_model_id) if r.reranker_model_id else None, - "reranker_config": r.reranker_config, - "job_id": str(r.job_id) if r.job_id else None, - "created_at": r.created_at.isoformat(), - "results": r.results, - } - for r in rows - ] + try: + with session_scope(factory) as session: + return list_evaluation_results_data(session, eval_id) + except EntityNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc @router.delete( @@ -616,12 +589,11 @@ def delete_evaluation_result( from protea.infrastructure.settings import load_settings from protea.infrastructure.storage import get_artifact_store - with session_scope(factory) as session: - result = session.get(EvaluationResult, result_id) - if result is None or result.evaluation_set_id != eval_id: - raise HTTPException(status_code=404, detail="EvaluationResult not found") - keys = (result.results or {}).get("artifacts", {}).get("keys") or [] - session.delete(result) + try: + with session_scope(factory) as session: + keys = delete_eval_result_collect_keys(session, eval_id, result_id) + except EntityNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc project_root = Path(__file__).resolve().parents[3] store = get_artifact_store(load_settings(project_root)) diff --git a/protea/services/annotations_service.py b/protea/services/annotations_service.py index fa6da9e..ae91018 100644 --- a/protea/services/annotations_service.py +++ b/protea/services/annotations_service.py @@ -321,6 +321,98 @@ def iter_delta_proteins_fasta( return lines +def evaluation_result_to_dict(r: EvaluationResult) -> dict[str, Any]: + """Serialise an :class:`EvaluationResult` to its API dict shape.""" + return { + "id": str(r.id), + "evaluation_set_id": str(r.evaluation_set_id), + "prediction_set_id": str(r.prediction_set_id), + "scoring_config_id": str(r.scoring_config_id) if r.scoring_config_id else None, + "reranker_model_id": str(r.reranker_model_id) if r.reranker_model_id else None, + "reranker_config": r.reranker_config, + "job_id": str(r.job_id) if r.job_id else None, + "created_at": r.created_at.isoformat(), + "results": r.results, + } + + +def list_evaluation_results_data( + session: Session, + eval_id: uuid.UUID, +) -> list[dict[str, Any]]: + """List EvaluationResult rows for one EvaluationSet (newest first). + + Raises :class:`EntityNotFoundError` when the EvaluationSet does + not resolve. + """ + if session.get(EvaluationSet, eval_id) is None: + raise EntityNotFoundError("EvaluationSet", eval_id) + rows = ( + session.query(EvaluationResult) + .filter(EvaluationResult.evaluation_set_id == eval_id) + .order_by(EvaluationResult.created_at.desc()) + .all() + ) + return [evaluation_result_to_dict(r) for r in rows] + + +def get_eval_result_with_keys( + session: Session, + eval_id: uuid.UUID, + result_id: uuid.UUID, +) -> tuple[EvaluationResult, list[str]]: + """Fetch an EvaluationResult belonging to ``eval_id``; return (row, artifact_keys). + + Raises :class:`EntityNotFoundError` ("EvaluationResult") when + the result does not exist or does not belong to ``eval_id``. + """ + result = session.get(EvaluationResult, result_id) + if result is None or result.evaluation_set_id != eval_id: + raise EntityNotFoundError("EvaluationResult", result_id) + keys: list[str] = (result.results or {}).get("artifacts", {}).get("keys") or [] + return result, keys + + +def delete_eval_result_collect_keys( + session: Session, + eval_id: uuid.UUID, + result_id: uuid.UUID, +) -> list[str]: + """Delete the EvaluationResult and return the artifact keys to clean up. + + Same split as :func:`delete_evaluation_set_collect_keys`: the + DB delete happens here; the artifact-store deletion is the + router's responsibility (it owns the ``ArtifactStore`` factory). + """ + result, keys = get_eval_result_with_keys(session, eval_id, result_id) + session.delete(result) + return keys + + +def render_evaluation_metrics_tsv( + result: EvaluationResult, + aspect_codes: tuple[str, ...], +) -> Any: + """Yield TSV rows for the per-(setting, namespace) metrics summary. + + The caller passes the aspect-codes tuple (``ASPECT_CAFA_CODES``) + so the service stays free of the domain layer. Returns a + generator suitable for ``StreamingResponse``. + """ + yield "setting\tnamespace\tfmax\tprecision\trecall\ttau\tcoverage\tn_proteins\n" + for setting in ("NK", "LK", "PK"): + ns_data = result.results.get(setting, {}) + for ns in aspect_codes: + m = ns_data.get(ns) + if m is None: + continue + yield ( + f"{setting}\t{ns}\t{m.get('fmax', '')}\t{m.get('precision', '')}\t" + f"{m.get('recall', '')}\t{m.get('tau', '')}\t{m.get('coverage', '')}\t" + f"{m.get('n_proteins', '')}\n" + ) + + def evaluation_set_to_dict(e: EvaluationSet) -> dict[str, Any]: """Serialise an :class:`EvaluationSet` to its API dict shape.""" return { @@ -395,8 +487,13 @@ def delete_evaluation_set_collect_keys( "get_annotation_set_data", "get_evaluation_set_data", "get_snapshot_data", + "delete_eval_result_collect_keys", + "evaluation_result_to_dict", + "get_eval_result_with_keys", "iter_delta_proteins_fasta", "iter_groundtruth_tsv", + "list_evaluation_results_data", + "render_evaluation_metrics_tsv", "list_annotation_sets_data", "list_evaluation_sets_data", "list_snapshots_data",