diff --git a/protea/api/routers/annotations.py b/protea/api/routers/annotations.py index 487da7e..cb50a0c 100644 --- a/protea/api/routers/annotations.py +++ b/protea/api/routers/annotations.py @@ -23,9 +23,6 @@ from protea.core.operations.run_cafa_evaluation import RunCafaEvaluationPayload from protea.infrastructure.benchmark_config import BenchmarkConfig from protea.infrastructure.orm.models.annotation.evaluation_set import EvaluationSet -from protea.infrastructure.orm.models.annotation.go_term import GOTerm -from protea.infrastructure.orm.models.annotation.go_term_relationship import GOTermRelationship -from protea.infrastructure.orm.models.annotation.ontology_snapshot import OntologySnapshot from protea.infrastructure.orm.models.embedding.scoring_config import ScoringConfig from protea.infrastructure.queue.publisher import publish_job from protea.infrastructure.session import session_scope @@ -38,6 +35,7 @@ get_annotation_set_data, get_eval_result_with_keys, get_evaluation_set_data, + get_go_subgraph_data, get_snapshot_data, iter_delta_proteins_fasta, iter_groundtruth_tsv, @@ -612,72 +610,8 @@ def get_go_subgraph( factory: sessionmaker[Session] = Depends(get_session_factory), ) -> dict[str, Any]: """Return a subgraph of the GO DAG containing the requested terms and their ancestors up to ``depth`` levels.""" - with session_scope(factory) as session: - snap = session.get(OntologySnapshot, snapshot_id) - if snap is None: - raise HTTPException(status_code=404, detail="OntologySnapshot not found") - - query_go_ids = {g.strip() for g in go_ids.split(",") if g.strip()} - - # Resolve initial term DB ids - seed_terms = ( - session.query(GOTerm) - .filter( - GOTerm.ontology_snapshot_id == snapshot_id, - GOTerm.go_id.in_(query_go_ids), - ) - .all() - ) - - if not seed_terms: - return {"nodes": [], "edges": []} - - # BFS upward through the DAG - visited_ids: set[int] = {t.id for t in seed_terms} - frontier: set[int] = visited_ids.copy() - all_terms: dict[int, GOTerm] = {t.id: t for t in seed_terms} - all_edges: list[dict[str, Any]] = [] - - for _ in range(depth): - if not frontier: - break - rels = ( - session.query(GOTermRelationship) - .filter( - GOTermRelationship.ontology_snapshot_id == snapshot_id, - GOTermRelationship.child_go_term_id.in_(frontier), - ) - .all() - ) - - parent_ids = {r.parent_go_term_id for r in rels} - visited_ids - for r in rels: - all_edges.append( - { - "source": r.child_go_term_id, - "target": r.parent_go_term_id, - "relation_type": r.relation_type, - } - ) - - if parent_ids: - parents = session.query(GOTerm).filter(GOTerm.id.in_(parent_ids)).all() - for p in parents: - all_terms[p.id] = p - visited_ids |= parent_ids - frontier = parent_ids - else: - break - - query_db_ids = {t.id for t in seed_terms} - nodes = [ - { - "id": t.id, - "go_id": t.go_id, - "name": t.name, - "aspect": t.aspect, - "is_query": t.id in query_db_ids, - } - for t in all_terms.values() - ] - return {"nodes": nodes, "edges": all_edges} + try: + with session_scope(factory) as session: + return get_go_subgraph_data(session, snapshot_id, go_ids, depth) + except EntityNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc diff --git a/protea/services/annotations_service.py b/protea/services/annotations_service.py index ae91018..2a80eaa 100644 --- a/protea/services/annotations_service.py +++ b/protea/services/annotations_service.py @@ -321,6 +321,94 @@ def iter_delta_proteins_fasta( return lines +def get_go_subgraph_data( + session: Session, + snapshot_id: uuid.UUID, + go_ids: str, + depth: int, +) -> dict[str, Any]: + """BFS the GO DAG upward from the requested seed terms. + + Returns ``{"nodes": [...], "edges": [...]}`` ready for the API. + Each node has ``id`` (DB id), ``go_id``, ``name``, ``aspect``, + ``is_query`` (True for the seed terms). Each edge has + ``source`` (child id), ``target`` (parent id), ``relation_type``. + + Raises :class:`EntityNotFoundError` when the snapshot does not + resolve. + """ + from protea.infrastructure.orm.models.annotation.go_term_relationship import ( + GOTermRelationship, + ) + + snap = session.get(OntologySnapshot, snapshot_id) + if snap is None: + raise EntityNotFoundError("OntologySnapshot", snapshot_id) + + query_go_ids = {g.strip() for g in go_ids.split(",") if g.strip()} + + seed_terms = ( + session.query(GOTerm) + .filter( + GOTerm.ontology_snapshot_id == snapshot_id, + GOTerm.go_id.in_(query_go_ids), + ) + .all() + ) + + if not seed_terms: + return {"nodes": [], "edges": []} + + visited_ids: set[int] = {t.id for t in seed_terms} + frontier: set[int] = visited_ids.copy() + all_terms: dict[int, GOTerm] = {t.id: t for t in seed_terms} + all_edges: list[dict[str, Any]] = [] + + for _ in range(depth): + if not frontier: + break + rels = ( + session.query(GOTermRelationship) + .filter( + GOTermRelationship.ontology_snapshot_id == snapshot_id, + GOTermRelationship.child_go_term_id.in_(frontier), + ) + .all() + ) + + parent_ids = {r.parent_go_term_id for r in rels} - visited_ids + for r in rels: + all_edges.append( + { + "source": r.child_go_term_id, + "target": r.parent_go_term_id, + "relation_type": r.relation_type, + } + ) + + if parent_ids: + parents = session.query(GOTerm).filter(GOTerm.id.in_(parent_ids)).all() + for p in parents: + all_terms[p.id] = p + visited_ids |= parent_ids + frontier = parent_ids + else: + break + + query_db_ids = {t.id for t in seed_terms} + nodes = [ + { + "id": t.id, + "go_id": t.go_id, + "name": t.name, + "aspect": t.aspect, + "is_query": t.id in query_db_ids, + } + for t in all_terms.values() + ] + return {"nodes": nodes, "edges": all_edges} + + def evaluation_result_to_dict(r: EvaluationResult) -> dict[str, Any]: """Serialise an :class:`EvaluationResult` to its API dict shape.""" return { @@ -490,6 +578,7 @@ def delete_evaluation_set_collect_keys( "delete_eval_result_collect_keys", "evaluation_result_to_dict", "get_eval_result_with_keys", + "get_go_subgraph_data", "iter_delta_proteins_fasta", "iter_groundtruth_tsv", "list_evaluation_results_data",