Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions api/analyzers/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,39 @@ def needs_lsp(self) -> bool:
"""
return True

def build_import_index(self, files: dict[Path, File], root: Path) -> object:
"""
Build a language-specific index used to resolve import statements to
in-repo files. Returns an opaque structure consumed by
``resolve_imports``. Default: no import resolution for this language.

Args:
files (dict[Path, File]): All parsed files keyed by absolute path.
root (Path): The analyzed repository root.

Returns:
object: Opaque index, or ``None`` when unsupported.
"""

return None

def resolve_imports(self, file: File, root: Path, index: object) -> list[File]:
"""
Resolve the import statements of ``file`` to the in-repo files they
depend on. Purely syntactic by default (no LSP). Each returned File is
connected to ``file`` with an ``IMPORTS`` edge by the orchestrator.

Args:
file (File): The importing file (already parsed; ``file.tree`` set).
root (Path): The analyzed repository root.
index (object): The structure returned by ``build_import_index``.

Returns:
list[File]: In-repo files imported by ``file`` (deduped, self excluded).
"""

return []

@abstractmethod
def add_dependencies(self, path: Path, files: list[Path]):
"""
Expand Down
104 changes: 104 additions & 0 deletions api/analyzers/python/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,110 @@ def add_symbols(self, entity: Entity) -> None:
def is_dependency(self, file_path: str) -> bool:
return "venv" in file_path

def _module_parts(self, file_path: Path, root: Path) -> Optional[list[str]]:
"""Dotted module path components for ``file_path`` relative to ``root``."""
try:
rel = file_path.relative_to(root)
except ValueError:
return None
parts = list(rel.with_suffix('').parts)
if parts and parts[-1] == '__init__':
parts = parts[:-1]
return parts

def build_import_index(self, files: dict[Path, File], root: Path) -> object:
"""Index in-repo files by dotted module name.

Two maps: ``exact`` keyed by the full dotted path from ``root`` and
``suffix`` keyed by every trailing sub-path (first file wins). The
suffix map tolerates ``src/``/``lib/`` layouts where the import name
(``matplotlib.axes``) differs from the path-from-root
(``lib.matplotlib.axes``).
"""
exact: dict[str, File] = {}
suffix: dict[str, File] = {}
for fpath, file in files.items():
if self.is_dependency(str(fpath)):
continue
parts = self._module_parts(fpath, root)
if not parts:
continue
exact.setdefault('.'.join(parts), file)
for i in range(len(parts)):
suffix.setdefault('.'.join(parts[i:]), file)
return {'exact': exact, 'suffix': suffix}

def _resolve_dotted(self, dotted: str, index: dict) -> Optional[File]:
if not dotted:
return None
f = index['exact'].get(dotted) or index['suffix'].get(dotted)
if f is None and '.' in dotted:
# imported name may be a symbol inside a module; drop the last part.
parent = dotted.rsplit('.', 1)[0]
f = index['exact'].get(parent) or index['suffix'].get(parent)
return f

def _import_requests(self, file: File) -> list[tuple[str, int]]:
"""Extract (dotted, level) resolution requests from import statements."""
requests: list[tuple[str, int]] = []
captures = self._captures(
"(import_statement) @i (import_from_statement) @f",
file.tree.root_node,
)
for node in captures.get('i', []):
for child in node.named_children:
target = child
if child.type == 'aliased_import':
target = child.child_by_field_name('name')
if target is not None and target.type == 'dotted_name':
requests.append((target.text.decode('utf-8'), 0))
for node in captures.get('f', []):
module = node.child_by_field_name('module_name')
level = 0
base = ''
if module is not None:
if module.type == 'relative_import':
prefix = next((c for c in module.children if c.type == 'import_prefix'), None)
level = len(prefix.text.decode('utf-8')) if prefix is not None else 1
dotted_part = next((c for c in module.named_children if c.type == 'dotted_name'), None)
base = dotted_part.text.decode('utf-8') if dotted_part is not None else ''
else:
base = module.text.decode('utf-8')
requests.append((base, level))
for name_node in node.children_by_field_name('name'):
leaf = name_node
if name_node.type == 'aliased_import':
leaf = name_node.child_by_field_name('name')
if leaf is not None:
name_txt = leaf.text.decode('utf-8')
requests.append((f"{base}.{name_txt}" if base else name_txt, level))
return requests

def resolve_imports(self, file: File, root: Path, index: object) -> list[File]:
if not index:
return []
package_parts = self._module_parts(file.path, root)
if package_parts is None:
return []
# Package of the importing file = its parent dotted path.
package_parts = package_parts[:-1] if package_parts else []
seen: set[Path] = set()
targets: list[File] = []
for dotted, level in self._import_requests(file):
if level:
base = package_parts[: len(package_parts) - (level - 1)] if level > 1 else list(package_parts)
full = '.'.join([*base, dotted]) if dotted else '.'.join(base)
else:
full = dotted
resolved = self._resolve_dotted(full, index)
if resolved is None or resolved.path == file.path or resolved.path in seen:
continue
if self.is_dependency(str(resolved.path)):
continue
seen.add(resolved.path)
targets.append(resolved)
return targets

def _extract_type_target(self, node: Node) -> Optional[Node]:
if node.type == 'attribute':
return node.child_by_field_name('attribute')
Expand Down
71 changes: 45 additions & 26 deletions api/analyzers/python/ts_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,21 @@ def _captures(query, root: Node) -> dict[str, list[Node]]:
return cursor.captures(root)


def _matches(query, root: Node) -> list[tuple[int, dict[str, list[Node]]]]:
"""Return per-match capture groups.

Unlike :func:`_captures` (which groups *all* nodes by capture name into
parallel lists that are **not** guaranteed to be index-aligned across
different capture names), this yields one dict per match so that, e.g.,
a ``@name`` capture is always paired with the ``@def`` capture from the
*same* match. Zipping the two independent lists from ``captures()`` mis-
aligns names and definitions whenever the per-capture node orderings
diverge, scrambling the module symbol table.
"""
cursor = QueryCursor(query)
return cursor.matches(root)


# ---------------------------------------------------------------------------
# Public resolver
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -217,46 +232,50 @@ def _ensure_built(self, files: dict[Path, File], project_root: Path) -> None:

def _index_file(self, mi: _ModuleIndex, root: Node) -> None:
# Top-level functions
caps = _captures(self._queries.top_level_func, root)
names = caps.get("name", [])
defs = caps.get("def", [])
for name_node, def_node in zip(names, defs):
name = name_node.text.decode("utf-8")
d = _Definition(mi.file_path, _strip_decorator(def_node), "func")
for _, caps in _matches(self._queries.top_level_func, root):
name_nodes = caps.get("name", [])
def_nodes = caps.get("def", [])
if not name_nodes or not def_nodes:
continue
name = name_nodes[0].text.decode("utf-8")
d = _Definition(mi.file_path, _strip_decorator(def_nodes[0]), "func")
mi.top_level[name] = d
self._by_name[name].append(d)

# Top-level classes
caps = _captures(self._queries.top_level_class, root)
names = caps.get("name", [])
defs = caps.get("def", [])
for name_node, def_node in zip(names, defs):
name = name_node.text.decode("utf-8")
d = _Definition(mi.file_path, _strip_decorator(def_node), "class")
for _, caps in _matches(self._queries.top_level_class, root):
name_nodes = caps.get("name", [])
def_nodes = caps.get("def", [])
if not name_nodes or not def_nodes:
continue
name = name_nodes[0].text.decode("utf-8")
d = _Definition(mi.file_path, _strip_decorator(def_nodes[0]), "class")
mi.top_level[name] = d
self._by_name[name].append(d)

# Top-level assignments (for class aliases like ``Foo = OtherFoo``)
caps = _captures(self._queries.top_level_assign, root)
names = caps.get("name", [])
defs = caps.get("def", [])
for name_node, def_node in zip(names, defs):
name = name_node.text.decode("utf-8")
for _, caps in _matches(self._queries.top_level_assign, root):
name_nodes = caps.get("name", [])
def_nodes = caps.get("def", [])
if not name_nodes or not def_nodes:
continue
name = name_nodes[0].text.decode("utf-8")
if name in mi.top_level:
continue
d = _Definition(mi.file_path, def_node, "var")
d = _Definition(mi.file_path, def_nodes[0], "var")
mi.top_level[name] = d
self._by_name[name].append(d)

# Class methods
caps = _captures(self._queries.class_methods, root)
class_names = caps.get("class_name", [])
method_names = caps.get("method_name", [])
method_defs = caps.get("method_def", [])
for cls_node, mname_node, mdef_node in zip(class_names, method_names, method_defs):
class_name = cls_node.text.decode("utf-8")
method_name = mname_node.text.decode("utf-8")
d = _Definition(mi.file_path, _strip_decorator(mdef_node), "method")
for _, caps in _matches(self._queries.class_methods, root):
class_nodes = caps.get("class_name", [])
mname_nodes = caps.get("method_name", [])
mdef_nodes = caps.get("method_def", [])
if not class_nodes or not mname_nodes or not mdef_nodes:
continue
class_name = class_nodes[0].text.decode("utf-8")
method_name = mname_nodes[0].text.decode("utf-8")
d = _Definition(mi.file_path, _strip_decorator(mdef_nodes[0]), "method")
mi.class_methods.setdefault(class_name, {})[method_name] = d
self._by_name[method_name].append(d)

Expand Down
30 changes: 30 additions & 0 deletions api/analyzers/source_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,49 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None:
elif key == "parameters":
graph.connect_entities("PARAMETERS", entity.id, resolved.id)

def link_imports(self, graph: Graph, root: Path) -> None:
"""Add ``IMPORTS`` edges (File -> File) via per-language resolution.

Purely syntactic for Python (no LSP), so this runs after ``first_pass``
once every file has a graph id. Languages whose analyzer does not
implement import resolution are silently skipped.
"""
indices: dict[str, object] = {}
for file_path, file in self.files.items():
analyzer = analyzers.get(file_path.suffix)
if analyzer is None:
continue
if file_path.suffix not in indices:
indices[file_path.suffix] = analyzer.build_import_index(self.files, root)
index = indices[file_path.suffix]
if not index:
continue
for target in analyzer.resolve_imports(file, root, index):
if getattr(file, "id", None) is None or getattr(target, "id", None) is None:
continue
graph.connect_entities("IMPORTS", file.id, target.id)

def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None:
self.first_pass(path, files, [], graph)
self.link_imports(graph, path)
self.second_pass(graph, files, path)
graph.derive_overrides()

def analyze_sources(self, path: Path, ignore: list[str], graph: Graph) -> None:
path = path.resolve()
files = list(path.rglob("*.java")) + list(path.rglob("*.py")) + list(path.rglob("*.cs")) + [f for f in path.rglob("*.js") if "node_modules" not in f.parts] + list(path.rglob("*.kt")) + list(path.rglob("*.kts"))
# First pass analysis of the source code
self.first_pass(path, files, ignore, graph)

# Link import edges (syntactic, language-specific, no LSP)
self.link_imports(graph, path)

# Second pass analysis of the source code
self.second_pass(graph, files, path)

# Derive override edges from the resolved class hierarchy
graph.derive_overrides()

def analyze_local_folder(self, path: str, g: Graph, ignore: Optional[list[str]] = []) -> None:
"""
Analyze path.
Expand Down
34 changes: 34 additions & 0 deletions api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,40 @@ def connect_entities(self, relation: str, src_id: int, dest_id: int, properties:
params = {'src_id': src_id, 'dest_id': dest_id, "properties": properties}
self._query(q, params)

def derive_overrides(self, max_depth: int = 3) -> int:
"""
Derive ``OVERRIDES`` edges from the existing class hierarchy.

A method ``m`` on a subclass overrides method ``m2`` on an ancestor
class when they share a name. Pure graph derivation over existing
``EXTENDS`` + ``DEFINES`` edges, so it is language-agnostic. The edge
carries ``depth`` (inheritance distance) for downstream filtering.

Args:
max_depth (int): Maximum inheritance distance to bridge.

Returns:
int: Number of OVERRIDES edges after derivation.
"""

q = f"""MATCH (sub:Class)-[x:EXTENDS*1..{int(max_depth)}]->(sup:Class)
WHERE ID(sub) <> ID(sup)
WITH DISTINCT sub, sup, length(x) AS depth
MATCH (sub)-[:DEFINES]->(m:Function)
MATCH (sup)-[:DEFINES]->(m2:Function)
WHERE m.name = m2.name AND ID(m) <> ID(m2)
MERGE (m)-[e:OVERRIDES]->(m2)
ON CREATE SET e.depth = depth"""

try:
self._query(q)
except Exception as exc: # noqa: BLE001 — derivation is best-effort
logging.warning("derive_overrides failed: %s", exc)
return 0

res = self._query("MATCH ()-[e:OVERRIDES]->() RETURN count(e)").result_set
return int(res[0][0]) if res else 0

def function_calls_function(self, caller_id: int, callee_id: int, pos: int) -> None:
"""
Establish a 'CALLS' relationship between two function nodes.
Expand Down
52 changes: 52 additions & 0 deletions tests/analyzers/test_ts_python_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,58 @@ def test_resolver_unknown_name_returns_empty(tmp_path: Path):
assert r.resolve(files, mod, tmp_path.resolve(), name) == []


def test_resolver_many_defs_name_def_alignment(tmp_path: Path):
"""Regression for the scrambled module symbol table.

With several top-level definitions in one module, pairing the ``@name``
and ``@def`` captures by zipping two independently-grouped lists mis-
aligned names with definitions (e.g. an imported ``arange`` call resolved
to the ``array`` def node). Each imported call must resolve to the def
whose name actually matches the call name.
"""
lib_src = "".join(f"def fn_{i}():\n return {i}\n\n" for i in range(10))
import_line = "from lib import " + ", ".join(f"fn_{i}" for i in range(10))
call_lines = "\n".join(f" fn_{i}()" for i in range(10))
app_src = f"{import_line}\n\ndef use():\n{call_lines}\n"
files = _make_project(tmp_path, {"lib.py": lib_src, "app.py": app_src})
r = TreeSitterPythonResolver(_PY)
app_path = (tmp_path / "app.py").resolve()
lib_path = (tmp_path / "lib.py").resolve()
root = files[app_path].tree.root_node
for i in range(10):
call = _find_call_node(root, f"fn_{i}(")
out = r.resolve(
files, app_path, tmp_path.resolve(), call.child_by_field_name("function")
)
assert len(out) == 1, f"fn_{i} did not resolve uniquely"
file, def_node = out[0]
assert file.path == lib_path
resolved_name = def_node.child_by_field_name("name").text.decode("utf-8")
assert resolved_name == f"fn_{i}", (
f"call fn_{i} resolved to wrong def {resolved_name}"
)


def test_resolver_many_classes_name_def_alignment(tmp_path: Path):
"""Same alignment regression for top-level classes."""
lib_src = "".join(f"class Cls{i}:\n pass\n\n" for i in range(8))
import_line = "from lib import " + ", ".join(f"Cls{i}" for i in range(8))
body = "\n".join(f" Cls{i}()" for i in range(8))
app_src = f"{import_line}\n\ndef use():\n{body}\n"
files = _make_project(tmp_path, {"lib.py": lib_src, "app.py": app_src})
r = TreeSitterPythonResolver(_PY)
app_path = (tmp_path / "app.py").resolve()
root = files[app_path].tree.root_node
for i in range(8):
call = _find_call_node(root, f"Cls{i}(")
out = r.resolve(
files, app_path, tmp_path.resolve(), call.child_by_field_name("function")
)
assert len(out) == 1
resolved_name = out[0][1].child_by_field_name("name").text.decode("utf-8")
assert resolved_name == f"Cls{i}"


# ---------------------------------------------------------------------------
# PythonAnalyzer integration via env var
# ---------------------------------------------------------------------------
Expand Down