Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 6 additions & 72 deletions protea/api/routers/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
from protea.core.operations.run_cafa_evaluation import RunCafaEvaluationPayload
from protea.infrastructure.benchmark_config import BenchmarkConfig
from protea.infrastructure.orm.models.annotation.evaluation_set import EvaluationSet
from protea.infrastructure.orm.models.annotation.go_term import GOTerm
from protea.infrastructure.orm.models.annotation.go_term_relationship import GOTermRelationship
from protea.infrastructure.orm.models.annotation.ontology_snapshot import OntologySnapshot
from protea.infrastructure.orm.models.embedding.scoring_config import ScoringConfig
from protea.infrastructure.queue.publisher import publish_job
from protea.infrastructure.session import session_scope
Expand All @@ -38,6 +35,7 @@
get_annotation_set_data,
get_eval_result_with_keys,
get_evaluation_set_data,
get_go_subgraph_data,
get_snapshot_data,
iter_delta_proteins_fasta,
iter_groundtruth_tsv,
Expand Down Expand Up @@ -612,72 +610,8 @@ def get_go_subgraph(
factory: sessionmaker[Session] = Depends(get_session_factory),
) -> dict[str, Any]:
"""Return a subgraph of the GO DAG containing the requested terms and their ancestors up to ``depth`` levels."""
with session_scope(factory) as session:
snap = session.get(OntologySnapshot, snapshot_id)
if snap is None:
raise HTTPException(status_code=404, detail="OntologySnapshot not found")

query_go_ids = {g.strip() for g in go_ids.split(",") if g.strip()}

# Resolve initial term DB ids
seed_terms = (
session.query(GOTerm)
.filter(
GOTerm.ontology_snapshot_id == snapshot_id,
GOTerm.go_id.in_(query_go_ids),
)
.all()
)

if not seed_terms:
return {"nodes": [], "edges": []}

# BFS upward through the DAG
visited_ids: set[int] = {t.id for t in seed_terms}
frontier: set[int] = visited_ids.copy()
all_terms: dict[int, GOTerm] = {t.id: t for t in seed_terms}
all_edges: list[dict[str, Any]] = []

for _ in range(depth):
if not frontier:
break
rels = (
session.query(GOTermRelationship)
.filter(
GOTermRelationship.ontology_snapshot_id == snapshot_id,
GOTermRelationship.child_go_term_id.in_(frontier),
)
.all()
)

parent_ids = {r.parent_go_term_id for r in rels} - visited_ids
for r in rels:
all_edges.append(
{
"source": r.child_go_term_id,
"target": r.parent_go_term_id,
"relation_type": r.relation_type,
}
)

if parent_ids:
parents = session.query(GOTerm).filter(GOTerm.id.in_(parent_ids)).all()
for p in parents:
all_terms[p.id] = p
visited_ids |= parent_ids
frontier = parent_ids
else:
break

query_db_ids = {t.id for t in seed_terms}
nodes = [
{
"id": t.id,
"go_id": t.go_id,
"name": t.name,
"aspect": t.aspect,
"is_query": t.id in query_db_ids,
}
for t in all_terms.values()
]
return {"nodes": nodes, "edges": all_edges}
try:
with session_scope(factory) as session:
return get_go_subgraph_data(session, snapshot_id, go_ids, depth)
except EntityNotFoundError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
89 changes: 89 additions & 0 deletions protea/services/annotations_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,94 @@ def iter_delta_proteins_fasta(
return lines


def get_go_subgraph_data(
session: Session,
snapshot_id: uuid.UUID,
go_ids: str,
depth: int,
) -> dict[str, Any]:
"""BFS the GO DAG upward from the requested seed terms.

Returns ``{"nodes": [...], "edges": [...]}`` ready for the API.
Each node has ``id`` (DB id), ``go_id``, ``name``, ``aspect``,
``is_query`` (True for the seed terms). Each edge has
``source`` (child id), ``target`` (parent id), ``relation_type``.

Raises :class:`EntityNotFoundError` when the snapshot does not
resolve.
"""
from protea.infrastructure.orm.models.annotation.go_term_relationship import (
GOTermRelationship,
)

snap = session.get(OntologySnapshot, snapshot_id)
if snap is None:
raise EntityNotFoundError("OntologySnapshot", snapshot_id)

query_go_ids = {g.strip() for g in go_ids.split(",") if g.strip()}

seed_terms = (
session.query(GOTerm)
.filter(
GOTerm.ontology_snapshot_id == snapshot_id,
GOTerm.go_id.in_(query_go_ids),
)
.all()
)

if not seed_terms:
return {"nodes": [], "edges": []}

visited_ids: set[int] = {t.id for t in seed_terms}
frontier: set[int] = visited_ids.copy()
all_terms: dict[int, GOTerm] = {t.id: t for t in seed_terms}
all_edges: list[dict[str, Any]] = []

for _ in range(depth):
if not frontier:
break
rels = (
session.query(GOTermRelationship)
.filter(
GOTermRelationship.ontology_snapshot_id == snapshot_id,
GOTermRelationship.child_go_term_id.in_(frontier),
)
.all()
)

parent_ids = {r.parent_go_term_id for r in rels} - visited_ids
for r in rels:
all_edges.append(
{
"source": r.child_go_term_id,
"target": r.parent_go_term_id,
"relation_type": r.relation_type,
}
)

if parent_ids:
parents = session.query(GOTerm).filter(GOTerm.id.in_(parent_ids)).all()
for p in parents:
all_terms[p.id] = p
visited_ids |= parent_ids
frontier = parent_ids
else:
break

query_db_ids = {t.id for t in seed_terms}
nodes = [
{
"id": t.id,
"go_id": t.go_id,
"name": t.name,
"aspect": t.aspect,
"is_query": t.id in query_db_ids,
}
for t in all_terms.values()
]
return {"nodes": nodes, "edges": all_edges}


def evaluation_result_to_dict(r: EvaluationResult) -> dict[str, Any]:
"""Serialise an :class:`EvaluationResult` to its API dict shape."""
return {
Expand Down Expand Up @@ -490,6 +578,7 @@ def delete_evaluation_set_collect_keys(
"delete_eval_result_collect_keys",
"evaluation_result_to_dict",
"get_eval_result_with_keys",
"get_go_subgraph_data",
"iter_delta_proteins_fasta",
"iter_groundtruth_tsv",
"list_evaluation_results_data",
Expand Down
Loading