Skip to content

Commit 5a1a0f4

Browse files
authored
fix(native): align Rust role classification with JS (#918)
* fix(native): align Rust role classification with JS implementation The native Rust role classifier was producing different (less accurate) results than the JS fallback path. This fixes all divergences: - Include 'imports-type' edges in fan-in queries (full, incremental, medians) - Add reexport-chain exported detection via recursive CTE (#837) - Include 'imports-type' in exported IDs and production fan-in queries - Add constant + hasActiveFileSiblings logic (classify as leaf, not dead) - Add !isExported guard to test-only classification - Include 'imports-type' and 'reexports' in neighbour file discovery - Add test_file_filter_col() for arbitrary column filtering * docs(native): add inline comments explaining reexport CTE logic (#918) Expand comments on the recursive prod_reachable CTEs in both the full and incremental classify paths to explain the base-case/recursive-step mechanics and how barrel reexport chains are traced.
1 parent a32d3d3 commit 5a1a0f4

1 file changed

Lines changed: 140 additions & 18 deletions

File tree

crates/codegraph-core/src/roles_db.rs

Lines changed: 140 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ fn classify_node(
115115
fan_out: u32,
116116
is_exported: bool,
117117
production_fan_in: u32,
118+
has_active_file_siblings: bool,
118119
median_fan_in: f64,
119120
median_fan_out: f64,
120121
) -> &'static str {
@@ -124,8 +125,12 @@ fn classify_node(
124125
}
125126

126127
if fan_in == 0 && !is_exported {
127-
// Test-only check: if node has test fan-in but zero total fan-in it's
128-
// classified in the dead sub-role path (JS mirrors this)
128+
// Constants consumed via identifier reference (not calls) have no
129+
// inbound call edges. If the same file has active callables, the
130+
// constant is almost certainly used locally — classify as leaf.
131+
if kind == "constant" && has_active_file_siblings {
132+
return "leaf";
133+
}
129134
return classify_dead_sub_role(name, kind, file);
130135
}
131136

@@ -134,7 +139,7 @@ fn classify_node(
134139
}
135140

136141
// Test-only: has callers but all are in test files
137-
if fan_in > 0 && production_fan_in == 0 {
142+
if fan_in > 0 && production_fan_in == 0 && !is_exported {
138143
return "test-only";
139144
}
140145

@@ -218,14 +223,15 @@ pub(crate) fn do_classify_full(conn: &Connection) -> rusqlite::Result<RoleSummar
218223
};
219224

220225
// 2. Fan-in/fan-out for callable nodes (uses JOIN approach for full scan)
226+
// Fan-in includes 'imports-type' edges to match JS classification.
221227
let rows: Vec<(i64, String, String, String, u32, u32)> = {
222228
let mut stmt = tx.prepare(
223229
"SELECT n.id, n.name, n.kind, n.file,
224230
COALESCE(fi.cnt, 0) AS fan_in,
225231
COALESCE(fo.cnt, 0) AS fan_out
226232
FROM nodes n
227233
LEFT JOIN (
228-
SELECT target_id, COUNT(*) AS cnt FROM edges WHERE kind = 'calls' GROUP BY target_id
234+
SELECT target_id, COUNT(*) AS cnt FROM edges WHERE kind IN ('calls', 'imports-type') GROUP BY target_id
229235
) fi ON n.id = fi.target_id
230236
LEFT JOIN (
231237
SELECT source_id, COUNT(*) AS cnt FROM edges WHERE kind = 'calls' GROUP BY source_id
@@ -250,26 +256,68 @@ pub(crate) fn do_classify_full(conn: &Connection) -> rusqlite::Result<RoleSummar
250256
return Ok(summary);
251257
}
252258

253-
// 3. Exported IDs (cross-file callers)
254-
let exported_ids: std::collections::HashSet<i64> = {
259+
// 3. Exported IDs (cross-file callers including imports-type)
260+
let mut exported_ids: std::collections::HashSet<i64> = {
255261
let mut stmt = tx.prepare(
256262
"SELECT DISTINCT e.target_id
257263
FROM edges e
258264
JOIN nodes caller ON e.source_id = caller.id
259265
JOIN nodes target ON e.target_id = target.id
260-
WHERE e.kind = 'calls' AND caller.file != target.file",
266+
WHERE e.kind IN ('calls', 'imports-type') AND caller.file != target.file",
261267
)?;
262268
let rows = stmt.query_map([], |row| row.get::<_, i64>(0))?;
263269
rows.filter_map(|r| r.ok()).collect()
264270
};
265271

266-
// 4. Production fan-in (excluding test files)
272+
// 3b. Mark symbols as exported when their files are targets of reexport edges
273+
// from production-reachable barrels (traces through multi-level chains) (#837).
274+
//
275+
// The recursive CTE works in two stages:
276+
// Base case: find all file nodes directly imported by production (non-test) files.
277+
// Recursive step: follow 'reexports' edges outward to discover barrel chains
278+
// (e.g. index.ts re-exports from internal.ts which re-exports from core.ts).
279+
// Then: any symbol whose file is a reexport target of a prod-reachable barrel
280+
// is considered exported (prevents false dead-code classification).
281+
{
282+
let sql = format!(
283+
"WITH RECURSIVE prod_reachable(file_id) AS (
284+
SELECT DISTINCT e.target_id
285+
FROM edges e
286+
JOIN nodes src ON e.source_id = src.id
287+
WHERE e.kind IN ('imports', 'dynamic-imports', 'imports-type')
288+
AND src.kind = 'file'
289+
{}
290+
UNION
291+
SELECT e.target_id
292+
FROM edges e
293+
JOIN prod_reachable pr ON e.source_id = pr.file_id
294+
WHERE e.kind = 'reexports'
295+
)
296+
SELECT DISTINCT n.id
297+
FROM nodes n
298+
JOIN nodes f ON f.file = n.file AND f.kind = 'file'
299+
WHERE f.id IN (
300+
SELECT e.target_id FROM edges e
301+
WHERE e.kind = 'reexports'
302+
AND e.source_id IN (SELECT file_id FROM prod_reachable)
303+
)
304+
AND n.kind NOT IN ('file', 'directory', 'parameter', 'property')",
305+
test_file_filter_col("src.file")
306+
);
307+
let mut stmt = tx.prepare(&sql)?;
308+
let reexport_rows = stmt.query_map([], |row| row.get::<_, i64>(0))?;
309+
for r in reexport_rows.flatten() {
310+
exported_ids.insert(r);
311+
}
312+
}
313+
314+
// 4. Production fan-in (excluding test files, including imports-type)
267315
let prod_fan_in: HashMap<i64, u32> = {
268316
let sql = format!(
269317
"SELECT e.target_id, COUNT(*) AS cnt
270318
FROM edges e
271319
JOIN nodes caller ON e.source_id = caller.id
272-
WHERE e.kind = 'calls' {}
320+
WHERE e.kind IN ('calls', 'imports-type') {}
273321
GROUP BY e.target_id",
274322
test_file_filter()
275323
);
@@ -287,6 +335,9 @@ pub(crate) fn do_classify_full(conn: &Connection) -> rusqlite::Result<RoleSummar
287335
let median_fan_in = median(&fan_in_vals);
288336
let median_fan_out = median(&fan_out_vals);
289337

338+
// 5b. Compute active files (files with non-constant callables connected to the graph)
339+
let active_files = compute_active_files(&rows);
340+
290341
// 6. Classify and collect IDs by role
291342
let mut ids_by_role: HashMap<&str, Vec<i64>> = HashMap::new();
292343

@@ -300,6 +351,7 @@ pub(crate) fn do_classify_full(conn: &Connection) -> rusqlite::Result<RoleSummar
300351
&rows,
301352
&exported_ids,
302353
&prod_fan_in,
354+
&active_files,
303355
median_fan_in,
304356
median_fan_out,
305357
&mut ids_by_role,
@@ -314,20 +366,38 @@ pub(crate) fn do_classify_full(conn: &Connection) -> rusqlite::Result<RoleSummar
314366
Ok(summary)
315367
}
316368

317-
/// Build the test-file exclusion filter for SQL queries.
369+
/// Build the test-file exclusion filter for SQL queries (default column: `caller.file`).
318370
fn test_file_filter() -> String {
371+
test_file_filter_col("caller.file")
372+
}
373+
374+
/// Build the test-file exclusion filter for an arbitrary column name.
375+
fn test_file_filter_col(column: &str) -> String {
319376
TEST_FILE_PATTERNS
320377
.iter()
321-
.map(|p| format!("AND caller.file NOT LIKE '{}'", p))
378+
.map(|p| format!("AND {} NOT LIKE '{}'", column, p))
322379
.collect::<Vec<_>>()
323380
.join(" ")
324381
}
325382

383+
/// Compute the set of files that have at least one non-constant callable connected to the graph.
384+
/// Constants in these files are likely consumed locally via identifier reference.
385+
fn compute_active_files(rows: &[(i64, String, String, String, u32, u32)]) -> std::collections::HashSet<String> {
386+
let mut active = std::collections::HashSet::new();
387+
for (_id, _name, kind, file, fan_in, fan_out) in rows {
388+
if (*fan_in > 0 || *fan_out > 0) && kind != "constant" {
389+
active.insert(file.clone());
390+
}
391+
}
392+
active
393+
}
394+
326395
/// Compute global median fan-in and fan-out from the edge distribution.
396+
/// Fan-in includes 'imports-type' edges to match JS classification.
327397
fn compute_global_medians(tx: &rusqlite::Transaction) -> rusqlite::Result<(f64, f64)> {
328398
let median_fan_in = {
329399
let mut stmt = tx
330-
.prepare("SELECT COUNT(*) AS cnt FROM edges WHERE kind = 'calls' GROUP BY target_id")?;
400+
.prepare("SELECT COUNT(*) AS cnt FROM edges WHERE kind IN ('calls', 'imports-type') GROUP BY target_id")?;
331401
let mut vals: Vec<u32> = stmt
332402
.query_map([], |row| row.get::<_, u32>(0))?
333403
.filter_map(|r| r.ok())
@@ -389,6 +459,7 @@ fn classify_rows(
389459
rows: &[(i64, String, String, String, u32, u32)],
390460
exported_ids: &std::collections::HashSet<i64>,
391461
prod_fan_in: &HashMap<i64, u32>,
462+
active_files: &std::collections::HashSet<String>,
392463
median_fan_in: f64,
393464
median_fan_out: f64,
394465
ids_by_role: &mut HashMap<&'static str, Vec<i64>>,
@@ -397,6 +468,11 @@ fn classify_rows(
397468
for (id, name, kind, file, fan_in, fan_out) in rows {
398469
let is_exported = exported_ids.contains(id);
399470
let prod_fi = prod_fan_in.get(id).copied().unwrap_or(0);
471+
let has_active_siblings = if kind == "constant" {
472+
active_files.contains(file)
473+
} else {
474+
false
475+
};
400476
let role = classify_node(
401477
name,
402478
kind,
@@ -405,6 +481,7 @@ fn classify_rows(
405481
*fan_out,
406482
is_exported,
407483
prod_fi,
484+
has_active_siblings,
408485
median_fan_in,
409486
median_fan_out,
410487
);
@@ -413,7 +490,7 @@ fn classify_rows(
413490
}
414491
}
415492

416-
/// Find neighbouring files connected by call edges to the changed files.
493+
/// Find neighbouring files connected by call/imports-type/reexports edges to the changed files.
417494
fn find_neighbour_files(
418495
tx: &rusqlite::Transaction,
419496
changed_files: &[String],
@@ -427,7 +504,7 @@ fn find_neighbour_files(
427504
"SELECT DISTINCT n2.file FROM edges e
428505
JOIN nodes n1 ON (e.source_id = n1.id OR e.target_id = n1.id)
429506
JOIN nodes n2 ON (e.source_id = n2.id OR e.target_id = n2.id)
430-
WHERE e.kind = 'calls'
507+
WHERE e.kind IN ('calls', 'imports-type', 'reexports')
431508
AND n1.file IN ({})
432509
AND n2.file NOT IN ({})
433510
AND n2.kind NOT IN ('file', 'directory')",
@@ -477,7 +554,7 @@ fn query_nodes_for_files(
477554

478555
let rows_sql = format!(
479556
"SELECT n.id, n.name, n.kind, n.file,
480-
(SELECT COUNT(*) FROM edges WHERE kind = 'calls' AND target_id = n.id) AS fan_in,
557+
(SELECT COUNT(*) FROM edges WHERE kind IN ('calls', 'imports-type') AND target_id = n.id) AS fan_in,
481558
(SELECT COUNT(*) FROM edges WHERE kind = 'calls' AND source_id = n.id) AS fan_out
482559
FROM nodes n
483560
WHERE n.kind NOT IN ('file', 'directory', 'parameter', 'property')
@@ -542,18 +619,60 @@ pub(crate) fn do_classify_incremental(
542619
FROM edges e
543620
JOIN nodes caller ON e.source_id = caller.id
544621
JOIN nodes target ON e.target_id = target.id
545-
WHERE e.kind = 'calls' AND caller.file != target.file
622+
WHERE e.kind IN ('calls', 'imports-type') AND caller.file != target.file
546623
AND target.file IN ({})",
547624
affected_ph
548625
);
549-
let exported_ids = query_id_set(&tx, &exported_sql, &all_affected)?;
626+
let mut exported_ids = query_id_set(&tx, &exported_sql, &all_affected)?;
627+
628+
// Mark symbols as exported when their files are targets of reexport edges
629+
// from production-reachable barrels (traces through multi-level chains) (#837).
630+
// Same recursive CTE logic as the full-classify path (step 3b), but scoped
631+
// to affected files only via the additional `AND n.file IN (...)` filter.
632+
{
633+
let reexport_sql = format!(
634+
"WITH RECURSIVE prod_reachable(file_id) AS (
635+
SELECT DISTINCT e.target_id
636+
FROM edges e
637+
JOIN nodes src ON e.source_id = src.id
638+
WHERE e.kind IN ('imports', 'dynamic-imports', 'imports-type')
639+
AND src.kind = 'file'
640+
{}
641+
UNION
642+
SELECT e.target_id
643+
FROM edges e
644+
JOIN prod_reachable pr ON e.source_id = pr.file_id
645+
WHERE e.kind = 'reexports'
646+
)
647+
SELECT DISTINCT n.id
648+
FROM nodes n
649+
JOIN nodes f ON f.file = n.file AND f.kind = 'file'
650+
WHERE f.id IN (
651+
SELECT e.target_id FROM edges e
652+
WHERE e.kind = 'reexports'
653+
AND e.source_id IN (SELECT file_id FROM prod_reachable)
654+
)
655+
AND n.kind NOT IN ('file', 'directory', 'parameter', 'property')
656+
AND n.file IN ({})",
657+
test_file_filter_col("src.file"),
658+
affected_ph
659+
);
660+
let mut stmt = tx.prepare(&reexport_sql)?;
661+
for (i, f) in all_affected.iter().enumerate() {
662+
stmt.raw_bind_parameter(i + 1, *f)?;
663+
}
664+
let mut rrows = stmt.raw_query();
665+
while let Some(row) = rrows.next()? {
666+
exported_ids.insert(row.get::<_, i64>(0)?);
667+
}
668+
}
550669

551670
let prod_sql = format!(
552671
"SELECT e.target_id, COUNT(*) AS cnt
553672
FROM edges e
554673
JOIN nodes caller ON e.source_id = caller.id
555674
JOIN nodes target ON e.target_id = target.id
556-
WHERE e.kind = 'calls'
675+
WHERE e.kind IN ('calls', 'imports-type')
557676
AND target.file IN ({})
558677
{}
559678
GROUP BY e.target_id",
@@ -562,6 +681,8 @@ pub(crate) fn do_classify_incremental(
562681
);
563682
let prod_fan_in = query_id_counts(&tx, &prod_sql, &all_affected)?;
564683

684+
let active_files = compute_active_files(&rows);
685+
565686
let mut ids_by_role: HashMap<&str, Vec<i64>> = HashMap::new();
566687

567688
if !leaf_ids.is_empty() {
@@ -574,6 +695,7 @@ pub(crate) fn do_classify_incremental(
574695
&rows,
575696
&exported_ids,
576697
&prod_fan_in,
698+
&active_files,
577699
median_fan_in,
578700
median_fan_out,
579701
&mut ids_by_role,

0 commit comments

Comments
 (0)