diff --git a/api/mcp/tools/structural.py b/api/mcp/tools/structural.py index 30540f4b..05d86ae2 100644 --- a/api/mcp/tools/structural.py +++ b/api/mcp/tools/structural.py @@ -182,3 +182,242 @@ def _payload(project) -> dict[str, Any]: } return await loop.run_in_executor(None, _do_index) + + +# --------------------------------------------------------------------------- +# T5 — get_callers / get_callees / get_dependencies +# --------------------------------------------------------------------------- + + +def _project_arg(project: str, branch: Optional[str]): + """Return an :class:`AsyncGraphQuery` for ``(project, branch)``.""" + from api.graph import AsyncGraphQuery + + return AsyncGraphQuery(project, branch=branch) + + +def _node_summary(n: Any) -> dict[str, Any]: + """Normalize a FalkorDB Node (or already-encoded dict) to a flat payload. + + ``encode_node`` returns ``{id, labels, properties: {...}}`` because Node + properties live on a nested attribute. Agents want a flat record, and + they also want a single ``label`` (the meaningful one — File, Class, + Function — not the fulltext-index marker ``Searchable``). + """ + if hasattr(n, "properties"): + props = dict(n.properties or {}) + labels = list(n.labels or []) + node_id = getattr(n, "id", None) + else: + d = dict(n) + props = dict(d.get("properties") or {}) + labels = list(d.get("labels") or []) + node_id = d.get("id") + + label = next((lbl for lbl in labels if lbl != "Searchable"), None) + return { + "id": node_id, + "name": props.get("name"), + "label": label, + "file": props.get("path"), + "line": props.get("src_start"), + } + + +def _coerce_node_id(symbol_id: Any) -> int: + """Accept int or stringified int; raise ValueError otherwise. + + The MCP wire format is JSON; agents sometimes hand back the id as a + string. Be permissive on input, strict on type after parsing. + """ + if isinstance(symbol_id, bool): # bool is an int subclass; reject loudly + raise ValueError(f"symbol_id must be an integer, got bool: {symbol_id!r}") + if isinstance(symbol_id, int): + return symbol_id + if isinstance(symbol_id, str) and symbol_id.lstrip("-").isdigit(): + return int(symbol_id) + raise ValueError(f"symbol_id must be an integer id, got: {symbol_id!r}") + + +async def _neighbors_payload( + project: str, + branch: Optional[str], + symbol_id: Any, + rel: str, + direction: str, + limit: int, +) -> list[dict[str, Any]]: + """Shared implementation for caller/callee/dependency tools. + + ``direction`` is ``IN`` (incoming edges, e.g. callers) or ``OUT`` + (outgoing edges, e.g. callees). When ``IN`` we run the inverse Cypher + ``(neighbor)-[:rel]->(target)``; ``AsyncGraphQuery.get_neighbors`` only + walks outgoing edges, so we inline the Cypher here for symmetry. + """ + node_id = _coerce_node_id(symbol_id) + g = _project_arg(project, branch) + try: + if direction == "OUT": + q = ( + f"MATCH (n)-[e:{rel}]->(dest) " + f"WHERE ID(n) = $sid " + f"RETURN dest, type(e) AS rel " + f"LIMIT $limit" + ) + elif direction == "IN": + q = ( + f"MATCH (src)-[e:{rel}]->(n) " + f"WHERE ID(n) = $sid " + f"RETURN src AS dest, type(e) AS rel " + f"LIMIT $limit" + ) + else: + raise ValueError(f"direction must be IN or OUT, got: {direction!r}") + + res = await g._query(q, {"sid": node_id, "limit": int(limit)}) + out: list[dict[str, Any]] = [] + for row in res.result_set: + entry = _node_summary(row[0]) + entry["relation"] = row[1] + entry["direction"] = direction + out.append(entry) + return out + finally: + await g.close() + + +@app.tool( + name="get_callers", + description=( + "Return functions that call the given symbol (incoming CALLS edges). " + "`symbol_id` is the integer node id returned by `search_code` or " + "other tools." + ), +) +async def get_callers( + symbol_id: Any, + project: str, + branch: Optional[str] = None, + limit: int = 50, +) -> list[dict[str, Any]]: + return await _neighbors_payload(project, branch, symbol_id, "CALLS", "IN", limit) + + +@app.tool( + name="get_callees", + description=( + "Return functions that the given symbol calls (outgoing CALLS edges)." + ), +) +async def get_callees( + symbol_id: Any, + project: str, + branch: Optional[str] = None, + limit: int = 50, +) -> list[dict[str, Any]]: + return await _neighbors_payload(project, branch, symbol_id, "CALLS", "OUT", limit) + + +@app.tool( + name="get_dependencies", + description=( + "Return outgoing neighbors of the given symbol across any of the " + "specified relation types (default: IMPORTS, CALLS, DEFINES). " + "Useful for 'what does this depend on' queries." + ), +) +async def get_dependencies( + symbol_id: Any, + project: str, + branch: Optional[str] = None, + rels: Optional[list[str]] = None, + limit: int = 50, +) -> list[dict[str, Any]]: + if rels is None: + rels = ["IMPORTS", "CALLS", "DEFINES"] + # Aggregate across relations; preserve ordering and dedupe by id. + seen: set[Any] = set() + out: list[dict[str, Any]] = [] + for rel in rels: + rows = await _neighbors_payload(project, branch, symbol_id, rel, "OUT", limit) + for row in rows: + key = (row.get("id"), row.get("relation")) + if key in seen: + continue + seen.add(key) + out.append(row) + if len(out) >= limit: + return out + return out + + +# --------------------------------------------------------------------------- +# T7 — find_path +# --------------------------------------------------------------------------- + + +@app.tool( + name="find_path", + description=( + "Return up to `max_paths` CALLS-path sequences from `source_id` to " + "`dest_id`. Useful for 'how does A reach B' questions. Returns an " + "empty list when no path exists." + ), +) +async def find_path( + source_id: Any, + dest_id: Any, + project: str, + branch: Optional[str] = None, + max_paths: int = 10, +) -> list[dict[str, Any]]: + src = _coerce_node_id(source_id) + dst = _coerce_node_id(dest_id) + g = _project_arg(project, branch) + try: + raw = await g.find_paths(src, dst) + finally: + await g.close() + + # ``AsyncGraphQuery.find_paths`` returns each path as an alternating + # [node, edge, node, edge, ..., node] list; we strip edges and surface + # only the node sequence — that's what agents typically want. + paths: list[dict[str, Any]] = [] + for entry in raw[:max_paths]: + node_seq = [ + _node_summary(x) + for x in entry + # Edges in the alternating list carry a top-level ``relation`` + # key (from ``encode_edge``); nodes carry ``properties``. + if isinstance(x, dict) and "properties" in x + ] + paths.append({"path": node_seq}) + return paths + + +# --------------------------------------------------------------------------- +# T8 — search_code +# --------------------------------------------------------------------------- + + +@app.tool( + name="search_code", + description=( + "Prefix-search for symbols (functions, classes, files) whose name " + "starts with `prefix`. Backed by FalkorDB's full-text index. The " + "agent typically calls this first to discover symbol ids for the " + "navigation tools (`get_callers`, `find_path`, ...)." + ), +) +async def search_code( + prefix: str, + project: str, + branch: Optional[str] = None, + limit: int = 20, +) -> list[dict[str, Any]]: + g = _project_arg(project, branch) + try: + raw = await g.prefix_search(prefix) + finally: + await g.close() + return [_node_summary(node) for node in raw[:limit]] diff --git a/tests/mcp/test_query_tools.py b/tests/mcp/test_query_tools.py new file mode 100644 index 00000000..ca034cf8 --- /dev/null +++ b/tests/mcp/test_query_tools.py @@ -0,0 +1,253 @@ +"""T5/T7/T8 — query MCP tools tests. + +Bundled because all three tools are thin async wrappers around existing +``AsyncGraphQuery`` operations and share the same fixture. +""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + + +pytestmark = pytest.mark.anyio + + +@pytest.fixture +def anyio_backend() -> str: + return "asyncio" + + +# --------------------------------------------------------------------------- +# search_code (T8) — runs first because callers/find_path need the ids it +# returns. +# --------------------------------------------------------------------------- + + +async def test_search_code_finds_entrypoint(indexed_fixture, expected_contract): + from api.mcp.tools.structural import search_code + + results = await search_code( + prefix="ent", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + names = {r["name"] for r in results} + for required in expected_contract["search_prefixes"]["ent"]["must_include"]: + assert required in names, f"expected {required} in {names}" + + +async def test_search_code_honors_limit(indexed_fixture): + from api.mcp.tools.structural import search_code + + results = await search_code( + prefix="r", # broad prefix + project=indexed_fixture.project, + branch=indexed_fixture.branch, + limit=1, + ) + assert len(results) <= 1 + + +async def test_search_code_empty_for_nonsense(indexed_fixture): + from api.mcp.tools.structural import search_code + + results = await search_code( + prefix="zzz_no_such_symbol_zzz", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert results == [] + + +async def test_search_code_result_serialisable(indexed_fixture): + from api.mcp.tools.structural import search_code + + results = await search_code( + prefix="serv", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + json.dumps(results) # must not raise + + +# --------------------------------------------------------------------------- +# get_callers / get_callees / get_dependencies (T5) +# --------------------------------------------------------------------------- + + +async def _find_id(indexed_fixture, name: str) -> int: + """Helper: resolve a symbol name to its int node id via search_code.""" + from api.mcp.tools.structural import search_code + + rows = await search_code( + prefix=name, + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + for r in rows: + if r["name"] == name: + return r["id"] + raise AssertionError(f"symbol {name!r} not found via search_code") + + +async def test_get_callees_of_entrypoint(indexed_fixture, expected_contract): + from api.mcp.tools.structural import get_callees + + entry_id = await _find_id(indexed_fixture, "entrypoint") + callees = await get_callees( + symbol_id=entry_id, + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + names = {c["name"] for c in callees} + expected = set(expected_contract["calls"]["entrypoint"]["callees_any_of"]) + assert names & expected, ( + f"entrypoint callees {names} disjoint from expected {expected}" + ) + + for c in callees: + assert c["relation"] == "CALLS" + assert c["direction"] == "OUT" + + +async def test_get_callers_of_service(indexed_fixture, expected_contract): + from api.mcp.tools.structural import get_callers + + service_id = await _find_id(indexed_fixture, "service") + callers = await get_callers( + symbol_id=service_id, + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + names = {c["name"] for c in callers} + for required in expected_contract["calls"]["service"]["callers"]: + assert required in names, f"expected caller {required} in {names}" + + for c in callers: + assert c["relation"] == "CALLS" + assert c["direction"] == "IN" + + +async def test_get_dependencies_aggregates_relations(indexed_fixture): + from api.mcp.tools.structural import get_dependencies + + entry_id = await _find_id(indexed_fixture, "entrypoint") + deps = await get_dependencies( + symbol_id=entry_id, + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + # Default relations include CALLS, IMPORTS, DEFINES — at minimum the + # CALLS edge to ``service`` must be present. + rels = {d["relation"] for d in deps} + assert "CALLS" in rels + + +async def test_neighbor_tools_accept_string_ids(indexed_fixture): + """Agents sometimes hand back ids as strings — must work.""" + from api.mcp.tools.structural import get_callees + + entry_id = await _find_id(indexed_fixture, "entrypoint") + callees = await get_callees( + symbol_id=str(entry_id), # ← string! + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert isinstance(callees, list) + + +async def test_neighbor_tools_reject_garbage_ids(indexed_fixture): + from api.mcp.tools.structural import get_callers + + with pytest.raises(ValueError, match="symbol_id"): + await get_callers( + symbol_id="not-a-number", + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + + +# --------------------------------------------------------------------------- +# find_path (T7) +# --------------------------------------------------------------------------- + + +async def test_find_path_entrypoint_to_db(indexed_fixture, expected_contract): + from api.mcp.tools.structural import find_path + + entry_id = await _find_id(indexed_fixture, "entrypoint") + db_id = await _find_id(indexed_fixture, "db") + + paths = await find_path( + source_id=entry_id, + dest_id=db_id, + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + # The contract requires at least one path entrypoint -> ... -> db + expected_min = next( + p["min_paths"] for p in expected_contract["paths"] + if p["source"] == "entrypoint" and p["dest"] == "db" + ) + assert len(paths) >= expected_min + + # Each path must have entrypoint first, db last, in a non-empty node + # sequence. + for entry in paths: + seq = entry["path"] + assert len(seq) >= 2 + assert seq[0]["name"] == "entrypoint" + assert seq[-1]["name"] == "db" + + +async def test_find_path_no_path_returns_empty(indexed_fixture): + """db -> entrypoint has no CALLS path (graph is acyclic).""" + from api.mcp.tools.structural import find_path + + entry_id = await _find_id(indexed_fixture, "entrypoint") + db_id = await _find_id(indexed_fixture, "db") + + paths = await find_path( + source_id=db_id, + dest_id=entry_id, + project=indexed_fixture.project, + branch=indexed_fixture.branch, + ) + assert paths == [] + + +async def test_find_path_honors_max_paths(indexed_fixture): + from api.mcp.tools.structural import find_path + + entry_id = await _find_id(indexed_fixture, "entrypoint") + db_id = await _find_id(indexed_fixture, "db") + + paths = await find_path( + source_id=entry_id, + dest_id=db_id, + project=indexed_fixture.project, + branch=indexed_fixture.branch, + max_paths=1, + ) + assert len(paths) <= 1 + + +# --------------------------------------------------------------------------- +# Protocol registration +# --------------------------------------------------------------------------- + + +async def test_all_query_tools_registered(): + from api.mcp.server import app + + tools = {t.name for t in await app.list_tools()} + assert { + "search_code", + "get_callers", + "get_callees", + "get_dependencies", + "find_path", + }.issubset(tools)