Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 31 additions & 59 deletions protea/api/routers/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from protea.core.operations.load_quickgo_annotations import LoadQuickGOAnnotationsPayload
from protea.core.operations.run_cafa_evaluation import RunCafaEvaluationPayload
from protea.infrastructure.benchmark_config import BenchmarkConfig
from protea.infrastructure.orm.models.annotation.evaluation_result import EvaluationResult
from protea.infrastructure.orm.models.annotation.evaluation_set import EvaluationSet
from protea.infrastructure.orm.models.annotation.go_term import GOTerm
from protea.infrastructure.orm.models.annotation.go_term_relationship import GOTermRelationship
Expand All @@ -34,15 +33,19 @@
AnnotationSetReferencedError,
EntityNotFoundError,
delete_annotation_set_data,
delete_eval_result_collect_keys,
delete_evaluation_set_collect_keys,
get_annotation_set_data,
get_eval_result_with_keys,
get_evaluation_set_data,
get_snapshot_data,
iter_delta_proteins_fasta,
iter_groundtruth_tsv,
list_annotation_sets_data,
list_evaluation_results_data,
list_evaluation_sets_data,
list_snapshots_data,
render_evaluation_metrics_tsv,
)
from protea.services.annotations_service import (
set_snapshot_ia_url as _set_snapshot_ia_url_service,
Expand Down Expand Up @@ -504,30 +507,18 @@ def download_evaluation_metrics(
result_id: UUID,
factory: sessionmaker[Session] = Depends(get_session_factory),
) -> StreamingResponse:
with session_scope(factory) as session:
result = session.get(EvaluationResult, result_id)
if result is None or result.evaluation_set_id != eval_id:
raise HTTPException(status_code=404, detail="EvaluationResult not found")

def _rows() -> Iterator[str]:
yield "setting\tnamespace\tfmax\tprecision\trecall\ttau\tcoverage\tn_proteins\n"
for setting in ("NK", "LK", "PK"):
ns_data = result.results.get(setting, {})
for ns in ASPECT_CAFA_CODES:
m = ns_data.get(ns)
if m is None:
continue
yield (
f"{setting}\t{ns}\t{m.get('fmax', '')}\t{m.get('precision', '')}\t"
f"{m.get('recall', '')}\t{m.get('tau', '')}\t{m.get('coverage', '')}\t"
f"{m.get('n_proteins', '')}\n"
)

return StreamingResponse(
_rows(),
media_type="text/tab-separated-values",
headers={"Content-Disposition": f'attachment; filename="metrics_{result_id}.tsv"'},
)
try:
with session_scope(factory) as session:
result, _ = get_eval_result_with_keys(session, eval_id, result_id)
return StreamingResponse(
render_evaluation_metrics_tsv(result, ASPECT_CAFA_CODES),
media_type="text/tab-separated-values",
headers={
"Content-Disposition": f'attachment; filename="metrics_{result_id}.tsv"'
},
)
except EntityNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc


@router.get(
Expand All @@ -539,11 +530,11 @@ def download_evaluation_artifacts(
result_id: UUID,
factory: sessionmaker[Session] = Depends(get_session_factory),
) -> StreamingResponse:
with session_scope(factory) as session:
result = session.get(EvaluationResult, result_id)
if result is None or result.evaluation_set_id != eval_id:
raise HTTPException(status_code=404, detail="EvaluationResult not found")
keys = (result.results or {}).get("artifacts", {}).get("keys") or []
try:
with session_scope(factory) as session:
_, keys = get_eval_result_with_keys(session, eval_id, result_id)
except EntityNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc

if not keys:
raise HTTPException(status_code=404, detail="No artifacts found for this result")
Expand Down Expand Up @@ -578,29 +569,11 @@ def list_evaluation_results(
eval_id: UUID,
factory: sessionmaker[Session] = Depends(get_session_factory),
) -> list[dict[str, Any]]:
with session_scope(factory) as session:
if session.get(EvaluationSet, eval_id) is None:
raise HTTPException(status_code=404, detail="EvaluationSet not found")
rows = (
session.query(EvaluationResult)
.filter(EvaluationResult.evaluation_set_id == eval_id)
.order_by(EvaluationResult.created_at.desc())
.all()
)
return [
{
"id": str(r.id),
"evaluation_set_id": str(r.evaluation_set_id),
"prediction_set_id": str(r.prediction_set_id),
"scoring_config_id": str(r.scoring_config_id) if r.scoring_config_id else None,
"reranker_model_id": str(r.reranker_model_id) if r.reranker_model_id else None,
"reranker_config": r.reranker_config,
"job_id": str(r.job_id) if r.job_id else None,
"created_at": r.created_at.isoformat(),
"results": r.results,
}
for r in rows
]
try:
with session_scope(factory) as session:
return list_evaluation_results_data(session, eval_id)
except EntityNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc


@router.delete(
Expand All @@ -616,12 +589,11 @@ def delete_evaluation_result(
from protea.infrastructure.settings import load_settings
from protea.infrastructure.storage import get_artifact_store

with session_scope(factory) as session:
result = session.get(EvaluationResult, result_id)
if result is None or result.evaluation_set_id != eval_id:
raise HTTPException(status_code=404, detail="EvaluationResult not found")
keys = (result.results or {}).get("artifacts", {}).get("keys") or []
session.delete(result)
try:
with session_scope(factory) as session:
keys = delete_eval_result_collect_keys(session, eval_id, result_id)
except EntityNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc

project_root = Path(__file__).resolve().parents[3]
store = get_artifact_store(load_settings(project_root))
Expand Down
97 changes: 97 additions & 0 deletions protea/services/annotations_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,98 @@ def iter_delta_proteins_fasta(
return lines


def evaluation_result_to_dict(r: EvaluationResult) -> dict[str, Any]:
"""Serialise an :class:`EvaluationResult` to its API dict shape."""
return {
"id": str(r.id),
"evaluation_set_id": str(r.evaluation_set_id),
"prediction_set_id": str(r.prediction_set_id),
"scoring_config_id": str(r.scoring_config_id) if r.scoring_config_id else None,
"reranker_model_id": str(r.reranker_model_id) if r.reranker_model_id else None,
"reranker_config": r.reranker_config,
"job_id": str(r.job_id) if r.job_id else None,
"created_at": r.created_at.isoformat(),
"results": r.results,
}


def list_evaluation_results_data(
session: Session,
eval_id: uuid.UUID,
) -> list[dict[str, Any]]:
"""List EvaluationResult rows for one EvaluationSet (newest first).

Raises :class:`EntityNotFoundError` when the EvaluationSet does
not resolve.
"""
if session.get(EvaluationSet, eval_id) is None:
raise EntityNotFoundError("EvaluationSet", eval_id)
rows = (
session.query(EvaluationResult)
.filter(EvaluationResult.evaluation_set_id == eval_id)
.order_by(EvaluationResult.created_at.desc())
.all()
)
return [evaluation_result_to_dict(r) for r in rows]


def get_eval_result_with_keys(
session: Session,
eval_id: uuid.UUID,
result_id: uuid.UUID,
) -> tuple[EvaluationResult, list[str]]:
"""Fetch an EvaluationResult belonging to ``eval_id``; return (row, artifact_keys).

Raises :class:`EntityNotFoundError` ("EvaluationResult") when
the result does not exist or does not belong to ``eval_id``.
"""
result = session.get(EvaluationResult, result_id)
if result is None or result.evaluation_set_id != eval_id:
raise EntityNotFoundError("EvaluationResult", result_id)
keys: list[str] = (result.results or {}).get("artifacts", {}).get("keys") or []
return result, keys


def delete_eval_result_collect_keys(
session: Session,
eval_id: uuid.UUID,
result_id: uuid.UUID,
) -> list[str]:
"""Delete the EvaluationResult and return the artifact keys to clean up.

Same split as :func:`delete_evaluation_set_collect_keys`: the
DB delete happens here; the artifact-store deletion is the
router's responsibility (it owns the ``ArtifactStore`` factory).
"""
result, keys = get_eval_result_with_keys(session, eval_id, result_id)
session.delete(result)
return keys


def render_evaluation_metrics_tsv(
result: EvaluationResult,
aspect_codes: tuple[str, ...],
) -> Any:
"""Yield TSV rows for the per-(setting, namespace) metrics summary.

The caller passes the aspect-codes tuple (``ASPECT_CAFA_CODES``)
so the service stays free of the domain layer. Returns a
generator suitable for ``StreamingResponse``.
"""
yield "setting\tnamespace\tfmax\tprecision\trecall\ttau\tcoverage\tn_proteins\n"
for setting in ("NK", "LK", "PK"):
ns_data = result.results.get(setting, {})
for ns in aspect_codes:
m = ns_data.get(ns)
if m is None:
continue
yield (
f"{setting}\t{ns}\t{m.get('fmax', '')}\t{m.get('precision', '')}\t"
f"{m.get('recall', '')}\t{m.get('tau', '')}\t{m.get('coverage', '')}\t"
f"{m.get('n_proteins', '')}\n"
)


def evaluation_set_to_dict(e: EvaluationSet) -> dict[str, Any]:
"""Serialise an :class:`EvaluationSet` to its API dict shape."""
return {
Expand Down Expand Up @@ -395,8 +487,13 @@ def delete_evaluation_set_collect_keys(
"get_annotation_set_data",
"get_evaluation_set_data",
"get_snapshot_data",
"delete_eval_result_collect_keys",
"evaluation_result_to_dict",
"get_eval_result_with_keys",
"iter_delta_proteins_fasta",
"iter_groundtruth_tsv",
"list_evaluation_results_data",
"render_evaluation_metrics_tsv",
"list_annotation_sets_data",
"list_evaluation_sets_data",
"list_snapshots_data",
Expand Down
Loading