From 2db32189cacee8bb1ed8e6e75b09f232c05f75e1 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 01:41:36 -0700 Subject: [PATCH 01/18] [CFG] Replace lambdas in control_flow_graph.cpp with named helpers Promote three closures inside CFGNode::get_store_forwarding_data to a file-local free function (may_contain_address) and two private methods (is_visible_at, update_forwarding_result). Promote the in_final_node_live_gen closure in live_variable_analysis to a file-local free function. Replace four std::any_of predicates with explicit loops, the std::sort comparator with a local function-object struct, and inline the FNV mix lambda inside FingerprintHash. No behaviour change. --- quadrants/ir/control_flow_graph.cpp | 232 ++++++++++++++++------------ quadrants/ir/control_flow_graph.h | 16 ++ 2 files changed, 148 insertions(+), 100 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index d41cd13997..856d30f83b 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -16,6 +16,55 @@ namespace quadrants::lang { +namespace { + +// Does |store_stmt| store to an address that may alias |var|? +// Matrix-pointer aliasing is handled both ways: a MatrixPtrStmt aliases its +// underlying origin, and vice-versa. +bool may_contain_address(Stmt *store_stmt, Stmt *var) { + for (auto store_ptr : irpass::analysis::get_store_destination(store_stmt)) { + if (var->is() && !store_ptr->is()) { + // check for aliased address with var + if (irpass::analysis::maybe_same_address(var->as()->origin, store_ptr)) { + return true; + } + } + + if (!var->is() && store_ptr->is()) { + // check for aliased address with store_ptr + if (irpass::analysis::maybe_same_address(store_ptr->as()->origin, var)) { + return true; + } + } + + if (irpass::analysis::maybe_same_address(var, store_ptr)) { + return true; + } + } + return false; +} + +// Should |stmt| appear in the synthetic `live_gen` of the final CFG node? +// Locals (allocas, matrix-pointers into allocas) are never live past the +// kernel boundary. Global pointers are live unless SFG has marked their +// SNode as eliminable in |config_opt|. +bool in_final_node_live_gen(const Stmt *stmt, + const std::optional &config_opt) { + if (stmt->is() || stmt->is()) { + return false; + } + if (stmt->is() && stmt->cast()->origin->is()) { + return false; + } + if (auto *gptr = stmt->cast(); gptr && config_opt.has_value()) { + return config_opt->eliminable_snodes.count(gptr->snode) == 0; + } + // A global pointer that may be loaded after this kernel. + return true; +} + +} // namespace + CFGNode::CFGNode(Block *block, int begin_location, int end_location, @@ -91,8 +140,12 @@ bool CFGNode::contain_variable(const std::unordered_set &var_set, Stmt * // TODO: How to optimize this? if (var_set.find(var) != var_set.end()) return true; - return std::any_of(var_set.begin(), var_set.end(), - [&](Stmt *set_var) { return irpass::analysis::definitely_same_address(var, set_var); }); + for (Stmt *set_var : var_set) { + if (irpass::analysis::definitely_same_address(var, set_var)) { + return true; + } + } + return false; } } @@ -107,12 +160,12 @@ bool CFGNode::contain_variable(const std::unordered_map &var_set, St // TODO: How to optimize this? if (var_set.find(var) != var_set.end()) return true; - return std::any_of(var_set.begin(), var_set.end(), - [&](Stmt *set_var) { return irpass::analysis::maybe_same_address(var, set_var); }); + for (Stmt *set_var : var_set) { + if (irpass::analysis::maybe_same_address(var, set_var)) { + return true; + } + } + return false; } } @@ -145,6 +206,48 @@ bool CFGNode::reach_kill_variable(Stmt *var) const { return contain_variable(reach_kill, var); } +bool CFGNode::is_visible_at(Stmt *stmt, int position) const { + // Check if |stmt| is before |position| here. + if (stmt->parent == block) { + return stmt->parent->locate(stmt) < position; + } + // |parent_blocks_| is precomputed in the constructor of CFGNode. + // TODO: What if |stmt| appears in an ancestor of |block| but after + // |position|? + return parent_blocks_.find(stmt->parent) != parent_blocks_.end(); +} + +bool CFGNode::update_forwarding_result(Stmt *stmt, + int position, + Stmt *&result, + bool &result_visible) const { + // |stmt| is a definition in the UD-chain of the variable being forwarded. + // Fold its stored data into |result| / |result_visible|. Return false if + // forwarding must abort (the caller should propagate nullptr); true to + // continue scanning. + auto data = irpass::analysis::get_store_data(stmt); + if (!data) { // not forwardable + return false; + } + if (!result) { + result = data; + result_visible = is_visible_at(data, position); + return true; + } + if (!irpass::analysis::same_value(result, data)) { + // check the special case of alloca (initialized to 0) + if (!(result->is() && data->is() && data->as()->val.equal_value(0))) { + return false; + } + } + if (!result_visible && is_visible_at(data, position)) { + // pick the visible one for store-to-load forwarding + result = data; + result_visible = true; + } + return true; +} + // var: dest_addr Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // Return the stored data if all definitions in the UD-chain of |var| at @@ -198,30 +301,6 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { } } - // Check if store_stmt will ever influence the value of var - auto may_contain_address = [&](Stmt *store_stmt, Stmt *var) { - for (auto store_ptr : irpass::analysis::get_store_destination(store_stmt)) { - if (var->is() && !store_ptr->is()) { - // check for aliased address with var - if (irpass::analysis::maybe_same_address(var->as()->origin, store_ptr)) { - return true; - } - } - - if (!var->is() && store_ptr->is()) { - // check for aliased address with store_ptr - if (irpass::analysis::maybe_same_address(store_ptr->as()->origin, var)) { - return true; - } - } - - if (irpass::analysis::maybe_same_address(var, store_ptr)) { - return true; - } - } - return false; - }; - // Check for aliased address // There's a store to the same dest_addr before this stmt if (last_def_position != -1) { @@ -247,46 +326,6 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // Search for store to the same dest_addr in reach_in and reach_gen Stmt *result = nullptr; bool result_visible = false; - auto visible = [&](Stmt *stmt) { - // Check if |stmt| is before |position| here. - if (stmt->parent == block) { - return stmt->parent->locate(stmt) < position; - } - // |parent_blocks| is precomputed in the constructor of CFGNode. - // TODO: What if |stmt| appears in an ancestor of |block| but after - // |position|? - return parent_blocks_.find(stmt->parent) != parent_blocks_.end(); - }; - /** - * |stmt| is a definition in the UD-chain of |var|. Update |result| with - * |stmt|. If either the stored data of |stmt| is not forwardable or the - * stored data of |stmt| is not definitely the same as other definitions of - * |var|, return false to show that there is no store-to-load forwardable - * data. - */ - auto update_result = [&](Stmt *stmt) { - auto data = irpass::analysis::get_store_data(stmt); - if (!data) { // not forwardable - return false; // return nullptr - } - if (!result) { - result = data; - result_visible = visible(data); - return true; // continue the following loops - } - if (!irpass::analysis::same_value(result, data)) { - // check the special case of alloca (initialized to 0) - if (!(result->is() && data->is() && data->as()->val.equal_value(0))) { - return false; // return nullptr - } - } - if (!result_visible && visible(data)) { - // pick the visible one for store-to-load forwarding - result = data; - result_visible = true; - } - return true; // continue the following loops - }; // [Global Addr only] // test whether there's a store to the same dest_addr in a previous block. @@ -296,7 +335,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // var == stmt is for the case that a global ptr is never stored. // In this case, stmt is from nodes[start_node]->reach_gen. if (var == stmt || may_contain_address(stmt, var)) { - if (!update_result(stmt)) + if (!update_forwarding_result(stmt, position, result, result_visible)) return nullptr; else last_def_position = 0; @@ -308,7 +347,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // if the store values are the same, then return the value for (auto stmt : reach_gen) { if (may_contain_address(stmt, var) && stmt->parent->locate(stmt) < position) { - if (!update_result(stmt)) + if (!update_forwarding_result(stmt, position, result, result_visible)) return nullptr; else last_def_position = stmt->parent->locate(stmt); @@ -1060,26 +1099,12 @@ void ControlFlowGraph::live_variable_analysis(bool after_lower_access, nodes[final_node]->live_gen.clear(); nodes[final_node]->live_kill.clear(); - auto in_final_node_live_gen = [&config_opt](const Stmt *stmt) -> bool { - if (stmt->is() || stmt->is()) { - return false; - } - if (stmt->is() && stmt->cast()->origin->is()) { - return false; - } - if (auto *gptr = stmt->cast(); gptr && config_opt.has_value()) { - const bool res = (config_opt->eliminable_snodes.count(gptr->snode) == 0); - return res; - } - // A global pointer that may be loaded after this kernel. - return true; - }; if (!after_lower_access) { for (int i = 0; i < num_nodes; i++) { for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { auto stmt = nodes[i]->block->statements[j].get(); for (auto store_ptr : irpass::analysis::get_store_destination(stmt, true /*get_alias*/)) { - if (in_final_node_live_gen(store_ptr)) { + if (in_final_node_live_gen(store_ptr, config_opt)) { nodes[final_node]->live_gen.insert(store_ptr); } } @@ -1495,8 +1520,14 @@ void ControlFlowGraph::determine_ad_stack_size() { } } if (scc_is_cyclic[s]) { + struct DfsFinishGreater { + const std::vector &dfs_finish; + bool operator()(int a, int b) const { + return dfs_finish[a] > dfs_finish[b]; + } + }; auto topo = nodes_in_s; - std::sort(topo.begin(), topo.end(), [&](int a, int b) { return dfs_finish[a] > dfs_finish[b]; }); + std::sort(topo.begin(), topo.end(), DfsFinishGreater{dfs_finish}); scc_topo[s] = std::move(topo); } } @@ -1510,14 +1541,15 @@ void ControlFlowGraph::determine_ad_stack_size() { struct FingerprintHash { std::size_t operator()(const Fingerprint &f) const noexcept { std::size_t h = 1469598103934665603ULL; + // FNV-1a mix of three components per fingerprint entry. + constexpr std::size_t fnv_prime = 1099511628211ULL; for (auto &[n, i, m] : f) { - auto mix = [&](std::size_t x) { - h ^= x; - h *= 1099511628211ULL; - }; - mix(static_cast(n)); - mix(static_cast(static_cast(i))); - mix(static_cast(static_cast(m))); + h ^= static_cast(n); + h *= fnv_prime; + h ^= static_cast(static_cast(i)); + h *= fnv_prime; + h ^= static_cast(static_cast(m)); + h *= fnv_prime; } return h; } diff --git a/quadrants/ir/control_flow_graph.h b/quadrants/ir/control_flow_graph.h index c1f01d7d5e..9735508ce8 100644 --- a/quadrants/ir/control_flow_graph.h +++ b/quadrants/ir/control_flow_graph.h @@ -88,6 +88,22 @@ class CFGNode { void gather_loaded_snodes(std::unordered_set &snodes) const; void live_variable_analysis(bool after_lower_access); bool dead_store_elimination(bool after_lower_access); + + private: + // Helper for get_store_forwarding_data: is |stmt| visible at |position| + // inside this node's block? A stmt is visible if it lives in the same block + // and precedes |position|, or if its parent block is an ancestor of + // |this->block|. + bool is_visible_at(Stmt *stmt, int position) const; + + // Helper for get_store_forwarding_data: incorporate |stmt|, a definition in + // the UD-chain of the variable being forwarded, into the running |result| / + // |result_visible| state. Returns false to signal that forwarding must + // abort (the caller should return nullptr), true to continue scanning. + bool update_forwarding_result(Stmt *stmt, + int position, + Stmt *&result, + bool &result_visible) const; }; class ControlFlowGraph { From 1e74a623f3e796db2d254fe8109b0a952c81b211 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 02:10:21 -0700 Subject: [PATCH 02/18] [CFG] Split determine_ad_stack_size into named phase helpers Break the 450-line monster into a 30-line top-level plus 11 helpers in a file-local anonymous namespace, each handling exactly one phase of the AD-stack sizing algorithm: collect_adaptive_ad_stacks, accumulate_per_stack_per_node_size_deltas, compute_outgoing_node_ids, tarjan_scc, classify_scc_edges, identify_cyclic_sccs_and_topo, group_stacks_by_fingerprint, plus four DP helpers (classify_cyclic_scc_fast_path, spread_max_begin_over_zero_scc, dp_mixed_sign_cyclic_scc, update_global_max_and_relax_inter_scc) driven by run_ad_stack_size_dp_for_representative, finishing with apply_ad_stack_dp_results. No behaviour change; the top-level prose docstring is preserved verbatim and per-phase rationale is moved to the helper that owns it. --- quadrants/ir/control_flow_graph.cpp | 804 ++++++++++++++++------------ 1 file changed, 471 insertions(+), 333 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index 856d30f83b..70b57eecd6 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -1281,110 +1281,96 @@ std::unordered_set ControlFlowGraph::gather_loaded_snodes() { return snodes; } -void ControlFlowGraph::determine_ad_stack_size() { - /** - * Determine the necessary size of every adaptive AD-stack on the control-flow graph (CFG). For each AD-stack we - * compute the maximum running net push count along any walk from the kernel entry. AD-stacks whose forward kernel - * contains a positive cycle (pushes > pops around a loop) are left at `max_size = 0`, and the caller routes them - * through the structural bounded-loop pre-pass for a symbolic `SizeExpr`, hard-erroring if the grammar still - * cannot resolve them. There is no compile-time size fallback. - * - * Implementation notes for compile-time perf on large reverse-mode kernels: - * 1. Per-stack per-node pre-aggregates (`max_increased_size`, `increased_size`) are stored in dense - * `vector>` indexed by a contiguous int stack id, instead of an - * `unordered_map>` -- this removes hash traffic from the hot inner loop. - * 2. Stacks whose `(increased_size, max_increased_size)` row pair is bit-identical share a single dynamic - * programming run -- typical kernels generate one alloca per autodiff variable in the same loop body, so - * most rows collapse to a few representatives. - * 3. The CFG is condensed via Tarjan into strongly connected components (SCCs). DFS finish times recorded - * during the same Tarjan pass split each cyclic SCC's intra-edges into a forward set (target finishes - * before source) and a back set (target finishes at or after source). Per representative we run a - * single-pass dynamic-programming (DP) sweep over the forward edges in descending finish-time order, then - * check the back edges once for positive-cycle relaxation. Correctness: any walk inside an SCC decomposes - * into a forward path plus zero or more cycles, and an SCC with no positive cycle has the same max-walk-sum - * as the back-edge-removed DAG; a positive cycle is exactly the case where some back-edge would still relax - * after the forward DP. This drops the per-cyclic-SCC cost from O(|S| * |E_S|) to O(|S| + |E_S|). - * 4. Two sign-based fast paths short-circuit the DP for trivial cyclic SCCs: an SCC with `min_is >= 0 && max_is - * > 0` for this stack must contain a positive cycle (every node lies on some cycle, and a cycle through a - * strictly-positive node with all non-negative `is` along it sums positive); an SCC with `min_is == 0 == - * max_is` has no `is` contribution at all and is handled by spreading the max entry-side - * `max_size_at_node_begin` to every node in O(|S|). - * Per-rep cost becomes O(V + E + sum_{cyclic S} |S| * |E_S|) (with the SCC sum dropping to O(|S| + |E_S|) for - * the common autodiff push/pop pattern); overall cost is O(V + E + R * (V + E)) with R the number of distinct - * row-pair representatives. - */ - const int num_nodes = size(); +namespace { + +// === Helpers for ControlFlowGraph::determine_ad_stack_size === - // Map AdStackAllocaStmt* to a contiguous int index. Only stacks whose `max_size` is still 0 (i.e. unresolved by an - // earlier pass) participate in the DP; resolved stacks are skipped so the pass cannot clobber them. +struct AdStackIndex { + // AdStackAllocaStmt* -> contiguous int id, populated only for adaptive stacks + // (max_size == 0) that have not yet been resolved by an earlier pass. std::unordered_map stack_id; std::vector stacks; - - std::unordered_map node_ids; - for (int i = 0; i < num_nodes; i++) - node_ids[nodes[i].get()] = i; - - for (int i = 0; i < num_nodes; i++) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - Stmt *stmt = nodes[i]->block->statements[j].get(); - if (auto *stack = stmt->cast()) { - if (stack->max_size != 0) { - continue; - } - if (stack_id.emplace(stack, static_cast(stacks.size())).second) { - stacks.push_back(stack); - } +}; + +AdStackIndex collect_adaptive_ad_stacks(const std::vector> &nodes) { + AdStackIndex idx; + for (const auto &node : nodes) { + for (int j = node->begin_location; j < node->end_location; j++) { + Stmt *stmt = node->block->statements[j].get(); + auto *stack = stmt->cast(); + if (!stack || stack->max_size != 0) { + continue; + } + if (idx.stack_id.emplace(stack, static_cast(idx.stacks.size())).second) { + idx.stacks.push_back(stack); } } } + return idx; +} - const int num_stacks = static_cast(stacks.size()); - if (num_stacks == 0) { - return; - } - - // max_increased_size[s][j] is the maximum number of (pushes - pops) of stack |s| among all prefixes of the - // CFGNode |j|. increased_size[s][j] is the net (pushes - pops) of stack |s| in the CFGNode |j|. Both are indexed - // by contiguous stack id, so the per-stack DP loop reads them with cheap vector access. - std::vector> max_increased_size(num_stacks, std::vector(num_nodes, 0)); - std::vector> increased_size(num_stacks, std::vector(num_nodes, 0)); - - // Track which stacks actually participate in any push/pop in the CFG. Stacks with no push/pop end with - // `max_size = 0` regardless of the DP (every node has zero increase), so we skip the DP for them and reproduce - // the original "Unused autodiff stack" warning here. - std::vector stack_active(num_stacks, false); - +struct AdStackPerNodeSizes { + // [stack_id][node_id]. `max_increased_size[s][j]` is the maximum (pushes - pops) of stack |s| + // among all prefixes of CFGNode |j|; `increased_size[s][j]` is the net (pushes - pops) in the + // whole node. Indexed by contiguous stack id so the per-stack DP reads cheap vectors instead of + // hashing. + std::vector> increased_size; + std::vector> max_increased_size; + // True iff the stack actually appears in any push/pop in the CFG. Inactive stacks would settle + // at `max_size = 0` regardless and are short-circuited below to reproduce the original "Unused + // autodiff stack" warning. + std::vector stack_active; +}; + +AdStackPerNodeSizes accumulate_per_stack_per_node_size_deltas( + const std::vector> &nodes, + const std::unordered_map &stack_id, + int num_stacks) { + const int num_nodes = static_cast(nodes.size()); + AdStackPerNodeSizes out; + out.increased_size.assign(num_stacks, std::vector(num_nodes, 0)); + out.max_increased_size.assign(num_stacks, std::vector(num_nodes, 0)); + out.stack_active.assign(num_stacks, false); for (int i = 0; i < num_nodes; i++) { for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { Stmt *stmt = nodes[i]->block->statements[j].get(); - if (auto *stack_push = stmt->cast()) { - auto *stack = stack_push->stack->as(); - if (stack->max_size == 0 /*adaptive*/) { - auto it = stack_id.find(stack); - QD_ASSERT(it != stack_id.end()); - const int sid = it->second; - stack_active[sid] = true; - int &cur = increased_size[sid][i]; - cur++; - if (cur > max_increased_size[sid][i]) { - max_increased_size[sid][i] = cur; - } - } - } else if (auto *stack_pop = stmt->cast()) { - auto *stack = stack_pop->stack->as(); - if (stack->max_size == 0 /*adaptive*/) { - auto it = stack_id.find(stack); - QD_ASSERT(it != stack_id.end()); - const int sid = it->second; - stack_active[sid] = true; - increased_size[sid][i]--; - } + AdStackAllocaStmt *stack = nullptr; + int delta = 0; + if (auto *push = stmt->cast()) { + stack = push->stack->as(); + delta = +1; + } else if (auto *pop = stmt->cast()) { + stack = pop->stack->as(); + delta = -1; + } else { + continue; + } + if (stack->max_size != 0 /*non-adaptive*/) { + continue; + } + auto it = stack_id.find(stack); + QD_ASSERT(it != stack_id.end()); + const int sid = it->second; + out.stack_active[sid] = true; + int &cur = out.increased_size[sid][i]; + cur += delta; + if (cur > out.max_increased_size[sid][i]) { + out.max_increased_size[sid][i] = cur; } } } + return out; +} - // Precompute outgoing-edge node ids once per node so the per-stack DP walks an int vector instead of hashing - // `node_ids[next_node]` on every traversal. +// Precompute outgoing-edge node ids once per node so the per-stack DP walks an int vector instead +// of hashing `node_ids[next_node]` on every traversal. +std::vector> compute_outgoing_node_ids(const std::vector> &nodes) { + const int num_nodes = static_cast(nodes.size()); + std::unordered_map node_ids; + node_ids.reserve(num_nodes); + for (int i = 0; i < num_nodes; i++) { + node_ids[nodes[i].get()] = i; + } std::vector> next_ids(num_nodes); for (int i = 0; i < num_nodes; i++) { auto &dst = next_ids[i]; @@ -1393,133 +1379,159 @@ void ControlFlowGraph::determine_ad_stack_size() { dst.push_back(node_ids[next_node]); } } + return next_ids; +} - // Tarjan strongly-connected-component (SCC) decomposition of the CFG, computed once and shared across all per-stack - // dynamic-programming runs. Output: - // scc_id[n] : SCC index of each node (lower index = topologically deeper / sink-side, so sources are at index - // num_sccs - 1, matching Tarjan's natural emission order which is reverse topological order). - // scc_nodes[s] : list of node ids in SCC s. - // dfs_finish[n]: DFS post-order index, used below to split each cyclic SCC's intra-edges into a forward and a - // back set without a separate edge classification pass. - // We then split each node's outgoing edges into intra-SCC and inter-SCC sets so the per-stack DP iterates each - // set without filtering inside the hot loop. Cyclic SCCs (|S| > 1 or |S| == 1 with self-loop) are flagged so - // cycle detection runs only on those, and only at SCC scope. - std::vector scc_id(num_nodes, -1); - std::vector dfs_finish(num_nodes, -1); // DFS post-order index per node; ancestors finish AFTER descendants. - std::vector> scc_nodes; - { - std::vector tarjan_index(num_nodes, -1); - std::vector tarjan_lowlink(num_nodes, 0); - std::vector on_stack(num_nodes, 0); - std::vector tarjan_stack; - tarjan_stack.reserve(num_nodes); - std::vector> dfs_stack; - dfs_stack.reserve(num_nodes); - int next_index = 0; - int next_finish = 0; - int next_scc = 0; - for (int origin = 0; origin < num_nodes; origin++) { - if (tarjan_index[origin] != -1) { - continue; - } - tarjan_index[origin] = next_index; - tarjan_lowlink[origin] = next_index; - next_index++; - tarjan_stack.push_back(origin); - on_stack[origin] = 1; - dfs_stack.emplace_back(origin, 0); - while (!dfs_stack.empty()) { - auto &frame = dfs_stack.back(); - const int u = frame.first; - const auto &nb = next_ids[u]; - if (frame.second < nb.size()) { - const int v = nb[frame.second++]; - if (tarjan_index[v] == -1) { - tarjan_index[v] = next_index; - tarjan_lowlink[v] = next_index; - next_index++; - tarjan_stack.push_back(v); - on_stack[v] = 1; - dfs_stack.emplace_back(v, 0); - } else if (on_stack[v]) { - if (tarjan_index[v] < tarjan_lowlink[u]) { - tarjan_lowlink[u] = tarjan_index[v]; - } +struct TarjanResult { + std::vector scc_id; // node -> SCC index + std::vector dfs_finish; // node -> DFS post-order index; ancestors finish AFTER descendants + std::vector> scc_nodes; // SCC -> list of node ids +}; + +// Iterative Tarjan SCC. Emits SCCs in reverse topological order (sources end up at the largest +// indices). Also records DFS post-order finish times, used downstream to split each cyclic SCC's +// intra-edges into forward vs back without a second edge-classification pass. +TarjanResult tarjan_scc(const std::vector> &next_ids) { + const int num_nodes = static_cast(next_ids.size()); + TarjanResult out; + out.scc_id.assign(num_nodes, -1); + out.dfs_finish.assign(num_nodes, -1); + + std::vector tarjan_index(num_nodes, -1); + std::vector tarjan_lowlink(num_nodes, 0); + std::vector on_stack(num_nodes, 0); + std::vector tarjan_stack; + tarjan_stack.reserve(num_nodes); + std::vector> dfs_stack; + dfs_stack.reserve(num_nodes); + int next_index = 0; + int next_finish = 0; + int next_scc = 0; + for (int origin = 0; origin < num_nodes; origin++) { + if (tarjan_index[origin] != -1) { + continue; + } + tarjan_index[origin] = next_index; + tarjan_lowlink[origin] = next_index; + next_index++; + tarjan_stack.push_back(origin); + on_stack[origin] = 1; + dfs_stack.emplace_back(origin, 0); + while (!dfs_stack.empty()) { + auto &frame = dfs_stack.back(); + const int u = frame.first; + const auto &nb = next_ids[u]; + if (frame.second < nb.size()) { + const int v = nb[frame.second++]; + if (tarjan_index[v] == -1) { + tarjan_index[v] = next_index; + tarjan_lowlink[v] = next_index; + next_index++; + tarjan_stack.push_back(v); + on_stack[v] = 1; + dfs_stack.emplace_back(v, 0); + } else if (on_stack[v]) { + if (tarjan_index[v] < tarjan_lowlink[u]) { + tarjan_lowlink[u] = tarjan_index[v]; } - } else { - if (tarjan_lowlink[u] == tarjan_index[u]) { - std::vector component; - while (true) { - const int w = tarjan_stack.back(); - tarjan_stack.pop_back(); - on_stack[w] = 0; - scc_id[w] = next_scc; - component.push_back(w); - if (w == u) { - break; - } + } + } else { + if (tarjan_lowlink[u] == tarjan_index[u]) { + std::vector component; + while (true) { + const int w = tarjan_stack.back(); + tarjan_stack.pop_back(); + on_stack[w] = 0; + out.scc_id[w] = next_scc; + component.push_back(w); + if (w == u) { + break; } - scc_nodes.push_back(std::move(component)); - next_scc++; } - dfs_finish[u] = next_finish++; - dfs_stack.pop_back(); - if (!dfs_stack.empty()) { - const int parent = dfs_stack.back().first; - if (tarjan_lowlink[u] < tarjan_lowlink[parent]) { - tarjan_lowlink[parent] = tarjan_lowlink[u]; - } + out.scc_nodes.push_back(std::move(component)); + next_scc++; + } + out.dfs_finish[u] = next_finish++; + dfs_stack.pop_back(); + if (!dfs_stack.empty()) { + const int parent = dfs_stack.back().first; + if (tarjan_lowlink[u] < tarjan_lowlink[parent]) { + tarjan_lowlink[parent] = tarjan_lowlink[u]; } } } } } - const int num_sccs = static_cast(scc_nodes.size()); + return out; +} - // Each intra-SCC edge is either a "forward" edge in the DFS spanning sense (source finishes after target, so source - // can be processed before target with a single topological pass) or a "back" edge (source finishes before target, - // closing a cycle). The forward set carries the per-stack DAG dynamic programming; back-edges only need a single - // post-DP relaxation check to detect positive cycles. Inter-SCC edges always relax forward in topological order. - std::vector> next_ids_intra_fwd(num_nodes); - std::vector> next_ids_intra_back(num_nodes); - std::vector> next_ids_inter(num_nodes); +struct SccEdgeSets { + // Each intra-SCC edge is either "forward" in the DFS spanning sense (target finishes before + // source, so source can be relaxed before target with a single topological pass) or "back" + // (target finishes at or after source, closing a cycle). The forward set drives the per-stack + // DAG dynamic programming; back-edges only need a single post-DP relaxation check to detect + // positive cycles. Inter-SCC edges always relax forward in topological order. + std::vector> next_ids_intra_fwd; + std::vector> next_ids_intra_back; + std::vector> next_ids_inter; +}; + +SccEdgeSets classify_scc_edges(const std::vector> &next_ids, + const std::vector &scc_id, + const std::vector &dfs_finish) { + const int num_nodes = static_cast(next_ids.size()); + SccEdgeSets out; + out.next_ids_intra_fwd.assign(num_nodes, {}); + out.next_ids_intra_back.assign(num_nodes, {}); + out.next_ids_inter.assign(num_nodes, {}); for (int u = 0; u < num_nodes; u++) { const int su = scc_id[u]; for (int v : next_ids[u]) { if (scc_id[v] == su) { if (dfs_finish[v] < dfs_finish[u]) { - next_ids_intra_fwd[u].push_back(v); + out.next_ids_intra_fwd[u].push_back(v); } else { - next_ids_intra_back[u].push_back(v); + out.next_ids_intra_back[u].push_back(v); } } else { - next_ids_inter[u].push_back(v); + out.next_ids_inter[u].push_back(v); } } } + return out; +} - // An SCC is cyclic iff it contains a cycle. By definition that is |S| > 1 (any two nodes lie on a cycle since the - // SCC is strongly connected) or |S| == 1 with a self-loop edge. For cyclic SCCs we also precompute a topological - // ordering of the SCC's nodes (descending DFS finish time) so the per-stack DAG DP visits each node once with all - // forward predecessors already finalized. - std::vector scc_is_cyclic(num_sccs, 0); - std::vector> scc_topo(num_sccs); +struct CyclicSccInfo { + std::vector scc_is_cyclic; // 1 iff SCC contains a cycle + std::vector> scc_topo; // for cyclic SCCs: nodes sorted by descending dfs_finish +}; + +// An SCC is cyclic iff |S| > 1 (any two nodes in a non-trivial SCC lie on a cycle) or |S| == 1 +// with a self-loop edge. For cyclic SCCs we precompute the topological ordering of their nodes +// so the per-stack DAG DP visits each node once with all forward predecessors already finalized. +CyclicSccInfo identify_cyclic_sccs_and_topo(const std::vector> &scc_nodes, + const std::vector> &next_ids_intra_back, + const std::vector &dfs_finish) { + const int num_sccs = static_cast(scc_nodes.size()); + CyclicSccInfo out; + out.scc_is_cyclic.assign(num_sccs, 0); + out.scc_topo.assign(num_sccs, {}); for (int s = 0; s < num_sccs; s++) { - auto &nodes_in_s = scc_nodes[s]; + const auto &nodes_in_s = scc_nodes[s]; if (nodes_in_s.size() > 1) { - scc_is_cyclic[s] = 1; + out.scc_is_cyclic[s] = 1; } else { + // Self-loops are classified as back-edges above, so a singleton SCC is cyclic iff it has a + // back-edge pointing at itself. const int n = nodes_in_s[0]; - // A singleton SCC is cyclic iff it has a self-loop. Self-loops are classified as back-edges above - // (dfs_finish[v] == dfs_finish[u]), so check there. for (int v : next_ids_intra_back[n]) { if (v == n) { - scc_is_cyclic[s] = 1; + out.scc_is_cyclic[s] = 1; break; } } } - if (scc_is_cyclic[s]) { + if (out.scc_is_cyclic[s]) { struct DfsFinishGreater { const std::vector &dfs_finish; bool operator()(int a, int b) const { @@ -1528,21 +1540,34 @@ void ControlFlowGraph::determine_ad_stack_size() { }; auto topo = nodes_in_s; std::sort(topo.begin(), topo.end(), DfsFinishGreater{dfs_finish}); - scc_topo[s] = std::move(topo); + out.scc_topo[s] = std::move(topo); } } + return out; +} - // Group AD-stacks whose per-node (increased_size, max_increased_size) rows are bit-identical, so that the DP runs - // once per equivalence class instead of once per stack. In practice many AD-stacks in the same kernel share their - // push/pop schedule (one alloca per autodiff variable in the same loop body), and the DP on a large CFG dwarfs - // the dedup cost. We hash a sparse fingerprint (only nodes where the stack actually has push/pop activity) and - // group by it. Worst case (no duplicates): one DP run per stack, equivalent to running it directly per stack. +struct FingerprintGroups { + std::vector stack_to_rep; // stack id -> representative stack id (whose DP run it shares) + std::vector rep_stack_ids; // representative stack ids that actually run the DP +}; + +// Group AD-stacks whose per-node (increased_size, max_increased_size) rows are bit-identical so +// the DP runs once per equivalence class instead of once per stack. In practice many AD-stacks in +// the same kernel share their push/pop schedule (one alloca per autodiff variable in the same +// loop body), and the DP on a large CFG dwarfs the dedup cost. We hash a sparse fingerprint (only +// nodes where the stack has activity) and group by it. Worst case (no duplicates): one DP run per +// stack, equivalent to running it directly per stack. +FingerprintGroups group_stacks_by_fingerprint(int num_nodes, + int num_stacks, + const std::vector &stack_active, + const std::vector> &increased_size, + const std::vector> &max_increased_size) { using Fingerprint = std::vector>; // (node_id, is, mis), sorted by node_id struct FingerprintHash { std::size_t operator()(const Fingerprint &f) const noexcept { - std::size_t h = 1469598103934665603ULL; // FNV-1a mix of three components per fingerprint entry. constexpr std::size_t fnv_prime = 1099511628211ULL; + std::size_t h = 1469598103934665603ULL; for (auto &[n, i, m] : f) { h ^= static_cast(n); h *= fnv_prime; @@ -1554,9 +1579,9 @@ void ControlFlowGraph::determine_ad_stack_size() { return h; } }; - std::unordered_map fp_to_rep; // fingerprint -> representative stack id - std::vector stack_to_rep(num_stacks, -1); - std::vector rep_stack_ids; // stack ids that actually run BF + std::unordered_map fp_to_rep; + FingerprintGroups out; + out.stack_to_rep.assign(num_stacks, -1); for (int sid = 0; sid < num_stacks; sid++) { if (!stack_active[sid]) { continue; @@ -1570,165 +1595,278 @@ void ControlFlowGraph::determine_ad_stack_size() { } } auto [it, inserted] = fp_to_rep.emplace(std::move(fp), sid); - stack_to_rep[sid] = it->second; + out.stack_to_rep[sid] = it->second; if (inserted) { - rep_stack_ids.push_back(sid); + out.rep_stack_ids.push_back(sid); } } + return out; +} - // Scratch buffer reused across representatives to avoid reallocating per iteration. - std::vector max_size_at_node_begin(num_nodes); +enum class CyclicSccFastPath { + kPositiveCycle, // proven positive cycle exists for this stack inside this SCC + kZeroSpread, // no `is` contribution in the SCC; just spread max entry-side begin value + kFallback, // mixed-sign, run the full DP +}; + +// Sign-based fast paths sidestep the cyclic-SCC dynamic programming when the stack's `is` +// contribution inside this SCC is structurally trivial: +// 1. min_is >= 0 with max_is > 0: every node in the SCC lies on some cycle (SCC property), and +// a cycle through a strictly-positive node with all non-negative `is` along it sums to a +// positive value, so a positive cycle exists for this stack. This is the autodiff +// push-only-in-SCC pattern. +// 2. min_is == max_is == 0: no `is` contribution at all in this SCC; the DP would only spread +// the maximum entry-side `max_size_at_node_begin` value to every node in O(|S|). +CyclicSccFastPath classify_cyclic_scc_fast_path(const std::vector &nodes_in_s, + const std::vector &is_for_stack) { + int min_is = INT_MAX; + int max_is = INT_MIN; + for (int u : nodes_in_s) { + const int v = is_for_stack[u]; + if (v < min_is) { + min_is = v; + } + if (v > max_is) { + max_is = v; + } + } + if (min_is >= 0 && max_is > 0) { + return CyclicSccFastPath::kPositiveCycle; + } + if (min_is == 0 && max_is == 0) { + return CyclicSccFastPath::kZeroSpread; + } + return CyclicSccFastPath::kFallback; +} - // Per-representative results, broadcast below to every stack that hashed to this representative. - struct DPResult { - int max_size; - bool has_positive_loop; - }; - std::unordered_map rep_results; - rep_results.reserve(rep_stack_ids.size()); - - // For each representative stack, walk the SCC condensation in topological order (sources first, since Tarjan emits - // SCCs in reverse-topological order so source SCCs end up at the largest indices). Inter-SCC edges only relax - // forward (predecessors are already finalized when an SCC is entered), so each is touched O(1) times per stack. - for (int rep_sid : rep_stack_ids) { - const std::vector &mis_for_stack = max_increased_size[rep_sid]; - const std::vector &is_for_stack = increased_size[rep_sid]; - - std::fill(max_size_at_node_begin.begin(), max_size_at_node_begin.end(), -1); - - int max_size = 0; - max_size_at_node_begin[start_node] = 0; - - bool has_positive_loop = false; - - for (int s = num_sccs - 1; s >= 0 && !has_positive_loop; s--) { - const auto &nodes_in_s = scc_nodes[s]; - - if (scc_is_cyclic[s]) { - // Sign-based fast paths sidestep the cyclic-SCC dynamic programming when the stack's `is` contribution - // inside this SCC is structurally trivial. Two cases short-circuit: - // 1. min_is >= 0 with max_is > 0: every node in the SCC lies on some cycle (SCC property), and a cycle - // through a strictly-positive node with all non-negative `is` along it sums to a positive value, so a - // positive cycle exists for this stack. This is the autodiff push-only-in-SCC pattern. - // 2. min_is == 0 == max_is (no `is` contribution in this SCC at all): the DP would only spread the maximum - // entry-side `max_size_at_node_begin` value to every node in S; we do that directly in O(|S|). - int min_is = INT_MAX; - int max_is = INT_MIN; - for (int u : nodes_in_s) { - const int v = is_for_stack[u]; - if (v < min_is) { - min_is = v; - } - if (v > max_is) { - max_is = v; - } - } - if (min_is >= 0 && max_is > 0) { - has_positive_loop = true; - break; - } - if (min_is == 0 && max_is == 0) { - int max_begin = -1; - for (int u : nodes_in_s) { - if (max_size_at_node_begin[u] > max_begin) { - max_begin = max_size_at_node_begin[u]; - } - } - if (max_begin >= 0) { - for (int u : nodes_in_s) { - if (max_size_at_node_begin[u] < max_begin) { - max_size_at_node_begin[u] = max_begin; - } - } - } - } else { - // Mixed-sign case: single-pass dynamic programming on the SCC's forward edges (processed in descending DFS - // finish-time so every forward predecessor is finalized before its successor relaxes), followed by one - // relaxation check on the back-edges. Correctness: every walk inside the SCC decomposes into a forward - // path (along non-back edges) plus zero or more closed cycles formed by a back-edge plus a forward path. - // For SCCs with no positive cycle, traversing a cycle adds a non-positive amount to the running size and - // so cannot improve max_size beyond what the forward DP already computed. The back-edge relaxation check - // after the DP detects the only failure mode (some back-edge would still improve a forward predecessor's - // value, which can only happen if a positive cycle exists for this stack). - for (int u : scc_topo[s]) { - const int begin = max_size_at_node_begin[u]; - if (begin < 0) { - continue; - } - const int exit_val = begin + is_for_stack[u]; - for (int v : next_ids_intra_fwd[u]) { - if (exit_val > max_size_at_node_begin[v]) { - max_size_at_node_begin[v] = exit_val; - } - } - } - for (int u : nodes_in_s) { - const int begin = max_size_at_node_begin[u]; - if (begin < 0) { - continue; - } - const int exit_val = begin + is_for_stack[u]; - for (int v : next_ids_intra_back[u]) { - if (exit_val > max_size_at_node_begin[v]) { - has_positive_loop = true; - break; - } - } - if (has_positive_loop) { - break; - } - } - if (has_positive_loop) { - break; - } - } +// Spread the maximum entry-side `max_size_at_node_begin` value across every node in the SCC. +// Equivalent to running the DP with all-zero `is` weights, in O(|S|). +void spread_max_begin_over_zero_scc(const std::vector &nodes_in_s, + std::vector &max_size_at_node_begin) { + int max_begin = -1; + for (int u : nodes_in_s) { + if (max_size_at_node_begin[u] > max_begin) { + max_begin = max_size_at_node_begin[u]; + } + } + if (max_begin < 0) { + return; + } + for (int u : nodes_in_s) { + if (max_size_at_node_begin[u] < max_begin) { + max_size_at_node_begin[u] = max_begin; + } + } +} + +// Mixed-sign case: single-pass dynamic programming on the SCC's forward edges (processed in +// descending DFS finish-time so every forward predecessor is finalized before its successor +// relaxes), followed by one relaxation check on the back-edges. Correctness: every walk inside +// the SCC decomposes into a forward path plus zero or more closed cycles (back-edge + forward +// path). For SCCs with no positive cycle, traversing a cycle adds a non-positive amount to the +// running size and so cannot improve max_size beyond what the forward DP already computed. The +// back-edge relaxation check after the DP detects the only failure mode (some back-edge would +// still improve a forward predecessor's value, which can only happen if a positive cycle exists +// for this stack). Returns true iff a positive cycle was detected. +bool dp_mixed_sign_cyclic_scc(const std::vector &topo, + const std::vector &nodes_in_s, + const std::vector &is_for_stack, + const std::vector> &next_ids_intra_fwd, + const std::vector> &next_ids_intra_back, + std::vector &max_size_at_node_begin) { + for (int u : topo) { + const int begin = max_size_at_node_begin[u]; + if (begin < 0) { + continue; + } + const int exit_val = begin + is_for_stack[u]; + for (int v : next_ids_intra_fwd[u]) { + if (exit_val > max_size_at_node_begin[v]) { + max_size_at_node_begin[v] = exit_val; + } + } + } + for (int u : nodes_in_s) { + const int begin = max_size_at_node_begin[u]; + if (begin < 0) { + continue; + } + const int exit_val = begin + is_for_stack[u]; + for (int v : next_ids_intra_back[u]) { + if (exit_val > max_size_at_node_begin[v]) { + return true; } + } + } + return false; +} - // SCC has converged for this stack. Update global max_size from each node's max-prefix contribution and relax - // inter-SCC outgoing edges into successor SCCs. - for (int u : nodes_in_s) { - const int begin = max_size_at_node_begin[u]; - if (begin < 0) { - continue; - } - const int prefix = begin + mis_for_stack[u]; - if (prefix > max_size) { - max_size = prefix; - } - const int exit_val = begin + is_for_stack[u]; - for (int v : next_ids_inter[u]) { - if (exit_val > max_size_at_node_begin[v]) { - max_size_at_node_begin[v] = exit_val; - } - } +// SCC has converged for this stack. Update global `max_size` from each node's max-prefix +// contribution, then relax inter-SCC outgoing edges into successor SCCs (predecessors are already +// finalized when an SCC is entered, so each inter-SCC edge is touched exactly once). +void update_global_max_and_relax_inter_scc(const std::vector &nodes_in_s, + const std::vector &is_for_stack, + const std::vector &mis_for_stack, + const std::vector> &next_ids_inter, + std::vector &max_size_at_node_begin, + int &max_size) { + for (int u : nodes_in_s) { + const int begin = max_size_at_node_begin[u]; + if (begin < 0) { + continue; + } + const int prefix = begin + mis_for_stack[u]; + if (prefix > max_size) { + max_size = prefix; + } + const int exit_val = begin + is_for_stack[u]; + for (int v : next_ids_inter[u]) { + if (exit_val > max_size_at_node_begin[v]) { + max_size_at_node_begin[v] = exit_val; } } + } +} - rep_results[rep_sid] = {max_size, has_positive_loop}; +struct AdStackDPResult { + int max_size; + bool has_positive_loop; +}; + +// Run the per-representative DP. Walks the SCC condensation in topological order (sources first, +// since Tarjan emits SCCs in reverse-topological order so source SCCs end up at the largest +// indices). `max_size_at_node_begin` is taken by reference as a scratch buffer reused across reps +// to avoid reallocating per iteration. +AdStackDPResult run_ad_stack_size_dp_for_representative(int start_node, + int num_sccs, + const std::vector &is_for_stack, + const std::vector &mis_for_stack, + const std::vector> &scc_nodes, + const std::vector &scc_is_cyclic, + const std::vector> &scc_topo, + const std::vector> &next_ids_intra_fwd, + const std::vector> &next_ids_intra_back, + const std::vector> &next_ids_inter, + std::vector &max_size_at_node_begin) { + std::fill(max_size_at_node_begin.begin(), max_size_at_node_begin.end(), -1); + max_size_at_node_begin[start_node] = 0; + int max_size = 0; + for (int s = num_sccs - 1; s >= 0; s--) { + const auto &nodes_in_s = scc_nodes[s]; + if (scc_is_cyclic[s]) { + switch (classify_cyclic_scc_fast_path(nodes_in_s, is_for_stack)) { + case CyclicSccFastPath::kPositiveCycle: + return {max_size, /*has_positive_loop=*/true}; + case CyclicSccFastPath::kZeroSpread: + spread_max_begin_over_zero_scc(nodes_in_s, max_size_at_node_begin); + break; + case CyclicSccFastPath::kFallback: + if (dp_mixed_sign_cyclic_scc(scc_topo[s], nodes_in_s, is_for_stack, next_ids_intra_fwd, next_ids_intra_back, + max_size_at_node_begin)) { + return {max_size, /*has_positive_loop=*/true}; + } + break; + } + } + update_global_max_and_relax_inter_scc(nodes_in_s, is_for_stack, mis_for_stack, next_ids_inter, + max_size_at_node_begin, max_size); } + return {max_size, /*has_positive_loop=*/false}; +} - // Broadcast representative results to every active stack. +// Broadcast per-representative DP results to every active stack and apply the resolved +// `max_size`. Stacks with positive cycles are left at `max_size = 0` so the structural +// bounded-loop pre-pass in `irpass::determine_ad_stack_size` gets a chance to derive a symbolic +// bound; if it also cannot, the caller emits a hard compile error (there is no compile-time +// `default_ad_stack_size` fallback). +void apply_ad_stack_dp_results(const std::vector &stacks, + const std::vector &stack_active, + const std::vector &stack_to_rep, + const std::unordered_map &rep_results) { + const int num_stacks = static_cast(stacks.size()); for (int sid = 0; sid < num_stacks; sid++) { AdStackAllocaStmt *stack = stacks[sid]; if (!stack_active[sid]) { - // No push/pop in the CFG: the DP would visit reachable nodes with all-zero edge weights and settle with - // `max_size = 0`, no positive loop. Reproduce that result directly. + // No push/pop in the CFG: the DP would visit reachable nodes with all-zero edge weights and + // settle with `max_size = 0`, no positive loop. Reproduce that result directly. QD_WARN("Unused autodiff stack {} should have been eliminated.", stack->name()); continue; } - const DPResult &res = rep_results[stack_to_rep[sid]]; + const AdStackDPResult &res = rep_results.at(stack_to_rep[sid]); if (res.has_positive_loop) { - // Leave `max_size = 0` so the structural bounded-loop pre-pass in `irpass::determine_ad_stack_size` gets a - // chance to derive a symbolic bound (a statically-bounded inner loop whose push-only body defeats the - // longest-path computation, resolved from the outer ranges). If it also cannot, the caller emits a hard - // compile error - there is no compile-time `default_ad_stack_size` fallback. - } else { - // Since we use |max_size| == 0 for adaptive sizes, we do not want stacks with maximum capacity indeed equal - // to 0. - QD_WARN_IF(res.max_size == 0, "Unused autodiff stack {} should have been eliminated.", stack->name()); - stack->max_size = res.max_size; + // Leave `max_size = 0` so the symbolic-bound pre-pass can take over. + continue; } + // Since we use |max_size| == 0 for adaptive sizes, we do not want stacks with maximum capacity + // indeed equal to 0. + QD_WARN_IF(res.max_size == 0, "Unused autodiff stack {} should have been eliminated.", stack->name()); + stack->max_size = res.max_size; + } +} + +} // namespace + +void ControlFlowGraph::determine_ad_stack_size() { + /** + * Determine the necessary size of every adaptive AD-stack on the control-flow graph (CFG). For each AD-stack we + * compute the maximum running net push count along any walk from the kernel entry. AD-stacks whose forward kernel + * contains a positive cycle (pushes > pops around a loop) are left at `max_size = 0`, and the caller routes them + * through the structural bounded-loop pre-pass for a symbolic `SizeExpr`, hard-erroring if the grammar still + * cannot resolve them. There is no compile-time size fallback. + * + * Implementation notes for compile-time perf on large reverse-mode kernels: + * 1. Per-stack per-node pre-aggregates (`max_increased_size`, `increased_size`) are stored in dense + * `vector>` indexed by a contiguous int stack id, instead of an + * `unordered_map>` -- this removes hash traffic from the hot inner loop. + * 2. Stacks whose `(increased_size, max_increased_size)` row pair is bit-identical share a single dynamic + * programming run -- typical kernels generate one alloca per autodiff variable in the same loop body, so + * most rows collapse to a few representatives. + * 3. The CFG is condensed via Tarjan into strongly connected components (SCCs). DFS finish times recorded + * during the same Tarjan pass split each cyclic SCC's intra-edges into a forward set (target finishes + * before source) and a back set (target finishes at or after source). Per representative we run a + * single-pass dynamic-programming (DP) sweep over the forward edges in descending finish-time order, then + * check the back edges once for positive-cycle relaxation. Correctness: any walk inside an SCC decomposes + * into a forward path plus zero or more cycles, and an SCC with no positive cycle has the same max-walk-sum + * as the back-edge-removed DAG; a positive cycle is exactly the case where some back-edge would still relax + * after the forward DP. This drops the per-cyclic-SCC cost from O(|S| * |E_S|) to O(|S| + |E_S|). + * 4. Two sign-based fast paths short-circuit the DP for trivial cyclic SCCs: an SCC with `min_is >= 0 && max_is + * > 0` for this stack must contain a positive cycle (every node lies on some cycle, and a cycle through a + * strictly-positive node with all non-negative `is` along it sums positive); an SCC with `min_is == 0 == + * max_is` has no `is` contribution at all and is handled by spreading the max entry-side + * `max_size_at_node_begin` to every node in O(|S|). + * Per-rep cost becomes O(V + E + sum_{cyclic S} |S| * |E_S|) (with the SCC sum dropping to O(|S| + |E_S|) for + * the common autodiff push/pop pattern); overall cost is O(V + E + R * (V + E)) with R the number of distinct + * row-pair representatives. + */ + AdStackIndex idx = collect_adaptive_ad_stacks(nodes); + const int num_stacks = static_cast(idx.stacks.size()); + if (num_stacks == 0) { + return; + } + + const int num_nodes = size(); + AdStackPerNodeSizes sizes = accumulate_per_stack_per_node_size_deltas(nodes, idx.stack_id, num_stacks); + std::vector> next_ids = compute_outgoing_node_ids(nodes); + + TarjanResult tarjan = tarjan_scc(next_ids); + const int num_sccs = static_cast(tarjan.scc_nodes.size()); + SccEdgeSets edges = classify_scc_edges(next_ids, tarjan.scc_id, tarjan.dfs_finish); + CyclicSccInfo cyc = identify_cyclic_sccs_and_topo(tarjan.scc_nodes, edges.next_ids_intra_back, tarjan.dfs_finish); + + FingerprintGroups groups = group_stacks_by_fingerprint(num_nodes, num_stacks, sizes.stack_active, sizes.increased_size, + sizes.max_increased_size); + + std::vector max_size_at_node_begin(num_nodes); + std::unordered_map rep_results; + rep_results.reserve(groups.rep_stack_ids.size()); + for (int rep_sid : groups.rep_stack_ids) { + rep_results[rep_sid] = run_ad_stack_size_dp_for_representative( + start_node, num_sccs, sizes.increased_size[rep_sid], sizes.max_increased_size[rep_sid], tarjan.scc_nodes, + cyc.scc_is_cyclic, cyc.scc_topo, edges.next_ids_intra_fwd, edges.next_ids_intra_back, edges.next_ids_inter, + max_size_at_node_begin); } + + apply_ad_stack_dp_results(idx.stacks, sizes.stack_active, groups.stack_to_rep, rep_results); } } // namespace quadrants::lang From 4eb0d3bc94d2c00638a695f8d77fe5df0bff6b66 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 02:15:16 -0700 Subject: [PATCH 03/18] [CFG] Split dead_store_elimination into per-statement helpers Break the 190-line CFGNode::dead_store_elimination into a 30-line top-level driver plus file-local helpers: build_matrix_ptr_alias_maps (alias tables), is_dse_eligible_pointer (uniform eligibility predicate), dse_store_destinations (AD-stack-aware store ptrs), is_store_dead, try_eliminate_dead_store_at (the main eliminate/weaken branch), record_weakened_atomic_as_load, try_eliminate_identical_load_at, and mark_loads_live_in_this_node. The DseAliasMaps and DseLiveState structs replace ad-hoc map pairs threaded through every helper. No behavior change: the atomic-dead-but-not-weakable case retains its original no-state-update fall-through. --- quadrants/ir/control_flow_graph.cpp | 388 ++++++++++++++++------------ 1 file changed, 221 insertions(+), 167 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index 70b57eecd6..e6a55490d8 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -682,193 +682,247 @@ static void update_container_with_alias( update_aliased_stmts(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, container, key, to_erase); } -bool CFGNode::dead_store_elimination(bool after_lower_access) { - bool modified = false; - // Map a variable to its nearest load - std::unordered_map live_load_in_this_node; +namespace { - // For any stmt with TensorType'd address, the address can be either partially - // or fully stored/loaded, which will eventually influence the - // dead-store-elimination strategy - // - // Here we use CFGNode::UseDefineStatus to mark whether a TensorType'd address - // is fully or partially modified. +// === Helpers for CFGNode::dead_store_elimination === + +// Aliasing tables built once per CFGNode pass over the block, then threaded through every +// container update so that touching a `MatrixPtrStmt` propagates to its tensor origin (and back). +struct DseAliasMaps { + // MatrixPtrStmt->origin -> list of MatrixPtrStmts that share that origin + std::unordered_map> tensor_to_matrix_ptrs; + // MatrixPtrStmt -> its origin + std::unordered_map matrix_ptr_to_tensor; +}; + +// Per-pass state mutated in lockstep during the reverse-order walk. `live_in_this_node` and +// `killed_in_this_node` use `UseDefineStatus` to distinguish full vs partial modification of +// tensor-typed addresses; `live_load_in_this_node` maps a variable to its nearest later load +// for identical-load elimination. +struct DseLiveState { + std::unordered_map live_load_in_this_node; std::unordered_map live_in_this_node; std::unordered_map killed_in_this_node; +}; - // Search for aliased addresses - // tensor_to_matrix_ptrs_map: map MatrixPtrStmt->origin to list of - // MatrixPtrStmts - // matrix_ptr_to_tensor_map: map MatrixPtrStmt to - // MatrixPtrStmt->origin - std::unordered_map> tensor_to_matrix_ptrs_map; - std::unordered_map matrix_ptr_to_tensor_map; +DseAliasMaps build_matrix_ptr_alias_maps(Block *block, int begin_location, int end_location) { + DseAliasMaps alias; for (int i = begin_location; i < end_location; i++) { - auto stmt = block->statements[i].get(); - if (stmt->is()) { - auto origin = stmt->as()->origin; - if (tensor_to_matrix_ptrs_map.count(origin) == 0) { - tensor_to_matrix_ptrs_map[origin] = {stmt}; - } else { - tensor_to_matrix_ptrs_map[origin].push_back(stmt); - } - matrix_ptr_to_tensor_map[stmt] = origin; + auto *stmt = block->statements[i].get(); + if (!stmt->is()) { + continue; } + auto *origin = stmt->as()->origin; + alias.tensor_to_matrix_ptrs[origin].push_back(stmt); + alias.matrix_ptr_to_tensor[stmt] = origin; } + return alias; +} - // Reverse order traversal, starting from the last IR to the first IR - for (int i = end_location - 1; i >= begin_location; i--) { - auto stmt = block->statements[i].get(); - if (stmt->is()) { - killed_in_this_node.clear(); - live_load_in_this_node.clear(); - continue; +// Pointer eligibility predicate, applied uniformly to store ptrs, load ptrs, and live-update load +// ptrs. After `lower_access`, only local variables (allocas) and AD-stacks remain analyzable; +// before it, everything is in scope. +bool is_dse_eligible_pointer(Stmt *ptr, bool after_lower_access) { + if (!after_lower_access) { + return true; + } + if (ptr->is() || ptr->is()) { + return true; + } + if (ptr->is()) { + auto *origin = ptr->as()->origin; + if (origin->is() || origin->is()) { + return true; } - auto store_ptrs = irpass::analysis::get_store_destination(stmt); + } + return false; +} - // TODO: Consider AD-stacks in get_store_destination instead of here - // for store-to-load forwarding on AD-stacks - if (auto stack_pop = stmt->cast()) { - store_ptrs = std::vector(1, stack_pop->stack); - } else if (auto stack_push = stmt->cast()) { - store_ptrs = std::vector(1, stack_push->stack); - } else if (auto stack_acc_adj = stmt->cast()) { - store_ptrs = std::vector(1, stack_acc_adj->stack); - } else if (stmt->is()) { - store_ptrs = std::vector(1, stmt); - } +// Compute the store destinations of |stmt|, including AD-stack stmts whose store semantics +// aren't yet captured by `get_store_destination`. +// TODO: Consider AD-stacks in get_store_destination instead of here for store-to-load forwarding +// on AD-stacks. +std::vector dse_store_destinations(Stmt *stmt) { + if (auto *pop = stmt->cast()) { + return {pop->stack}; + } + if (auto *push = stmt->cast()) { + return {push->stack}; + } + if (auto *acc = stmt->cast()) { + return {acc->stack}; + } + if (stmt->is()) { + return {stmt}; + } + return irpass::analysis::get_store_destination(stmt); +} - if (store_ptrs.size() == 1) { - // Dead store elimination - auto store_ptr = *store_ptrs.begin(); +// Is |store_ptr| guaranteed dead at this point in the reverse-order walk? +// - !may_contain_variable(live_in_this_node, store_ptr): not loaded after this store in-node +// - contain_variable(killed_in_this_node, store_ptr): already overwritten in-node, OR +// - !may_contain_variable(live_out, store_ptr): not used in any successor node +bool is_store_dead(Stmt *store_ptr, + const std::unordered_set &live_out, + const DseLiveState &state) { + bool is_used_in_next_nodes = false; + for (auto *ptr : irpass::analysis::include_aliased_stmts(store_ptr)) { + is_used_in_next_nodes |= CFGNode::may_contain_variable(live_out, ptr); + } + const bool is_killed_in_current_node = CFGNode::contain_variable(state.killed_in_this_node, store_ptr); + bool is_dead = is_killed_in_current_node || !is_used_in_next_nodes; + is_dead &= !CFGNode::may_contain_variable(state.live_in_this_node, store_ptr); + return is_dead; +} - if (!after_lower_access || - (store_ptr->is() && store_ptr->as()->origin->is()) || - (store_ptr->is() && store_ptr->as()->origin->is()) || - (store_ptr->is() || store_ptr->is())) { - // !may_contain_variable(live_in_this_node, store_ptr): address is not - // loaded after this store - // contain_variable(killed_in_this_node, store_ptr): address is already - // stored by another store stmt in this node (thus killed) - // !may_contain_variable(live_out, store_ptr): address is not used - // in the next nodes - bool is_used_in_next_nodes = false; - for (auto ptr : irpass::analysis::include_aliased_stmts(store_ptr)) { - is_used_in_next_nodes |= may_contain_variable(live_out, ptr); - } +// On dead-store elimination of an `AtomicOpStmt`, the store part is dropped but the load +// (= the atomic's return value) remains; record the load and mark the dest killed/loaded. +void record_weakened_atomic_as_load(Stmt *dest, Stmt *new_load, const DseAliasMaps &alias, DseLiveState &state) { + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.live_in_this_node, dest, + false); + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.killed_in_this_node, dest, + true); + state.live_load_in_this_node[dest] = new_load; +} - bool is_killed_in_current_node = contain_variable(killed_in_this_node, store_ptr); - bool is_dead = is_killed_in_current_node || !is_used_in_next_nodes; - is_dead &= !may_contain_variable(live_in_this_node, store_ptr); - if (!stmt->is() && !stmt->is() && !stmt->is() && is_dead) { - // If an address is neither used in this node, nor used in the next - // nodes, then we can consider eliminating any stores to this address - // (it's not used anyway). There's two different scenerios though: - // 1. Any direct store stmt can be eliminated immediately (LocalStore, - // GlobalStore, AdStackPush, ...) - // 2. AtomicStmt (load + store): remove the store part, thus - // converting the AtomicStmt into a LoadStmt - if (!stmt->is()) { - // Eliminate the dead store. - erase(i); - modified = true; - continue; - } - auto atomic = stmt->cast(); - // Weaken the atomic operation to a load. - if (atomic->dest->is()) { - auto local_load = Stmt::make(atomic->dest); - local_load->ret_type = atomic->ret_type; - // Notice that we have a load here - // (the return value of AtomicOpStmt). - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, live_in_this_node, - atomic->dest, false); - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, killed_in_this_node, - atomic->dest, true); - live_load_in_this_node[atomic->dest] = local_load.get(); - - replace_with(i, std::move(local_load), true); - modified = true; - continue; - } else if (!is_parallel_executed || - (atomic->dest->is() && atomic->dest->as()->snode->is_scalar())) { - // If this node is parallel executed, we can't weaken a global - // atomic operation to a global load. - // TODO: we can weaken it if it's element-wise (i.e. never - // accessed by other threads). - auto global_load = Stmt::make(atomic->dest); - global_load->ret_type = atomic->ret_type; - // Notice that we have a load here - // (the return value of AtomicOpStmt). - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, live_in_this_node, - atomic->dest, false); - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, killed_in_this_node, - atomic->dest, true); - live_load_in_this_node[atomic->dest] = global_load.get(); - - replace_with(i, std::move(global_load), true); - modified = true; - continue; - } - } else { - // A non-eliminated store. - // Insert to killed_in_this_node if it's stored in this node. - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, killed_in_this_node, - store_ptr, false); - - // Remove the address from live_in_this_node if it's stored in this - // node. - auto old_live_in_this_node = live_in_this_node; - for (auto &var : old_live_in_this_node) { - if (irpass::analysis::definitely_same_address(store_ptr, var.first)) { - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, live_in_this_node, - store_ptr, true); - } - } - } - } +// Try to eliminate a dead store (or weaken a dead-store atomic to a pure load) at position |i|. +// If no elimination applies, fall through to update killed/live state for this non-eliminated +// store and return false; the caller will then move on to load handling for the same stmt. +bool try_eliminate_dead_store_at(CFGNode *node, + int i, + Stmt *stmt, + Stmt *store_ptr, + bool after_lower_access, + const DseAliasMaps &alias, + DseLiveState &state) { + if (!is_dse_eligible_pointer(store_ptr, after_lower_access)) { + return false; + } + const bool is_dead = is_store_dead(store_ptr, node->live_out, state); + const bool stmt_eliminable = + !stmt->is() && !stmt->is() && !stmt->is(); + if (stmt_eliminable && is_dead) { + // If an address is neither used in this node, nor in the next nodes, eliminate any stores to + // it. Two scenarios: + // 1. Direct store stmts (LocalStore, GlobalStore, AdStackPush, ...): erase outright. + // 2. AtomicOpStmt (load+store): drop the store half, leaving a pure load. + if (!stmt->is()) { + node->erase(i); + return true; } - auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty()) { - // Identical load elimination - auto load_ptr = load_ptrs.begin()[0]; + auto *atomic = stmt->cast(); + if (atomic->dest->is()) { + auto local_load = Stmt::make(atomic->dest); + local_load->ret_type = atomic->ret_type; + record_weakened_atomic_as_load(atomic->dest, local_load.get(), alias, state); + node->replace_with(i, std::move(local_load), true); + return true; + } + // If this node is parallel executed, we can't weaken a global atomic to a global load. + // TODO: we can weaken it if it's element-wise (i.e. never accessed by other threads). + const bool atomic_global_weakable = + !node->is_parallel_executed || + (atomic->dest->is() && atomic->dest->as()->snode->is_scalar()); + if (atomic_global_weakable) { + auto global_load = Stmt::make(atomic->dest); + global_load->ret_type = atomic->ret_type; + record_weakened_atomic_as_load(atomic->dest, global_load.get(), alias, state); + node->replace_with(i, std::move(global_load), true); + return true; + } + // Atomic was dead but not safely weakable (parallel global, non-scalar). The original code + // intentionally leaves state alone in this case (the atomic still executes, but state is not + // updated for it). Preserve that behavior. + return false; + } + // Non-eliminated store (not dead, or stmt not eliminable): update state. Insert into killed, + // and drop any live-in entry that aliases the same address. + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.killed_in_this_node, + store_ptr, false); + auto old_live_in = state.live_in_this_node; + for (auto &var : old_live_in) { + if (irpass::analysis::definitely_same_address(store_ptr, var.first)) { + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.live_in_this_node, + store_ptr, true); + } + } + return false; +} - if (!after_lower_access || - (load_ptr->is() && load_ptr->as()->origin->is()) || - (load_ptr->is() && load_ptr->as()->origin->is()) || - (load_ptr->is() || load_ptr->is())) { - // live_load_in_this_node[addr]: tracks the - // next load to the same address - // "!may_contain_variable(killed_in_this_node, load_ptr)": means it's - // not been stored in between the two loads - if (live_load_in_this_node.find(load_ptr) != live_load_in_this_node.end() && - !may_contain_variable(killed_in_this_node, load_ptr)) { - // Only perform identical load elimination within a CFGNode. - auto next_load_stmt = live_load_in_this_node[load_ptr]; - if (irpass::analysis::same_statements(stmt, next_load_stmt)) { - next_load_stmt->replace_usages_with(stmt); - erase(block->locate(next_load_stmt)); - modified = true; - } - } +// Try to eliminate an identical (redundant) load at |stmt|. We only do this within a single +// CFGNode, and only if no store has intervened since the prior load. The prior load (later in the +// block, since we walk in reverse) is the one erased; |stmt| takes over its usages. +bool try_eliminate_identical_load_at(CFGNode *node, + Stmt *stmt, + Stmt *load_ptr, + bool after_lower_access, + const DseAliasMaps &alias, + DseLiveState &state) { + if (!is_dse_eligible_pointer(load_ptr, after_lower_access)) { + return false; + } + bool modified = false; + auto it = state.live_load_in_this_node.find(load_ptr); + if (it != state.live_load_in_this_node.end() && + !CFGNode::may_contain_variable(state.killed_in_this_node, load_ptr)) { + auto *next_load_stmt = it->second; + if (irpass::analysis::same_statements(stmt, next_load_stmt)) { + next_load_stmt->replace_usages_with(stmt); + node->erase(node->block->locate(next_load_stmt)); + modified = true; + } + } + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.killed_in_this_node, + load_ptr, true); + state.live_load_in_this_node[load_ptr] = stmt; + return modified; +} - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, killed_in_this_node, load_ptr, - true); - live_load_in_this_node[load_ptr] = stmt; - } +// Mark all (eligible) loads of |stmt| as live-in-this-node so that earlier-in-reverse-order stores +// to the same address see them and abort their dead-store check. +void mark_loads_live_in_this_node(const std::vector &load_ptrs, + bool after_lower_access, + const DseAliasMaps &alias, + DseLiveState &state) { + for (auto *load_ptr : load_ptrs) { + if (is_dse_eligible_pointer(load_ptr, after_lower_access)) { + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.live_in_this_node, + load_ptr, false); } + } +} - // Update live_in_this_node - for (auto &load_ptr : load_ptrs) { - if (!after_lower_access || - (load_ptr->is() && load_ptr->as()->origin->is()) || - (load_ptr->is() && load_ptr->as()->origin->is()) || - (load_ptr->is() || load_ptr->is())) { - // Addr is used in this node, so it's live in this node - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, live_in_this_node, load_ptr, - false); +} // namespace + +bool CFGNode::dead_store_elimination(bool after_lower_access) { + // Reverse-order walk over this node's statements. At each statement we may (a) eliminate a dead + // store, (b) eliminate an identical load, or (c) update live/killed state for later iterations. + // `FuncCallStmt` is treated as a full barrier: it can read/write anything, so we drop the kill + // and live-load tables and start fresh. + const DseAliasMaps alias = build_matrix_ptr_alias_maps(block, begin_location, end_location); + DseLiveState state; + bool modified = false; + for (int i = end_location - 1; i >= begin_location; i--) { + auto *stmt = block->statements[i].get(); + if (stmt->is()) { + state.killed_in_this_node.clear(); + state.live_load_in_this_node.clear(); + continue; + } + const auto store_ptrs = dse_store_destinations(stmt); + if (store_ptrs.size() == 1) { + if (try_eliminate_dead_store_at(this, i, stmt, store_ptrs.front(), after_lower_access, alias, state)) { + modified = true; + continue; + } + } + const auto load_ptrs = irpass::analysis::get_load_pointers(stmt); + if (load_ptrs.size() == 1 && store_ptrs.empty()) { + if (try_eliminate_identical_load_at(this, stmt, load_ptrs.front(), after_lower_access, alias, state)) { + modified = true; } } + mark_loads_live_in_this_node(load_ptrs, after_lower_access, alias, state); } return modified; } From 3bfb38344e7f6034a72af5b54b943e99a917c15d Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 02:17:22 -0700 Subject: [PATCH 04/18] [CFG] Split get_store_forwarding_data into three phase helpers Break CFGNode::get_store_forwarding_data into three private methods plus a thin top-level driver: find_intra_block_last_def handles the intra-block scan with the quant and MatrixInitStmt special cases, find_cross_block_def folds reach_in / reach_gen definitions through update_forwarding_result and returns the last_def_position as optional (nullopt = abort, value = position or -1), any_aliased_store_breaks_forwarding deduplicates the aliased-store between-check that previously appeared verbatim at both the intra- and cross-block exit paths. The top-level body is now a ~25-line driver. No behavior change. --- quadrants/ir/control_flow_graph.cpp | 152 +++++++++++++--------------- quadrants/ir/control_flow_graph.h | 25 +++++ 2 files changed, 96 insertions(+), 81 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index e6a55490d8..1c8a32437a 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -248,16 +248,11 @@ bool CFGNode::update_forwarding_result(Stmt *stmt, return true; } -// var: dest_addr -Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { - // Return the stored data if all definitions in the UD-chain of |var| at - // this position store the same data. - // [Intra-block Search] - int last_def_position = -1; +int CFGNode::find_intra_block_last_def(Stmt *var, int position) const { for (int i = position - 1; i >= begin_location; i--) { // Find previous store stmt to the same dest_addr, stop at the closest one. // store_ptr: prev-store dest_addr - for (auto store_ptr : irpass::analysis::get_store_destination(block->statements[i].get())) { + for (auto *store_ptr : irpass::analysis::get_store_destination(block->statements[i].get())) { // Exclude `store_ptr` as a potential store destination due to mixed // semantics of store statements for quant types. The store operation // involves implicit casting before storing, which may result in a loss of @@ -273,8 +268,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // TODO: Still forward the store if the value can be statically proven to // fit into the quant type. if (!is_quant(store_ptr->ret_type.ptr_removed()) && irpass::analysis::definitely_same_address(var, store_ptr)) { - last_def_position = i; - break; + return i; } // Special case: @@ -284,107 +278,103 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // $3 = load $2 // We can forward MatrixInitStmt->values[offset] to $3 if (var->is() && var->as()->offset->is()) { - auto var_origin = var->as()->origin; - // Check for same origin address + auto *var_origin = var->as()->origin; if (irpass::analysis::definitely_same_address(var_origin, store_ptr)) { - // Check for MatrixInitStmt Stmt *store_data = irpass::analysis::get_store_data(block->statements[i].get()); if (store_data->is()) { - last_def_position = i; - break; + return i; } } } } - if (last_def_position != -1) { - break; - } } + return -1; +} - // Check for aliased address - // There's a store to the same dest_addr before this stmt - if (last_def_position != -1) { - // result: the value to store - Stmt *result = irpass::analysis::get_store_data(block->statements[last_def_position].get()); - bool is_tensor_involved = var->ret_type.ptr_removed()->is(); - if (!(var->is() && !is_tensor_involved)) { - // In between the store stmt and current stmt, - // if there's a third-stmt that "may" have stored a "different value" to - // the "same dest_addr", then we can't forward the stored data. - for (int i = last_def_position + 1; i < position; i++) { - if (!irpass::analysis::same_value(result, irpass::analysis::get_store_data(block->statements[i].get()))) { - if (may_contain_address(block->statements[i].get(), var)) { - return nullptr; - } - } +std::optional CFGNode::find_cross_block_def(Stmt *var, + int position, + Stmt *&result, + bool &result_visible) const { + int last_def_position = -1; + // [Global Addr only] Stores reaching the entry of this node from previous blocks. `var == stmt` + // is for the case that a global ptr is never stored; in that case `stmt` comes from + // nodes[start_node]->reach_gen. + for (auto *stmt : reach_in) { + if (var == stmt || may_contain_address(stmt, var)) { + if (!update_forwarding_result(stmt, position, result, result_visible)) { + return std::nullopt; } + last_def_position = 0; } - return result; } + // Stores generated within this node (in reach_gen) that precede |position|. + for (auto *stmt : reach_gen) { + if (may_contain_address(stmt, var) && stmt->parent->locate(stmt) < position) { + if (!update_forwarding_result(stmt, position, result, result_visible)) { + return std::nullopt; + } + last_def_position = stmt->parent->locate(stmt); + } + } + return last_def_position; +} - // [Cross-block search] - // Search for store to the same dest_addr in reach_in and reach_gen - Stmt *result = nullptr; - bool result_visible = false; - - // [Global Addr only] - // test whether there's a store to the same dest_addr in a previous block. - // if the store values are the same, then return the value - last_def_position = -1; - for (auto stmt : reach_in) { - // var == stmt is for the case that a global ptr is never stored. - // In this case, stmt is from nodes[start_node]->reach_gen. - if (var == stmt || may_contain_address(stmt, var)) { - if (!update_forwarding_result(stmt, position, result, result_visible)) - return nullptr; - else - last_def_position = 0; +bool CFGNode::any_aliased_store_breaks_forwarding(Stmt *result, Stmt *var, int from, int to_exclusive) const { + // Allocas without tensor type cannot be aliased through MatrixPtrStmt, so the check is moot. + const bool is_tensor_involved = var->ret_type.ptr_removed()->is(); + if (var->is() && !is_tensor_involved) { + return false; + } + for (int i = from; i < to_exclusive; i++) { + auto *s = block->statements[i].get(); + if (!irpass::analysis::same_value(result, irpass::analysis::get_store_data(s))) { + if (may_contain_address(s, var)) { + return true; + } } } + return false; +} - // test whether there's a store to the same dest_addr before this stmt (in - // reach_gen) - // if the store values are the same, then return the value - for (auto stmt : reach_gen) { - if (may_contain_address(stmt, var) && stmt->parent->locate(stmt) < position) { - if (!update_forwarding_result(stmt, position, result, result_visible)) - return nullptr; - else - last_def_position = stmt->parent->locate(stmt); +// var: dest_addr. Return the stored data if all definitions in the UD-chain of |var| at this +// position store the same data; otherwise nullptr. Looks intra-block first (cheaper, dominates +// cross-block forwarding when present), then falls back to the cross-block search over reach_in +// and reach_gen. In both cases an intervening aliased store that may write a different value +// breaks the forward. +Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { + // [Intra-block search] Walks backwards in this node's block. + if (int last_def = find_intra_block_last_def(var, position); last_def != -1) { + Stmt *result = irpass::analysis::get_store_data(block->statements[last_def].get()); + if (any_aliased_store_breaks_forwarding(result, var, last_def + 1, position)) { + return nullptr; } + return result; + } + + // [Cross-block search] Walks reach_in / reach_gen, accumulating a single forwardable result. + Stmt *result = nullptr; + bool result_visible = false; + std::optional last_def = find_cross_block_def(var, position, result, result_visible); + if (!last_def.has_value()) { + return nullptr; } if (!result) { // The UD-chain is empty. - auto offending_load = block->statements[position].get(); + auto *offending_load = block->statements[position].get(); ErrorEmitter(QuadrantsIrWarning(), offending_load, fmt::format("Loading variable {} before anything is stored to it.", var->id)); return nullptr; } if (!result_visible) { - // The data is store-to-load forwardable but not visible at the place we - // are going to forward. We cannot forward it in this case. + // The data is store-to-load forwardable but not visible at the place we are going to forward. return nullptr; } - - if (last_def_position == -1) + if (*last_def == -1) { + return nullptr; + } + if (any_aliased_store_breaks_forwarding(result, var, *last_def, position)) { return nullptr; - - // Check for aliased address - // There's a store to the same dest_addr before this stmt - bool is_tensor_involved = var->ret_type.ptr_removed()->is(); - if (!(var->is() && !is_tensor_involved)) { - // In between the store stmt and current stmt, - // if there's a third-stmt that "may" have stored a "different value" to - // the "same dest_addr", then we can't forward the stored data. - for (int i = last_def_position; i < position; i++) { - if (!irpass::analysis::same_value(result, irpass::analysis::get_store_data(block->statements[i].get()))) { - if (may_contain_address(block->statements[i].get(), var)) { - return nullptr; - } - } - } } - return result; } diff --git a/quadrants/ir/control_flow_graph.h b/quadrants/ir/control_flow_graph.h index 9735508ce8..01e532f7e5 100644 --- a/quadrants/ir/control_flow_graph.h +++ b/quadrants/ir/control_flow_graph.h @@ -104,6 +104,31 @@ class CFGNode { int position, Stmt *&result, bool &result_visible) const; + + // Helper for get_store_forwarding_data: walk this node's block backwards + // from |position| and return the index of the most recent store to |var|, + // or -1 if none is in this block. Handles the quant-store exclusion plus the + // MatrixInitStmt-via-MatrixPtrStmt forwarding special case. + int find_intra_block_last_def(Stmt *var, int position) const; + + // Helper for get_store_forwarding_data: scan |reach_in| and |reach_gen| for + // definitions of |var| reaching |position|, folding each into |result| / + // |result_visible| via update_forwarding_result. Returns nullopt if any + // visited def is unforwardable (caller must return nullptr); otherwise the + // last_def_position (0 if only reach_in matched, an in-block index if + // reach_gen matched, -1 if no eligible def was found). + std::optional find_cross_block_def(Stmt *var, + int position, + Stmt *&result, + bool &result_visible) const; + + // Helper for get_store_forwarding_data: scan block statements in + // [from, to_exclusive) for a store that may write a different value to an + // address aliasing |var|. Returns true iff such a store exists (so + // forwarding |result| must abort). The check is skipped (returns false) for + // non-tensor alloca destinations, where aliasing through MatrixPtrStmt + // cannot apply. + bool any_aliased_store_breaks_forwarding(Stmt *result, Stmt *var, int from, int to_exclusive) const; }; class ControlFlowGraph { From b1e1726e486b0e33f82526b04a161051c440dda8 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 02:19:30 -0700 Subject: [PATCH 05/18] [CFG] Split store_to_load_forwarding into forward + eliminate helpers Break the single-loop CFGNode::store_to_load_forwarding into two private methods (try_forward_load_at, try_eliminate_identical_store_at) plus a ~15-line driver. Forwarding takes precedence and short-circuits the identical-store pass for the same stmt; both helpers manage |i| and |modified| in lockstep with the legacy semantics, including the replace-with-zero-does-not-flip-modified quirk and the alloca-zero special case. No behavior change. --- quadrants/ir/control_flow_graph.cpp | 190 ++++++++++++++-------------- quadrants/ir/control_flow_graph.h | 19 +++ 2 files changed, 114 insertions(+), 95 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index 1c8a32437a..78e10cb807 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -404,109 +404,109 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { } } -bool CFGNode::store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled) { - // Contains two separate parts: - // 1. Store-to-load Forwarding: for each load stmt, find the closest previous - // store stmt - // that stores to the same address as the load stmt, then replace - // load with the "val". - // 2. Identical Store Elimination: for each store stmt, find the closest - // previous store stmt - // that stores to the same address as the store stmt. If the "val"s - // are the same, then remove the store stmt. - bool modified = false; - for (int i = begin_location; i < end_location; i++) { - // Store-to-load forwarding - auto stmt = block->statements[i].get(); - - // result: the value to be store/load - Stmt *result = nullptr; - - // [get_store_forwarding_data] find the store stmt that: - // 1. stores to the same address and as the load stmt - // 2. (one value at a time) closest to the load stmt but before the load - // stmt - Stmt *load_src = nullptr; - if (auto local_load = stmt->cast()) { - result = get_store_forwarding_data(local_load->src, i); - load_src = local_load->src; - } else if (auto global_load = stmt->cast()) { - if (!after_lower_access && !autodiff_enabled) { - result = get_store_forwarding_data(global_load->src, i); - load_src = global_load->src; - } +bool CFGNode::try_forward_load_at(int &i, Stmt *stmt, bool after_lower_access, bool autodiff_enabled, bool &modified) { + // Find the closest preceding store-to-same-address. For GlobalLoadStmt the forwarder is gated: + // after lower_access or under autodiff we don't trust cross-block forwarding to be sound. + Stmt *load_src = nullptr; + Stmt *result = nullptr; + if (auto *local_load = stmt->cast()) { + load_src = local_load->src; + result = get_store_forwarding_data(load_src, i); + } else if (auto *global_load = stmt->cast()) { + if (!after_lower_access && !autodiff_enabled) { + load_src = global_load->src; + result = get_store_forwarding_data(load_src, i); } - - // [Apply Load-Store-Forwarding] - // replace load stmt with the value-"result" - if (result) { - // Forward the stored data |result|. - if (result->is()) { - // TensorType does not apply to this special case - if (result->ret_type.ptr_removed()->is()) - continue; - - // special case of alloca (initialized to 0) - auto zero = Stmt::make(TypedConstant(result->ret_type.ptr_removed(), 0)); - replace_with(i, std::move(zero), true); - } else { - if (result->ret_type.ptr_removed()->is() && !stmt->ret_type->is()) { - QD_ASSERT(load_src->is() && load_src->as()->offset->is()); - QD_ASSERT(result->is()); - - int offset = load_src->as()->offset->as()->val.val_int32(); - - result = result->as()->values[offset]; - } - - stmt->replace_usages_with(result); - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop - modified = true; - } - continue; + } + if (!result) { + return false; + } + if (result->is()) { + // TensorType does not apply to this special case; skip further handling for this stmt. + if (result->ret_type.ptr_removed()->is()) { + return true; } + // Alloca initialized to 0: replace the load with a zero const. + // Note: |modified| is intentionally NOT flipped here (preserved from legacy behavior). + auto zero = Stmt::make(TypedConstant(result->ret_type.ptr_removed(), 0)); + replace_with(i, std::move(zero), true); + return true; + } + // Non-alloca result: forward it. Extract a MatrixInitStmt element when the forwarded data is + // a TensorType-typed init but the load is for a scalar slot. + if (result->ret_type.ptr_removed()->is() && !stmt->ret_type->is()) { + QD_ASSERT(load_src->is() && load_src->as()->offset->is()); + QD_ASSERT(result->is()); + const int offset = load_src->as()->offset->as()->val.val_int32(); + result = result->as()->values[offset]; + } + stmt->replace_usages_with(result); + erase(i); // end_location-- + i--; // cancel the for-loop's i++ + modified = true; + return true; +} - // [Identical store elimination] - // find the store stmt that: - // 1. stores to the same address as the current store stmt - // 2. has the same store value as the current store stmt - // 3. (one value at a time) closest to the current store stmt but before the - // current store stmt then erase the current store stmt - if (auto local_store = stmt->cast()) { - result = get_store_forwarding_data(local_store->dest, i); - if (result && result->is() && !autodiff_enabled) { - // TensorType does not apply to this special case - if (result->ret_type.ptr_removed()->is()) { - continue; - } - - // special case of alloca (initialized to 0) - if (auto stored_data = local_store->val->cast()) { - if (stored_data->val.equal_value(0)) { - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop - modified = true; - } - } - } else { - // not alloca - if (irpass::analysis::same_value(result, local_store->val)) { - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop - modified = true; - } +void CFGNode::try_eliminate_identical_store_at(int &i, + Stmt *stmt, + bool after_lower_access, + bool autodiff_enabled, + bool &modified) { + // Eliminate a store whose value is provably identical to the closest preceding store to the + // same address. For local stores under non-autodiff there's also an alloca-zero special case: + // writing a zero to a freshly-allocated alloca is redundant. + if (auto *local_store = stmt->cast()) { + Stmt *result = get_store_forwarding_data(local_store->dest, i); + if (result && result->is() && !autodiff_enabled) { + // TensorType does not apply to this special case. + if (result->ret_type.ptr_removed()->is()) { + return; } - } else if (auto global_store = stmt->cast()) { - if (!after_lower_access) { - result = get_store_forwarding_data(global_store->dest, i); - if (irpass::analysis::same_value(result, global_store->val)) { - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop + if (auto *stored = local_store->val->cast()) { + if (stored->val.equal_value(0)) { + erase(i); + i--; modified = true; } } + return; + } + if (irpass::analysis::same_value(result, local_store->val)) { + erase(i); + i--; + modified = true; + } + return; + } + if (auto *global_store = stmt->cast()) { + if (after_lower_access) { + return; + } + Stmt *result = get_store_forwarding_data(global_store->dest, i); + if (irpass::analysis::same_value(result, global_store->val)) { + erase(i); + i--; + modified = true; + } + } +} + +bool CFGNode::store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled) { + // Two passes fused into one forward walk: + // 1. Store-to-load forwarding: replace a load with the value from the closest preceding store + // to the same address. + // 2. Identical-store elimination: erase a store whose value is identical to the closest + // preceding store to the same address. + // The forwarding pass takes precedence: if we forwarded a load, we don't also try to treat the + // same stmt as a store. erase() shrinks end_location, so per-elimination we both `erase(i)` and + // decrement `i` to keep the for-loop pointing at the right next stmt. + bool modified = false; + for (int i = begin_location; i < end_location; i++) { + auto *stmt = block->statements[i].get(); + if (try_forward_load_at(i, stmt, after_lower_access, autodiff_enabled, modified)) { + continue; } + try_eliminate_identical_store_at(i, stmt, after_lower_access, autodiff_enabled, modified); } return modified; } diff --git a/quadrants/ir/control_flow_graph.h b/quadrants/ir/control_flow_graph.h index 01e532f7e5..bd34d07817 100644 --- a/quadrants/ir/control_flow_graph.h +++ b/quadrants/ir/control_flow_graph.h @@ -129,6 +129,25 @@ class CFGNode { // non-tensor alloca destinations, where aliasing through MatrixPtrStmt // cannot apply. bool any_aliased_store_breaks_forwarding(Stmt *result, Stmt *var, int from, int to_exclusive) const; + + // Helper for store_to_load_forwarding: if |stmt| is a load whose source has + // a forwardable preceding store, replace it. Returns true iff the load was + // handled (the caller must skip identical-store elimination for this stmt + // even if no IR change happened, matching the legacy fall-through). |i| is + // decremented on erase so the for-loop's natural `i++` lands on the next + // unread stmt; |modified| is flipped on erase only (preserved from legacy: + // replace-with-zero does not flip it). + bool try_forward_load_at(int &i, Stmt *stmt, bool after_lower_access, bool autodiff_enabled, bool &modified); + + // Helper for store_to_load_forwarding: if |stmt| is a store identical to a + // preceding store to the same address, erase it. Handles the alloca-init-0 + // special case under non-autodiff. Same |i|/|modified| accounting as + // try_forward_load_at. + void try_eliminate_identical_store_at(int &i, + Stmt *stmt, + bool after_lower_access, + bool autodiff_enabled, + bool &modified); }; class ControlFlowGraph { From 86c01a62d0010deb52b21dd1e5462e537afcbb5b Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 02:20:57 -0700 Subject: [PATCH 06/18] [CFG] Split reaching_definition_analysis into seed + kill-check helpers Extract two file-local helpers (is_external_input_pointer, seed_start_node_reach_gen) for the entry-node reach_gen seeding phase and one (is_reach_in_stmt_killed_at) for the per-stmt kill test inside the worklist loop. The top-level body is now a ~40-line worklist algorithm reading top-to-bottom: seed, per-node init, propagate. The docstring is condensed but preserves the algorithmic contract. --- quadrants/ir/control_flow_graph.cpp | 143 ++++++++++++++++------------ 1 file changed, 81 insertions(+), 62 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index 78e10cb807..d179ceeffb 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -1025,54 +1025,88 @@ void ControlFlowGraph::dump_graph_to_file(const CompileConfig &config, QD_INFO("CFG dumped to: {}", filename.string()); } +namespace { + +// === Helpers for ControlFlowGraph::reaching_definition_analysis === + +// Statements that carry data into this kernel from outside its CFG. Treated as definitions in the +// synthetic entry node's reach_gen so cross-block forwarding correctly sees them. After access +// lowering, only MatrixPtrStmt aliases of allocas/matrix-pointers remain in scope; before, the +// full menagerie of pointer flavors applies. +// TODO: unify them. +bool is_external_input_pointer(const Stmt *stmt, bool after_lower_access) { + if (stmt->is()) { + const auto *origin = stmt->as()->origin; + if (origin->is() || origin->is()) { + return true; + } + } + if (after_lower_access) { + return false; + } + return stmt->is() || stmt->is() || stmt->is() || + stmt->is() || stmt->is() || stmt->is() || + stmt->is() || stmt->is() || stmt->is(); +} + +// Seed the synthetic entry node's reach_gen with definitions that "exist" before this kernel +// begins executing: external pointer loads (so a load before the first store still has a +// UD-chain) plus FuncCallStmt store destinations (the callee may have written anything declared +// in its `store_dests`). +void seed_start_node_reach_gen(CFGNode *start_node, + const std::vector> &nodes, + bool after_lower_access) { + start_node->reach_gen.clear(); + start_node->reach_kill.clear(); + for (const auto &node : nodes) { + for (int j = node->begin_location; j < node->end_location; j++) { + auto *stmt = node->block->statements[j].get(); + if (is_external_input_pointer(stmt, after_lower_access)) { + start_node->reach_gen.insert(stmt); + } else if (auto *func_call = stmt->cast()) { + const auto &dests = func_call->func->store_dests; + start_node->reach_gen.insert(dests.begin(), dests.end()); + } + } + } +} + +// Is |stmt| killed at |node| (i.e., excluded from reach_out)? For stmts with explicit store +// destinations, killed iff every dest is killed at |node|. For stmts without dests (e.g. raw +// global pointers seeded into start_node's reach_gen), killed iff the stmt itself is killed. +bool is_reach_in_stmt_killed_at(CFGNode *node, Stmt *stmt) { + const auto store_ptrs = irpass::analysis::get_store_destination(stmt); + if (store_ptrs.empty()) { + return node->reach_kill_variable(stmt); + } + for (auto *store_ptr : store_ptrs) { + if (!node->reach_kill_variable(store_ptr)) { + return false; + } + } + return true; +} + +} // namespace + void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { - // Prerequisite analysis for load-store-forwarding to help determine - // cross-block use-define chain - // - // The algorithm is separated into two parts: - // 1. Determine reach_gen and reach_kill within each node - // 2. Propagate reach_in and reach_out through the graph - // - // - reach_gen: instruction that defines a variable (store stmts) in the - // current node - // - reach_kill: address (GlobalPtrStmt, AllocaStmt, ...) that's been defined - // (stored to) in the current node + // Prerequisite analysis for load-store-forwarding; computes the cross-block use-define chain. // - // In general, reach_gen and reach_kill are the same except that reach_gen - // tracks the store stmts and reach_kill tracks the address + // Per-node: + // - reach_gen: store stmts that define a variable in this node. + // - reach_kill: addresses (GlobalPtrStmt, AllocaStmt, ...) stored to in this node. + // (reach_gen tracks the defining stmts; reach_kill tracks the killed addresses.) // - // - reach_out: reach_gen + { reach_in's dest not in reach_kill } - // - reach_in: collection of all the reach_out of previous nodes - // - // reach_out and reach_in is the ultimate result that helps analyze - // cross-block use-define chain - + // Per-graph (worklist fixpoint): + // - reach_in: union of reach_out of all predecessor nodes. + // - reach_out: reach_gen + { stmts from reach_in whose dest is not in reach_kill }. QD_AUTO_PROF; const int num_nodes = size(); - std::queue to_visit; - std::unordered_map in_queue; QD_ASSERT(nodes[start_node]->empty()); - nodes[start_node]->reach_gen.clear(); - nodes[start_node]->reach_kill.clear(); - for (int i = 0; i < num_nodes; i++) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - auto stmt = nodes[i]->block->statements[j].get(); - if ((stmt->is() && stmt->as()->origin->is()) || - (stmt->is() && stmt->as()->origin->is()) || - (!after_lower_access && - (stmt->is() || stmt->is() || stmt->is() || - stmt->is() || stmt->is() || stmt->is() || - stmt->is() || stmt->is() || stmt->is()))) { - // TODO: unify them - // A global pointer that may contain some data before this kernel. - nodes[start_node]->reach_gen.insert(stmt); - } else if (auto func_call = stmt->cast()) { - const auto &dests = func_call->func->store_dests; - nodes[start_node]->reach_gen.insert(dests.begin(), dests.end()); - } - } - } + seed_start_node_reach_gen(nodes[start_node].get(), nodes, after_lower_access); + std::queue to_visit; + std::unordered_map in_queue; for (int i = 0; i < num_nodes; i++) { if (i != start_node) { nodes[i]->reaching_definition_analysis(after_lower_access); @@ -1083,40 +1117,25 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { in_queue[nodes[i].get()] = true; } - // [The worklist algorithm] - // Determines reach_in and reach_out for each node iteratively. + // [Worklist algorithm] Converge reach_in / reach_out iteratively. while (!to_visit.empty()) { - auto now = to_visit.front(); + auto *now = to_visit.front(); to_visit.pop(); in_queue[now] = false; now->reach_in.clear(); - for (auto prev_node : now->prev) { + for (auto *prev_node : now->prev) { now->reach_in.insert(prev_node->reach_out.begin(), prev_node->reach_out.end()); } auto old_out = std::move(now->reach_out); now->reach_out = now->reach_gen; - for (auto stmt : now->reach_in) { - auto store_ptrs = irpass::analysis::get_store_destination(stmt); - bool killed; - if (store_ptrs.empty()) { // the case of a global pointer - killed = now->reach_kill_variable(stmt); - } else { - killed = true; - for (auto store_ptr : store_ptrs) { - if (!now->reach_kill_variable(store_ptr)) { - killed = false; - break; - } - } - } - if (!killed) { + for (auto *stmt : now->reach_in) { + if (!is_reach_in_stmt_killed_at(now, stmt)) { now->reach_out.insert(stmt); } } if (now->reach_out != old_out) { - // changed - for (auto next_node : now->next) { + for (auto *next_node : now->next) { if (!in_queue[next_node]) { to_visit.push(next_node); in_queue[next_node] = true; From 91190d73f51e24d6b5682833b277fca6a5a09f96 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 02:22:23 -0700 Subject: [PATCH 07/18] [CFG] Split dump_graph_to_file into per-section formatters Extract five file-local helpers (format_cfg_node_range_label, format_neighbor_indices, format_live_var_names, write_cfg_node_header, write_cfg_node_statements) that each handle one piece of the textual dump. The top-level body is now a ~20-line driver: open file, build node->index map, iterate nodes writing header and statements. No behavior change. --- quadrants/ir/control_flow_graph.cpp | 132 +++++++++++++++++----------- 1 file changed, 80 insertions(+), 52 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index d179ceeffb..a08a4f2c4e 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -949,12 +949,87 @@ CFGNode *ControlFlowGraph::back() const { return nodes.back().get(); } +namespace { + +// === Helpers for ControlFlowGraph::dump_graph_to_file === + +// Range label: "empty" or "~ (size=N)". +std::string format_cfg_node_range_label(const CFGNode *node) { + if (node->empty()) { + return "empty"; + } + return fmt::format("{}~{} (size={})", node->block->statements[node->begin_location]->name(), + node->block->statements[node->end_location - 1]->name(), node->size()); +} + +// Brace-delimited list of neighbor indices, e.g. "{0, 1, 2}". +std::string format_neighbor_indices(const std::vector &neighbors, + const std::unordered_map &to_index) { + std::vector parts; + parts.reserve(neighbors.size()); + for (auto *n : neighbors) { + parts.push_back(std::to_string(to_index.at(n))); + } + return fmt::format("{{{}}}", fmt::join(parts, ", ")); +} + +// Brace-delimited list of stmt names, e.g. "{$3, $5}". +std::string format_live_var_names(const std::unordered_set &vars) { + std::vector parts; + parts.reserve(vars.size()); + for (auto *stmt : vars) { + parts.push_back(stmt->name()); + } + return fmt::format("{{{}}}", fmt::join(parts, ", ")); +} + +// Write one node's header line: index, range label, optional prev/next/live_out lists. +void write_cfg_node_header(std::ostream &out, + int index, + const CFGNode *node, + const std::unordered_map &to_index) { + out << fmt::format("Node {} : ", index) << format_cfg_node_range_label(node); + if (!node->prev.empty()) { + out << "; prev=" << format_neighbor_indices(node->prev, to_index); + } + if (!node->next.empty()) { + out << "; next=" << format_neighbor_indices(node->next, to_index); + } + if (!node->live_out.empty()) { + out << "; live_out=" << format_live_var_names(node->live_out); + } + out << "\n"; +} + +// Write one node's statements, each line indented 4 spaces, blank line after. +void write_cfg_node_statements(std::ostream &out, const CFGNode *node) { + if (node->empty()) { + return; + } + for (int j = node->begin_location; j < node->end_location; j++) { + auto *stmt = node->block->statements[j].get(); + std::string stmt_output; + // print_kernel_wrapper=false to avoid the surrounding "kernel { }" wrapper. + irpass::print(stmt, &stmt_output, false, false); + std::istringstream iss(stmt_output); + std::string line; + while (std::getline(iss, line)) { + if (!line.empty()) { + out << " " << line << "\n"; + } + } + } + out << "\n"; +} + +} // namespace + void ControlFlowGraph::dump_graph_to_file(const CompileConfig &config, const std::string &kernel_name, const std::string &suffix) const { - std::filesystem::path ir_dump_dir = config.debug_dump_path; + const std::filesystem::path ir_dump_dir = config.debug_dump_path; std::filesystem::create_directories(ir_dump_dir); - std::filesystem::path filename = ir_dump_dir / (kernel_name + "_CFG" + suffix + ".txt"); + const std::filesystem::path filename = ir_dump_dir / (kernel_name + "_CFG" + suffix + ".txt"); std::ofstream out_file(filename.string()); if (!out_file) { @@ -962,63 +1037,16 @@ void ControlFlowGraph::dump_graph_to_file(const CompileConfig &config, return; } - // Write directly to the file using fmt::format const int num_nodes = size(); std::unordered_map to_index; + to_index.reserve(num_nodes); for (int i = 0; i < num_nodes; i++) { to_index[nodes[i].get()] = i; } for (int i = 0; i < num_nodes; i++) { - out_file << fmt::format("Node {} : ", i); - if (nodes[i]->empty()) { - out_file << "empty"; - } else { - out_file << fmt::format("{}~{} (size={})", nodes[i]->block->statements[nodes[i]->begin_location]->name(), - nodes[i]->block->statements[nodes[i]->end_location - 1]->name(), nodes[i]->size()); - } - if (!nodes[i]->prev.empty()) { - std::vector indices; - for (auto prev_node : nodes[i]->prev) { - indices.push_back(std::to_string(to_index[prev_node])); - } - out_file << fmt::format("; prev={{{}}}", fmt::join(indices, ", ")); - } - if (!nodes[i]->next.empty()) { - std::vector indices; - for (auto next_node : nodes[i]->next) { - indices.push_back(std::to_string(to_index[next_node])); - } - out_file << fmt::format("; next={{{}}}", fmt::join(indices, ", ")); - } - if (!nodes[i]->live_out.empty()) { - std::vector vars; - for (auto stmt : nodes[i]->live_out) { - vars.push_back(stmt->name()); - } - out_file << fmt::format("; live_out={{{}}}", fmt::join(vars, ", ")); - } - out_file << "\n"; - - // Print the actual statements in this node - if (!nodes[i]->empty()) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - auto stmt = nodes[i]->block->statements[j].get(); - std::string stmt_output; - // Use print_kernel_wrapper=false to avoid the "kernel { }" wrapper - irpass::print(stmt, &stmt_output, false, false); - - // Add indentation to each line - std::istringstream iss(stmt_output); - std::string line; - while (std::getline(iss, line)) { - if (!line.empty()) { - out_file << " " << line << "\n"; - } - } - } - out_file << "\n"; - } + write_cfg_node_header(out_file, i, nodes[i].get(), to_index); + write_cfg_node_statements(out_file, nodes[i].get()); } out_file.close(); From 02687308266b85a85fd5f74ca106405a38552661 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 02:23:13 -0700 Subject: [PATCH 08/18] [CFG] Split live_variable_analysis seeding into seed_final_node_live_gen Extract the kernel-wide live-out seeding of the synthetic exit node (controlled by !after_lower_access and the SFG eliminable-snodes config) into a file-local helper. The top-level body is now a ~35-line worklist algorithm: seed, per-node init, propagate. Mirrors the shape of the sibling reaching_definition_analysis. No behavior change. --- quadrants/ir/control_flow_graph.cpp | 79 ++++++++++++++++++----------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index a08a4f2c4e..23e22bbbf4 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -1173,38 +1173,56 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { } } -void ControlFlowGraph::live_variable_analysis(bool after_lower_access, - const std::optional &config_opt) { - // [live_variable_analysis] - // live_gen: address loaded with no previous stored in this node. One cannot - // load before storing so - // addrs in live_gen must come from previous nodes - // live_kill: address stored in this node - // live_in: live_gen + (live_out - live_kill) - // live_out: collection of all the live_in of next nodes - QD_AUTO_PROF; - const int num_nodes = size(); - std::queue to_visit; - std::unordered_map in_queue; - QD_ASSERT(nodes[final_node]->empty()); - nodes[final_node]->live_gen.clear(); - nodes[final_node]->live_kill.clear(); +namespace { - if (!after_lower_access) { - for (int i = 0; i < num_nodes; i++) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - auto stmt = nodes[i]->block->statements[j].get(); - for (auto store_ptr : irpass::analysis::get_store_destination(stmt, true /*get_alias*/)) { - if (in_final_node_live_gen(store_ptr, config_opt)) { - nodes[final_node]->live_gen.insert(store_ptr); - } +// === Helpers for ControlFlowGraph::live_variable_analysis === + +// Seed the synthetic exit node's live_gen with every store destination in the kernel that may +// still be observable after the kernel exits (globals not flagged eliminable by the SFG, plus any +// aliased pointers via `get_alias`). Skipped entirely after access lowering, when only locals +// remain and nothing escapes. +void seed_final_node_live_gen(CFGNode *final_node_ptr, + const std::vector> &nodes, + bool after_lower_access, + const std::optional &config_opt) { + final_node_ptr->live_gen.clear(); + final_node_ptr->live_kill.clear(); + if (after_lower_access) { + return; + } + for (const auto &node : nodes) { + for (int j = node->begin_location; j < node->end_location; j++) { + auto *stmt = node->block->statements[j].get(); + for (auto *store_ptr : irpass::analysis::get_store_destination(stmt, /*get_alias=*/true)) { + if (in_final_node_live_gen(store_ptr, config_opt)) { + final_node_ptr->live_gen.insert(store_ptr); } } } } +} + +} // namespace +void ControlFlowGraph::live_variable_analysis(bool after_lower_access, + const std::optional &config_opt) { + // Per-node: + // - live_gen: address loaded with no preceding store in this node (must live in from + // somewhere -- you can't load before storing). + // - live_kill: address stored in this node. + // Per-graph (worklist fixpoint, propagated backwards): + // - live_in: live_gen + (live_out - live_kill). + // - live_out: union of live_in of all successor nodes. + QD_AUTO_PROF; + const int num_nodes = size(); + QD_ASSERT(nodes[final_node]->empty()); + seed_final_node_live_gen(nodes[final_node].get(), nodes, after_lower_access, config_opt); + + std::queue to_visit; + std::unordered_map in_queue; + // Push in reversed order: backwards analysis converges slightly faster when the worklist seeds + // are dequeued from the back of the graph first. for (int i = num_nodes - 1; i >= 0; i--) { - // push into the queue in reversed order to make it slightly faster if (i != final_node) { nodes[i]->live_variable_analysis(after_lower_access); } @@ -1214,26 +1232,25 @@ void ControlFlowGraph::live_variable_analysis(bool after_lower_access, in_queue[nodes[i].get()] = true; } - // The worklist algorithm. + // [Worklist algorithm] Converge live_in / live_out iteratively (backwards). while (!to_visit.empty()) { - auto now = to_visit.front(); + auto *now = to_visit.front(); to_visit.pop(); in_queue[now] = false; now->live_out.clear(); - for (auto next_node : now->next) { + for (auto *next_node : now->next) { now->live_out.insert(next_node->live_in.begin(), next_node->live_in.end()); } auto old_in = std::move(now->live_in); now->live_in = now->live_gen; - for (auto stmt : now->live_out) { + for (auto *stmt : now->live_out) { if (!CFGNode::contain_variable(now->live_kill, stmt)) { now->live_in.insert(stmt); } } if (now->live_in != old_in) { - // changed - for (auto prev_node : now->prev) { + for (auto *prev_node : now->prev) { if (!in_queue[prev_node]) { to_visit.push(prev_node); in_queue[prev_node] = true; From 451b1c9b0c1f68272aaff8c69e760d0f837158a0 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 02:32:36 -0700 Subject: [PATCH 09/18] [CFG] Fix compile errors: use stmt_refs instead of std::vector The DSE and reaching-definition splits assumed std::vector with .front(), but irpass::analysis::get_store_destination / get_load_pointers return quadrants::stmt_refs (= one_or_more), which has no .front() and a non-const begin()/end(). Switch: - dse_store_destinations now returns stmt_refs (single-element cases use stmt_refs(ptr)); mark_loads_live_in_this_node takes stmt_refs &. - Replace .front() with *begin() in the DSE driver. - Drop const on local store_ptrs/load_ptrs whose range-for iteration needs non-const begin() (DSE driver, is_reach_in_stmt_killed_at). --- quadrants/ir/control_flow_graph.cpp | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index 23e22bbbf4..d4f7c6113f 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -729,21 +729,22 @@ bool is_dse_eligible_pointer(Stmt *ptr, bool after_lower_access) { } // Compute the store destinations of |stmt|, including AD-stack stmts whose store semantics -// aren't yet captured by `get_store_destination`. +// aren't yet captured by `get_store_destination`. Returns the same `stmt_refs` (= one_or_more<>) +// type the analyzer uses so iteration and size queries are uniform. // TODO: Consider AD-stacks in get_store_destination instead of here for store-to-load forwarding // on AD-stacks. -std::vector dse_store_destinations(Stmt *stmt) { +stmt_refs dse_store_destinations(Stmt *stmt) { if (auto *pop = stmt->cast()) { - return {pop->stack}; + return stmt_refs(pop->stack); } if (auto *push = stmt->cast()) { - return {push->stack}; + return stmt_refs(push->stack); } if (auto *acc = stmt->cast()) { - return {acc->stack}; + return stmt_refs(acc->stack); } if (stmt->is()) { - return {stmt}; + return stmt_refs(stmt); } return irpass::analysis::get_store_destination(stmt); } @@ -869,8 +870,9 @@ bool try_eliminate_identical_load_at(CFGNode *node, } // Mark all (eligible) loads of |stmt| as live-in-this-node so that earlier-in-reverse-order stores -// to the same address see them and abort their dead-store check. -void mark_loads_live_in_this_node(const std::vector &load_ptrs, +// to the same address see them and abort their dead-store check. Taken by non-const reference +// because `one_or_more::begin()` is non-const. +void mark_loads_live_in_this_node(stmt_refs &load_ptrs, bool after_lower_access, const DseAliasMaps &alias, DseLiveState &state) { @@ -899,16 +901,16 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { state.live_load_in_this_node.clear(); continue; } - const auto store_ptrs = dse_store_destinations(stmt); + auto store_ptrs = dse_store_destinations(stmt); if (store_ptrs.size() == 1) { - if (try_eliminate_dead_store_at(this, i, stmt, store_ptrs.front(), after_lower_access, alias, state)) { + if (try_eliminate_dead_store_at(this, i, stmt, *store_ptrs.begin(), after_lower_access, alias, state)) { modified = true; continue; } } - const auto load_ptrs = irpass::analysis::get_load_pointers(stmt); + auto load_ptrs = irpass::analysis::get_load_pointers(stmt); if (load_ptrs.size() == 1 && store_ptrs.empty()) { - if (try_eliminate_identical_load_at(this, stmt, load_ptrs.front(), after_lower_access, alias, state)) { + if (try_eliminate_identical_load_at(this, stmt, *load_ptrs.begin(), after_lower_access, alias, state)) { modified = true; } } @@ -1103,7 +1105,8 @@ void seed_start_node_reach_gen(CFGNode *start_node, // destinations, killed iff every dest is killed at |node|. For stmts without dests (e.g. raw // global pointers seeded into start_node's reach_gen), killed iff the stmt itself is killed. bool is_reach_in_stmt_killed_at(CFGNode *node, Stmt *stmt) { - const auto store_ptrs = irpass::analysis::get_store_destination(stmt); + // Not const: `one_or_more::begin()` is non-const, so the range-for below would not compile. + auto store_ptrs = irpass::analysis::get_store_destination(stmt); if (store_ptrs.empty()) { return node->reach_kill_variable(stmt); } From 45ca4a3f2b25d0e1b6cb774ac4c2e369fa156d80 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 11:45:19 -0700 Subject: [PATCH 10/18] [CFG] Add file-top orientation and glossary to control_flow_graph.cpp Open the file with a paragraph explaining the two classes (CFGNode is per-block intra-block work; ControlFlowGraph is the whole-graph driver), the five analyses/transforms implemented (RD, LV, S2L, DSE, determine_ad_stack_size), and a glossary defining reach_in/out/gen/kill, live_in/out/gen/kill, UD-chain, adaptive AD-stack, stmt_refs, and the MatrixPtrStmt aliasing rule. None of this was previously in-file; a reader had to bounce through ir.h, analysis.h, statements.h, and one_or_more.h before the first function. --- quadrants/ir/control_flow_graph.cpp | 73 +++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index d4f7c6113f..b89e7dafad 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -14,6 +14,79 @@ #include "quadrants/program/function.h" #include "quadrants/codegen/ir_dump.h" +// =========================================================================== +// Control-flow graph: analyses and transforms. +// +// This file implements two classes (declared in control_flow_graph.h): +// - CFGNode: a single basic block in the CFG. Holds a slice of an +// IR Block (begin_location .. end_location) plus prev / +// next edges and the per-node sets used by each +// analysis. Its methods do the *intra-block* work +// (compute reach_gen / reach_kill in this block, run +// store-to-load forwarding within this block, etc.). +// - ControlFlowGraph: the whole graph. Owns the vector and the +// synthetic entry / exit nodes. Its methods are +// whole-graph *drivers*: they call the per-node method +// on every node, then run the worklist fixpoint or +// post-processing. +// +// Five analyses / transforms live here. For each, the per-node method is on +// CFGNode and the whole-graph driver is the same-named method on +// ControlFlowGraph: +// 1. reaching_definition_analysis (RD): cross-block use-define chain. Output +// in each node: reach_in, reach_out (sets of Stmt*). +// Used by store-to-load forwarding. +// 2. live_variable_analysis (LV): which addresses are still loaded +// after this point? Output: live_in, live_out (sets of +// Stmt*). Used by dead-store elimination. +// 3. store_to_load_forwarding (S2L): replace each load with the value +// of the closest preceding store to the same address; +// also erase identical stores. IR-mutating. +// 4. dead_store_elimination (DSE): erase stores whose written value +// is never read, weaken atomics whose store half is +// dead, eliminate identical loads. IR-mutating. +// 5. determine_ad_stack_size: size each adaptive AD-stack (max_size==0) +// from its push/pop schedule across the CFG. Whole- +// graph only; no per-node counterpart. +// +// Glossary (these terms appear throughout; their formal definitions are +// scattered across analysis.h / ir.h / statements.h, so collected here): +// - reach_gen[node]: stmts that define some variable in `node` (store +// stmts). Plus, in the synthetic entry node, stmts +// whose value "exists" before this kernel runs +// (external pointers, FuncCall store destinations). +// - reach_kill[node]: addresses (GlobalPtrStmt / AllocaStmt / ...) that +// `node` definitely overwrites. (gen tracks the *stmts* +// that write; kill tracks the *addresses* written-to.) +// - reach_in[node]: union of reach_out of predecessors. +// - reach_out[node]: reach_gen[node] union (reach_in[node] minus killed). +// - live_gen[node]: addresses loaded in `node` with no preceding store +// in the same node (so they must live in from outside). +// - live_kill[node]: addresses stored to in `node`. +// - live_in[node]: live_gen[node] union (live_out[node] minus live_kill). +// - live_out[node]: union of live_in of successors. +// - UD-chain: "use-define chain". For a load (use), the set of +// stores (defs) whose value may flow to it. Computed +// via the RD analysis. +// - Adaptive AD-stack: an `AdStackAllocaStmt` with `max_size == 0`, meaning +// its size has not yet been fixed and must be inferred +// from the kernel's push/pop schedule. `max_size != 0` +// stacks are already sized and are skipped here. +// - stmt_refs: alias for `quadrants::one_or_more` (see +// common/one_or_more.h). A variant that holds either a +// single Stmt* or a vector; supports `.size()`, +// `.empty()`, non-const `.begin()/.end()`. Used to +// avoid heap allocation in the common "exactly one +// destination" case. +// - MatrixPtrStmt: a derived pointer of the form `&base[offset]` where +// `base` is either an `AllocaStmt` (tensor-typed local) +// or another `MatrixPtrStmt` (nested element). The +// analyses treat these as aliasing their origin: a +// store through the derived pointer partially defines +// the origin; a store to the origin kills all +// derived-pointer addresses. +// =========================================================================== + namespace quadrants::lang { namespace { From 38d7b5bf256141a107c020d4f17b695d541600ed Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 11:46:54 -0700 Subject: [PATCH 11/18] [CFG] Document per-node vs whole-graph distinction in header + .cpp Header: extend class doc comments on CFGNode and ControlFlowGraph to spell out that the two classes implement the same five passes at different scopes (per-block intra-block work on CFGNode, whole-graph driver on ControlFlowGraph). On each paired method declaration, add a one-line summary distinguishing the per-node work from the whole-graph work (RD/LV seed + worklist; S2L/DSE flat per-node loop). .cpp: add two banner section comments demarcating CFGNode method definitions from ControlFlowGraph method definitions so a reader scrolling the file can see the per-node vs whole-graph boundary without hunting for the class qualifier on each function. --- quadrants/ir/control_flow_graph.cpp | 14 ++++++ quadrants/ir/control_flow_graph.h | 72 +++++++++++++++++++++++------ 2 files changed, 71 insertions(+), 15 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index b89e7dafad..64bd6c8631 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -138,6 +138,12 @@ bool in_final_node_live_gen(const Stmt *stmt, } // namespace +// =========================================================================== +// CFGNode -- per-block (intra-block) methods. +// Each method here does the *per-node* work; the whole-graph driver with the +// same name lives on ControlFlowGraph further down this file. +// =========================================================================== + CFGNode::CFGNode(Block *block, int begin_location, int end_location, @@ -992,6 +998,14 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { return modified; } +// =========================================================================== +// ControlFlowGraph -- whole-graph methods. +// Each analysis/transform below is the driver for the same-named per-node +// method on CFGNode above. It calls the per-node method on every node, then +// runs whatever cross-node work the pass needs (worklist fixpoint for RD/LV, +// a flat per-node loop for S2L/DSE). +// =========================================================================== + void ControlFlowGraph::erase(int node_id) { // Erase an empty node. QD_ASSERT(node_id >= 0 && node_id < (int)size()); diff --git a/quadrants/ir/control_flow_graph.h b/quadrants/ir/control_flow_graph.h index bd34d07817..2470ec5d0c 100644 --- a/quadrants/ir/control_flow_graph.h +++ b/quadrants/ir/control_flow_graph.h @@ -9,12 +9,21 @@ namespace quadrants::lang { class Function; /** - * A basic block in control-flow graph. - * A CFGNode contains a reference to a part of the CHI IR, or more precisely, - * an interval of statements in a Block. - * The edges in the graph are stored in |prev| and |next|. The control flow is - * possible to go from any node in |prev| to this node, and is possible to go - * from this node to any node in |next|. + * A basic block in the control-flow graph (one node). + * + * A CFGNode references an interval of statements in a Block: + * `block->statements[i]` for `i in [begin_location, end_location)`. + * The graph edges are stored in `prev` and `next`: control may flow from any + * node in `prev` into this node, and out of this node into any node in `next`. + * + * Scope of the methods on this class: each analysis/transform method on + * CFGNode does only the *intra-block* work for that pass (e.g. + * `reaching_definition_analysis` computes `reach_gen` / `reach_kill` from this + * node's statements; `store_to_load_forwarding` rewrites loads/stores within + * this node's slice of the block). The same-named method on + * `ControlFlowGraph` is the *whole-graph driver*: it calls this per-node + * method on every node, then runs the worklist fixpoint or post-processing + * that needs cross-node state. */ class CFGNode { public: @@ -82,11 +91,23 @@ class CFGNode { bool reach_kill_variable(Stmt *var) const; Stmt *get_store_forwarding_data(Stmt *var, int position) const; - // Analyses and optimizations inside a CFGNode. + // Per-node (intra-block) analyses and transforms. Each is driven across the + // whole graph by the same-named method on ControlFlowGraph; see below. + // Per-node `reaching_definition_analysis`: populate this node's reach_gen / + // reach_kill from its statements. void reaching_definition_analysis(bool after_lower_access); + // Per-node `store_to_load_forwarding`: rewrite loads/stores within this + // node's statement range. Returns true if any IR change was made. bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); + // Per-node helper for `ControlFlowGraph::gather_loaded_snodes`: append this + // node's loaded SNodes into `snodes`. void gather_loaded_snodes(std::unordered_set &snodes) const; + // Per-node `live_variable_analysis`: populate this node's live_gen / + // live_kill from its statements. void live_variable_analysis(bool after_lower_access); + // Per-node `dead_store_elimination`: erase dead stores / weaken dead + // atomics within this node's statement range. Returns true if any IR change + // was made. bool dead_store_elimination(bool after_lower_access); private: @@ -150,6 +171,16 @@ class CFGNode { bool &modified); }; +/** + * The whole control-flow graph (all nodes plus their edges). + * + * Owns the `nodes` vector, the synthetic entry node (`start_node`, always + * empty), and the synthetic exit node (`final_node`, always empty). Each + * analysis/transform method on this class is the *whole-graph driver*: it + * invokes the same-named per-node method on `CFGNode` for every node, then + * runs the worklist fixpoint (RD / LV) or per-node IR rewrite loop (S2L / + * DSE). `determine_ad_stack_size` has no per-node counterpart. + */ class ControlFlowGraph { private: // Erase an empty node. @@ -182,22 +213,26 @@ class ControlFlowGraph { const std::string &suffix = "") const; /** - * Perform reaching definition analysis using the worklist algorithm, - * and store the results in CFGNodes. + * Whole-graph driver: reaching-definition analysis via worklist fixpoint. + * Seeds the entry node's reach_gen with external-input pointers, calls + * `CFGNode::reaching_definition_analysis` per node, then converges + * reach_in/reach_out. Results are stored on each CFGNode. * https://en.wikipedia.org/wiki/Reaching_definition * * @param after_lower_access - * When after_lower_access is true, only consider local variables (allocas). + * When true, only consider local variables (allocas). */ void reaching_definition_analysis(bool after_lower_access); /** - * Perform live variable analysis using the worklist algorithm, - * and store the results in CFGNodes. + * Whole-graph driver: live-variable analysis via worklist fixpoint (run + * backwards). Seeds the exit node's live_gen with kernel-escaping stores, + * calls `CFGNode::live_variable_analysis` per node, then converges + * live_in/live_out. Results are stored on each CFGNode. * https://en.wikipedia.org/wiki/Live_variable_analysis * * @param after_lower_access - * When after_lower_access is true, only consider local variables (allocas). + * When true, only consider local variables (allocas). * @param config_opt * The set of SNodes which is never loaded after this task. */ @@ -213,12 +248,19 @@ class ControlFlowGraph { bool unreachable_code_elimination(); /** - * Perform store-to-load forwarding and identical store elimination. + * Whole-graph driver: store-to-load forwarding + identical-store elimination. + * Calls `CFGNode::store_to_load_forwarding` on every node and ORs the + * per-node "did we change the IR" return. Caller is responsible for + * (re)running `reaching_definition_analysis` first to populate reach_in / + * reach_out. */ bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); /** - * Perform dead store elimination and identical load elimination. + * Whole-graph driver: dead-store elimination + identical-load elimination. + * Calls `CFGNode::dead_store_elimination` on every node and ORs the + * per-node return. Caller is responsible for (re)running + * `live_variable_analysis` first to populate live_in / live_out. */ bool dead_store_elimination(bool after_lower_access, const std::optional &lva_config_opt); From 1a630dd7421332fe76ee0cc929e6f0ff21015b42 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 11:47:44 -0700 Subject: [PATCH 12/18] [CFG] Rewrite determine_ad_stack_size docstring with plain-English preamble Lead with a "What problem this solves" paragraph in user-facing terms (reverse-mode autodiff records primal writes on per-variable stacks; we need to size adaptive stacks by walking pushes-minus-pops on every CFG path), then a "High-level algorithm" section that introduces SCC condensation and fingerprint-deduplicated DP in narrative form, explaining the two sign-based fast paths in one bullet each. The dense implementation/perf notes follow as a "Performance summary" instead of being the first thing a reader hits. No behavior change. --- quadrants/ir/control_flow_graph.cpp | 77 ++++++++++++++++++----------- 1 file changed, 49 insertions(+), 28 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index 64bd6c8631..208e92109d 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -2006,35 +2006,56 @@ void apply_ad_stack_dp_results(const std::vector &stacks, void ControlFlowGraph::determine_ad_stack_size() { /** - * Determine the necessary size of every adaptive AD-stack on the control-flow graph (CFG). For each AD-stack we - * compute the maximum running net push count along any walk from the kernel entry. AD-stacks whose forward kernel - * contains a positive cycle (pushes > pops around a loop) are left at `max_size = 0`, and the caller routes them - * through the structural bounded-loop pre-pass for a symbolic `SizeExpr`, hard-erroring if the grammar still - * cannot resolve them. There is no compile-time size fallback. + * What problem this solves + * ------------------------ + * Reverse-mode autodiff records each primal write to a variable by pushing the value onto a + * per-variable stack (an "AD-stack"), then pops in reverse order during the backward pass. The + * stack's *capacity* must be known at allocation time. Some stacks are sized explicitly by the + * user (`max_size != 0`); the rest are "adaptive" (`max_size == 0`) and we have to figure out + * how many pushes can stack up on them at runtime by walking the CFG and bookkeeping pushes + * minus pops along every reachable path from kernel entry. * - * Implementation notes for compile-time perf on large reverse-mode kernels: - * 1. Per-stack per-node pre-aggregates (`max_increased_size`, `increased_size`) are stored in dense - * `vector>` indexed by a contiguous int stack id, instead of an - * `unordered_map>` -- this removes hash traffic from the hot inner loop. - * 2. Stacks whose `(increased_size, max_increased_size)` row pair is bit-identical share a single dynamic - * programming run -- typical kernels generate one alloca per autodiff variable in the same loop body, so - * most rows collapse to a few representatives. - * 3. The CFG is condensed via Tarjan into strongly connected components (SCCs). DFS finish times recorded - * during the same Tarjan pass split each cyclic SCC's intra-edges into a forward set (target finishes - * before source) and a back set (target finishes at or after source). Per representative we run a - * single-pass dynamic-programming (DP) sweep over the forward edges in descending finish-time order, then - * check the back edges once for positive-cycle relaxation. Correctness: any walk inside an SCC decomposes - * into a forward path plus zero or more cycles, and an SCC with no positive cycle has the same max-walk-sum - * as the back-edge-removed DAG; a positive cycle is exactly the case where some back-edge would still relax - * after the forward DP. This drops the per-cyclic-SCC cost from O(|S| * |E_S|) to O(|S| + |E_S|). - * 4. Two sign-based fast paths short-circuit the DP for trivial cyclic SCCs: an SCC with `min_is >= 0 && max_is - * > 0` for this stack must contain a positive cycle (every node lies on some cycle, and a cycle through a - * strictly-positive node with all non-negative `is` along it sums positive); an SCC with `min_is == 0 == - * max_is` has no `is` contribution at all and is handled by spreading the max entry-side - * `max_size_at_node_begin` to every node in O(|S|). - * Per-rep cost becomes O(V + E + sum_{cyclic S} |S| * |E_S|) (with the SCC sum dropping to O(|S| + |E_S|) for - * the common autodiff push/pop pattern); overall cost is O(V + E + R * (V + E)) with R the number of distinct - * row-pair representatives. + * For each adaptive AD-stack we want: the maximum value of (pushes - pops) over any walk from + * the entry node. Set that as `max_size`. If the kernel has a loop where pushes > pops on every + * iteration the maximum is unbounded; we leave `max_size = 0` and let the caller's structural + * bounded-loop pre-pass derive a symbolic `SizeExpr` from the loop ranges (and hard-error if + * even that fails -- there is no compile-time default fallback). + * + * High-level algorithm + * -------------------- + * 1. Index every adaptive AD-stack (`collect_adaptive_ad_stacks`). + * 2. For every (stack, CFG node) pair, precompute the net push/pop delta inside that node and + * the max running push-count prefix inside that node + * (`accumulate_per_stack_per_node_size_deltas`). After this point the actual stmts are + * irrelevant; everything below operates on int matrices. + * 3. Condense the CFG with Tarjan's SCC algorithm (`tarjan_scc`). Inside any cyclic SCC, the + * "max walk-sum from entry" question becomes "is there a positive cycle, and if not, what's + * the DAG longest path inside this SCC?". The condensation lets us solve those questions + * once per SCC, in topological order across SCCs. + * 4. For each *distinct* push/pop fingerprint across stacks (`group_stacks_by_fingerprint`), + * run the DP once: walk SCCs in topological order, per cyclic SCC try two cheap sign-based + * fast paths and fall back to a single-pass DAG DP + back-edge relaxation check for cycle + * detection. Broadcast the result to every stack that shares this fingerprint + * (`apply_ad_stack_dp_results`). In practice many autodiff kernels generate identical + * push/pop schedules across stacks (one alloca per AD variable in the same loop body), so + * the fingerprint dedup collapses N stacks to a handful of DP runs. + * + * The two cheap fast paths for a cyclic SCC: + * (a) min_is >= 0 and max_is > 0 => positive cycle exists for this stack (every node in an + * SCC is on a cycle; a cycle through a strictly-positive node with non-negative weights + * sums positive). Bail out. + * (b) min_is == 0 == max_is => the SCC adds nothing; just spread the maximum entry- + * side `max_size_at_node_begin` to every node in the SCC in O(|S|). + * The full mixed-sign DP runs only when neither shortcut applies. + * + * Performance summary + * ------------------- + * Per representative cost: O(V + E + sum_{cyclic SCC S} |S| * |E_S|), with the inner sum + * collapsing to O(|S| + |E_S|) for the common autodiff push/pop pattern (one DP sweep + one + * back-edge check). Overall: O(V + E + R * (V + E)), with R = number of distinct row-pair + * fingerprints (typically << number of stacks). Per-stack per-node tables are dense + * `vector>` indexed by contiguous int stack id, not `unordered_map`, to keep the + * hot inner loop branch-and-cache friendly. */ AdStackIndex idx = collect_adaptive_ad_stacks(nodes); const int num_stacks = static_cast(idx.stacks.size()); From 8baef4ca780599fc4f32fefb06709c67d1df1a80 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 11:49:20 -0700 Subject: [PATCH 13/18] [CFG] Rename three under-telling identifiers Three names that did not match what the function/method does, renamed for self-documentation: - get_store_forwarding_data -> find_forwardable_store_value Returns the forwarded *value* (Stmt *), not "data". "find" reflects that it's a UD-chain search that can fail (returns nullptr). - update_forwarding_result -> fold_definition_into_result Folds one UD-chain definition into the running (result, result_visible) state; "fold" is the standard reduce-step term. - reach_kill_variable -> is_reach_killed Question-form, signalling it is a query ("does this node kill the definition reaching it?") rather than a mutator. Aligns with the reach_kill set name. All call sites are within control_flow_graph.{h,cpp}. --- quadrants/ir/control_flow_graph.cpp | 32 ++++++++++++++--------------- quadrants/ir/control_flow_graph.h | 26 +++++++++++------------ 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index 208e92109d..7edbbe3001 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -159,7 +159,7 @@ CFGNode::CFGNode(Block *block, prev_node_in_same_block->next_node_in_same_block = this; if (!empty()) { // For non-empty nodes, precompute |parent_blocks| to accelerate - // get_store_forwarding_data(). + // find_forwardable_store_value(). QD_ASSERT(begin_location >= 0); QD_ASSERT(block); auto parent_block = block; @@ -280,7 +280,7 @@ bool CFGNode::may_contain_variable(const std::unordered_set &var_set, St } } -bool CFGNode::reach_kill_variable(Stmt *var) const { +bool CFGNode::is_reach_killed(Stmt *var) const { // Does this node (definitely) kill a definition of var? return contain_variable(reach_kill, var); } @@ -296,10 +296,10 @@ bool CFGNode::is_visible_at(Stmt *stmt, int position) const { return parent_blocks_.find(stmt->parent) != parent_blocks_.end(); } -bool CFGNode::update_forwarding_result(Stmt *stmt, - int position, - Stmt *&result, - bool &result_visible) const { +bool CFGNode::fold_definition_into_result(Stmt *stmt, + int position, + Stmt *&result, + bool &result_visible) const { // |stmt| is a definition in the UD-chain of the variable being forwarded. // Fold its stored data into |result| / |result_visible|. Return false if // forwarding must abort (the caller should propagate nullptr); true to @@ -380,7 +380,7 @@ std::optional CFGNode::find_cross_block_def(Stmt *var, // nodes[start_node]->reach_gen. for (auto *stmt : reach_in) { if (var == stmt || may_contain_address(stmt, var)) { - if (!update_forwarding_result(stmt, position, result, result_visible)) { + if (!fold_definition_into_result(stmt, position, result, result_visible)) { return std::nullopt; } last_def_position = 0; @@ -389,7 +389,7 @@ std::optional CFGNode::find_cross_block_def(Stmt *var, // Stores generated within this node (in reach_gen) that precede |position|. for (auto *stmt : reach_gen) { if (may_contain_address(stmt, var) && stmt->parent->locate(stmt) < position) { - if (!update_forwarding_result(stmt, position, result, result_visible)) { + if (!fold_definition_into_result(stmt, position, result, result_visible)) { return std::nullopt; } last_def_position = stmt->parent->locate(stmt); @@ -420,7 +420,7 @@ bool CFGNode::any_aliased_store_breaks_forwarding(Stmt *result, Stmt *var, int f // cross-block forwarding when present), then falls back to the cross-block search over reach_in // and reach_gen. In both cases an intervening aliased store that may write a different value // breaks the forward. -Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { +Stmt *CFGNode::find_forwardable_store_value(Stmt *var, int position) const { // [Intra-block search] Walks backwards in this node's block. if (int last_def = find_intra_block_last_def(var, position); last_def != -1) { Stmt *result = irpass::analysis::get_store_data(block->statements[last_def].get()); @@ -475,7 +475,7 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { // After lower_access, we only analyze local variables. continue; } - if (!reach_kill_variable(data_source_ptr)) { + if (!is_reach_killed(data_source_ptr)) { reach_gen.insert(stmt); reach_kill.insert(data_source_ptr); } @@ -490,11 +490,11 @@ bool CFGNode::try_forward_load_at(int &i, Stmt *stmt, bool after_lower_access, b Stmt *result = nullptr; if (auto *local_load = stmt->cast()) { load_src = local_load->src; - result = get_store_forwarding_data(load_src, i); + result = find_forwardable_store_value(load_src, i); } else if (auto *global_load = stmt->cast()) { if (!after_lower_access && !autodiff_enabled) { load_src = global_load->src; - result = get_store_forwarding_data(load_src, i); + result = find_forwardable_store_value(load_src, i); } } if (!result) { @@ -535,7 +535,7 @@ void CFGNode::try_eliminate_identical_store_at(int &i, // same address. For local stores under non-autodiff there's also an alloca-zero special case: // writing a zero to a freshly-allocated alloca is redundant. if (auto *local_store = stmt->cast()) { - Stmt *result = get_store_forwarding_data(local_store->dest, i); + Stmt *result = find_forwardable_store_value(local_store->dest, i); if (result && result->is() && !autodiff_enabled) { // TensorType does not apply to this special case. if (result->ret_type.ptr_removed()->is()) { @@ -561,7 +561,7 @@ void CFGNode::try_eliminate_identical_store_at(int &i, if (after_lower_access) { return; } - Stmt *result = get_store_forwarding_data(global_store->dest, i); + Stmt *result = find_forwardable_store_value(global_store->dest, i); if (irpass::analysis::same_value(result, global_store->val)) { erase(i); i--; @@ -1195,10 +1195,10 @@ bool is_reach_in_stmt_killed_at(CFGNode *node, Stmt *stmt) { // Not const: `one_or_more::begin()` is non-const, so the range-for below would not compile. auto store_ptrs = irpass::analysis::get_store_destination(stmt); if (store_ptrs.empty()) { - return node->reach_kill_variable(stmt); + return node->is_reach_killed(stmt); } for (auto *store_ptr : store_ptrs) { - if (!node->reach_kill_variable(store_ptr)) { + if (!node->is_reach_killed(store_ptr)) { return false; } } diff --git a/quadrants/ir/control_flow_graph.h b/quadrants/ir/control_flow_graph.h index 2470ec5d0c..fbe5bf5381 100644 --- a/quadrants/ir/control_flow_graph.h +++ b/quadrants/ir/control_flow_graph.h @@ -37,7 +37,7 @@ class CFGNode { }; private: - // For accelerating get_store_forwarding_data() + // For accelerating find_forwardable_store_value() std::unordered_set parent_blocks_; public: @@ -88,8 +88,8 @@ class CFGNode { static bool contain_variable(const std::unordered_map &var_set, Stmt *var); static bool may_contain_variable(const std::unordered_set &var_set, Stmt *var); static bool may_contain_variable(const std::unordered_map &var_set, Stmt *var); - bool reach_kill_variable(Stmt *var) const; - Stmt *get_store_forwarding_data(Stmt *var, int position) const; + bool is_reach_killed(Stmt *var) const; + Stmt *find_forwardable_store_value(Stmt *var, int position) const; // Per-node (intra-block) analyses and transforms. Each is driven across the // whole graph by the same-named method on ControlFlowGraph; see below. @@ -111,30 +111,30 @@ class CFGNode { bool dead_store_elimination(bool after_lower_access); private: - // Helper for get_store_forwarding_data: is |stmt| visible at |position| + // Helper for find_forwardable_store_value: is |stmt| visible at |position| // inside this node's block? A stmt is visible if it lives in the same block // and precedes |position|, or if its parent block is an ancestor of // |this->block|. bool is_visible_at(Stmt *stmt, int position) const; - // Helper for get_store_forwarding_data: incorporate |stmt|, a definition in + // Helper for find_forwardable_store_value: incorporate |stmt|, a definition in // the UD-chain of the variable being forwarded, into the running |result| / // |result_visible| state. Returns false to signal that forwarding must // abort (the caller should return nullptr), true to continue scanning. - bool update_forwarding_result(Stmt *stmt, - int position, - Stmt *&result, - bool &result_visible) const; + bool fold_definition_into_result(Stmt *stmt, + int position, + Stmt *&result, + bool &result_visible) const; - // Helper for get_store_forwarding_data: walk this node's block backwards + // Helper for find_forwardable_store_value: walk this node's block backwards // from |position| and return the index of the most recent store to |var|, // or -1 if none is in this block. Handles the quant-store exclusion plus the // MatrixInitStmt-via-MatrixPtrStmt forwarding special case. int find_intra_block_last_def(Stmt *var, int position) const; - // Helper for get_store_forwarding_data: scan |reach_in| and |reach_gen| for + // Helper for find_forwardable_store_value: scan |reach_in| and |reach_gen| for // definitions of |var| reaching |position|, folding each into |result| / - // |result_visible| via update_forwarding_result. Returns nullopt if any + // |result_visible| via fold_definition_into_result. Returns nullopt if any // visited def is unforwardable (caller must return nullptr); otherwise the // last_def_position (0 if only reach_in matched, an in-block index if // reach_gen matched, -1 if no eligible def was found). @@ -143,7 +143,7 @@ class CFGNode { Stmt *&result, bool &result_visible) const; - // Helper for get_store_forwarding_data: scan block statements in + // Helper for find_forwardable_store_value: scan block statements in // [from, to_exclusive) for a store that may write a different value to an // address aliasing |var|. Returns true iff such a store exists (so // forwarding |result| must abort). The check is skipped (returns false) for From 377cd3509ecea011281789bb12f1c9e2ba1310b2 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 11:54:30 -0700 Subject: [PATCH 14/18] [CFG] Split determine_ad_stack_size into its own .cpp determine_ad_stack_size and its ~20 file-local helpers are a self-contained ~500-line algorithm (Tarjan SCC condensation + fingerprint-deduplicated DP) that has no per-node CFGNode counterpart and is not used by any other pass in control_flow_graph.cpp. Pulling it into a sibling file drops control_flow_graph.cpp from ~2100 to ~1480 lines and lets a first-time reader see at a glance that the AD-stack sizer is a distinct concern. The build picks the file up automatically through the existing `file(GLOB ... "quadrants/ir/*")` in cmake/QuadrantsCore.cmake; no CMakeLists change required. The implementation is a verbatim move (including the rewritten "What problem this solves" docstring); access to `start_node` still works because the method is still defined on ControlFlowGraph. --- quadrants/ir/control_flow_graph.cpp | 609 ---------------------- quadrants/ir/determine_ad_stack_size.cpp | 637 +++++++++++++++++++++++ 2 files changed, 637 insertions(+), 609 deletions(-) create mode 100644 quadrants/ir/determine_ad_stack_size.cpp diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index 7edbbe3001..ae5c785d37 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -1479,613 +1479,4 @@ std::unordered_set ControlFlowGraph::gather_loaded_snodes() { return snodes; } -namespace { - -// === Helpers for ControlFlowGraph::determine_ad_stack_size === - -struct AdStackIndex { - // AdStackAllocaStmt* -> contiguous int id, populated only for adaptive stacks - // (max_size == 0) that have not yet been resolved by an earlier pass. - std::unordered_map stack_id; - std::vector stacks; -}; - -AdStackIndex collect_adaptive_ad_stacks(const std::vector> &nodes) { - AdStackIndex idx; - for (const auto &node : nodes) { - for (int j = node->begin_location; j < node->end_location; j++) { - Stmt *stmt = node->block->statements[j].get(); - auto *stack = stmt->cast(); - if (!stack || stack->max_size != 0) { - continue; - } - if (idx.stack_id.emplace(stack, static_cast(idx.stacks.size())).second) { - idx.stacks.push_back(stack); - } - } - } - return idx; -} - -struct AdStackPerNodeSizes { - // [stack_id][node_id]. `max_increased_size[s][j]` is the maximum (pushes - pops) of stack |s| - // among all prefixes of CFGNode |j|; `increased_size[s][j]` is the net (pushes - pops) in the - // whole node. Indexed by contiguous stack id so the per-stack DP reads cheap vectors instead of - // hashing. - std::vector> increased_size; - std::vector> max_increased_size; - // True iff the stack actually appears in any push/pop in the CFG. Inactive stacks would settle - // at `max_size = 0` regardless and are short-circuited below to reproduce the original "Unused - // autodiff stack" warning. - std::vector stack_active; -}; - -AdStackPerNodeSizes accumulate_per_stack_per_node_size_deltas( - const std::vector> &nodes, - const std::unordered_map &stack_id, - int num_stacks) { - const int num_nodes = static_cast(nodes.size()); - AdStackPerNodeSizes out; - out.increased_size.assign(num_stacks, std::vector(num_nodes, 0)); - out.max_increased_size.assign(num_stacks, std::vector(num_nodes, 0)); - out.stack_active.assign(num_stacks, false); - for (int i = 0; i < num_nodes; i++) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - Stmt *stmt = nodes[i]->block->statements[j].get(); - AdStackAllocaStmt *stack = nullptr; - int delta = 0; - if (auto *push = stmt->cast()) { - stack = push->stack->as(); - delta = +1; - } else if (auto *pop = stmt->cast()) { - stack = pop->stack->as(); - delta = -1; - } else { - continue; - } - if (stack->max_size != 0 /*non-adaptive*/) { - continue; - } - auto it = stack_id.find(stack); - QD_ASSERT(it != stack_id.end()); - const int sid = it->second; - out.stack_active[sid] = true; - int &cur = out.increased_size[sid][i]; - cur += delta; - if (cur > out.max_increased_size[sid][i]) { - out.max_increased_size[sid][i] = cur; - } - } - } - return out; -} - -// Precompute outgoing-edge node ids once per node so the per-stack DP walks an int vector instead -// of hashing `node_ids[next_node]` on every traversal. -std::vector> compute_outgoing_node_ids(const std::vector> &nodes) { - const int num_nodes = static_cast(nodes.size()); - std::unordered_map node_ids; - node_ids.reserve(num_nodes); - for (int i = 0; i < num_nodes; i++) { - node_ids[nodes[i].get()] = i; - } - std::vector> next_ids(num_nodes); - for (int i = 0; i < num_nodes; i++) { - auto &dst = next_ids[i]; - dst.reserve(nodes[i]->next.size()); - for (auto *next_node : nodes[i]->next) { - dst.push_back(node_ids[next_node]); - } - } - return next_ids; -} - -struct TarjanResult { - std::vector scc_id; // node -> SCC index - std::vector dfs_finish; // node -> DFS post-order index; ancestors finish AFTER descendants - std::vector> scc_nodes; // SCC -> list of node ids -}; - -// Iterative Tarjan SCC. Emits SCCs in reverse topological order (sources end up at the largest -// indices). Also records DFS post-order finish times, used downstream to split each cyclic SCC's -// intra-edges into forward vs back without a second edge-classification pass. -TarjanResult tarjan_scc(const std::vector> &next_ids) { - const int num_nodes = static_cast(next_ids.size()); - TarjanResult out; - out.scc_id.assign(num_nodes, -1); - out.dfs_finish.assign(num_nodes, -1); - - std::vector tarjan_index(num_nodes, -1); - std::vector tarjan_lowlink(num_nodes, 0); - std::vector on_stack(num_nodes, 0); - std::vector tarjan_stack; - tarjan_stack.reserve(num_nodes); - std::vector> dfs_stack; - dfs_stack.reserve(num_nodes); - int next_index = 0; - int next_finish = 0; - int next_scc = 0; - for (int origin = 0; origin < num_nodes; origin++) { - if (tarjan_index[origin] != -1) { - continue; - } - tarjan_index[origin] = next_index; - tarjan_lowlink[origin] = next_index; - next_index++; - tarjan_stack.push_back(origin); - on_stack[origin] = 1; - dfs_stack.emplace_back(origin, 0); - while (!dfs_stack.empty()) { - auto &frame = dfs_stack.back(); - const int u = frame.first; - const auto &nb = next_ids[u]; - if (frame.second < nb.size()) { - const int v = nb[frame.second++]; - if (tarjan_index[v] == -1) { - tarjan_index[v] = next_index; - tarjan_lowlink[v] = next_index; - next_index++; - tarjan_stack.push_back(v); - on_stack[v] = 1; - dfs_stack.emplace_back(v, 0); - } else if (on_stack[v]) { - if (tarjan_index[v] < tarjan_lowlink[u]) { - tarjan_lowlink[u] = tarjan_index[v]; - } - } - } else { - if (tarjan_lowlink[u] == tarjan_index[u]) { - std::vector component; - while (true) { - const int w = tarjan_stack.back(); - tarjan_stack.pop_back(); - on_stack[w] = 0; - out.scc_id[w] = next_scc; - component.push_back(w); - if (w == u) { - break; - } - } - out.scc_nodes.push_back(std::move(component)); - next_scc++; - } - out.dfs_finish[u] = next_finish++; - dfs_stack.pop_back(); - if (!dfs_stack.empty()) { - const int parent = dfs_stack.back().first; - if (tarjan_lowlink[u] < tarjan_lowlink[parent]) { - tarjan_lowlink[parent] = tarjan_lowlink[u]; - } - } - } - } - } - return out; -} - -struct SccEdgeSets { - // Each intra-SCC edge is either "forward" in the DFS spanning sense (target finishes before - // source, so source can be relaxed before target with a single topological pass) or "back" - // (target finishes at or after source, closing a cycle). The forward set drives the per-stack - // DAG dynamic programming; back-edges only need a single post-DP relaxation check to detect - // positive cycles. Inter-SCC edges always relax forward in topological order. - std::vector> next_ids_intra_fwd; - std::vector> next_ids_intra_back; - std::vector> next_ids_inter; -}; - -SccEdgeSets classify_scc_edges(const std::vector> &next_ids, - const std::vector &scc_id, - const std::vector &dfs_finish) { - const int num_nodes = static_cast(next_ids.size()); - SccEdgeSets out; - out.next_ids_intra_fwd.assign(num_nodes, {}); - out.next_ids_intra_back.assign(num_nodes, {}); - out.next_ids_inter.assign(num_nodes, {}); - for (int u = 0; u < num_nodes; u++) { - const int su = scc_id[u]; - for (int v : next_ids[u]) { - if (scc_id[v] == su) { - if (dfs_finish[v] < dfs_finish[u]) { - out.next_ids_intra_fwd[u].push_back(v); - } else { - out.next_ids_intra_back[u].push_back(v); - } - } else { - out.next_ids_inter[u].push_back(v); - } - } - } - return out; -} - -struct CyclicSccInfo { - std::vector scc_is_cyclic; // 1 iff SCC contains a cycle - std::vector> scc_topo; // for cyclic SCCs: nodes sorted by descending dfs_finish -}; - -// An SCC is cyclic iff |S| > 1 (any two nodes in a non-trivial SCC lie on a cycle) or |S| == 1 -// with a self-loop edge. For cyclic SCCs we precompute the topological ordering of their nodes -// so the per-stack DAG DP visits each node once with all forward predecessors already finalized. -CyclicSccInfo identify_cyclic_sccs_and_topo(const std::vector> &scc_nodes, - const std::vector> &next_ids_intra_back, - const std::vector &dfs_finish) { - const int num_sccs = static_cast(scc_nodes.size()); - CyclicSccInfo out; - out.scc_is_cyclic.assign(num_sccs, 0); - out.scc_topo.assign(num_sccs, {}); - for (int s = 0; s < num_sccs; s++) { - const auto &nodes_in_s = scc_nodes[s]; - if (nodes_in_s.size() > 1) { - out.scc_is_cyclic[s] = 1; - } else { - // Self-loops are classified as back-edges above, so a singleton SCC is cyclic iff it has a - // back-edge pointing at itself. - const int n = nodes_in_s[0]; - for (int v : next_ids_intra_back[n]) { - if (v == n) { - out.scc_is_cyclic[s] = 1; - break; - } - } - } - if (out.scc_is_cyclic[s]) { - struct DfsFinishGreater { - const std::vector &dfs_finish; - bool operator()(int a, int b) const { - return dfs_finish[a] > dfs_finish[b]; - } - }; - auto topo = nodes_in_s; - std::sort(topo.begin(), topo.end(), DfsFinishGreater{dfs_finish}); - out.scc_topo[s] = std::move(topo); - } - } - return out; -} - -struct FingerprintGroups { - std::vector stack_to_rep; // stack id -> representative stack id (whose DP run it shares) - std::vector rep_stack_ids; // representative stack ids that actually run the DP -}; - -// Group AD-stacks whose per-node (increased_size, max_increased_size) rows are bit-identical so -// the DP runs once per equivalence class instead of once per stack. In practice many AD-stacks in -// the same kernel share their push/pop schedule (one alloca per autodiff variable in the same -// loop body), and the DP on a large CFG dwarfs the dedup cost. We hash a sparse fingerprint (only -// nodes where the stack has activity) and group by it. Worst case (no duplicates): one DP run per -// stack, equivalent to running it directly per stack. -FingerprintGroups group_stacks_by_fingerprint(int num_nodes, - int num_stacks, - const std::vector &stack_active, - const std::vector> &increased_size, - const std::vector> &max_increased_size) { - using Fingerprint = std::vector>; // (node_id, is, mis), sorted by node_id - struct FingerprintHash { - std::size_t operator()(const Fingerprint &f) const noexcept { - // FNV-1a mix of three components per fingerprint entry. - constexpr std::size_t fnv_prime = 1099511628211ULL; - std::size_t h = 1469598103934665603ULL; - for (auto &[n, i, m] : f) { - h ^= static_cast(n); - h *= fnv_prime; - h ^= static_cast(static_cast(i)); - h *= fnv_prime; - h ^= static_cast(static_cast(m)); - h *= fnv_prime; - } - return h; - } - }; - std::unordered_map fp_to_rep; - FingerprintGroups out; - out.stack_to_rep.assign(num_stacks, -1); - for (int sid = 0; sid < num_stacks; sid++) { - if (!stack_active[sid]) { - continue; - } - Fingerprint fp; - const auto &is_row = increased_size[sid]; - const auto &mis_row = max_increased_size[sid]; - for (int n = 0; n < num_nodes; n++) { - if (is_row[n] != 0 || mis_row[n] != 0) { - fp.emplace_back(n, is_row[n], mis_row[n]); - } - } - auto [it, inserted] = fp_to_rep.emplace(std::move(fp), sid); - out.stack_to_rep[sid] = it->second; - if (inserted) { - out.rep_stack_ids.push_back(sid); - } - } - return out; -} - -enum class CyclicSccFastPath { - kPositiveCycle, // proven positive cycle exists for this stack inside this SCC - kZeroSpread, // no `is` contribution in the SCC; just spread max entry-side begin value - kFallback, // mixed-sign, run the full DP -}; - -// Sign-based fast paths sidestep the cyclic-SCC dynamic programming when the stack's `is` -// contribution inside this SCC is structurally trivial: -// 1. min_is >= 0 with max_is > 0: every node in the SCC lies on some cycle (SCC property), and -// a cycle through a strictly-positive node with all non-negative `is` along it sums to a -// positive value, so a positive cycle exists for this stack. This is the autodiff -// push-only-in-SCC pattern. -// 2. min_is == max_is == 0: no `is` contribution at all in this SCC; the DP would only spread -// the maximum entry-side `max_size_at_node_begin` value to every node in O(|S|). -CyclicSccFastPath classify_cyclic_scc_fast_path(const std::vector &nodes_in_s, - const std::vector &is_for_stack) { - int min_is = INT_MAX; - int max_is = INT_MIN; - for (int u : nodes_in_s) { - const int v = is_for_stack[u]; - if (v < min_is) { - min_is = v; - } - if (v > max_is) { - max_is = v; - } - } - if (min_is >= 0 && max_is > 0) { - return CyclicSccFastPath::kPositiveCycle; - } - if (min_is == 0 && max_is == 0) { - return CyclicSccFastPath::kZeroSpread; - } - return CyclicSccFastPath::kFallback; -} - -// Spread the maximum entry-side `max_size_at_node_begin` value across every node in the SCC. -// Equivalent to running the DP with all-zero `is` weights, in O(|S|). -void spread_max_begin_over_zero_scc(const std::vector &nodes_in_s, - std::vector &max_size_at_node_begin) { - int max_begin = -1; - for (int u : nodes_in_s) { - if (max_size_at_node_begin[u] > max_begin) { - max_begin = max_size_at_node_begin[u]; - } - } - if (max_begin < 0) { - return; - } - for (int u : nodes_in_s) { - if (max_size_at_node_begin[u] < max_begin) { - max_size_at_node_begin[u] = max_begin; - } - } -} - -// Mixed-sign case: single-pass dynamic programming on the SCC's forward edges (processed in -// descending DFS finish-time so every forward predecessor is finalized before its successor -// relaxes), followed by one relaxation check on the back-edges. Correctness: every walk inside -// the SCC decomposes into a forward path plus zero or more closed cycles (back-edge + forward -// path). For SCCs with no positive cycle, traversing a cycle adds a non-positive amount to the -// running size and so cannot improve max_size beyond what the forward DP already computed. The -// back-edge relaxation check after the DP detects the only failure mode (some back-edge would -// still improve a forward predecessor's value, which can only happen if a positive cycle exists -// for this stack). Returns true iff a positive cycle was detected. -bool dp_mixed_sign_cyclic_scc(const std::vector &topo, - const std::vector &nodes_in_s, - const std::vector &is_for_stack, - const std::vector> &next_ids_intra_fwd, - const std::vector> &next_ids_intra_back, - std::vector &max_size_at_node_begin) { - for (int u : topo) { - const int begin = max_size_at_node_begin[u]; - if (begin < 0) { - continue; - } - const int exit_val = begin + is_for_stack[u]; - for (int v : next_ids_intra_fwd[u]) { - if (exit_val > max_size_at_node_begin[v]) { - max_size_at_node_begin[v] = exit_val; - } - } - } - for (int u : nodes_in_s) { - const int begin = max_size_at_node_begin[u]; - if (begin < 0) { - continue; - } - const int exit_val = begin + is_for_stack[u]; - for (int v : next_ids_intra_back[u]) { - if (exit_val > max_size_at_node_begin[v]) { - return true; - } - } - } - return false; -} - -// SCC has converged for this stack. Update global `max_size` from each node's max-prefix -// contribution, then relax inter-SCC outgoing edges into successor SCCs (predecessors are already -// finalized when an SCC is entered, so each inter-SCC edge is touched exactly once). -void update_global_max_and_relax_inter_scc(const std::vector &nodes_in_s, - const std::vector &is_for_stack, - const std::vector &mis_for_stack, - const std::vector> &next_ids_inter, - std::vector &max_size_at_node_begin, - int &max_size) { - for (int u : nodes_in_s) { - const int begin = max_size_at_node_begin[u]; - if (begin < 0) { - continue; - } - const int prefix = begin + mis_for_stack[u]; - if (prefix > max_size) { - max_size = prefix; - } - const int exit_val = begin + is_for_stack[u]; - for (int v : next_ids_inter[u]) { - if (exit_val > max_size_at_node_begin[v]) { - max_size_at_node_begin[v] = exit_val; - } - } - } -} - -struct AdStackDPResult { - int max_size; - bool has_positive_loop; -}; - -// Run the per-representative DP. Walks the SCC condensation in topological order (sources first, -// since Tarjan emits SCCs in reverse-topological order so source SCCs end up at the largest -// indices). `max_size_at_node_begin` is taken by reference as a scratch buffer reused across reps -// to avoid reallocating per iteration. -AdStackDPResult run_ad_stack_size_dp_for_representative(int start_node, - int num_sccs, - const std::vector &is_for_stack, - const std::vector &mis_for_stack, - const std::vector> &scc_nodes, - const std::vector &scc_is_cyclic, - const std::vector> &scc_topo, - const std::vector> &next_ids_intra_fwd, - const std::vector> &next_ids_intra_back, - const std::vector> &next_ids_inter, - std::vector &max_size_at_node_begin) { - std::fill(max_size_at_node_begin.begin(), max_size_at_node_begin.end(), -1); - max_size_at_node_begin[start_node] = 0; - int max_size = 0; - for (int s = num_sccs - 1; s >= 0; s--) { - const auto &nodes_in_s = scc_nodes[s]; - if (scc_is_cyclic[s]) { - switch (classify_cyclic_scc_fast_path(nodes_in_s, is_for_stack)) { - case CyclicSccFastPath::kPositiveCycle: - return {max_size, /*has_positive_loop=*/true}; - case CyclicSccFastPath::kZeroSpread: - spread_max_begin_over_zero_scc(nodes_in_s, max_size_at_node_begin); - break; - case CyclicSccFastPath::kFallback: - if (dp_mixed_sign_cyclic_scc(scc_topo[s], nodes_in_s, is_for_stack, next_ids_intra_fwd, next_ids_intra_back, - max_size_at_node_begin)) { - return {max_size, /*has_positive_loop=*/true}; - } - break; - } - } - update_global_max_and_relax_inter_scc(nodes_in_s, is_for_stack, mis_for_stack, next_ids_inter, - max_size_at_node_begin, max_size); - } - return {max_size, /*has_positive_loop=*/false}; -} - -// Broadcast per-representative DP results to every active stack and apply the resolved -// `max_size`. Stacks with positive cycles are left at `max_size = 0` so the structural -// bounded-loop pre-pass in `irpass::determine_ad_stack_size` gets a chance to derive a symbolic -// bound; if it also cannot, the caller emits a hard compile error (there is no compile-time -// `default_ad_stack_size` fallback). -void apply_ad_stack_dp_results(const std::vector &stacks, - const std::vector &stack_active, - const std::vector &stack_to_rep, - const std::unordered_map &rep_results) { - const int num_stacks = static_cast(stacks.size()); - for (int sid = 0; sid < num_stacks; sid++) { - AdStackAllocaStmt *stack = stacks[sid]; - if (!stack_active[sid]) { - // No push/pop in the CFG: the DP would visit reachable nodes with all-zero edge weights and - // settle with `max_size = 0`, no positive loop. Reproduce that result directly. - QD_WARN("Unused autodiff stack {} should have been eliminated.", stack->name()); - continue; - } - const AdStackDPResult &res = rep_results.at(stack_to_rep[sid]); - if (res.has_positive_loop) { - // Leave `max_size = 0` so the symbolic-bound pre-pass can take over. - continue; - } - // Since we use |max_size| == 0 for adaptive sizes, we do not want stacks with maximum capacity - // indeed equal to 0. - QD_WARN_IF(res.max_size == 0, "Unused autodiff stack {} should have been eliminated.", stack->name()); - stack->max_size = res.max_size; - } -} - -} // namespace - -void ControlFlowGraph::determine_ad_stack_size() { - /** - * What problem this solves - * ------------------------ - * Reverse-mode autodiff records each primal write to a variable by pushing the value onto a - * per-variable stack (an "AD-stack"), then pops in reverse order during the backward pass. The - * stack's *capacity* must be known at allocation time. Some stacks are sized explicitly by the - * user (`max_size != 0`); the rest are "adaptive" (`max_size == 0`) and we have to figure out - * how many pushes can stack up on them at runtime by walking the CFG and bookkeeping pushes - * minus pops along every reachable path from kernel entry. - * - * For each adaptive AD-stack we want: the maximum value of (pushes - pops) over any walk from - * the entry node. Set that as `max_size`. If the kernel has a loop where pushes > pops on every - * iteration the maximum is unbounded; we leave `max_size = 0` and let the caller's structural - * bounded-loop pre-pass derive a symbolic `SizeExpr` from the loop ranges (and hard-error if - * even that fails -- there is no compile-time default fallback). - * - * High-level algorithm - * -------------------- - * 1. Index every adaptive AD-stack (`collect_adaptive_ad_stacks`). - * 2. For every (stack, CFG node) pair, precompute the net push/pop delta inside that node and - * the max running push-count prefix inside that node - * (`accumulate_per_stack_per_node_size_deltas`). After this point the actual stmts are - * irrelevant; everything below operates on int matrices. - * 3. Condense the CFG with Tarjan's SCC algorithm (`tarjan_scc`). Inside any cyclic SCC, the - * "max walk-sum from entry" question becomes "is there a positive cycle, and if not, what's - * the DAG longest path inside this SCC?". The condensation lets us solve those questions - * once per SCC, in topological order across SCCs. - * 4. For each *distinct* push/pop fingerprint across stacks (`group_stacks_by_fingerprint`), - * run the DP once: walk SCCs in topological order, per cyclic SCC try two cheap sign-based - * fast paths and fall back to a single-pass DAG DP + back-edge relaxation check for cycle - * detection. Broadcast the result to every stack that shares this fingerprint - * (`apply_ad_stack_dp_results`). In practice many autodiff kernels generate identical - * push/pop schedules across stacks (one alloca per AD variable in the same loop body), so - * the fingerprint dedup collapses N stacks to a handful of DP runs. - * - * The two cheap fast paths for a cyclic SCC: - * (a) min_is >= 0 and max_is > 0 => positive cycle exists for this stack (every node in an - * SCC is on a cycle; a cycle through a strictly-positive node with non-negative weights - * sums positive). Bail out. - * (b) min_is == 0 == max_is => the SCC adds nothing; just spread the maximum entry- - * side `max_size_at_node_begin` to every node in the SCC in O(|S|). - * The full mixed-sign DP runs only when neither shortcut applies. - * - * Performance summary - * ------------------- - * Per representative cost: O(V + E + sum_{cyclic SCC S} |S| * |E_S|), with the inner sum - * collapsing to O(|S| + |E_S|) for the common autodiff push/pop pattern (one DP sweep + one - * back-edge check). Overall: O(V + E + R * (V + E)), with R = number of distinct row-pair - * fingerprints (typically << number of stacks). Per-stack per-node tables are dense - * `vector>` indexed by contiguous int stack id, not `unordered_map`, to keep the - * hot inner loop branch-and-cache friendly. - */ - AdStackIndex idx = collect_adaptive_ad_stacks(nodes); - const int num_stacks = static_cast(idx.stacks.size()); - if (num_stacks == 0) { - return; - } - - const int num_nodes = size(); - AdStackPerNodeSizes sizes = accumulate_per_stack_per_node_size_deltas(nodes, idx.stack_id, num_stacks); - std::vector> next_ids = compute_outgoing_node_ids(nodes); - - TarjanResult tarjan = tarjan_scc(next_ids); - const int num_sccs = static_cast(tarjan.scc_nodes.size()); - SccEdgeSets edges = classify_scc_edges(next_ids, tarjan.scc_id, tarjan.dfs_finish); - CyclicSccInfo cyc = identify_cyclic_sccs_and_topo(tarjan.scc_nodes, edges.next_ids_intra_back, tarjan.dfs_finish); - - FingerprintGroups groups = group_stacks_by_fingerprint(num_nodes, num_stacks, sizes.stack_active, sizes.increased_size, - sizes.max_increased_size); - - std::vector max_size_at_node_begin(num_nodes); - std::unordered_map rep_results; - rep_results.reserve(groups.rep_stack_ids.size()); - for (int rep_sid : groups.rep_stack_ids) { - rep_results[rep_sid] = run_ad_stack_size_dp_for_representative( - start_node, num_sccs, sizes.increased_size[rep_sid], sizes.max_increased_size[rep_sid], tarjan.scc_nodes, - cyc.scc_is_cyclic, cyc.scc_topo, edges.next_ids_intra_fwd, edges.next_ids_intra_back, edges.next_ids_inter, - max_size_at_node_begin); - } - - apply_ad_stack_dp_results(idx.stacks, sizes.stack_active, groups.stack_to_rep, rep_results); -} - } // namespace quadrants::lang diff --git a/quadrants/ir/determine_ad_stack_size.cpp b/quadrants/ir/determine_ad_stack_size.cpp new file mode 100644 index 0000000000..e9f9d94a23 --- /dev/null +++ b/quadrants/ir/determine_ad_stack_size.cpp @@ -0,0 +1,637 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "quadrants/common/exceptions.h" +#include "quadrants/ir/control_flow_graph.h" +#include "quadrants/ir/statements.h" + +// =========================================================================== +// Adaptive AD-stack sizing. +// +// This file implements ControlFlowGraph::determine_ad_stack_size() and its +// file-local helpers. Pulled out of control_flow_graph.cpp because it is a +// self-contained ~500-line algorithm (Tarjan SCC condensation + per-rep DAG +// DP + fingerprint dedup) that has no per-node CFGNode counterpart and is +// not used by any other pass in this file. +// +// See the docstring on ControlFlowGraph::determine_ad_stack_size() below for +// the "what problem this solves / how" overview. +// =========================================================================== + +namespace quadrants::lang { + +namespace { + +// === Helpers for ControlFlowGraph::determine_ad_stack_size === + +struct AdStackIndex { + // AdStackAllocaStmt* -> contiguous int id, populated only for adaptive stacks + // (max_size == 0) that have not yet been resolved by an earlier pass. + std::unordered_map stack_id; + std::vector stacks; +}; + +AdStackIndex collect_adaptive_ad_stacks(const std::vector> &nodes) { + AdStackIndex idx; + for (const auto &node : nodes) { + for (int j = node->begin_location; j < node->end_location; j++) { + Stmt *stmt = node->block->statements[j].get(); + auto *stack = stmt->cast(); + if (!stack || stack->max_size != 0) { + continue; + } + if (idx.stack_id.emplace(stack, static_cast(idx.stacks.size())).second) { + idx.stacks.push_back(stack); + } + } + } + return idx; +} + +struct AdStackPerNodeSizes { + // [stack_id][node_id]. `max_increased_size[s][j]` is the maximum (pushes - pops) of stack |s| + // among all prefixes of CFGNode |j|; `increased_size[s][j]` is the net (pushes - pops) in the + // whole node. Indexed by contiguous stack id so the per-stack DP reads cheap vectors instead of + // hashing. + std::vector> increased_size; + std::vector> max_increased_size; + // True iff the stack actually appears in any push/pop in the CFG. Inactive stacks would settle + // at `max_size = 0` regardless and are short-circuited below to reproduce the original "Unused + // autodiff stack" warning. + std::vector stack_active; +}; + +AdStackPerNodeSizes accumulate_per_stack_per_node_size_deltas( + const std::vector> &nodes, + const std::unordered_map &stack_id, + int num_stacks) { + const int num_nodes = static_cast(nodes.size()); + AdStackPerNodeSizes out; + out.increased_size.assign(num_stacks, std::vector(num_nodes, 0)); + out.max_increased_size.assign(num_stacks, std::vector(num_nodes, 0)); + out.stack_active.assign(num_stacks, false); + for (int i = 0; i < num_nodes; i++) { + for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { + Stmt *stmt = nodes[i]->block->statements[j].get(); + AdStackAllocaStmt *stack = nullptr; + int delta = 0; + if (auto *push = stmt->cast()) { + stack = push->stack->as(); + delta = +1; + } else if (auto *pop = stmt->cast()) { + stack = pop->stack->as(); + delta = -1; + } else { + continue; + } + if (stack->max_size != 0 /*non-adaptive*/) { + continue; + } + auto it = stack_id.find(stack); + QD_ASSERT(it != stack_id.end()); + const int sid = it->second; + out.stack_active[sid] = true; + int &cur = out.increased_size[sid][i]; + cur += delta; + if (cur > out.max_increased_size[sid][i]) { + out.max_increased_size[sid][i] = cur; + } + } + } + return out; +} + +// Precompute outgoing-edge node ids once per node so the per-stack DP walks an int vector instead +// of hashing `node_ids[next_node]` on every traversal. +std::vector> compute_outgoing_node_ids(const std::vector> &nodes) { + const int num_nodes = static_cast(nodes.size()); + std::unordered_map node_ids; + node_ids.reserve(num_nodes); + for (int i = 0; i < num_nodes; i++) { + node_ids[nodes[i].get()] = i; + } + std::vector> next_ids(num_nodes); + for (int i = 0; i < num_nodes; i++) { + auto &dst = next_ids[i]; + dst.reserve(nodes[i]->next.size()); + for (auto *next_node : nodes[i]->next) { + dst.push_back(node_ids[next_node]); + } + } + return next_ids; +} + +struct TarjanResult { + std::vector scc_id; // node -> SCC index + std::vector dfs_finish; // node -> DFS post-order index; ancestors finish AFTER descendants + std::vector> scc_nodes; // SCC -> list of node ids +}; + +// Iterative Tarjan SCC. Emits SCCs in reverse topological order (sources end up at the largest +// indices). Also records DFS post-order finish times, used downstream to split each cyclic SCC's +// intra-edges into forward vs back without a second edge-classification pass. +TarjanResult tarjan_scc(const std::vector> &next_ids) { + const int num_nodes = static_cast(next_ids.size()); + TarjanResult out; + out.scc_id.assign(num_nodes, -1); + out.dfs_finish.assign(num_nodes, -1); + + std::vector tarjan_index(num_nodes, -1); + std::vector tarjan_lowlink(num_nodes, 0); + std::vector on_stack(num_nodes, 0); + std::vector tarjan_stack; + tarjan_stack.reserve(num_nodes); + std::vector> dfs_stack; + dfs_stack.reserve(num_nodes); + int next_index = 0; + int next_finish = 0; + int next_scc = 0; + for (int origin = 0; origin < num_nodes; origin++) { + if (tarjan_index[origin] != -1) { + continue; + } + tarjan_index[origin] = next_index; + tarjan_lowlink[origin] = next_index; + next_index++; + tarjan_stack.push_back(origin); + on_stack[origin] = 1; + dfs_stack.emplace_back(origin, 0); + while (!dfs_stack.empty()) { + auto &frame = dfs_stack.back(); + const int u = frame.first; + const auto &nb = next_ids[u]; + if (frame.second < nb.size()) { + const int v = nb[frame.second++]; + if (tarjan_index[v] == -1) { + tarjan_index[v] = next_index; + tarjan_lowlink[v] = next_index; + next_index++; + tarjan_stack.push_back(v); + on_stack[v] = 1; + dfs_stack.emplace_back(v, 0); + } else if (on_stack[v]) { + if (tarjan_index[v] < tarjan_lowlink[u]) { + tarjan_lowlink[u] = tarjan_index[v]; + } + } + } else { + if (tarjan_lowlink[u] == tarjan_index[u]) { + std::vector component; + while (true) { + const int w = tarjan_stack.back(); + tarjan_stack.pop_back(); + on_stack[w] = 0; + out.scc_id[w] = next_scc; + component.push_back(w); + if (w == u) { + break; + } + } + out.scc_nodes.push_back(std::move(component)); + next_scc++; + } + out.dfs_finish[u] = next_finish++; + dfs_stack.pop_back(); + if (!dfs_stack.empty()) { + const int parent = dfs_stack.back().first; + if (tarjan_lowlink[u] < tarjan_lowlink[parent]) { + tarjan_lowlink[parent] = tarjan_lowlink[u]; + } + } + } + } + } + return out; +} + +struct SccEdgeSets { + // Each intra-SCC edge is either "forward" in the DFS spanning sense (target finishes before + // source, so source can be relaxed before target with a single topological pass) or "back" + // (target finishes at or after source, closing a cycle). The forward set drives the per-stack + // DAG dynamic programming; back-edges only need a single post-DP relaxation check to detect + // positive cycles. Inter-SCC edges always relax forward in topological order. + std::vector> next_ids_intra_fwd; + std::vector> next_ids_intra_back; + std::vector> next_ids_inter; +}; + +SccEdgeSets classify_scc_edges(const std::vector> &next_ids, + const std::vector &scc_id, + const std::vector &dfs_finish) { + const int num_nodes = static_cast(next_ids.size()); + SccEdgeSets out; + out.next_ids_intra_fwd.assign(num_nodes, {}); + out.next_ids_intra_back.assign(num_nodes, {}); + out.next_ids_inter.assign(num_nodes, {}); + for (int u = 0; u < num_nodes; u++) { + const int su = scc_id[u]; + for (int v : next_ids[u]) { + if (scc_id[v] == su) { + if (dfs_finish[v] < dfs_finish[u]) { + out.next_ids_intra_fwd[u].push_back(v); + } else { + out.next_ids_intra_back[u].push_back(v); + } + } else { + out.next_ids_inter[u].push_back(v); + } + } + } + return out; +} + +struct CyclicSccInfo { + std::vector scc_is_cyclic; // 1 iff SCC contains a cycle + std::vector> scc_topo; // for cyclic SCCs: nodes sorted by descending dfs_finish +}; + +// An SCC is cyclic iff |S| > 1 (any two nodes in a non-trivial SCC lie on a cycle) or |S| == 1 +// with a self-loop edge. For cyclic SCCs we precompute the topological ordering of their nodes +// so the per-stack DAG DP visits each node once with all forward predecessors already finalized. +CyclicSccInfo identify_cyclic_sccs_and_topo(const std::vector> &scc_nodes, + const std::vector> &next_ids_intra_back, + const std::vector &dfs_finish) { + const int num_sccs = static_cast(scc_nodes.size()); + CyclicSccInfo out; + out.scc_is_cyclic.assign(num_sccs, 0); + out.scc_topo.assign(num_sccs, {}); + for (int s = 0; s < num_sccs; s++) { + const auto &nodes_in_s = scc_nodes[s]; + if (nodes_in_s.size() > 1) { + out.scc_is_cyclic[s] = 1; + } else { + // Self-loops are classified as back-edges above, so a singleton SCC is cyclic iff it has a + // back-edge pointing at itself. + const int n = nodes_in_s[0]; + for (int v : next_ids_intra_back[n]) { + if (v == n) { + out.scc_is_cyclic[s] = 1; + break; + } + } + } + if (out.scc_is_cyclic[s]) { + struct DfsFinishGreater { + const std::vector &dfs_finish; + bool operator()(int a, int b) const { + return dfs_finish[a] > dfs_finish[b]; + } + }; + auto topo = nodes_in_s; + std::sort(topo.begin(), topo.end(), DfsFinishGreater{dfs_finish}); + out.scc_topo[s] = std::move(topo); + } + } + return out; +} + +struct FingerprintGroups { + std::vector stack_to_rep; // stack id -> representative stack id (whose DP run it shares) + std::vector rep_stack_ids; // representative stack ids that actually run the DP +}; + +// Group AD-stacks whose per-node (increased_size, max_increased_size) rows are bit-identical so +// the DP runs once per equivalence class instead of once per stack. In practice many AD-stacks in +// the same kernel share their push/pop schedule (one alloca per autodiff variable in the same +// loop body), and the DP on a large CFG dwarfs the dedup cost. We hash a sparse fingerprint (only +// nodes where the stack has activity) and group by it. Worst case (no duplicates): one DP run per +// stack, equivalent to running it directly per stack. +FingerprintGroups group_stacks_by_fingerprint(int num_nodes, + int num_stacks, + const std::vector &stack_active, + const std::vector> &increased_size, + const std::vector> &max_increased_size) { + using Fingerprint = std::vector>; // (node_id, is, mis), sorted by node_id + struct FingerprintHash { + std::size_t operator()(const Fingerprint &f) const noexcept { + // FNV-1a mix of three components per fingerprint entry. + constexpr std::size_t fnv_prime = 1099511628211ULL; + std::size_t h = 1469598103934665603ULL; + for (auto &[n, i, m] : f) { + h ^= static_cast(n); + h *= fnv_prime; + h ^= static_cast(static_cast(i)); + h *= fnv_prime; + h ^= static_cast(static_cast(m)); + h *= fnv_prime; + } + return h; + } + }; + std::unordered_map fp_to_rep; + FingerprintGroups out; + out.stack_to_rep.assign(num_stacks, -1); + for (int sid = 0; sid < num_stacks; sid++) { + if (!stack_active[sid]) { + continue; + } + Fingerprint fp; + const auto &is_row = increased_size[sid]; + const auto &mis_row = max_increased_size[sid]; + for (int n = 0; n < num_nodes; n++) { + if (is_row[n] != 0 || mis_row[n] != 0) { + fp.emplace_back(n, is_row[n], mis_row[n]); + } + } + auto [it, inserted] = fp_to_rep.emplace(std::move(fp), sid); + out.stack_to_rep[sid] = it->second; + if (inserted) { + out.rep_stack_ids.push_back(sid); + } + } + return out; +} + +enum class CyclicSccFastPath { + kPositiveCycle, // proven positive cycle exists for this stack inside this SCC + kZeroSpread, // no `is` contribution in the SCC; just spread max entry-side begin value + kFallback, // mixed-sign, run the full DP +}; + +// Sign-based fast paths sidestep the cyclic-SCC dynamic programming when the stack's `is` +// contribution inside this SCC is structurally trivial: +// 1. min_is >= 0 with max_is > 0: every node in the SCC lies on some cycle (SCC property), and +// a cycle through a strictly-positive node with all non-negative `is` along it sums to a +// positive value, so a positive cycle exists for this stack. This is the autodiff +// push-only-in-SCC pattern. +// 2. min_is == max_is == 0: no `is` contribution at all in this SCC; the DP would only spread +// the maximum entry-side `max_size_at_node_begin` value to every node in O(|S|). +CyclicSccFastPath classify_cyclic_scc_fast_path(const std::vector &nodes_in_s, + const std::vector &is_for_stack) { + int min_is = INT_MAX; + int max_is = INT_MIN; + for (int u : nodes_in_s) { + const int v = is_for_stack[u]; + if (v < min_is) { + min_is = v; + } + if (v > max_is) { + max_is = v; + } + } + if (min_is >= 0 && max_is > 0) { + return CyclicSccFastPath::kPositiveCycle; + } + if (min_is == 0 && max_is == 0) { + return CyclicSccFastPath::kZeroSpread; + } + return CyclicSccFastPath::kFallback; +} + +// Spread the maximum entry-side `max_size_at_node_begin` value across every node in the SCC. +// Equivalent to running the DP with all-zero `is` weights, in O(|S|). +void spread_max_begin_over_zero_scc(const std::vector &nodes_in_s, + std::vector &max_size_at_node_begin) { + int max_begin = -1; + for (int u : nodes_in_s) { + if (max_size_at_node_begin[u] > max_begin) { + max_begin = max_size_at_node_begin[u]; + } + } + if (max_begin < 0) { + return; + } + for (int u : nodes_in_s) { + if (max_size_at_node_begin[u] < max_begin) { + max_size_at_node_begin[u] = max_begin; + } + } +} + +// Mixed-sign case: single-pass dynamic programming on the SCC's forward edges (processed in +// descending DFS finish-time so every forward predecessor is finalized before its successor +// relaxes), followed by one relaxation check on the back-edges. Correctness: every walk inside +// the SCC decomposes into a forward path plus zero or more closed cycles (back-edge + forward +// path). For SCCs with no positive cycle, traversing a cycle adds a non-positive amount to the +// running size and so cannot improve max_size beyond what the forward DP already computed. The +// back-edge relaxation check after the DP detects the only failure mode (some back-edge would +// still improve a forward predecessor's value, which can only happen if a positive cycle exists +// for this stack). Returns true iff a positive cycle was detected. +bool dp_mixed_sign_cyclic_scc(const std::vector &topo, + const std::vector &nodes_in_s, + const std::vector &is_for_stack, + const std::vector> &next_ids_intra_fwd, + const std::vector> &next_ids_intra_back, + std::vector &max_size_at_node_begin) { + for (int u : topo) { + const int begin = max_size_at_node_begin[u]; + if (begin < 0) { + continue; + } + const int exit_val = begin + is_for_stack[u]; + for (int v : next_ids_intra_fwd[u]) { + if (exit_val > max_size_at_node_begin[v]) { + max_size_at_node_begin[v] = exit_val; + } + } + } + for (int u : nodes_in_s) { + const int begin = max_size_at_node_begin[u]; + if (begin < 0) { + continue; + } + const int exit_val = begin + is_for_stack[u]; + for (int v : next_ids_intra_back[u]) { + if (exit_val > max_size_at_node_begin[v]) { + return true; + } + } + } + return false; +} + +// SCC has converged for this stack. Update global `max_size` from each node's max-prefix +// contribution, then relax inter-SCC outgoing edges into successor SCCs (predecessors are already +// finalized when an SCC is entered, so each inter-SCC edge is touched exactly once). +void update_global_max_and_relax_inter_scc(const std::vector &nodes_in_s, + const std::vector &is_for_stack, + const std::vector &mis_for_stack, + const std::vector> &next_ids_inter, + std::vector &max_size_at_node_begin, + int &max_size) { + for (int u : nodes_in_s) { + const int begin = max_size_at_node_begin[u]; + if (begin < 0) { + continue; + } + const int prefix = begin + mis_for_stack[u]; + if (prefix > max_size) { + max_size = prefix; + } + const int exit_val = begin + is_for_stack[u]; + for (int v : next_ids_inter[u]) { + if (exit_val > max_size_at_node_begin[v]) { + max_size_at_node_begin[v] = exit_val; + } + } + } +} + +struct AdStackDPResult { + int max_size; + bool has_positive_loop; +}; + +// Run the per-representative DP. Walks the SCC condensation in topological order (sources first, +// since Tarjan emits SCCs in reverse-topological order so source SCCs end up at the largest +// indices). `max_size_at_node_begin` is taken by reference as a scratch buffer reused across reps +// to avoid reallocating per iteration. +AdStackDPResult run_ad_stack_size_dp_for_representative(int start_node, + int num_sccs, + const std::vector &is_for_stack, + const std::vector &mis_for_stack, + const std::vector> &scc_nodes, + const std::vector &scc_is_cyclic, + const std::vector> &scc_topo, + const std::vector> &next_ids_intra_fwd, + const std::vector> &next_ids_intra_back, + const std::vector> &next_ids_inter, + std::vector &max_size_at_node_begin) { + std::fill(max_size_at_node_begin.begin(), max_size_at_node_begin.end(), -1); + max_size_at_node_begin[start_node] = 0; + int max_size = 0; + for (int s = num_sccs - 1; s >= 0; s--) { + const auto &nodes_in_s = scc_nodes[s]; + if (scc_is_cyclic[s]) { + switch (classify_cyclic_scc_fast_path(nodes_in_s, is_for_stack)) { + case CyclicSccFastPath::kPositiveCycle: + return {max_size, /*has_positive_loop=*/true}; + case CyclicSccFastPath::kZeroSpread: + spread_max_begin_over_zero_scc(nodes_in_s, max_size_at_node_begin); + break; + case CyclicSccFastPath::kFallback: + if (dp_mixed_sign_cyclic_scc(scc_topo[s], nodes_in_s, is_for_stack, next_ids_intra_fwd, next_ids_intra_back, + max_size_at_node_begin)) { + return {max_size, /*has_positive_loop=*/true}; + } + break; + } + } + update_global_max_and_relax_inter_scc(nodes_in_s, is_for_stack, mis_for_stack, next_ids_inter, + max_size_at_node_begin, max_size); + } + return {max_size, /*has_positive_loop=*/false}; +} + +// Broadcast per-representative DP results to every active stack and apply the resolved +// `max_size`. Stacks with positive cycles are left at `max_size = 0` so the structural +// bounded-loop pre-pass in `irpass::determine_ad_stack_size` gets a chance to derive a symbolic +// bound; if it also cannot, the caller emits a hard compile error (there is no compile-time +// `default_ad_stack_size` fallback). +void apply_ad_stack_dp_results(const std::vector &stacks, + const std::vector &stack_active, + const std::vector &stack_to_rep, + const std::unordered_map &rep_results) { + const int num_stacks = static_cast(stacks.size()); + for (int sid = 0; sid < num_stacks; sid++) { + AdStackAllocaStmt *stack = stacks[sid]; + if (!stack_active[sid]) { + // No push/pop in the CFG: the DP would visit reachable nodes with all-zero edge weights and + // settle with `max_size = 0`, no positive loop. Reproduce that result directly. + QD_WARN("Unused autodiff stack {} should have been eliminated.", stack->name()); + continue; + } + const AdStackDPResult &res = rep_results.at(stack_to_rep[sid]); + if (res.has_positive_loop) { + // Leave `max_size = 0` so the symbolic-bound pre-pass can take over. + continue; + } + // Since we use |max_size| == 0 for adaptive sizes, we do not want stacks with maximum capacity + // indeed equal to 0. + QD_WARN_IF(res.max_size == 0, "Unused autodiff stack {} should have been eliminated.", stack->name()); + stack->max_size = res.max_size; + } +} + +} // namespace + +void ControlFlowGraph::determine_ad_stack_size() { + /** + * What problem this solves + * ------------------------ + * Reverse-mode autodiff records each primal write to a variable by pushing the value onto a + * per-variable stack (an "AD-stack"), then pops in reverse order during the backward pass. The + * stack's *capacity* must be known at allocation time. Some stacks are sized explicitly by the + * user (`max_size != 0`); the rest are "adaptive" (`max_size == 0`) and we have to figure out + * how many pushes can stack up on them at runtime by walking the CFG and bookkeeping pushes + * minus pops along every reachable path from kernel entry. + * + * For each adaptive AD-stack we want: the maximum value of (pushes - pops) over any walk from + * the entry node. Set that as `max_size`. If the kernel has a loop where pushes > pops on every + * iteration the maximum is unbounded; we leave `max_size = 0` and let the caller's structural + * bounded-loop pre-pass derive a symbolic `SizeExpr` from the loop ranges (and hard-error if + * even that fails -- there is no compile-time default fallback). + * + * High-level algorithm + * -------------------- + * 1. Index every adaptive AD-stack (`collect_adaptive_ad_stacks`). + * 2. For every (stack, CFG node) pair, precompute the net push/pop delta inside that node and + * the max running push-count prefix inside that node + * (`accumulate_per_stack_per_node_size_deltas`). After this point the actual stmts are + * irrelevant; everything below operates on int matrices. + * 3. Condense the CFG with Tarjan's SCC algorithm (`tarjan_scc`). Inside any cyclic SCC, the + * "max walk-sum from entry" question becomes "is there a positive cycle, and if not, what's + * the DAG longest path inside this SCC?". The condensation lets us solve those questions + * once per SCC, in topological order across SCCs. + * 4. For each *distinct* push/pop fingerprint across stacks (`group_stacks_by_fingerprint`), + * run the DP once: walk SCCs in topological order, per cyclic SCC try two cheap sign-based + * fast paths and fall back to a single-pass DAG DP + back-edge relaxation check for cycle + * detection. Broadcast the result to every stack that shares this fingerprint + * (`apply_ad_stack_dp_results`). In practice many autodiff kernels generate identical + * push/pop schedules across stacks (one alloca per AD variable in the same loop body), so + * the fingerprint dedup collapses N stacks to a handful of DP runs. + * + * The two cheap fast paths for a cyclic SCC: + * (a) min_is >= 0 and max_is > 0 => positive cycle exists for this stack (every node in an + * SCC is on a cycle; a cycle through a strictly-positive node with non-negative weights + * sums positive). Bail out. + * (b) min_is == 0 == max_is => the SCC adds nothing; just spread the maximum entry- + * side `max_size_at_node_begin` to every node in the SCC in O(|S|). + * The full mixed-sign DP runs only when neither shortcut applies. + * + * Performance summary + * ------------------- + * Per representative cost: O(V + E + sum_{cyclic SCC S} |S| * |E_S|), with the inner sum + * collapsing to O(|S| + |E_S|) for the common autodiff push/pop pattern (one DP sweep + one + * back-edge check). Overall: O(V + E + R * (V + E)), with R = number of distinct row-pair + * fingerprints (typically << number of stacks). Per-stack per-node tables are dense + * `vector>` indexed by contiguous int stack id, not `unordered_map`, to keep the + * hot inner loop branch-and-cache friendly. + */ + AdStackIndex idx = collect_adaptive_ad_stacks(nodes); + const int num_stacks = static_cast(idx.stacks.size()); + if (num_stacks == 0) { + return; + } + + const int num_nodes = size(); + AdStackPerNodeSizes sizes = accumulate_per_stack_per_node_size_deltas(nodes, idx.stack_id, num_stacks); + std::vector> next_ids = compute_outgoing_node_ids(nodes); + + TarjanResult tarjan = tarjan_scc(next_ids); + const int num_sccs = static_cast(tarjan.scc_nodes.size()); + SccEdgeSets edges = classify_scc_edges(next_ids, tarjan.scc_id, tarjan.dfs_finish); + CyclicSccInfo cyc = identify_cyclic_sccs_and_topo(tarjan.scc_nodes, edges.next_ids_intra_back, tarjan.dfs_finish); + + FingerprintGroups groups = group_stacks_by_fingerprint(num_nodes, num_stacks, sizes.stack_active, sizes.increased_size, + sizes.max_increased_size); + + std::vector max_size_at_node_begin(num_nodes); + std::unordered_map rep_results; + rep_results.reserve(groups.rep_stack_ids.size()); + for (int rep_sid : groups.rep_stack_ids) { + rep_results[rep_sid] = run_ad_stack_size_dp_for_representative( + start_node, num_sccs, sizes.increased_size[rep_sid], sizes.max_increased_size[rep_sid], tarjan.scc_nodes, + cyc.scc_is_cyclic, cyc.scc_topo, edges.next_ids_intra_fwd, edges.next_ids_intra_back, edges.next_ids_inter, + max_size_at_node_begin); + } + + apply_ad_stack_dp_results(idx.stacks, sizes.stack_active, groups.stack_to_rep, rep_results); +} + +} // namespace quadrants::lang From 8a1a94146c6c7e009dc6fa62c24461e0257b881f Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 13:12:53 -0700 Subject: [PATCH 15/18] [AdStack] Encapsulate the "max_size == 0 == adaptive" sentinel Add `AdStackAllocaStmt::kAdaptiveSentinel` (=0) and `AdStackAllocaStmt::is_adaptive()`, and route every "is this stack still unsized?" check through the predicate. Five call sites updated: - quadrants/ir/determine_ad_stack_size.cpp (2) - quadrants/transforms/determine_ad_stack_size.cpp (3) The local `AdStackDPResult::max_size == 0` check in quadrants/ir/determine_ad_stack_size.cpp is unrelated (compares a DP result, not the stmt) and is left alone. This makes the "0 means adaptive" magic local to one class. A future change to the sentinel (e.g. switching to std::optional or a typed state enum) now touches only the one predicate. No behavior change. --- quadrants/ir/determine_ad_stack_size.cpp | 4 ++-- quadrants/ir/statements.h | 15 ++++++++++++++- quadrants/transforms/determine_ad_stack_size.cpp | 8 ++++---- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/quadrants/ir/determine_ad_stack_size.cpp b/quadrants/ir/determine_ad_stack_size.cpp index e9f9d94a23..d7ef9890f1 100644 --- a/quadrants/ir/determine_ad_stack_size.cpp +++ b/quadrants/ir/determine_ad_stack_size.cpp @@ -42,7 +42,7 @@ AdStackIndex collect_adaptive_ad_stacks(const std::vectorbegin_location; j < node->end_location; j++) { Stmt *stmt = node->block->statements[j].get(); auto *stack = stmt->cast(); - if (!stack || stack->max_size != 0) { + if (!stack || !stack->is_adaptive()) { continue; } if (idx.stack_id.emplace(stack, static_cast(idx.stacks.size())).second) { @@ -89,7 +89,7 @@ AdStackPerNodeSizes accumulate_per_stack_per_node_size_deltas( } else { continue; } - if (stack->max_size != 0 /*non-adaptive*/) { + if (!stack->is_adaptive()) { continue; } auto it = stack_id.find(stack); diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 49c6de35fb..a78d8d046c 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -1633,8 +1633,13 @@ class InternalFuncStmt : public Stmt { */ class AdStackAllocaStmt : public Stmt { public: + // Sentinel value of `max_size` meaning "size has not been resolved yet; treat as adaptive". + // Used as a state marker in multiple passes (see `is_adaptive()`); do not redefine to a real + // capacity even if 0-sized adstacks were valid. + static constexpr std::size_t kAdaptiveSentinel = 0; + DataType dt; - std::size_t max_size{0}; // 0 = adaptive + std::size_t max_size{kAdaptiveSentinel}; // Compile-time captured symbolic expression for `max_size`, populated by // `determine_ad_stack_size` when the bound is derivable from constants and scalar field loads. // Host-evaluated pre-launch to size the adstack heap; null until the pre-pass runs. @@ -1648,6 +1653,14 @@ class AdStackAllocaStmt : public Stmt { QD_STMT_REG_FIELDS; } + // True iff this stack's capacity has not yet been resolved by `determine_ad_stack_size` (or any + // prior pass that may seed it). Prefer this over comparing `max_size` to a literal `0` -- it + // keeps the "0 means adaptive" sentinel encapsulated and lets a future change to the sentinel + // (e.g. to a typed enum) touch only this class. + bool is_adaptive() const { + return max_size == kAdaptiveSentinel; + } + std::size_t element_size_in_bytes() const { return data_type_size(ret_type); } diff --git a/quadrants/transforms/determine_ad_stack_size.cpp b/quadrants/transforms/determine_ad_stack_size.cpp index 203d72a48b..13536e5e84 100644 --- a/quadrants/transforms/determine_ad_stack_size.cpp +++ b/quadrants/transforms/determine_ad_stack_size.cpp @@ -1313,7 +1313,7 @@ bool size_expr_contains_inner_domain_enumeration(const SizeExpr *e) { bool determine_ad_stack_size(IRNode *root, const CompileConfig &config) { auto adaptive_allocas = irpass::analysis::gather_statements(root, [&](Stmt *s) { auto *ad_stack = s->cast(); - return ad_stack != nullptr && ad_stack->max_size == 0; + return ad_stack != nullptr && ad_stack->is_adaptive(); }); if (adaptive_allocas.empty()) { return false; @@ -1339,7 +1339,7 @@ bool determine_ad_stack_size(IRNode *root, const CompileConfig &config) { std::unordered_map alloca_cycle_detected; for (Stmt *s : adaptive_allocas) { auto *alloca = s->as(); - if (alloca->max_size != 0) { + if (!alloca->is_adaptive()) { continue; // Already resolved by Bellman-Ford in phase 1. } t_cycle_detected = false; @@ -1368,7 +1368,7 @@ bool determine_ad_stack_size(IRNode *root, const CompileConfig &config) { // uniform representation regardless of which phase resolved the bound. for (Stmt *s : adaptive_allocas) { auto *alloca = s->as(); - if (!alloca->size_expr && alloca->max_size != 0) { + if (!alloca->size_expr && !alloca->is_adaptive()) { alloca->size_expr = SizeExpr::make_const(static_cast(alloca->max_size)); } } @@ -1381,7 +1381,7 @@ bool determine_ad_stack_size(IRNode *root, const CompileConfig &config) { // `/ir_adstack_unresolved/unresolved_alloca_.ll` for offline inspection. for (Stmt *s : adaptive_allocas) { auto *alloca = s->as(); - if (alloca->max_size != 0 || alloca->size_expr) { + if (!alloca->is_adaptive() || alloca->size_expr) { continue; } std::string dump_hint; From da10b16c06f0f9dcfe7a90db86a4aa46dd0851d5 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 13:14:16 -0700 Subject: [PATCH 16/18] [CFG] Bundle DP args into AdStackDPGraph + AdStackPerNodeRow The old run_ad_stack_size_dp_for_representative took 11 positional args of which 7 were vector>& (three of them being next_ids_intra_fwd / next_ids_intra_back / next_ids_inter, structurally identical and trivial to swap). Bundle them into two structs: AdStackDPGraph -- per-graph (computed once, reused per stack): start_node, num_sccs, scc_nodes, scc_is_cyclic, scc_topo, next_ids_intra_{fwd,back}, next_ids_inter. AdStackPerNodeRow -- per-stack (varies per rep): increased_size + max_increased_size. Call site uses designated initializers, so swapping e.g. `next_ids_intra_fwd` and `next_ids_intra_back` (or `increased_size` and `max_increased_size`) is now a compile error rather than a silent positive-cycle-detection regression. No behavior change. Chose this over an `enum class NodeId/StackId/SccId` strong typedef: the strong-typedef approach would have forced ~40 static_cast insertions in the inner loops for ~zero added type safety inside the helpers, where the actual swap risk is at the function boundary. --- quadrants/ir/determine_ad_stack_size.cpp | 73 ++++++++++++++++-------- 1 file changed, 50 insertions(+), 23 deletions(-) diff --git a/quadrants/ir/determine_ad_stack_size.cpp b/quadrants/ir/determine_ad_stack_size.cpp index d7ef9890f1..5e3c0c7f32 100644 --- a/quadrants/ir/determine_ad_stack_size.cpp +++ b/quadrants/ir/determine_ad_stack_size.cpp @@ -477,43 +477,59 @@ struct AdStackDPResult { bool has_positive_loop; }; +// Bundle of "graph-shape" inputs to the per-stack DP: same for every stack, computed once before +// the loop. Named-field aggregate construction (designated initializers at the call site) makes +// it a compile error to swap, e.g. `next_ids_intra_fwd` and `next_ids_intra_back` -- the highest- +// risk swap surface in the old 11-positional-int-arg signature this struct replaces. +struct AdStackDPGraph { + int start_node; + int num_sccs; + const std::vector> &scc_nodes; + const std::vector &scc_is_cyclic; + const std::vector> &scc_topo; + const std::vector> &next_ids_intra_fwd; + const std::vector> &next_ids_intra_back; + const std::vector> &next_ids_inter; +}; + +// Bundle of per-stack per-node values for one DP representative. Same shape as the graph +// (`increased_size[i]` and `max_increased_size[i]` are the deltas for CFG node i). Bundling them +// together keeps `(increased_size, max_increased_size)` named at the call site -- before this +// struct, both were `const std::vector &` and trivial to swap. +struct AdStackPerNodeRow { + const std::vector &increased_size; // net (pushes - pops) inside each node + const std::vector &max_increased_size; // max running (pushes - pops) prefix inside each node +}; + // Run the per-representative DP. Walks the SCC condensation in topological order (sources first, // since Tarjan emits SCCs in reverse-topological order so source SCCs end up at the largest // indices). `max_size_at_node_begin` is taken by reference as a scratch buffer reused across reps // to avoid reallocating per iteration. -AdStackDPResult run_ad_stack_size_dp_for_representative(int start_node, - int num_sccs, - const std::vector &is_for_stack, - const std::vector &mis_for_stack, - const std::vector> &scc_nodes, - const std::vector &scc_is_cyclic, - const std::vector> &scc_topo, - const std::vector> &next_ids_intra_fwd, - const std::vector> &next_ids_intra_back, - const std::vector> &next_ids_inter, +AdStackDPResult run_ad_stack_size_dp_for_representative(const AdStackDPGraph &graph, + const AdStackPerNodeRow &row, std::vector &max_size_at_node_begin) { std::fill(max_size_at_node_begin.begin(), max_size_at_node_begin.end(), -1); - max_size_at_node_begin[start_node] = 0; + max_size_at_node_begin[graph.start_node] = 0; int max_size = 0; - for (int s = num_sccs - 1; s >= 0; s--) { - const auto &nodes_in_s = scc_nodes[s]; - if (scc_is_cyclic[s]) { - switch (classify_cyclic_scc_fast_path(nodes_in_s, is_for_stack)) { + for (int s = graph.num_sccs - 1; s >= 0; s--) { + const auto &nodes_in_s = graph.scc_nodes[s]; + if (graph.scc_is_cyclic[s]) { + switch (classify_cyclic_scc_fast_path(nodes_in_s, row.increased_size)) { case CyclicSccFastPath::kPositiveCycle: return {max_size, /*has_positive_loop=*/true}; case CyclicSccFastPath::kZeroSpread: spread_max_begin_over_zero_scc(nodes_in_s, max_size_at_node_begin); break; case CyclicSccFastPath::kFallback: - if (dp_mixed_sign_cyclic_scc(scc_topo[s], nodes_in_s, is_for_stack, next_ids_intra_fwd, next_ids_intra_back, - max_size_at_node_begin)) { + if (dp_mixed_sign_cyclic_scc(graph.scc_topo[s], nodes_in_s, row.increased_size, graph.next_ids_intra_fwd, + graph.next_ids_intra_back, max_size_at_node_begin)) { return {max_size, /*has_positive_loop=*/true}; } break; } } - update_global_max_and_relax_inter_scc(nodes_in_s, is_for_stack, mis_for_stack, next_ids_inter, - max_size_at_node_begin, max_size); + update_global_max_and_relax_inter_scc(nodes_in_s, row.increased_size, row.max_increased_size, + graph.next_ids_inter, max_size_at_node_begin, max_size); } return {max_size, /*has_positive_loop=*/false}; } @@ -624,11 +640,22 @@ void ControlFlowGraph::determine_ad_stack_size() { std::vector max_size_at_node_begin(num_nodes); std::unordered_map rep_results; rep_results.reserve(groups.rep_stack_ids.size()); + const AdStackDPGraph graph{ + .start_node = start_node, + .num_sccs = num_sccs, + .scc_nodes = tarjan.scc_nodes, + .scc_is_cyclic = cyc.scc_is_cyclic, + .scc_topo = cyc.scc_topo, + .next_ids_intra_fwd = edges.next_ids_intra_fwd, + .next_ids_intra_back = edges.next_ids_intra_back, + .next_ids_inter = edges.next_ids_inter, + }; for (int rep_sid : groups.rep_stack_ids) { - rep_results[rep_sid] = run_ad_stack_size_dp_for_representative( - start_node, num_sccs, sizes.increased_size[rep_sid], sizes.max_increased_size[rep_sid], tarjan.scc_nodes, - cyc.scc_is_cyclic, cyc.scc_topo, edges.next_ids_intra_fwd, edges.next_ids_intra_back, edges.next_ids_inter, - max_size_at_node_begin); + const AdStackPerNodeRow row{ + .increased_size = sizes.increased_size[rep_sid], + .max_increased_size = sizes.max_increased_size[rep_sid], + }; + rep_results[rep_sid] = run_ad_stack_size_dp_for_representative(graph, row, max_size_at_node_begin); } apply_ad_stack_dp_results(idx.stacks, sizes.stack_active, groups.stack_to_rep, rep_results); From 702cb39a511727b850829749a1126757f69380b3 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 13:17:28 -0700 Subject: [PATCH 17/18] [CFG] Add structural invariant asserts at every whole-graph driver entry New private `ControlFlowGraph::assert_structural_invariants()` checks: nodes non-empty, start_node and final_node in range, and both endpoints actually allocated. Called from each whole-graph driver: - reaching_definition_analysis - live_variable_analysis - store_to_load_forwarding - dead_store_elimination - gather_loaded_snodes - simplify_graph - unreachable_code_elimination - dump_graph_to_file - determine_ad_stack_size These invariants are assumed by every helper in the file (worklist seeding via `nodes[start_node]`, dump loop indexing by start_node / final_node, DP scratch buffer indexed by start_node, ...). A violation used to segfault or silently corrupt; now it asserts at the boundary. Cross-pass ordering checks (e.g. "did you run RD before S2L?") are not added because the drivers self-bootstrap: `store_to_load_forwarding` calls `reaching_definition_analysis()` internally; `dead_store_elimination` calls `live_variable_analysis()` internally. The ordering is structural, not a per-call precondition. --- quadrants/ir/control_flow_graph.cpp | 21 +++++++++++++++++++++ quadrants/ir/control_flow_graph.h | 6 ++++++ quadrants/ir/determine_ad_stack_size.cpp | 1 + 3 files changed, 28 insertions(+) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index ae5c785d37..5f2ddd01c6 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -1006,6 +1006,19 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { // a flat per-node loop for S2L/DSE). // =========================================================================== +void ControlFlowGraph::assert_structural_invariants() const { + // These invariants are established by `analysis::build_cfg` and preserved by every + // ControlFlowGraph method below. They are also assumed by every helper in this file (e.g. the + // worklist seeding `nodes[start_node]->reach_gen.insert(...)`, the DP scratch buffer indexed by + // start_node, the dump-graph loop that skips start_node / final_node by index). If they ever + // fail, the code following will segfault or silently corrupt -- catch it here instead. + QD_ASSERT_INFO(!nodes.empty(), "ControlFlowGraph has no nodes"); + QD_ASSERT_INFO(start_node >= 0 && start_node < (int)nodes.size(), "start_node out of range"); + QD_ASSERT_INFO(final_node >= 0 && final_node < (int)nodes.size(), "final_node out of range"); + QD_ASSERT_INFO(nodes[start_node] != nullptr, "start_node entry is null"); + QD_ASSERT_INFO(nodes[final_node] != nullptr, "final_node entry is null"); +} + void ControlFlowGraph::erase(int node_id) { // Erase an empty node. QD_ASSERT(node_id >= 0 && node_id < (int)size()); @@ -1116,6 +1129,7 @@ void write_cfg_node_statements(std::ostream &out, const CFGNode *node) { void ControlFlowGraph::dump_graph_to_file(const CompileConfig &config, const std::string &kernel_name, const std::string &suffix) const { + assert_structural_invariants(); const std::filesystem::path ir_dump_dir = config.debug_dump_path; std::filesystem::create_directories(ir_dump_dir); const std::filesystem::path filename = ir_dump_dir / (kernel_name + "_CFG" + suffix + ".txt"); @@ -1219,6 +1233,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { // - reach_in: union of reach_out of all predecessor nodes. // - reach_out: reach_gen + { stmts from reach_in whose dest is not in reach_kill }. QD_AUTO_PROF; + assert_structural_invariants(); const int num_nodes = size(); QD_ASSERT(nodes[start_node]->empty()); seed_start_node_reach_gen(nodes[start_node].get(), nodes, after_lower_access); @@ -1304,6 +1319,7 @@ void ControlFlowGraph::live_variable_analysis(bool after_lower_access, // - live_in: live_gen + (live_out - live_kill). // - live_out: union of live_in of all successor nodes. QD_AUTO_PROF; + assert_structural_invariants(); const int num_nodes = size(); QD_ASSERT(nodes[final_node]->empty()); seed_final_node_live_gen(nodes[final_node].get(), nodes, after_lower_access, config_opt); @@ -1352,6 +1368,7 @@ void ControlFlowGraph::live_variable_analysis(bool after_lower_access, void ControlFlowGraph::simplify_graph() { // Simplify the graph structure, do not modify the IR. + assert_structural_invariants(); const int num_nodes = size(); while (true) { bool modified = false; @@ -1386,6 +1403,7 @@ bool ControlFlowGraph::unreachable_code_elimination() { // Note that container statements are not in the control-flow graph, so // this pass cannot eliminate container statements properly for now. QD_AUTO_PROF; + assert_structural_invariants(); std::unordered_set visited; std::queue to_visit; to_visit.push(nodes[start_node].get()); @@ -1428,6 +1446,7 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, bool au // This is done in CFGNode::store_to_load_forwarding() of each node QD_AUTO_PROF; + assert_structural_invariants(); reaching_definition_analysis(after_lower_access); const int num_nodes = size(); bool modified = false; @@ -1441,6 +1460,7 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, bool au bool ControlFlowGraph::dead_store_elimination(bool after_lower_access, const std::optional &lva_config_opt) { QD_AUTO_PROF; + assert_structural_invariants(); live_variable_analysis(after_lower_access, lva_config_opt); const int num_nodes = size(); bool modified = false; @@ -1453,6 +1473,7 @@ bool ControlFlowGraph::dead_store_elimination(bool after_lower_access, std::unordered_set ControlFlowGraph::gather_loaded_snodes() { QD_AUTO_PROF; + assert_structural_invariants(); reaching_definition_analysis(/*after_lower_access=*/false); const int num_nodes = size(); std::unordered_set snodes; diff --git a/quadrants/ir/control_flow_graph.h b/quadrants/ir/control_flow_graph.h index fbe5bf5381..a6834d9103 100644 --- a/quadrants/ir/control_flow_graph.h +++ b/quadrants/ir/control_flow_graph.h @@ -186,6 +186,12 @@ class ControlFlowGraph { // Erase an empty node. void erase(int node_id); + // Assert structural invariants that every whole-graph driver assumes: `nodes` non-empty, + // `start_node` and `final_node` in range, both endpoints actually allocated. Called from each + // public driver. Cheap (a handful of comparisons); leave on even in release builds since the + // alternative on violation is silent corruption. + void assert_structural_invariants() const; + public: struct LiveVarAnalysisConfig { // This is mostly useful for SFG task-level dead store elimination. SFG may diff --git a/quadrants/ir/determine_ad_stack_size.cpp b/quadrants/ir/determine_ad_stack_size.cpp index 5e3c0c7f32..ca416b1a3a 100644 --- a/quadrants/ir/determine_ad_stack_size.cpp +++ b/quadrants/ir/determine_ad_stack_size.cpp @@ -619,6 +619,7 @@ void ControlFlowGraph::determine_ad_stack_size() { * `vector>` indexed by contiguous int stack id, not `unordered_map`, to keep the * hot inner loop branch-and-cache friendly. */ + assert_structural_invariants(); AdStackIndex idx = collect_adaptive_ad_stacks(nodes); const int num_stacks = static_cast(idx.stacks.size()); if (num_stacks == 0) { From a5e95bda561875e77fa6bb6b7eaadf0bab0a00fa Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Fri, 15 May 2026 13:19:22 -0700 Subject: [PATCH 18/18] [CFG] Tag 5 semantic-preservation trap branches with [DO NOT REMOVE] Five branches whose behavior is load-bearing for compiler correctness but where the reason "looks like dead code / obvious cleanup" to a casual reader (human or AI). Each now opens with an explicit '[SEMANTIC TRAP -- DO NOT REMOVE ...]' header, names the legacy contract being preserved, and (where I know it) explains the failure mode if the branch is changed. 1. try_eliminate_dead_store_at: the dead-but-unweakable atomic case returns false without updating killed/live state. (A previous refactor broke exactly this; tests do not cover it directly.) 2. try_forward_load_at: replace-with-zero does NOT flip `modified`. Preserved verbatim from the pre-refactor code; reasoning not captured in the original, comment now flags that any change needs a benchmark. 3. try_eliminate_identical_store_at: the alloca-init-0 fast path is gated on !autodiff_enabled and only fires for const-0 stores. Both are load- bearing under autodiff push/pop semantics. 4. find_intra_block_last_def: the is_quant() guard skips forwarding of quant-typed stores (implicit cast may truncate, so the forwarded value may differ from what would be loaded back). 5. any_aliased_store_breaks_forwarding: the non-tensor-alloca early-out skips the aliasing scan; without it, may_contain_address would mis- report aliasing for unrelated stores. No behavior change. --- quadrants/ir/control_flow_graph.cpp | 38 +++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index 5f2ddd01c6..97ed176033 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -332,6 +332,7 @@ int CFGNode::find_intra_block_last_def(Stmt *var, int position) const { // Find previous store stmt to the same dest_addr, stop at the closest one. // store_ptr: prev-store dest_addr for (auto *store_ptr : irpass::analysis::get_store_destination(block->statements[i].get())) { + // [SEMANTIC TRAP -- DO NOT REMOVE the is_quant() guard below] // Exclude `store_ptr` as a potential store destination due to mixed // semantics of store statements for quant types. The store operation // involves implicit casting before storing, which may result in a loss of @@ -399,7 +400,11 @@ std::optional CFGNode::find_cross_block_def(Stmt *var, } bool CFGNode::any_aliased_store_breaks_forwarding(Stmt *result, Stmt *var, int from, int to_exclusive) const { - // Allocas without tensor type cannot be aliased through MatrixPtrStmt, so the check is moot. + // [SEMANTIC TRAP -- DO NOT REMOVE the non-tensor-alloca early-out] + // Allocas without tensor type cannot be aliased through MatrixPtrStmt (the only derived-pointer + // form that creates aliasing in this IR). Without this early-out, the loop below would still + // run and `may_contain_address(non_tensor_alloca, store_ptr)` would mis-report aliasing for + // unrelated stores in the same block, falsely aborting forwardings that the legacy code allowed. const bool is_tensor_involved = var->ret_type.ptr_removed()->is(); if (var->is() && !is_tensor_involved) { return false; @@ -505,8 +510,13 @@ bool CFGNode::try_forward_load_at(int &i, Stmt *stmt, bool after_lower_access, b if (result->ret_type.ptr_removed()->is()) { return true; } - // Alloca initialized to 0: replace the load with a zero const. - // Note: |modified| is intentionally NOT flipped here (preserved from legacy behavior). + // [SEMANTIC TRAP -- DO NOT "fix" by setting modified = true here] + // Alloca initialized to 0 (default): replace the load with a zero const. `modified` is + // intentionally NOT flipped -- preserved verbatim from the pre-refactor implementation. The + // exact reason was not captured in the original code; the safe assumption is that some + // upstream caller relies on `modified` strictly tracking "could a further S2L iteration find + // more work" rather than "IR changed", and flipping it here may regress that. If a future + // change needs to flip it, do it behind a regression-tested benchmark. auto zero = Stmt::make(TypedConstant(result->ret_type.ptr_removed(), 0)); replace_with(i, std::move(zero), true); return true; @@ -537,9 +547,15 @@ void CFGNode::try_eliminate_identical_store_at(int &i, if (auto *local_store = stmt->cast()) { Stmt *result = find_forwardable_store_value(local_store->dest, i); if (result && result->is() && !autodiff_enabled) { - // TensorType does not apply to this special case. + // [SEMANTIC TRAP -- DO NOT REMOVE the !autodiff_enabled gate, or generalize past const 0] + // The "default-initialized alloca implicitly holds 0" rewrite only fires under non-autodiff. + // Under autodiff, AdStack/AdStackPush semantics depend on every primal write being observed + // by the recorder; treating a store-zero-to-fresh-alloca as redundant breaks the push + // ordering in the backward pass. The const-zero check is also load-bearing: it's the only + // value we know matches the default-init contract; other constants would need a separate + // proof that the alloca has not been written yet on every path reaching here. if (result->ret_type.ptr_removed()->is()) { - return; + return; // TensorType does not apply to this special case. } if (auto *stored = local_store->val->cast()) { if (stored->val.equal_value(0)) { @@ -900,9 +916,15 @@ bool try_eliminate_dead_store_at(CFGNode *node, node->replace_with(i, std::move(global_load), true); return true; } - // Atomic was dead but not safely weakable (parallel global, non-scalar). The original code - // intentionally leaves state alone in this case (the atomic still executes, but state is not - // updated for it). Preserve that behavior. + // [SEMANTIC TRAP -- DO NOT REMOVE] + // Atomic was dead but not safely weakable (parallel global, non-scalar). + // The legacy code intentionally returns *without* updating + // `killed_in_this_node` / `live_in_this_node` here -- the atomic still executes at runtime, so + // its dest is neither newly-killed nor a fresh live entry. Falling through into the + // state-update block below would mark this address as killed-by-an-erased-store, which it + // isn't, and silently corrupt downstream DSE decisions on aliasing addresses. + // A previous refactor of this function broke exactly this branch; the existing AD/atomic test + // suites do not cover it directly. Touch with care. return false; } // Non-eliminated store (not dead, or stmt not eliminable): update state. Insert into killed,