diff --git a/quadrants/ir/control_flow_graph.cpp b/quadrants/ir/control_flow_graph.cpp index d41cd13997..97ed176033 100644 --- a/quadrants/ir/control_flow_graph.cpp +++ b/quadrants/ir/control_flow_graph.cpp @@ -14,8 +14,136 @@ #include "quadrants/program/function.h" #include "quadrants/codegen/ir_dump.h" +// =========================================================================== +// Control-flow graph: analyses and transforms. +// +// This file implements two classes (declared in control_flow_graph.h): +// - CFGNode: a single basic block in the CFG. Holds a slice of an +// IR Block (begin_location .. end_location) plus prev / +// next edges and the per-node sets used by each +// analysis. Its methods do the *intra-block* work +// (compute reach_gen / reach_kill in this block, run +// store-to-load forwarding within this block, etc.). +// - ControlFlowGraph: the whole graph. Owns the vector and the +// synthetic entry / exit nodes. Its methods are +// whole-graph *drivers*: they call the per-node method +// on every node, then run the worklist fixpoint or +// post-processing. +// +// Five analyses / transforms live here. For each, the per-node method is on +// CFGNode and the whole-graph driver is the same-named method on +// ControlFlowGraph: +// 1. reaching_definition_analysis (RD): cross-block use-define chain. Output +// in each node: reach_in, reach_out (sets of Stmt*). +// Used by store-to-load forwarding. +// 2. live_variable_analysis (LV): which addresses are still loaded +// after this point? Output: live_in, live_out (sets of +// Stmt*). Used by dead-store elimination. +// 3. store_to_load_forwarding (S2L): replace each load with the value +// of the closest preceding store to the same address; +// also erase identical stores. IR-mutating. +// 4. dead_store_elimination (DSE): erase stores whose written value +// is never read, weaken atomics whose store half is +// dead, eliminate identical loads. IR-mutating. +// 5. determine_ad_stack_size: size each adaptive AD-stack (max_size==0) +// from its push/pop schedule across the CFG. Whole- +// graph only; no per-node counterpart. +// +// Glossary (these terms appear throughout; their formal definitions are +// scattered across analysis.h / ir.h / statements.h, so collected here): +// - reach_gen[node]: stmts that define some variable in `node` (store +// stmts). Plus, in the synthetic entry node, stmts +// whose value "exists" before this kernel runs +// (external pointers, FuncCall store destinations). +// - reach_kill[node]: addresses (GlobalPtrStmt / AllocaStmt / ...) that +// `node` definitely overwrites. (gen tracks the *stmts* +// that write; kill tracks the *addresses* written-to.) +// - reach_in[node]: union of reach_out of predecessors. +// - reach_out[node]: reach_gen[node] union (reach_in[node] minus killed). +// - live_gen[node]: addresses loaded in `node` with no preceding store +// in the same node (so they must live in from outside). +// - live_kill[node]: addresses stored to in `node`. +// - live_in[node]: live_gen[node] union (live_out[node] minus live_kill). +// - live_out[node]: union of live_in of successors. +// - UD-chain: "use-define chain". For a load (use), the set of +// stores (defs) whose value may flow to it. Computed +// via the RD analysis. +// - Adaptive AD-stack: an `AdStackAllocaStmt` with `max_size == 0`, meaning +// its size has not yet been fixed and must be inferred +// from the kernel's push/pop schedule. `max_size != 0` +// stacks are already sized and are skipped here. +// - stmt_refs: alias for `quadrants::one_or_more` (see +// common/one_or_more.h). A variant that holds either a +// single Stmt* or a vector; supports `.size()`, +// `.empty()`, non-const `.begin()/.end()`. Used to +// avoid heap allocation in the common "exactly one +// destination" case. +// - MatrixPtrStmt: a derived pointer of the form `&base[offset]` where +// `base` is either an `AllocaStmt` (tensor-typed local) +// or another `MatrixPtrStmt` (nested element). The +// analyses treat these as aliasing their origin: a +// store through the derived pointer partially defines +// the origin; a store to the origin kills all +// derived-pointer addresses. +// =========================================================================== + namespace quadrants::lang { +namespace { + +// Does |store_stmt| store to an address that may alias |var|? +// Matrix-pointer aliasing is handled both ways: a MatrixPtrStmt aliases its +// underlying origin, and vice-versa. +bool may_contain_address(Stmt *store_stmt, Stmt *var) { + for (auto store_ptr : irpass::analysis::get_store_destination(store_stmt)) { + if (var->is() && !store_ptr->is()) { + // check for aliased address with var + if (irpass::analysis::maybe_same_address(var->as()->origin, store_ptr)) { + return true; + } + } + + if (!var->is() && store_ptr->is()) { + // check for aliased address with store_ptr + if (irpass::analysis::maybe_same_address(store_ptr->as()->origin, var)) { + return true; + } + } + + if (irpass::analysis::maybe_same_address(var, store_ptr)) { + return true; + } + } + return false; +} + +// Should |stmt| appear in the synthetic `live_gen` of the final CFG node? +// Locals (allocas, matrix-pointers into allocas) are never live past the +// kernel boundary. Global pointers are live unless SFG has marked their +// SNode as eliminable in |config_opt|. +bool in_final_node_live_gen(const Stmt *stmt, + const std::optional &config_opt) { + if (stmt->is() || stmt->is()) { + return false; + } + if (stmt->is() && stmt->cast()->origin->is()) { + return false; + } + if (auto *gptr = stmt->cast(); gptr && config_opt.has_value()) { + return config_opt->eliminable_snodes.count(gptr->snode) == 0; + } + // A global pointer that may be loaded after this kernel. + return true; +} + +} // namespace + +// =========================================================================== +// CFGNode -- per-block (intra-block) methods. +// Each method here does the *per-node* work; the whole-graph driver with the +// same name lives on ControlFlowGraph further down this file. +// =========================================================================== + CFGNode::CFGNode(Block *block, int begin_location, int end_location, @@ -31,7 +159,7 @@ CFGNode::CFGNode(Block *block, prev_node_in_same_block->next_node_in_same_block = this; if (!empty()) { // For non-empty nodes, precompute |parent_blocks| to accelerate - // get_store_forwarding_data(). + // find_forwardable_store_value(). QD_ASSERT(begin_location >= 0); QD_ASSERT(block); auto parent_block = block; @@ -91,8 +219,12 @@ bool CFGNode::contain_variable(const std::unordered_set &var_set, Stmt * // TODO: How to optimize this? if (var_set.find(var) != var_set.end()) return true; - return std::any_of(var_set.begin(), var_set.end(), - [&](Stmt *set_var) { return irpass::analysis::definitely_same_address(var, set_var); }); + for (Stmt *set_var : var_set) { + if (irpass::analysis::definitely_same_address(var, set_var)) { + return true; + } + } + return false; } } @@ -107,12 +239,12 @@ bool CFGNode::contain_variable(const std::unordered_map &var_set, St // TODO: How to optimize this? if (var_set.find(var) != var_set.end()) return true; - return std::any_of(var_set.begin(), var_set.end(), - [&](Stmt *set_var) { return irpass::analysis::maybe_same_address(var, set_var); }); + for (Stmt *set_var : var_set) { + if (irpass::analysis::maybe_same_address(var, set_var)) { + return true; + } + } + return false; } } -bool CFGNode::reach_kill_variable(Stmt *var) const { +bool CFGNode::is_reach_killed(Stmt *var) const { // Does this node (definitely) kill a definition of var? return contain_variable(reach_kill, var); } -// var: dest_addr -Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { - // Return the stored data if all definitions in the UD-chain of |var| at - // this position store the same data. - // [Intra-block Search] - int last_def_position = -1; +bool CFGNode::is_visible_at(Stmt *stmt, int position) const { + // Check if |stmt| is before |position| here. + if (stmt->parent == block) { + return stmt->parent->locate(stmt) < position; + } + // |parent_blocks_| is precomputed in the constructor of CFGNode. + // TODO: What if |stmt| appears in an ancestor of |block| but after + // |position|? + return parent_blocks_.find(stmt->parent) != parent_blocks_.end(); +} + +bool CFGNode::fold_definition_into_result(Stmt *stmt, + int position, + Stmt *&result, + bool &result_visible) const { + // |stmt| is a definition in the UD-chain of the variable being forwarded. + // Fold its stored data into |result| / |result_visible|. Return false if + // forwarding must abort (the caller should propagate nullptr); true to + // continue scanning. + auto data = irpass::analysis::get_store_data(stmt); + if (!data) { // not forwardable + return false; + } + if (!result) { + result = data; + result_visible = is_visible_at(data, position); + return true; + } + if (!irpass::analysis::same_value(result, data)) { + // check the special case of alloca (initialized to 0) + if (!(result->is() && data->is() && data->as()->val.equal_value(0))) { + return false; + } + } + if (!result_visible && is_visible_at(data, position)) { + // pick the visible one for store-to-load forwarding + result = data; + result_visible = true; + } + return true; +} + +int CFGNode::find_intra_block_last_def(Stmt *var, int position) const { for (int i = position - 1; i >= begin_location; i--) { // Find previous store stmt to the same dest_addr, stop at the closest one. // store_ptr: prev-store dest_addr - for (auto store_ptr : irpass::analysis::get_store_destination(block->statements[i].get())) { + for (auto *store_ptr : irpass::analysis::get_store_destination(block->statements[i].get())) { + // [SEMANTIC TRAP -- DO NOT REMOVE the is_quant() guard below] // Exclude `store_ptr` as a potential store destination due to mixed // semantics of store statements for quant types. The store operation // involves implicit casting before storing, which may result in a loss of @@ -170,8 +348,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // TODO: Still forward the store if the value can be statically proven to // fit into the quant type. if (!is_quant(store_ptr->ret_type.ptr_removed()) && irpass::analysis::definitely_same_address(var, store_ptr)) { - last_def_position = i; - break; + return i; } // Special case: @@ -181,171 +358,107 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // $3 = load $2 // We can forward MatrixInitStmt->values[offset] to $3 if (var->is() && var->as()->offset->is()) { - auto var_origin = var->as()->origin; - // Check for same origin address + auto *var_origin = var->as()->origin; if (irpass::analysis::definitely_same_address(var_origin, store_ptr)) { - // Check for MatrixInitStmt Stmt *store_data = irpass::analysis::get_store_data(block->statements[i].get()); if (store_data->is()) { - last_def_position = i; - break; + return i; } } } } - if (last_def_position != -1) { - break; - } } + return -1; +} - // Check if store_stmt will ever influence the value of var - auto may_contain_address = [&](Stmt *store_stmt, Stmt *var) { - for (auto store_ptr : irpass::analysis::get_store_destination(store_stmt)) { - if (var->is() && !store_ptr->is()) { - // check for aliased address with var - if (irpass::analysis::maybe_same_address(var->as()->origin, store_ptr)) { - return true; - } +std::optional CFGNode::find_cross_block_def(Stmt *var, + int position, + Stmt *&result, + bool &result_visible) const { + int last_def_position = -1; + // [Global Addr only] Stores reaching the entry of this node from previous blocks. `var == stmt` + // is for the case that a global ptr is never stored; in that case `stmt` comes from + // nodes[start_node]->reach_gen. + for (auto *stmt : reach_in) { + if (var == stmt || may_contain_address(stmt, var)) { + if (!fold_definition_into_result(stmt, position, result, result_visible)) { + return std::nullopt; } - - if (!var->is() && store_ptr->is()) { - // check for aliased address with store_ptr - if (irpass::analysis::maybe_same_address(store_ptr->as()->origin, var)) { - return true; - } + last_def_position = 0; + } + } + // Stores generated within this node (in reach_gen) that precede |position|. + for (auto *stmt : reach_gen) { + if (may_contain_address(stmt, var) && stmt->parent->locate(stmt) < position) { + if (!fold_definition_into_result(stmt, position, result, result_visible)) { + return std::nullopt; } + last_def_position = stmt->parent->locate(stmt); + } + } + return last_def_position; +} - if (irpass::analysis::maybe_same_address(var, store_ptr)) { +bool CFGNode::any_aliased_store_breaks_forwarding(Stmt *result, Stmt *var, int from, int to_exclusive) const { + // [SEMANTIC TRAP -- DO NOT REMOVE the non-tensor-alloca early-out] + // Allocas without tensor type cannot be aliased through MatrixPtrStmt (the only derived-pointer + // form that creates aliasing in this IR). Without this early-out, the loop below would still + // run and `may_contain_address(non_tensor_alloca, store_ptr)` would mis-report aliasing for + // unrelated stores in the same block, falsely aborting forwardings that the legacy code allowed. + const bool is_tensor_involved = var->ret_type.ptr_removed()->is(); + if (var->is() && !is_tensor_involved) { + return false; + } + for (int i = from; i < to_exclusive; i++) { + auto *s = block->statements[i].get(); + if (!irpass::analysis::same_value(result, irpass::analysis::get_store_data(s))) { + if (may_contain_address(s, var)) { return true; } } - return false; - }; - - // Check for aliased address - // There's a store to the same dest_addr before this stmt - if (last_def_position != -1) { - // result: the value to store - Stmt *result = irpass::analysis::get_store_data(block->statements[last_def_position].get()); - bool is_tensor_involved = var->ret_type.ptr_removed()->is(); - if (!(var->is() && !is_tensor_involved)) { - // In between the store stmt and current stmt, - // if there's a third-stmt that "may" have stored a "different value" to - // the "same dest_addr", then we can't forward the stored data. - for (int i = last_def_position + 1; i < position; i++) { - if (!irpass::analysis::same_value(result, irpass::analysis::get_store_data(block->statements[i].get()))) { - if (may_contain_address(block->statements[i].get(), var)) { - return nullptr; - } - } - } + } + return false; +} + +// var: dest_addr. Return the stored data if all definitions in the UD-chain of |var| at this +// position store the same data; otherwise nullptr. Looks intra-block first (cheaper, dominates +// cross-block forwarding when present), then falls back to the cross-block search over reach_in +// and reach_gen. In both cases an intervening aliased store that may write a different value +// breaks the forward. +Stmt *CFGNode::find_forwardable_store_value(Stmt *var, int position) const { + // [Intra-block search] Walks backwards in this node's block. + if (int last_def = find_intra_block_last_def(var, position); last_def != -1) { + Stmt *result = irpass::analysis::get_store_data(block->statements[last_def].get()); + if (any_aliased_store_breaks_forwarding(result, var, last_def + 1, position)) { + return nullptr; } return result; } - // [Cross-block search] - // Search for store to the same dest_addr in reach_in and reach_gen + // [Cross-block search] Walks reach_in / reach_gen, accumulating a single forwardable result. Stmt *result = nullptr; bool result_visible = false; - auto visible = [&](Stmt *stmt) { - // Check if |stmt| is before |position| here. - if (stmt->parent == block) { - return stmt->parent->locate(stmt) < position; - } - // |parent_blocks| is precomputed in the constructor of CFGNode. - // TODO: What if |stmt| appears in an ancestor of |block| but after - // |position|? - return parent_blocks_.find(stmt->parent) != parent_blocks_.end(); - }; - /** - * |stmt| is a definition in the UD-chain of |var|. Update |result| with - * |stmt|. If either the stored data of |stmt| is not forwardable or the - * stored data of |stmt| is not definitely the same as other definitions of - * |var|, return false to show that there is no store-to-load forwardable - * data. - */ - auto update_result = [&](Stmt *stmt) { - auto data = irpass::analysis::get_store_data(stmt); - if (!data) { // not forwardable - return false; // return nullptr - } - if (!result) { - result = data; - result_visible = visible(data); - return true; // continue the following loops - } - if (!irpass::analysis::same_value(result, data)) { - // check the special case of alloca (initialized to 0) - if (!(result->is() && data->is() && data->as()->val.equal_value(0))) { - return false; // return nullptr - } - } - if (!result_visible && visible(data)) { - // pick the visible one for store-to-load forwarding - result = data; - result_visible = true; - } - return true; // continue the following loops - }; - - // [Global Addr only] - // test whether there's a store to the same dest_addr in a previous block. - // if the store values are the same, then return the value - last_def_position = -1; - for (auto stmt : reach_in) { - // var == stmt is for the case that a global ptr is never stored. - // In this case, stmt is from nodes[start_node]->reach_gen. - if (var == stmt || may_contain_address(stmt, var)) { - if (!update_result(stmt)) - return nullptr; - else - last_def_position = 0; - } - } - - // test whether there's a store to the same dest_addr before this stmt (in - // reach_gen) - // if the store values are the same, then return the value - for (auto stmt : reach_gen) { - if (may_contain_address(stmt, var) && stmt->parent->locate(stmt) < position) { - if (!update_result(stmt)) - return nullptr; - else - last_def_position = stmt->parent->locate(stmt); - } + std::optional last_def = find_cross_block_def(var, position, result, result_visible); + if (!last_def.has_value()) { + return nullptr; } if (!result) { // The UD-chain is empty. - auto offending_load = block->statements[position].get(); + auto *offending_load = block->statements[position].get(); ErrorEmitter(QuadrantsIrWarning(), offending_load, fmt::format("Loading variable {} before anything is stored to it.", var->id)); return nullptr; } if (!result_visible) { - // The data is store-to-load forwardable but not visible at the place we - // are going to forward. We cannot forward it in this case. + // The data is store-to-load forwardable but not visible at the place we are going to forward. return nullptr; } - - if (last_def_position == -1) + if (*last_def == -1) { + return nullptr; + } + if (any_aliased_store_breaks_forwarding(result, var, *last_def, position)) { return nullptr; - - // Check for aliased address - // There's a store to the same dest_addr before this stmt - bool is_tensor_involved = var->ret_type.ptr_removed()->is(); - if (!(var->is() && !is_tensor_involved)) { - // In between the store stmt and current stmt, - // if there's a third-stmt that "may" have stored a "different value" to - // the "same dest_addr", then we can't forward the stored data. - for (int i = last_def_position; i < position; i++) { - if (!irpass::analysis::same_value(result, irpass::analysis::get_store_data(block->statements[i].get()))) { - if (may_contain_address(block->statements[i].get(), var)) { - return nullptr; - } - } - } } - return result; } @@ -367,7 +480,7 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { // After lower_access, we only analyze local variables. continue; } - if (!reach_kill_variable(data_source_ptr)) { + if (!is_reach_killed(data_source_ptr)) { reach_gen.insert(stmt); reach_kill.insert(data_source_ptr); } @@ -375,109 +488,120 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { } } -bool CFGNode::store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled) { - // Contains two separate parts: - // 1. Store-to-load Forwarding: for each load stmt, find the closest previous - // store stmt - // that stores to the same address as the load stmt, then replace - // load with the "val". - // 2. Identical Store Elimination: for each store stmt, find the closest - // previous store stmt - // that stores to the same address as the store stmt. If the "val"s - // are the same, then remove the store stmt. - bool modified = false; - for (int i = begin_location; i < end_location; i++) { - // Store-to-load forwarding - auto stmt = block->statements[i].get(); - - // result: the value to be store/load - Stmt *result = nullptr; - - // [get_store_forwarding_data] find the store stmt that: - // 1. stores to the same address and as the load stmt - // 2. (one value at a time) closest to the load stmt but before the load - // stmt - Stmt *load_src = nullptr; - if (auto local_load = stmt->cast()) { - result = get_store_forwarding_data(local_load->src, i); - load_src = local_load->src; - } else if (auto global_load = stmt->cast()) { - if (!after_lower_access && !autodiff_enabled) { - result = get_store_forwarding_data(global_load->src, i); - load_src = global_load->src; - } +bool CFGNode::try_forward_load_at(int &i, Stmt *stmt, bool after_lower_access, bool autodiff_enabled, bool &modified) { + // Find the closest preceding store-to-same-address. For GlobalLoadStmt the forwarder is gated: + // after lower_access or under autodiff we don't trust cross-block forwarding to be sound. + Stmt *load_src = nullptr; + Stmt *result = nullptr; + if (auto *local_load = stmt->cast()) { + load_src = local_load->src; + result = find_forwardable_store_value(load_src, i); + } else if (auto *global_load = stmt->cast()) { + if (!after_lower_access && !autodiff_enabled) { + load_src = global_load->src; + result = find_forwardable_store_value(load_src, i); } - - // [Apply Load-Store-Forwarding] - // replace load stmt with the value-"result" - if (result) { - // Forward the stored data |result|. - if (result->is()) { - // TensorType does not apply to this special case - if (result->ret_type.ptr_removed()->is()) - continue; - - // special case of alloca (initialized to 0) - auto zero = Stmt::make(TypedConstant(result->ret_type.ptr_removed(), 0)); - replace_with(i, std::move(zero), true); - } else { - if (result->ret_type.ptr_removed()->is() && !stmt->ret_type->is()) { - QD_ASSERT(load_src->is() && load_src->as()->offset->is()); - QD_ASSERT(result->is()); - - int offset = load_src->as()->offset->as()->val.val_int32(); - - result = result->as()->values[offset]; - } - - stmt->replace_usages_with(result); - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop - modified = true; - } - continue; + } + if (!result) { + return false; + } + if (result->is()) { + // TensorType does not apply to this special case; skip further handling for this stmt. + if (result->ret_type.ptr_removed()->is()) { + return true; } + // [SEMANTIC TRAP -- DO NOT "fix" by setting modified = true here] + // Alloca initialized to 0 (default): replace the load with a zero const. `modified` is + // intentionally NOT flipped -- preserved verbatim from the pre-refactor implementation. The + // exact reason was not captured in the original code; the safe assumption is that some + // upstream caller relies on `modified` strictly tracking "could a further S2L iteration find + // more work" rather than "IR changed", and flipping it here may regress that. If a future + // change needs to flip it, do it behind a regression-tested benchmark. + auto zero = Stmt::make(TypedConstant(result->ret_type.ptr_removed(), 0)); + replace_with(i, std::move(zero), true); + return true; + } + // Non-alloca result: forward it. Extract a MatrixInitStmt element when the forwarded data is + // a TensorType-typed init but the load is for a scalar slot. + if (result->ret_type.ptr_removed()->is() && !stmt->ret_type->is()) { + QD_ASSERT(load_src->is() && load_src->as()->offset->is()); + QD_ASSERT(result->is()); + const int offset = load_src->as()->offset->as()->val.val_int32(); + result = result->as()->values[offset]; + } + stmt->replace_usages_with(result); + erase(i); // end_location-- + i--; // cancel the for-loop's i++ + modified = true; + return true; +} - // [Identical store elimination] - // find the store stmt that: - // 1. stores to the same address as the current store stmt - // 2. has the same store value as the current store stmt - // 3. (one value at a time) closest to the current store stmt but before the - // current store stmt then erase the current store stmt - if (auto local_store = stmt->cast()) { - result = get_store_forwarding_data(local_store->dest, i); - if (result && result->is() && !autodiff_enabled) { - // TensorType does not apply to this special case - if (result->ret_type.ptr_removed()->is()) { - continue; - } - - // special case of alloca (initialized to 0) - if (auto stored_data = local_store->val->cast()) { - if (stored_data->val.equal_value(0)) { - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop - modified = true; - } - } - } else { - // not alloca - if (irpass::analysis::same_value(result, local_store->val)) { - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop - modified = true; - } +void CFGNode::try_eliminate_identical_store_at(int &i, + Stmt *stmt, + bool after_lower_access, + bool autodiff_enabled, + bool &modified) { + // Eliminate a store whose value is provably identical to the closest preceding store to the + // same address. For local stores under non-autodiff there's also an alloca-zero special case: + // writing a zero to a freshly-allocated alloca is redundant. + if (auto *local_store = stmt->cast()) { + Stmt *result = find_forwardable_store_value(local_store->dest, i); + if (result && result->is() && !autodiff_enabled) { + // [SEMANTIC TRAP -- DO NOT REMOVE the !autodiff_enabled gate, or generalize past const 0] + // The "default-initialized alloca implicitly holds 0" rewrite only fires under non-autodiff. + // Under autodiff, AdStack/AdStackPush semantics depend on every primal write being observed + // by the recorder; treating a store-zero-to-fresh-alloca as redundant breaks the push + // ordering in the backward pass. The const-zero check is also load-bearing: it's the only + // value we know matches the default-init contract; other constants would need a separate + // proof that the alloca has not been written yet on every path reaching here. + if (result->ret_type.ptr_removed()->is()) { + return; // TensorType does not apply to this special case. } - } else if (auto global_store = stmt->cast()) { - if (!after_lower_access) { - result = get_store_forwarding_data(global_store->dest, i); - if (irpass::analysis::same_value(result, global_store->val)) { - erase(i); // This causes end_location-- - i--; // to cancel i++ in the for loop + if (auto *stored = local_store->val->cast()) { + if (stored->val.equal_value(0)) { + erase(i); + i--; modified = true; } } + return; + } + if (irpass::analysis::same_value(result, local_store->val)) { + erase(i); + i--; + modified = true; + } + return; + } + if (auto *global_store = stmt->cast()) { + if (after_lower_access) { + return; + } + Stmt *result = find_forwardable_store_value(global_store->dest, i); + if (irpass::analysis::same_value(result, global_store->val)) { + erase(i); + i--; + modified = true; + } + } +} + +bool CFGNode::store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled) { + // Two passes fused into one forward walk: + // 1. Store-to-load forwarding: replace a load with the value from the closest preceding store + // to the same address. + // 2. Identical-store elimination: erase a store whose value is identical to the closest + // preceding store to the same address. + // The forwarding pass takes precedence: if we forwarded a load, we don't also try to treat the + // same stmt as a store. erase() shrinks end_location, so per-elimination we both `erase(i)` and + // decrement `i` to keep the for-loop pointing at the right next stmt. + bool modified = false; + for (int i = begin_location; i < end_location; i++) { + auto *stmt = block->statements[i].get(); + if (try_forward_load_at(i, stmt, after_lower_access, autodiff_enabled, modified)) { + continue; } + try_eliminate_identical_store_at(i, stmt, after_lower_access, autodiff_enabled, modified); } return modified; } @@ -643,197 +767,280 @@ static void update_container_with_alias( update_aliased_stmts(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, container, key, to_erase); } -bool CFGNode::dead_store_elimination(bool after_lower_access) { - bool modified = false; - // Map a variable to its nearest load - std::unordered_map live_load_in_this_node; +namespace { - // For any stmt with TensorType'd address, the address can be either partially - // or fully stored/loaded, which will eventually influence the - // dead-store-elimination strategy - // - // Here we use CFGNode::UseDefineStatus to mark whether a TensorType'd address - // is fully or partially modified. +// === Helpers for CFGNode::dead_store_elimination === + +// Aliasing tables built once per CFGNode pass over the block, then threaded through every +// container update so that touching a `MatrixPtrStmt` propagates to its tensor origin (and back). +struct DseAliasMaps { + // MatrixPtrStmt->origin -> list of MatrixPtrStmts that share that origin + std::unordered_map> tensor_to_matrix_ptrs; + // MatrixPtrStmt -> its origin + std::unordered_map matrix_ptr_to_tensor; +}; + +// Per-pass state mutated in lockstep during the reverse-order walk. `live_in_this_node` and +// `killed_in_this_node` use `UseDefineStatus` to distinguish full vs partial modification of +// tensor-typed addresses; `live_load_in_this_node` maps a variable to its nearest later load +// for identical-load elimination. +struct DseLiveState { + std::unordered_map live_load_in_this_node; std::unordered_map live_in_this_node; std::unordered_map killed_in_this_node; +}; - // Search for aliased addresses - // tensor_to_matrix_ptrs_map: map MatrixPtrStmt->origin to list of - // MatrixPtrStmts - // matrix_ptr_to_tensor_map: map MatrixPtrStmt to - // MatrixPtrStmt->origin - std::unordered_map> tensor_to_matrix_ptrs_map; - std::unordered_map matrix_ptr_to_tensor_map; +DseAliasMaps build_matrix_ptr_alias_maps(Block *block, int begin_location, int end_location) { + DseAliasMaps alias; for (int i = begin_location; i < end_location; i++) { - auto stmt = block->statements[i].get(); - if (stmt->is()) { - auto origin = stmt->as()->origin; - if (tensor_to_matrix_ptrs_map.count(origin) == 0) { - tensor_to_matrix_ptrs_map[origin] = {stmt}; - } else { - tensor_to_matrix_ptrs_map[origin].push_back(stmt); - } - matrix_ptr_to_tensor_map[stmt] = origin; + auto *stmt = block->statements[i].get(); + if (!stmt->is()) { + continue; } + auto *origin = stmt->as()->origin; + alias.tensor_to_matrix_ptrs[origin].push_back(stmt); + alias.matrix_ptr_to_tensor[stmt] = origin; } + return alias; +} - // Reverse order traversal, starting from the last IR to the first IR - for (int i = end_location - 1; i >= begin_location; i--) { - auto stmt = block->statements[i].get(); - if (stmt->is()) { - killed_in_this_node.clear(); - live_load_in_this_node.clear(); - continue; +// Pointer eligibility predicate, applied uniformly to store ptrs, load ptrs, and live-update load +// ptrs. After `lower_access`, only local variables (allocas) and AD-stacks remain analyzable; +// before it, everything is in scope. +bool is_dse_eligible_pointer(Stmt *ptr, bool after_lower_access) { + if (!after_lower_access) { + return true; + } + if (ptr->is() || ptr->is()) { + return true; + } + if (ptr->is()) { + auto *origin = ptr->as()->origin; + if (origin->is() || origin->is()) { + return true; } - auto store_ptrs = irpass::analysis::get_store_destination(stmt); + } + return false; +} - // TODO: Consider AD-stacks in get_store_destination instead of here - // for store-to-load forwarding on AD-stacks - if (auto stack_pop = stmt->cast()) { - store_ptrs = std::vector(1, stack_pop->stack); - } else if (auto stack_push = stmt->cast()) { - store_ptrs = std::vector(1, stack_push->stack); - } else if (auto stack_acc_adj = stmt->cast()) { - store_ptrs = std::vector(1, stack_acc_adj->stack); - } else if (stmt->is()) { - store_ptrs = std::vector(1, stmt); - } +// Compute the store destinations of |stmt|, including AD-stack stmts whose store semantics +// aren't yet captured by `get_store_destination`. Returns the same `stmt_refs` (= one_or_more<>) +// type the analyzer uses so iteration and size queries are uniform. +// TODO: Consider AD-stacks in get_store_destination instead of here for store-to-load forwarding +// on AD-stacks. +stmt_refs dse_store_destinations(Stmt *stmt) { + if (auto *pop = stmt->cast()) { + return stmt_refs(pop->stack); + } + if (auto *push = stmt->cast()) { + return stmt_refs(push->stack); + } + if (auto *acc = stmt->cast()) { + return stmt_refs(acc->stack); + } + if (stmt->is()) { + return stmt_refs(stmt); + } + return irpass::analysis::get_store_destination(stmt); +} - if (store_ptrs.size() == 1) { - // Dead store elimination - auto store_ptr = *store_ptrs.begin(); +// Is |store_ptr| guaranteed dead at this point in the reverse-order walk? +// - !may_contain_variable(live_in_this_node, store_ptr): not loaded after this store in-node +// - contain_variable(killed_in_this_node, store_ptr): already overwritten in-node, OR +// - !may_contain_variable(live_out, store_ptr): not used in any successor node +bool is_store_dead(Stmt *store_ptr, + const std::unordered_set &live_out, + const DseLiveState &state) { + bool is_used_in_next_nodes = false; + for (auto *ptr : irpass::analysis::include_aliased_stmts(store_ptr)) { + is_used_in_next_nodes |= CFGNode::may_contain_variable(live_out, ptr); + } + const bool is_killed_in_current_node = CFGNode::contain_variable(state.killed_in_this_node, store_ptr); + bool is_dead = is_killed_in_current_node || !is_used_in_next_nodes; + is_dead &= !CFGNode::may_contain_variable(state.live_in_this_node, store_ptr); + return is_dead; +} - if (!after_lower_access || - (store_ptr->is() && store_ptr->as()->origin->is()) || - (store_ptr->is() && store_ptr->as()->origin->is()) || - (store_ptr->is() || store_ptr->is())) { - // !may_contain_variable(live_in_this_node, store_ptr): address is not - // loaded after this store - // contain_variable(killed_in_this_node, store_ptr): address is already - // stored by another store stmt in this node (thus killed) - // !may_contain_variable(live_out, store_ptr): address is not used - // in the next nodes - bool is_used_in_next_nodes = false; - for (auto ptr : irpass::analysis::include_aliased_stmts(store_ptr)) { - is_used_in_next_nodes |= may_contain_variable(live_out, ptr); - } +// On dead-store elimination of an `AtomicOpStmt`, the store part is dropped but the load +// (= the atomic's return value) remains; record the load and mark the dest killed/loaded. +void record_weakened_atomic_as_load(Stmt *dest, Stmt *new_load, const DseAliasMaps &alias, DseLiveState &state) { + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.live_in_this_node, dest, + false); + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.killed_in_this_node, dest, + true); + state.live_load_in_this_node[dest] = new_load; +} - bool is_killed_in_current_node = contain_variable(killed_in_this_node, store_ptr); - bool is_dead = is_killed_in_current_node || !is_used_in_next_nodes; - is_dead &= !may_contain_variable(live_in_this_node, store_ptr); - if (!stmt->is() && !stmt->is() && !stmt->is() && is_dead) { - // If an address is neither used in this node, nor used in the next - // nodes, then we can consider eliminating any stores to this address - // (it's not used anyway). There's two different scenerios though: - // 1. Any direct store stmt can be eliminated immediately (LocalStore, - // GlobalStore, AdStackPush, ...) - // 2. AtomicStmt (load + store): remove the store part, thus - // converting the AtomicStmt into a LoadStmt - if (!stmt->is()) { - // Eliminate the dead store. - erase(i); - modified = true; - continue; - } - auto atomic = stmt->cast(); - // Weaken the atomic operation to a load. - if (atomic->dest->is()) { - auto local_load = Stmt::make(atomic->dest); - local_load->ret_type = atomic->ret_type; - // Notice that we have a load here - // (the return value of AtomicOpStmt). - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, live_in_this_node, - atomic->dest, false); - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, killed_in_this_node, - atomic->dest, true); - live_load_in_this_node[atomic->dest] = local_load.get(); - - replace_with(i, std::move(local_load), true); - modified = true; - continue; - } else if (!is_parallel_executed || - (atomic->dest->is() && atomic->dest->as()->snode->is_scalar())) { - // If this node is parallel executed, we can't weaken a global - // atomic operation to a global load. - // TODO: we can weaken it if it's element-wise (i.e. never - // accessed by other threads). - auto global_load = Stmt::make(atomic->dest); - global_load->ret_type = atomic->ret_type; - // Notice that we have a load here - // (the return value of AtomicOpStmt). - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, live_in_this_node, - atomic->dest, false); - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, killed_in_this_node, - atomic->dest, true); - live_load_in_this_node[atomic->dest] = global_load.get(); - - replace_with(i, std::move(global_load), true); - modified = true; - continue; - } - } else { - // A non-eliminated store. - // Insert to killed_in_this_node if it's stored in this node. - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, killed_in_this_node, - store_ptr, false); - - // Remove the address from live_in_this_node if it's stored in this - // node. - auto old_live_in_this_node = live_in_this_node; - for (auto &var : old_live_in_this_node) { - if (irpass::analysis::definitely_same_address(store_ptr, var.first)) { - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, live_in_this_node, - store_ptr, true); - } - } - } - } +// Try to eliminate a dead store (or weaken a dead-store atomic to a pure load) at position |i|. +// If no elimination applies, fall through to update killed/live state for this non-eliminated +// store and return false; the caller will then move on to load handling for the same stmt. +bool try_eliminate_dead_store_at(CFGNode *node, + int i, + Stmt *stmt, + Stmt *store_ptr, + bool after_lower_access, + const DseAliasMaps &alias, + DseLiveState &state) { + if (!is_dse_eligible_pointer(store_ptr, after_lower_access)) { + return false; + } + const bool is_dead = is_store_dead(store_ptr, node->live_out, state); + const bool stmt_eliminable = + !stmt->is() && !stmt->is() && !stmt->is(); + if (stmt_eliminable && is_dead) { + // If an address is neither used in this node, nor in the next nodes, eliminate any stores to + // it. Two scenarios: + // 1. Direct store stmts (LocalStore, GlobalStore, AdStackPush, ...): erase outright. + // 2. AtomicOpStmt (load+store): drop the store half, leaving a pure load. + if (!stmt->is()) { + node->erase(i); + return true; } - auto load_ptrs = irpass::analysis::get_load_pointers(stmt); - if (load_ptrs.size() == 1 && store_ptrs.empty()) { - // Identical load elimination - auto load_ptr = load_ptrs.begin()[0]; + auto *atomic = stmt->cast(); + if (atomic->dest->is()) { + auto local_load = Stmt::make(atomic->dest); + local_load->ret_type = atomic->ret_type; + record_weakened_atomic_as_load(atomic->dest, local_load.get(), alias, state); + node->replace_with(i, std::move(local_load), true); + return true; + } + // If this node is parallel executed, we can't weaken a global atomic to a global load. + // TODO: we can weaken it if it's element-wise (i.e. never accessed by other threads). + const bool atomic_global_weakable = + !node->is_parallel_executed || + (atomic->dest->is() && atomic->dest->as()->snode->is_scalar()); + if (atomic_global_weakable) { + auto global_load = Stmt::make(atomic->dest); + global_load->ret_type = atomic->ret_type; + record_weakened_atomic_as_load(atomic->dest, global_load.get(), alias, state); + node->replace_with(i, std::move(global_load), true); + return true; + } + // [SEMANTIC TRAP -- DO NOT REMOVE] + // Atomic was dead but not safely weakable (parallel global, non-scalar). + // The legacy code intentionally returns *without* updating + // `killed_in_this_node` / `live_in_this_node` here -- the atomic still executes at runtime, so + // its dest is neither newly-killed nor a fresh live entry. Falling through into the + // state-update block below would mark this address as killed-by-an-erased-store, which it + // isn't, and silently corrupt downstream DSE decisions on aliasing addresses. + // A previous refactor of this function broke exactly this branch; the existing AD/atomic test + // suites do not cover it directly. Touch with care. + return false; + } + // Non-eliminated store (not dead, or stmt not eliminable): update state. Insert into killed, + // and drop any live-in entry that aliases the same address. + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.killed_in_this_node, + store_ptr, false); + auto old_live_in = state.live_in_this_node; + for (auto &var : old_live_in) { + if (irpass::analysis::definitely_same_address(store_ptr, var.first)) { + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.live_in_this_node, + store_ptr, true); + } + } + return false; +} - if (!after_lower_access || - (load_ptr->is() && load_ptr->as()->origin->is()) || - (load_ptr->is() && load_ptr->as()->origin->is()) || - (load_ptr->is() || load_ptr->is())) { - // live_load_in_this_node[addr]: tracks the - // next load to the same address - // "!may_contain_variable(killed_in_this_node, load_ptr)": means it's - // not been stored in between the two loads - if (live_load_in_this_node.find(load_ptr) != live_load_in_this_node.end() && - !may_contain_variable(killed_in_this_node, load_ptr)) { - // Only perform identical load elimination within a CFGNode. - auto next_load_stmt = live_load_in_this_node[load_ptr]; - if (irpass::analysis::same_statements(stmt, next_load_stmt)) { - next_load_stmt->replace_usages_with(stmt); - erase(block->locate(next_load_stmt)); - modified = true; - } - } +// Try to eliminate an identical (redundant) load at |stmt|. We only do this within a single +// CFGNode, and only if no store has intervened since the prior load. The prior load (later in the +// block, since we walk in reverse) is the one erased; |stmt| takes over its usages. +bool try_eliminate_identical_load_at(CFGNode *node, + Stmt *stmt, + Stmt *load_ptr, + bool after_lower_access, + const DseAliasMaps &alias, + DseLiveState &state) { + if (!is_dse_eligible_pointer(load_ptr, after_lower_access)) { + return false; + } + bool modified = false; + auto it = state.live_load_in_this_node.find(load_ptr); + if (it != state.live_load_in_this_node.end() && + !CFGNode::may_contain_variable(state.killed_in_this_node, load_ptr)) { + auto *next_load_stmt = it->second; + if (irpass::analysis::same_statements(stmt, next_load_stmt)) { + next_load_stmt->replace_usages_with(stmt); + node->erase(node->block->locate(next_load_stmt)); + modified = true; + } + } + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.killed_in_this_node, + load_ptr, true); + state.live_load_in_this_node[load_ptr] = stmt; + return modified; +} - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, killed_in_this_node, load_ptr, - true); - live_load_in_this_node[load_ptr] = stmt; - } +// Mark all (eligible) loads of |stmt| as live-in-this-node so that earlier-in-reverse-order stores +// to the same address see them and abort their dead-store check. Taken by non-const reference +// because `one_or_more::begin()` is non-const. +void mark_loads_live_in_this_node(stmt_refs &load_ptrs, + bool after_lower_access, + const DseAliasMaps &alias, + DseLiveState &state) { + for (auto *load_ptr : load_ptrs) { + if (is_dse_eligible_pointer(load_ptr, after_lower_access)) { + update_container_with_alias(alias.tensor_to_matrix_ptrs, alias.matrix_ptr_to_tensor, state.live_in_this_node, + load_ptr, false); } + } +} - // Update live_in_this_node - for (auto &load_ptr : load_ptrs) { - if (!after_lower_access || - (load_ptr->is() && load_ptr->as()->origin->is()) || - (load_ptr->is() && load_ptr->as()->origin->is()) || - (load_ptr->is() || load_ptr->is())) { - // Addr is used in this node, so it's live in this node - update_container_with_alias(tensor_to_matrix_ptrs_map, matrix_ptr_to_tensor_map, live_in_this_node, load_ptr, - false); +} // namespace + +bool CFGNode::dead_store_elimination(bool after_lower_access) { + // Reverse-order walk over this node's statements. At each statement we may (a) eliminate a dead + // store, (b) eliminate an identical load, or (c) update live/killed state for later iterations. + // `FuncCallStmt` is treated as a full barrier: it can read/write anything, so we drop the kill + // and live-load tables and start fresh. + const DseAliasMaps alias = build_matrix_ptr_alias_maps(block, begin_location, end_location); + DseLiveState state; + bool modified = false; + for (int i = end_location - 1; i >= begin_location; i--) { + auto *stmt = block->statements[i].get(); + if (stmt->is()) { + state.killed_in_this_node.clear(); + state.live_load_in_this_node.clear(); + continue; + } + auto store_ptrs = dse_store_destinations(stmt); + if (store_ptrs.size() == 1) { + if (try_eliminate_dead_store_at(this, i, stmt, *store_ptrs.begin(), after_lower_access, alias, state)) { + modified = true; + continue; + } + } + auto load_ptrs = irpass::analysis::get_load_pointers(stmt); + if (load_ptrs.size() == 1 && store_ptrs.empty()) { + if (try_eliminate_identical_load_at(this, stmt, *load_ptrs.begin(), after_lower_access, alias, state)) { + modified = true; } } + mark_loads_live_in_this_node(load_ptrs, after_lower_access, alias, state); } return modified; } +// =========================================================================== +// ControlFlowGraph -- whole-graph methods. +// Each analysis/transform below is the driver for the same-named per-node +// method on CFGNode above. It calls the per-node method on every node, then +// runs whatever cross-node work the pass needs (worklist fixpoint for RD/LV, +// a flat per-node loop for S2L/DSE). +// =========================================================================== + +void ControlFlowGraph::assert_structural_invariants() const { + // These invariants are established by `analysis::build_cfg` and preserved by every + // ControlFlowGraph method below. They are also assumed by every helper in this file (e.g. the + // worklist seeding `nodes[start_node]->reach_gen.insert(...)`, the DP scratch buffer indexed by + // start_node, the dump-graph loop that skips start_node / final_node by index). If they ever + // fail, the code following will segfault or silently corrupt -- catch it here instead. + QD_ASSERT_INFO(!nodes.empty(), "ControlFlowGraph has no nodes"); + QD_ASSERT_INFO(start_node >= 0 && start_node < (int)nodes.size(), "start_node out of range"); + QD_ASSERT_INFO(final_node >= 0 && final_node < (int)nodes.size(), "final_node out of range"); + QD_ASSERT_INFO(nodes[start_node] != nullptr, "start_node entry is null"); + QD_ASSERT_INFO(nodes[final_node] != nullptr, "final_node entry is null"); +} + void ControlFlowGraph::erase(int node_id) { // Erase an empty node. QD_ASSERT(node_id >= 0 && node_id < (int)size()); @@ -866,12 +1073,88 @@ CFGNode *ControlFlowGraph::back() const { return nodes.back().get(); } +namespace { + +// === Helpers for ControlFlowGraph::dump_graph_to_file === + +// Range label: "empty" or "~ (size=N)". +std::string format_cfg_node_range_label(const CFGNode *node) { + if (node->empty()) { + return "empty"; + } + return fmt::format("{}~{} (size={})", node->block->statements[node->begin_location]->name(), + node->block->statements[node->end_location - 1]->name(), node->size()); +} + +// Brace-delimited list of neighbor indices, e.g. "{0, 1, 2}". +std::string format_neighbor_indices(const std::vector &neighbors, + const std::unordered_map &to_index) { + std::vector parts; + parts.reserve(neighbors.size()); + for (auto *n : neighbors) { + parts.push_back(std::to_string(to_index.at(n))); + } + return fmt::format("{{{}}}", fmt::join(parts, ", ")); +} + +// Brace-delimited list of stmt names, e.g. "{$3, $5}". +std::string format_live_var_names(const std::unordered_set &vars) { + std::vector parts; + parts.reserve(vars.size()); + for (auto *stmt : vars) { + parts.push_back(stmt->name()); + } + return fmt::format("{{{}}}", fmt::join(parts, ", ")); +} + +// Write one node's header line: index, range label, optional prev/next/live_out lists. +void write_cfg_node_header(std::ostream &out, + int index, + const CFGNode *node, + const std::unordered_map &to_index) { + out << fmt::format("Node {} : ", index) << format_cfg_node_range_label(node); + if (!node->prev.empty()) { + out << "; prev=" << format_neighbor_indices(node->prev, to_index); + } + if (!node->next.empty()) { + out << "; next=" << format_neighbor_indices(node->next, to_index); + } + if (!node->live_out.empty()) { + out << "; live_out=" << format_live_var_names(node->live_out); + } + out << "\n"; +} + +// Write one node's statements, each line indented 4 spaces, blank line after. +void write_cfg_node_statements(std::ostream &out, const CFGNode *node) { + if (node->empty()) { + return; + } + for (int j = node->begin_location; j < node->end_location; j++) { + auto *stmt = node->block->statements[j].get(); + std::string stmt_output; + // print_kernel_wrapper=false to avoid the surrounding "kernel { }" wrapper. + irpass::print(stmt, &stmt_output, false, false); + std::istringstream iss(stmt_output); + std::string line; + while (std::getline(iss, line)) { + if (!line.empty()) { + out << " " << line << "\n"; + } + } + } + out << "\n"; +} + +} // namespace + void ControlFlowGraph::dump_graph_to_file(const CompileConfig &config, const std::string &kernel_name, const std::string &suffix) const { - std::filesystem::path ir_dump_dir = config.debug_dump_path; + assert_structural_invariants(); + const std::filesystem::path ir_dump_dir = config.debug_dump_path; std::filesystem::create_directories(ir_dump_dir); - std::filesystem::path filename = ir_dump_dir / (kernel_name + "_CFG" + suffix + ".txt"); + const std::filesystem::path filename = ir_dump_dir / (kernel_name + "_CFG" + suffix + ".txt"); std::ofstream out_file(filename.string()); if (!out_file) { @@ -879,117 +1162,106 @@ void ControlFlowGraph::dump_graph_to_file(const CompileConfig &config, return; } - // Write directly to the file using fmt::format const int num_nodes = size(); std::unordered_map to_index; + to_index.reserve(num_nodes); for (int i = 0; i < num_nodes; i++) { to_index[nodes[i].get()] = i; } for (int i = 0; i < num_nodes; i++) { - out_file << fmt::format("Node {} : ", i); - if (nodes[i]->empty()) { - out_file << "empty"; - } else { - out_file << fmt::format("{}~{} (size={})", nodes[i]->block->statements[nodes[i]->begin_location]->name(), - nodes[i]->block->statements[nodes[i]->end_location - 1]->name(), nodes[i]->size()); - } - if (!nodes[i]->prev.empty()) { - std::vector indices; - for (auto prev_node : nodes[i]->prev) { - indices.push_back(std::to_string(to_index[prev_node])); - } - out_file << fmt::format("; prev={{{}}}", fmt::join(indices, ", ")); - } - if (!nodes[i]->next.empty()) { - std::vector indices; - for (auto next_node : nodes[i]->next) { - indices.push_back(std::to_string(to_index[next_node])); - } - out_file << fmt::format("; next={{{}}}", fmt::join(indices, ", ")); - } - if (!nodes[i]->live_out.empty()) { - std::vector vars; - for (auto stmt : nodes[i]->live_out) { - vars.push_back(stmt->name()); - } - out_file << fmt::format("; live_out={{{}}}", fmt::join(vars, ", ")); + write_cfg_node_header(out_file, i, nodes[i].get(), to_index); + write_cfg_node_statements(out_file, nodes[i].get()); + } + + out_file.close(); + QD_INFO("CFG dumped to: {}", filename.string()); +} + +namespace { + +// === Helpers for ControlFlowGraph::reaching_definition_analysis === + +// Statements that carry data into this kernel from outside its CFG. Treated as definitions in the +// synthetic entry node's reach_gen so cross-block forwarding correctly sees them. After access +// lowering, only MatrixPtrStmt aliases of allocas/matrix-pointers remain in scope; before, the +// full menagerie of pointer flavors applies. +// TODO: unify them. +bool is_external_input_pointer(const Stmt *stmt, bool after_lower_access) { + if (stmt->is()) { + const auto *origin = stmt->as()->origin; + if (origin->is() || origin->is()) { + return true; } - out_file << "\n"; - - // Print the actual statements in this node - if (!nodes[i]->empty()) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - auto stmt = nodes[i]->block->statements[j].get(); - std::string stmt_output; - // Use print_kernel_wrapper=false to avoid the "kernel { }" wrapper - irpass::print(stmt, &stmt_output, false, false); - - // Add indentation to each line - std::istringstream iss(stmt_output); - std::string line; - while (std::getline(iss, line)) { - if (!line.empty()) { - out_file << " " << line << "\n"; - } - } + } + if (after_lower_access) { + return false; + } + return stmt->is() || stmt->is() || stmt->is() || + stmt->is() || stmt->is() || stmt->is() || + stmt->is() || stmt->is() || stmt->is(); +} + +// Seed the synthetic entry node's reach_gen with definitions that "exist" before this kernel +// begins executing: external pointer loads (so a load before the first store still has a +// UD-chain) plus FuncCallStmt store destinations (the callee may have written anything declared +// in its `store_dests`). +void seed_start_node_reach_gen(CFGNode *start_node, + const std::vector> &nodes, + bool after_lower_access) { + start_node->reach_gen.clear(); + start_node->reach_kill.clear(); + for (const auto &node : nodes) { + for (int j = node->begin_location; j < node->end_location; j++) { + auto *stmt = node->block->statements[j].get(); + if (is_external_input_pointer(stmt, after_lower_access)) { + start_node->reach_gen.insert(stmt); + } else if (auto *func_call = stmt->cast()) { + const auto &dests = func_call->func->store_dests; + start_node->reach_gen.insert(dests.begin(), dests.end()); } - out_file << "\n"; } } +} - out_file.close(); - QD_INFO("CFG dumped to: {}", filename.string()); +// Is |stmt| killed at |node| (i.e., excluded from reach_out)? For stmts with explicit store +// destinations, killed iff every dest is killed at |node|. For stmts without dests (e.g. raw +// global pointers seeded into start_node's reach_gen), killed iff the stmt itself is killed. +bool is_reach_in_stmt_killed_at(CFGNode *node, Stmt *stmt) { + // Not const: `one_or_more::begin()` is non-const, so the range-for below would not compile. + auto store_ptrs = irpass::analysis::get_store_destination(stmt); + if (store_ptrs.empty()) { + return node->is_reach_killed(stmt); + } + for (auto *store_ptr : store_ptrs) { + if (!node->is_reach_killed(store_ptr)) { + return false; + } + } + return true; } +} // namespace + void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { - // Prerequisite analysis for load-store-forwarding to help determine - // cross-block use-define chain - // - // The algorithm is separated into two parts: - // 1. Determine reach_gen and reach_kill within each node - // 2. Propagate reach_in and reach_out through the graph - // - // - reach_gen: instruction that defines a variable (store stmts) in the - // current node - // - reach_kill: address (GlobalPtrStmt, AllocaStmt, ...) that's been defined - // (stored to) in the current node - // - // In general, reach_gen and reach_kill are the same except that reach_gen - // tracks the store stmts and reach_kill tracks the address + // Prerequisite analysis for load-store-forwarding; computes the cross-block use-define chain. // - // - reach_out: reach_gen + { reach_in's dest not in reach_kill } - // - reach_in: collection of all the reach_out of previous nodes + // Per-node: + // - reach_gen: store stmts that define a variable in this node. + // - reach_kill: addresses (GlobalPtrStmt, AllocaStmt, ...) stored to in this node. + // (reach_gen tracks the defining stmts; reach_kill tracks the killed addresses.) // - // reach_out and reach_in is the ultimate result that helps analyze - // cross-block use-define chain - + // Per-graph (worklist fixpoint): + // - reach_in: union of reach_out of all predecessor nodes. + // - reach_out: reach_gen + { stmts from reach_in whose dest is not in reach_kill }. QD_AUTO_PROF; + assert_structural_invariants(); const int num_nodes = size(); - std::queue to_visit; - std::unordered_map in_queue; QD_ASSERT(nodes[start_node]->empty()); - nodes[start_node]->reach_gen.clear(); - nodes[start_node]->reach_kill.clear(); - for (int i = 0; i < num_nodes; i++) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - auto stmt = nodes[i]->block->statements[j].get(); - if ((stmt->is() && stmt->as()->origin->is()) || - (stmt->is() && stmt->as()->origin->is()) || - (!after_lower_access && - (stmt->is() || stmt->is() || stmt->is() || - stmt->is() || stmt->is() || stmt->is() || - stmt->is() || stmt->is() || stmt->is()))) { - // TODO: unify them - // A global pointer that may contain some data before this kernel. - nodes[start_node]->reach_gen.insert(stmt); - } else if (auto func_call = stmt->cast()) { - const auto &dests = func_call->func->store_dests; - nodes[start_node]->reach_gen.insert(dests.begin(), dests.end()); - } - } - } + seed_start_node_reach_gen(nodes[start_node].get(), nodes, after_lower_access); + std::queue to_visit; + std::unordered_map in_queue; for (int i = 0; i < num_nodes; i++) { if (i != start_node) { nodes[i]->reaching_definition_analysis(after_lower_access); @@ -1000,40 +1272,25 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { in_queue[nodes[i].get()] = true; } - // [The worklist algorithm] - // Determines reach_in and reach_out for each node iteratively. + // [Worklist algorithm] Converge reach_in / reach_out iteratively. while (!to_visit.empty()) { - auto now = to_visit.front(); + auto *now = to_visit.front(); to_visit.pop(); in_queue[now] = false; now->reach_in.clear(); - for (auto prev_node : now->prev) { + for (auto *prev_node : now->prev) { now->reach_in.insert(prev_node->reach_out.begin(), prev_node->reach_out.end()); } auto old_out = std::move(now->reach_out); now->reach_out = now->reach_gen; - for (auto stmt : now->reach_in) { - auto store_ptrs = irpass::analysis::get_store_destination(stmt); - bool killed; - if (store_ptrs.empty()) { // the case of a global pointer - killed = now->reach_kill_variable(stmt); - } else { - killed = true; - for (auto store_ptr : store_ptrs) { - if (!now->reach_kill_variable(store_ptr)) { - killed = false; - break; - } - } - } - if (!killed) { + for (auto *stmt : now->reach_in) { + if (!is_reach_in_stmt_killed_at(now, stmt)) { now->reach_out.insert(stmt); } } if (now->reach_out != old_out) { - // changed - for (auto next_node : now->next) { + for (auto *next_node : now->next) { if (!in_queue[next_node]) { to_visit.push(next_node); in_queue[next_node] = true; @@ -1043,52 +1300,57 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { } } +namespace { + +// === Helpers for ControlFlowGraph::live_variable_analysis === + +// Seed the synthetic exit node's live_gen with every store destination in the kernel that may +// still be observable after the kernel exits (globals not flagged eliminable by the SFG, plus any +// aliased pointers via `get_alias`). Skipped entirely after access lowering, when only locals +// remain and nothing escapes. +void seed_final_node_live_gen(CFGNode *final_node_ptr, + const std::vector> &nodes, + bool after_lower_access, + const std::optional &config_opt) { + final_node_ptr->live_gen.clear(); + final_node_ptr->live_kill.clear(); + if (after_lower_access) { + return; + } + for (const auto &node : nodes) { + for (int j = node->begin_location; j < node->end_location; j++) { + auto *stmt = node->block->statements[j].get(); + for (auto *store_ptr : irpass::analysis::get_store_destination(stmt, /*get_alias=*/true)) { + if (in_final_node_live_gen(store_ptr, config_opt)) { + final_node_ptr->live_gen.insert(store_ptr); + } + } + } + } +} + +} // namespace + void ControlFlowGraph::live_variable_analysis(bool after_lower_access, const std::optional &config_opt) { - // [live_variable_analysis] - // live_gen: address loaded with no previous stored in this node. One cannot - // load before storing so - // addrs in live_gen must come from previous nodes - // live_kill: address stored in this node - // live_in: live_gen + (live_out - live_kill) - // live_out: collection of all the live_in of next nodes + // Per-node: + // - live_gen: address loaded with no preceding store in this node (must live in from + // somewhere -- you can't load before storing). + // - live_kill: address stored in this node. + // Per-graph (worklist fixpoint, propagated backwards): + // - live_in: live_gen + (live_out - live_kill). + // - live_out: union of live_in of all successor nodes. QD_AUTO_PROF; + assert_structural_invariants(); const int num_nodes = size(); - std::queue to_visit; - std::unordered_map in_queue; QD_ASSERT(nodes[final_node]->empty()); - nodes[final_node]->live_gen.clear(); - nodes[final_node]->live_kill.clear(); - - auto in_final_node_live_gen = [&config_opt](const Stmt *stmt) -> bool { - if (stmt->is() || stmt->is()) { - return false; - } - if (stmt->is() && stmt->cast()->origin->is()) { - return false; - } - if (auto *gptr = stmt->cast(); gptr && config_opt.has_value()) { - const bool res = (config_opt->eliminable_snodes.count(gptr->snode) == 0); - return res; - } - // A global pointer that may be loaded after this kernel. - return true; - }; - if (!after_lower_access) { - for (int i = 0; i < num_nodes; i++) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - auto stmt = nodes[i]->block->statements[j].get(); - for (auto store_ptr : irpass::analysis::get_store_destination(stmt, true /*get_alias*/)) { - if (in_final_node_live_gen(store_ptr)) { - nodes[final_node]->live_gen.insert(store_ptr); - } - } - } - } - } + seed_final_node_live_gen(nodes[final_node].get(), nodes, after_lower_access, config_opt); + std::queue to_visit; + std::unordered_map in_queue; + // Push in reversed order: backwards analysis converges slightly faster when the worklist seeds + // are dequeued from the back of the graph first. for (int i = num_nodes - 1; i >= 0; i--) { - // push into the queue in reversed order to make it slightly faster if (i != final_node) { nodes[i]->live_variable_analysis(after_lower_access); } @@ -1098,26 +1360,25 @@ void ControlFlowGraph::live_variable_analysis(bool after_lower_access, in_queue[nodes[i].get()] = true; } - // The worklist algorithm. + // [Worklist algorithm] Converge live_in / live_out iteratively (backwards). while (!to_visit.empty()) { - auto now = to_visit.front(); + auto *now = to_visit.front(); to_visit.pop(); in_queue[now] = false; now->live_out.clear(); - for (auto next_node : now->next) { + for (auto *next_node : now->next) { now->live_out.insert(next_node->live_in.begin(), next_node->live_in.end()); } auto old_in = std::move(now->live_in); now->live_in = now->live_gen; - for (auto stmt : now->live_out) { + for (auto *stmt : now->live_out) { if (!CFGNode::contain_variable(now->live_kill, stmt)) { now->live_in.insert(stmt); } } if (now->live_in != old_in) { - // changed - for (auto prev_node : now->prev) { + for (auto *prev_node : now->prev) { if (!in_queue[prev_node]) { to_visit.push(prev_node); in_queue[prev_node] = true; @@ -1129,6 +1390,7 @@ void ControlFlowGraph::live_variable_analysis(bool after_lower_access, void ControlFlowGraph::simplify_graph() { // Simplify the graph structure, do not modify the IR. + assert_structural_invariants(); const int num_nodes = size(); while (true) { bool modified = false; @@ -1163,6 +1425,7 @@ bool ControlFlowGraph::unreachable_code_elimination() { // Note that container statements are not in the control-flow graph, so // this pass cannot eliminate container statements properly for now. QD_AUTO_PROF; + assert_structural_invariants(); std::unordered_set visited; std::queue to_visit; to_visit.push(nodes[start_node].get()); @@ -1205,6 +1468,7 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, bool au // This is done in CFGNode::store_to_load_forwarding() of each node QD_AUTO_PROF; + assert_structural_invariants(); reaching_definition_analysis(after_lower_access); const int num_nodes = size(); bool modified = false; @@ -1218,6 +1482,7 @@ bool ControlFlowGraph::store_to_load_forwarding(bool after_lower_access, bool au bool ControlFlowGraph::dead_store_elimination(bool after_lower_access, const std::optional &lva_config_opt) { QD_AUTO_PROF; + assert_structural_invariants(); live_variable_analysis(after_lower_access, lva_config_opt); const int num_nodes = size(); bool modified = false; @@ -1230,6 +1495,7 @@ bool ControlFlowGraph::dead_store_elimination(bool after_lower_access, std::unordered_set ControlFlowGraph::gather_loaded_snodes() { QD_AUTO_PROF; + assert_structural_invariants(); reaching_definition_analysis(/*after_lower_access=*/false); const int num_nodes = size(); std::unordered_set snodes; @@ -1256,447 +1522,4 @@ std::unordered_set ControlFlowGraph::gather_loaded_snodes() { return snodes; } -void ControlFlowGraph::determine_ad_stack_size() { - /** - * Determine the necessary size of every adaptive AD-stack on the control-flow graph (CFG). For each AD-stack we - * compute the maximum running net push count along any walk from the kernel entry. AD-stacks whose forward kernel - * contains a positive cycle (pushes > pops around a loop) are left at `max_size = 0`, and the caller routes them - * through the structural bounded-loop pre-pass for a symbolic `SizeExpr`, hard-erroring if the grammar still - * cannot resolve them. There is no compile-time size fallback. - * - * Implementation notes for compile-time perf on large reverse-mode kernels: - * 1. Per-stack per-node pre-aggregates (`max_increased_size`, `increased_size`) are stored in dense - * `vector>` indexed by a contiguous int stack id, instead of an - * `unordered_map>` -- this removes hash traffic from the hot inner loop. - * 2. Stacks whose `(increased_size, max_increased_size)` row pair is bit-identical share a single dynamic - * programming run -- typical kernels generate one alloca per autodiff variable in the same loop body, so - * most rows collapse to a few representatives. - * 3. The CFG is condensed via Tarjan into strongly connected components (SCCs). DFS finish times recorded - * during the same Tarjan pass split each cyclic SCC's intra-edges into a forward set (target finishes - * before source) and a back set (target finishes at or after source). Per representative we run a - * single-pass dynamic-programming (DP) sweep over the forward edges in descending finish-time order, then - * check the back edges once for positive-cycle relaxation. Correctness: any walk inside an SCC decomposes - * into a forward path plus zero or more cycles, and an SCC with no positive cycle has the same max-walk-sum - * as the back-edge-removed DAG; a positive cycle is exactly the case where some back-edge would still relax - * after the forward DP. This drops the per-cyclic-SCC cost from O(|S| * |E_S|) to O(|S| + |E_S|). - * 4. Two sign-based fast paths short-circuit the DP for trivial cyclic SCCs: an SCC with `min_is >= 0 && max_is - * > 0` for this stack must contain a positive cycle (every node lies on some cycle, and a cycle through a - * strictly-positive node with all non-negative `is` along it sums positive); an SCC with `min_is == 0 == - * max_is` has no `is` contribution at all and is handled by spreading the max entry-side - * `max_size_at_node_begin` to every node in O(|S|). - * Per-rep cost becomes O(V + E + sum_{cyclic S} |S| * |E_S|) (with the SCC sum dropping to O(|S| + |E_S|) for - * the common autodiff push/pop pattern); overall cost is O(V + E + R * (V + E)) with R the number of distinct - * row-pair representatives. - */ - const int num_nodes = size(); - - // Map AdStackAllocaStmt* to a contiguous int index. Only stacks whose `max_size` is still 0 (i.e. unresolved by an - // earlier pass) participate in the DP; resolved stacks are skipped so the pass cannot clobber them. - std::unordered_map stack_id; - std::vector stacks; - - std::unordered_map node_ids; - for (int i = 0; i < num_nodes; i++) - node_ids[nodes[i].get()] = i; - - for (int i = 0; i < num_nodes; i++) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - Stmt *stmt = nodes[i]->block->statements[j].get(); - if (auto *stack = stmt->cast()) { - if (stack->max_size != 0) { - continue; - } - if (stack_id.emplace(stack, static_cast(stacks.size())).second) { - stacks.push_back(stack); - } - } - } - } - - const int num_stacks = static_cast(stacks.size()); - if (num_stacks == 0) { - return; - } - - // max_increased_size[s][j] is the maximum number of (pushes - pops) of stack |s| among all prefixes of the - // CFGNode |j|. increased_size[s][j] is the net (pushes - pops) of stack |s| in the CFGNode |j|. Both are indexed - // by contiguous stack id, so the per-stack DP loop reads them with cheap vector access. - std::vector> max_increased_size(num_stacks, std::vector(num_nodes, 0)); - std::vector> increased_size(num_stacks, std::vector(num_nodes, 0)); - - // Track which stacks actually participate in any push/pop in the CFG. Stacks with no push/pop end with - // `max_size = 0` regardless of the DP (every node has zero increase), so we skip the DP for them and reproduce - // the original "Unused autodiff stack" warning here. - std::vector stack_active(num_stacks, false); - - for (int i = 0; i < num_nodes; i++) { - for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { - Stmt *stmt = nodes[i]->block->statements[j].get(); - if (auto *stack_push = stmt->cast()) { - auto *stack = stack_push->stack->as(); - if (stack->max_size == 0 /*adaptive*/) { - auto it = stack_id.find(stack); - QD_ASSERT(it != stack_id.end()); - const int sid = it->second; - stack_active[sid] = true; - int &cur = increased_size[sid][i]; - cur++; - if (cur > max_increased_size[sid][i]) { - max_increased_size[sid][i] = cur; - } - } - } else if (auto *stack_pop = stmt->cast()) { - auto *stack = stack_pop->stack->as(); - if (stack->max_size == 0 /*adaptive*/) { - auto it = stack_id.find(stack); - QD_ASSERT(it != stack_id.end()); - const int sid = it->second; - stack_active[sid] = true; - increased_size[sid][i]--; - } - } - } - } - - // Precompute outgoing-edge node ids once per node so the per-stack DP walks an int vector instead of hashing - // `node_ids[next_node]` on every traversal. - std::vector> next_ids(num_nodes); - for (int i = 0; i < num_nodes; i++) { - auto &dst = next_ids[i]; - dst.reserve(nodes[i]->next.size()); - for (auto *next_node : nodes[i]->next) { - dst.push_back(node_ids[next_node]); - } - } - - // Tarjan strongly-connected-component (SCC) decomposition of the CFG, computed once and shared across all per-stack - // dynamic-programming runs. Output: - // scc_id[n] : SCC index of each node (lower index = topologically deeper / sink-side, so sources are at index - // num_sccs - 1, matching Tarjan's natural emission order which is reverse topological order). - // scc_nodes[s] : list of node ids in SCC s. - // dfs_finish[n]: DFS post-order index, used below to split each cyclic SCC's intra-edges into a forward and a - // back set without a separate edge classification pass. - // We then split each node's outgoing edges into intra-SCC and inter-SCC sets so the per-stack DP iterates each - // set without filtering inside the hot loop. Cyclic SCCs (|S| > 1 or |S| == 1 with self-loop) are flagged so - // cycle detection runs only on those, and only at SCC scope. - std::vector scc_id(num_nodes, -1); - std::vector dfs_finish(num_nodes, -1); // DFS post-order index per node; ancestors finish AFTER descendants. - std::vector> scc_nodes; - { - std::vector tarjan_index(num_nodes, -1); - std::vector tarjan_lowlink(num_nodes, 0); - std::vector on_stack(num_nodes, 0); - std::vector tarjan_stack; - tarjan_stack.reserve(num_nodes); - std::vector> dfs_stack; - dfs_stack.reserve(num_nodes); - int next_index = 0; - int next_finish = 0; - int next_scc = 0; - for (int origin = 0; origin < num_nodes; origin++) { - if (tarjan_index[origin] != -1) { - continue; - } - tarjan_index[origin] = next_index; - tarjan_lowlink[origin] = next_index; - next_index++; - tarjan_stack.push_back(origin); - on_stack[origin] = 1; - dfs_stack.emplace_back(origin, 0); - while (!dfs_stack.empty()) { - auto &frame = dfs_stack.back(); - const int u = frame.first; - const auto &nb = next_ids[u]; - if (frame.second < nb.size()) { - const int v = nb[frame.second++]; - if (tarjan_index[v] == -1) { - tarjan_index[v] = next_index; - tarjan_lowlink[v] = next_index; - next_index++; - tarjan_stack.push_back(v); - on_stack[v] = 1; - dfs_stack.emplace_back(v, 0); - } else if (on_stack[v]) { - if (tarjan_index[v] < tarjan_lowlink[u]) { - tarjan_lowlink[u] = tarjan_index[v]; - } - } - } else { - if (tarjan_lowlink[u] == tarjan_index[u]) { - std::vector component; - while (true) { - const int w = tarjan_stack.back(); - tarjan_stack.pop_back(); - on_stack[w] = 0; - scc_id[w] = next_scc; - component.push_back(w); - if (w == u) { - break; - } - } - scc_nodes.push_back(std::move(component)); - next_scc++; - } - dfs_finish[u] = next_finish++; - dfs_stack.pop_back(); - if (!dfs_stack.empty()) { - const int parent = dfs_stack.back().first; - if (tarjan_lowlink[u] < tarjan_lowlink[parent]) { - tarjan_lowlink[parent] = tarjan_lowlink[u]; - } - } - } - } - } - } - const int num_sccs = static_cast(scc_nodes.size()); - - // Each intra-SCC edge is either a "forward" edge in the DFS spanning sense (source finishes after target, so source - // can be processed before target with a single topological pass) or a "back" edge (source finishes before target, - // closing a cycle). The forward set carries the per-stack DAG dynamic programming; back-edges only need a single - // post-DP relaxation check to detect positive cycles. Inter-SCC edges always relax forward in topological order. - std::vector> next_ids_intra_fwd(num_nodes); - std::vector> next_ids_intra_back(num_nodes); - std::vector> next_ids_inter(num_nodes); - for (int u = 0; u < num_nodes; u++) { - const int su = scc_id[u]; - for (int v : next_ids[u]) { - if (scc_id[v] == su) { - if (dfs_finish[v] < dfs_finish[u]) { - next_ids_intra_fwd[u].push_back(v); - } else { - next_ids_intra_back[u].push_back(v); - } - } else { - next_ids_inter[u].push_back(v); - } - } - } - - // An SCC is cyclic iff it contains a cycle. By definition that is |S| > 1 (any two nodes lie on a cycle since the - // SCC is strongly connected) or |S| == 1 with a self-loop edge. For cyclic SCCs we also precompute a topological - // ordering of the SCC's nodes (descending DFS finish time) so the per-stack DAG DP visits each node once with all - // forward predecessors already finalized. - std::vector scc_is_cyclic(num_sccs, 0); - std::vector> scc_topo(num_sccs); - for (int s = 0; s < num_sccs; s++) { - auto &nodes_in_s = scc_nodes[s]; - if (nodes_in_s.size() > 1) { - scc_is_cyclic[s] = 1; - } else { - const int n = nodes_in_s[0]; - // A singleton SCC is cyclic iff it has a self-loop. Self-loops are classified as back-edges above - // (dfs_finish[v] == dfs_finish[u]), so check there. - for (int v : next_ids_intra_back[n]) { - if (v == n) { - scc_is_cyclic[s] = 1; - break; - } - } - } - if (scc_is_cyclic[s]) { - auto topo = nodes_in_s; - std::sort(topo.begin(), topo.end(), [&](int a, int b) { return dfs_finish[a] > dfs_finish[b]; }); - scc_topo[s] = std::move(topo); - } - } - - // Group AD-stacks whose per-node (increased_size, max_increased_size) rows are bit-identical, so that the DP runs - // once per equivalence class instead of once per stack. In practice many AD-stacks in the same kernel share their - // push/pop schedule (one alloca per autodiff variable in the same loop body), and the DP on a large CFG dwarfs - // the dedup cost. We hash a sparse fingerprint (only nodes where the stack actually has push/pop activity) and - // group by it. Worst case (no duplicates): one DP run per stack, equivalent to running it directly per stack. - using Fingerprint = std::vector>; // (node_id, is, mis), sorted by node_id - struct FingerprintHash { - std::size_t operator()(const Fingerprint &f) const noexcept { - std::size_t h = 1469598103934665603ULL; - for (auto &[n, i, m] : f) { - auto mix = [&](std::size_t x) { - h ^= x; - h *= 1099511628211ULL; - }; - mix(static_cast(n)); - mix(static_cast(static_cast(i))); - mix(static_cast(static_cast(m))); - } - return h; - } - }; - std::unordered_map fp_to_rep; // fingerprint -> representative stack id - std::vector stack_to_rep(num_stacks, -1); - std::vector rep_stack_ids; // stack ids that actually run BF - for (int sid = 0; sid < num_stacks; sid++) { - if (!stack_active[sid]) { - continue; - } - Fingerprint fp; - const auto &is_row = increased_size[sid]; - const auto &mis_row = max_increased_size[sid]; - for (int n = 0; n < num_nodes; n++) { - if (is_row[n] != 0 || mis_row[n] != 0) { - fp.emplace_back(n, is_row[n], mis_row[n]); - } - } - auto [it, inserted] = fp_to_rep.emplace(std::move(fp), sid); - stack_to_rep[sid] = it->second; - if (inserted) { - rep_stack_ids.push_back(sid); - } - } - - // Scratch buffer reused across representatives to avoid reallocating per iteration. - std::vector max_size_at_node_begin(num_nodes); - - // Per-representative results, broadcast below to every stack that hashed to this representative. - struct DPResult { - int max_size; - bool has_positive_loop; - }; - std::unordered_map rep_results; - rep_results.reserve(rep_stack_ids.size()); - - // For each representative stack, walk the SCC condensation in topological order (sources first, since Tarjan emits - // SCCs in reverse-topological order so source SCCs end up at the largest indices). Inter-SCC edges only relax - // forward (predecessors are already finalized when an SCC is entered), so each is touched O(1) times per stack. - for (int rep_sid : rep_stack_ids) { - const std::vector &mis_for_stack = max_increased_size[rep_sid]; - const std::vector &is_for_stack = increased_size[rep_sid]; - - std::fill(max_size_at_node_begin.begin(), max_size_at_node_begin.end(), -1); - - int max_size = 0; - max_size_at_node_begin[start_node] = 0; - - bool has_positive_loop = false; - - for (int s = num_sccs - 1; s >= 0 && !has_positive_loop; s--) { - const auto &nodes_in_s = scc_nodes[s]; - - if (scc_is_cyclic[s]) { - // Sign-based fast paths sidestep the cyclic-SCC dynamic programming when the stack's `is` contribution - // inside this SCC is structurally trivial. Two cases short-circuit: - // 1. min_is >= 0 with max_is > 0: every node in the SCC lies on some cycle (SCC property), and a cycle - // through a strictly-positive node with all non-negative `is` along it sums to a positive value, so a - // positive cycle exists for this stack. This is the autodiff push-only-in-SCC pattern. - // 2. min_is == 0 == max_is (no `is` contribution in this SCC at all): the DP would only spread the maximum - // entry-side `max_size_at_node_begin` value to every node in S; we do that directly in O(|S|). - int min_is = INT_MAX; - int max_is = INT_MIN; - for (int u : nodes_in_s) { - const int v = is_for_stack[u]; - if (v < min_is) { - min_is = v; - } - if (v > max_is) { - max_is = v; - } - } - if (min_is >= 0 && max_is > 0) { - has_positive_loop = true; - break; - } - if (min_is == 0 && max_is == 0) { - int max_begin = -1; - for (int u : nodes_in_s) { - if (max_size_at_node_begin[u] > max_begin) { - max_begin = max_size_at_node_begin[u]; - } - } - if (max_begin >= 0) { - for (int u : nodes_in_s) { - if (max_size_at_node_begin[u] < max_begin) { - max_size_at_node_begin[u] = max_begin; - } - } - } - } else { - // Mixed-sign case: single-pass dynamic programming on the SCC's forward edges (processed in descending DFS - // finish-time so every forward predecessor is finalized before its successor relaxes), followed by one - // relaxation check on the back-edges. Correctness: every walk inside the SCC decomposes into a forward - // path (along non-back edges) plus zero or more closed cycles formed by a back-edge plus a forward path. - // For SCCs with no positive cycle, traversing a cycle adds a non-positive amount to the running size and - // so cannot improve max_size beyond what the forward DP already computed. The back-edge relaxation check - // after the DP detects the only failure mode (some back-edge would still improve a forward predecessor's - // value, which can only happen if a positive cycle exists for this stack). - for (int u : scc_topo[s]) { - const int begin = max_size_at_node_begin[u]; - if (begin < 0) { - continue; - } - const int exit_val = begin + is_for_stack[u]; - for (int v : next_ids_intra_fwd[u]) { - if (exit_val > max_size_at_node_begin[v]) { - max_size_at_node_begin[v] = exit_val; - } - } - } - for (int u : nodes_in_s) { - const int begin = max_size_at_node_begin[u]; - if (begin < 0) { - continue; - } - const int exit_val = begin + is_for_stack[u]; - for (int v : next_ids_intra_back[u]) { - if (exit_val > max_size_at_node_begin[v]) { - has_positive_loop = true; - break; - } - } - if (has_positive_loop) { - break; - } - } - if (has_positive_loop) { - break; - } - } - } - - // SCC has converged for this stack. Update global max_size from each node's max-prefix contribution and relax - // inter-SCC outgoing edges into successor SCCs. - for (int u : nodes_in_s) { - const int begin = max_size_at_node_begin[u]; - if (begin < 0) { - continue; - } - const int prefix = begin + mis_for_stack[u]; - if (prefix > max_size) { - max_size = prefix; - } - const int exit_val = begin + is_for_stack[u]; - for (int v : next_ids_inter[u]) { - if (exit_val > max_size_at_node_begin[v]) { - max_size_at_node_begin[v] = exit_val; - } - } - } - } - - rep_results[rep_sid] = {max_size, has_positive_loop}; - } - - // Broadcast representative results to every active stack. - for (int sid = 0; sid < num_stacks; sid++) { - AdStackAllocaStmt *stack = stacks[sid]; - if (!stack_active[sid]) { - // No push/pop in the CFG: the DP would visit reachable nodes with all-zero edge weights and settle with - // `max_size = 0`, no positive loop. Reproduce that result directly. - QD_WARN("Unused autodiff stack {} should have been eliminated.", stack->name()); - continue; - } - const DPResult &res = rep_results[stack_to_rep[sid]]; - if (res.has_positive_loop) { - // Leave `max_size = 0` so the structural bounded-loop pre-pass in `irpass::determine_ad_stack_size` gets a - // chance to derive a symbolic bound (a statically-bounded inner loop whose push-only body defeats the - // longest-path computation, resolved from the outer ranges). If it also cannot, the caller emits a hard - // compile error - there is no compile-time `default_ad_stack_size` fallback. - } else { - // Since we use |max_size| == 0 for adaptive sizes, we do not want stacks with maximum capacity indeed equal - // to 0. - QD_WARN_IF(res.max_size == 0, "Unused autodiff stack {} should have been eliminated.", stack->name()); - stack->max_size = res.max_size; - } - } -} - } // namespace quadrants::lang diff --git a/quadrants/ir/control_flow_graph.h b/quadrants/ir/control_flow_graph.h index c1f01d7d5e..a6834d9103 100644 --- a/quadrants/ir/control_flow_graph.h +++ b/quadrants/ir/control_flow_graph.h @@ -9,12 +9,21 @@ namespace quadrants::lang { class Function; /** - * A basic block in control-flow graph. - * A CFGNode contains a reference to a part of the CHI IR, or more precisely, - * an interval of statements in a Block. - * The edges in the graph are stored in |prev| and |next|. The control flow is - * possible to go from any node in |prev| to this node, and is possible to go - * from this node to any node in |next|. + * A basic block in the control-flow graph (one node). + * + * A CFGNode references an interval of statements in a Block: + * `block->statements[i]` for `i in [begin_location, end_location)`. + * The graph edges are stored in `prev` and `next`: control may flow from any + * node in `prev` into this node, and out of this node into any node in `next`. + * + * Scope of the methods on this class: each analysis/transform method on + * CFGNode does only the *intra-block* work for that pass (e.g. + * `reaching_definition_analysis` computes `reach_gen` / `reach_kill` from this + * node's statements; `store_to_load_forwarding` rewrites loads/stores within + * this node's slice of the block). The same-named method on + * `ControlFlowGraph` is the *whole-graph driver*: it calls this per-node + * method on every node, then runs the worklist fixpoint or post-processing + * that needs cross-node state. */ class CFGNode { public: @@ -28,7 +37,7 @@ class CFGNode { }; private: - // For accelerating get_store_forwarding_data() + // For accelerating find_forwardable_store_value() std::unordered_set parent_blocks_; public: @@ -79,22 +88,110 @@ class CFGNode { static bool contain_variable(const std::unordered_map &var_set, Stmt *var); static bool may_contain_variable(const std::unordered_set &var_set, Stmt *var); static bool may_contain_variable(const std::unordered_map &var_set, Stmt *var); - bool reach_kill_variable(Stmt *var) const; - Stmt *get_store_forwarding_data(Stmt *var, int position) const; + bool is_reach_killed(Stmt *var) const; + Stmt *find_forwardable_store_value(Stmt *var, int position) const; - // Analyses and optimizations inside a CFGNode. + // Per-node (intra-block) analyses and transforms. Each is driven across the + // whole graph by the same-named method on ControlFlowGraph; see below. + // Per-node `reaching_definition_analysis`: populate this node's reach_gen / + // reach_kill from its statements. void reaching_definition_analysis(bool after_lower_access); + // Per-node `store_to_load_forwarding`: rewrite loads/stores within this + // node's statement range. Returns true if any IR change was made. bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); + // Per-node helper for `ControlFlowGraph::gather_loaded_snodes`: append this + // node's loaded SNodes into `snodes`. void gather_loaded_snodes(std::unordered_set &snodes) const; + // Per-node `live_variable_analysis`: populate this node's live_gen / + // live_kill from its statements. void live_variable_analysis(bool after_lower_access); + // Per-node `dead_store_elimination`: erase dead stores / weaken dead + // atomics within this node's statement range. Returns true if any IR change + // was made. bool dead_store_elimination(bool after_lower_access); + + private: + // Helper for find_forwardable_store_value: is |stmt| visible at |position| + // inside this node's block? A stmt is visible if it lives in the same block + // and precedes |position|, or if its parent block is an ancestor of + // |this->block|. + bool is_visible_at(Stmt *stmt, int position) const; + + // Helper for find_forwardable_store_value: incorporate |stmt|, a definition in + // the UD-chain of the variable being forwarded, into the running |result| / + // |result_visible| state. Returns false to signal that forwarding must + // abort (the caller should return nullptr), true to continue scanning. + bool fold_definition_into_result(Stmt *stmt, + int position, + Stmt *&result, + bool &result_visible) const; + + // Helper for find_forwardable_store_value: walk this node's block backwards + // from |position| and return the index of the most recent store to |var|, + // or -1 if none is in this block. Handles the quant-store exclusion plus the + // MatrixInitStmt-via-MatrixPtrStmt forwarding special case. + int find_intra_block_last_def(Stmt *var, int position) const; + + // Helper for find_forwardable_store_value: scan |reach_in| and |reach_gen| for + // definitions of |var| reaching |position|, folding each into |result| / + // |result_visible| via fold_definition_into_result. Returns nullopt if any + // visited def is unforwardable (caller must return nullptr); otherwise the + // last_def_position (0 if only reach_in matched, an in-block index if + // reach_gen matched, -1 if no eligible def was found). + std::optional find_cross_block_def(Stmt *var, + int position, + Stmt *&result, + bool &result_visible) const; + + // Helper for find_forwardable_store_value: scan block statements in + // [from, to_exclusive) for a store that may write a different value to an + // address aliasing |var|. Returns true iff such a store exists (so + // forwarding |result| must abort). The check is skipped (returns false) for + // non-tensor alloca destinations, where aliasing through MatrixPtrStmt + // cannot apply. + bool any_aliased_store_breaks_forwarding(Stmt *result, Stmt *var, int from, int to_exclusive) const; + + // Helper for store_to_load_forwarding: if |stmt| is a load whose source has + // a forwardable preceding store, replace it. Returns true iff the load was + // handled (the caller must skip identical-store elimination for this stmt + // even if no IR change happened, matching the legacy fall-through). |i| is + // decremented on erase so the for-loop's natural `i++` lands on the next + // unread stmt; |modified| is flipped on erase only (preserved from legacy: + // replace-with-zero does not flip it). + bool try_forward_load_at(int &i, Stmt *stmt, bool after_lower_access, bool autodiff_enabled, bool &modified); + + // Helper for store_to_load_forwarding: if |stmt| is a store identical to a + // preceding store to the same address, erase it. Handles the alloca-init-0 + // special case under non-autodiff. Same |i|/|modified| accounting as + // try_forward_load_at. + void try_eliminate_identical_store_at(int &i, + Stmt *stmt, + bool after_lower_access, + bool autodiff_enabled, + bool &modified); }; +/** + * The whole control-flow graph (all nodes plus their edges). + * + * Owns the `nodes` vector, the synthetic entry node (`start_node`, always + * empty), and the synthetic exit node (`final_node`, always empty). Each + * analysis/transform method on this class is the *whole-graph driver*: it + * invokes the same-named per-node method on `CFGNode` for every node, then + * runs the worklist fixpoint (RD / LV) or per-node IR rewrite loop (S2L / + * DSE). `determine_ad_stack_size` has no per-node counterpart. + */ class ControlFlowGraph { private: // Erase an empty node. void erase(int node_id); + // Assert structural invariants that every whole-graph driver assumes: `nodes` non-empty, + // `start_node` and `final_node` in range, both endpoints actually allocated. Called from each + // public driver. Cheap (a handful of comparisons); leave on even in release builds since the + // alternative on violation is silent corruption. + void assert_structural_invariants() const; + public: struct LiveVarAnalysisConfig { // This is mostly useful for SFG task-level dead store elimination. SFG may @@ -122,22 +219,26 @@ class ControlFlowGraph { const std::string &suffix = "") const; /** - * Perform reaching definition analysis using the worklist algorithm, - * and store the results in CFGNodes. + * Whole-graph driver: reaching-definition analysis via worklist fixpoint. + * Seeds the entry node's reach_gen with external-input pointers, calls + * `CFGNode::reaching_definition_analysis` per node, then converges + * reach_in/reach_out. Results are stored on each CFGNode. * https://en.wikipedia.org/wiki/Reaching_definition * * @param after_lower_access - * When after_lower_access is true, only consider local variables (allocas). + * When true, only consider local variables (allocas). */ void reaching_definition_analysis(bool after_lower_access); /** - * Perform live variable analysis using the worklist algorithm, - * and store the results in CFGNodes. + * Whole-graph driver: live-variable analysis via worklist fixpoint (run + * backwards). Seeds the exit node's live_gen with kernel-escaping stores, + * calls `CFGNode::live_variable_analysis` per node, then converges + * live_in/live_out. Results are stored on each CFGNode. * https://en.wikipedia.org/wiki/Live_variable_analysis * * @param after_lower_access - * When after_lower_access is true, only consider local variables (allocas). + * When true, only consider local variables (allocas). * @param config_opt * The set of SNodes which is never loaded after this task. */ @@ -153,12 +254,19 @@ class ControlFlowGraph { bool unreachable_code_elimination(); /** - * Perform store-to-load forwarding and identical store elimination. + * Whole-graph driver: store-to-load forwarding + identical-store elimination. + * Calls `CFGNode::store_to_load_forwarding` on every node and ORs the + * per-node "did we change the IR" return. Caller is responsible for + * (re)running `reaching_definition_analysis` first to populate reach_in / + * reach_out. */ bool store_to_load_forwarding(bool after_lower_access, bool autodiff_enabled); /** - * Perform dead store elimination and identical load elimination. + * Whole-graph driver: dead-store elimination + identical-load elimination. + * Calls `CFGNode::dead_store_elimination` on every node and ORs the + * per-node return. Caller is responsible for (re)running + * `live_variable_analysis` first to populate live_in / live_out. */ bool dead_store_elimination(bool after_lower_access, const std::optional &lva_config_opt); diff --git a/quadrants/ir/determine_ad_stack_size.cpp b/quadrants/ir/determine_ad_stack_size.cpp new file mode 100644 index 0000000000..ca416b1a3a --- /dev/null +++ b/quadrants/ir/determine_ad_stack_size.cpp @@ -0,0 +1,665 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "quadrants/common/exceptions.h" +#include "quadrants/ir/control_flow_graph.h" +#include "quadrants/ir/statements.h" + +// =========================================================================== +// Adaptive AD-stack sizing. +// +// This file implements ControlFlowGraph::determine_ad_stack_size() and its +// file-local helpers. Pulled out of control_flow_graph.cpp because it is a +// self-contained ~500-line algorithm (Tarjan SCC condensation + per-rep DAG +// DP + fingerprint dedup) that has no per-node CFGNode counterpart and is +// not used by any other pass in this file. +// +// See the docstring on ControlFlowGraph::determine_ad_stack_size() below for +// the "what problem this solves / how" overview. +// =========================================================================== + +namespace quadrants::lang { + +namespace { + +// === Helpers for ControlFlowGraph::determine_ad_stack_size === + +struct AdStackIndex { + // AdStackAllocaStmt* -> contiguous int id, populated only for adaptive stacks + // (max_size == 0) that have not yet been resolved by an earlier pass. + std::unordered_map stack_id; + std::vector stacks; +}; + +AdStackIndex collect_adaptive_ad_stacks(const std::vector> &nodes) { + AdStackIndex idx; + for (const auto &node : nodes) { + for (int j = node->begin_location; j < node->end_location; j++) { + Stmt *stmt = node->block->statements[j].get(); + auto *stack = stmt->cast(); + if (!stack || !stack->is_adaptive()) { + continue; + } + if (idx.stack_id.emplace(stack, static_cast(idx.stacks.size())).second) { + idx.stacks.push_back(stack); + } + } + } + return idx; +} + +struct AdStackPerNodeSizes { + // [stack_id][node_id]. `max_increased_size[s][j]` is the maximum (pushes - pops) of stack |s| + // among all prefixes of CFGNode |j|; `increased_size[s][j]` is the net (pushes - pops) in the + // whole node. Indexed by contiguous stack id so the per-stack DP reads cheap vectors instead of + // hashing. + std::vector> increased_size; + std::vector> max_increased_size; + // True iff the stack actually appears in any push/pop in the CFG. Inactive stacks would settle + // at `max_size = 0` regardless and are short-circuited below to reproduce the original "Unused + // autodiff stack" warning. + std::vector stack_active; +}; + +AdStackPerNodeSizes accumulate_per_stack_per_node_size_deltas( + const std::vector> &nodes, + const std::unordered_map &stack_id, + int num_stacks) { + const int num_nodes = static_cast(nodes.size()); + AdStackPerNodeSizes out; + out.increased_size.assign(num_stacks, std::vector(num_nodes, 0)); + out.max_increased_size.assign(num_stacks, std::vector(num_nodes, 0)); + out.stack_active.assign(num_stacks, false); + for (int i = 0; i < num_nodes; i++) { + for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { + Stmt *stmt = nodes[i]->block->statements[j].get(); + AdStackAllocaStmt *stack = nullptr; + int delta = 0; + if (auto *push = stmt->cast()) { + stack = push->stack->as(); + delta = +1; + } else if (auto *pop = stmt->cast()) { + stack = pop->stack->as(); + delta = -1; + } else { + continue; + } + if (!stack->is_adaptive()) { + continue; + } + auto it = stack_id.find(stack); + QD_ASSERT(it != stack_id.end()); + const int sid = it->second; + out.stack_active[sid] = true; + int &cur = out.increased_size[sid][i]; + cur += delta; + if (cur > out.max_increased_size[sid][i]) { + out.max_increased_size[sid][i] = cur; + } + } + } + return out; +} + +// Precompute outgoing-edge node ids once per node so the per-stack DP walks an int vector instead +// of hashing `node_ids[next_node]` on every traversal. +std::vector> compute_outgoing_node_ids(const std::vector> &nodes) { + const int num_nodes = static_cast(nodes.size()); + std::unordered_map node_ids; + node_ids.reserve(num_nodes); + for (int i = 0; i < num_nodes; i++) { + node_ids[nodes[i].get()] = i; + } + std::vector> next_ids(num_nodes); + for (int i = 0; i < num_nodes; i++) { + auto &dst = next_ids[i]; + dst.reserve(nodes[i]->next.size()); + for (auto *next_node : nodes[i]->next) { + dst.push_back(node_ids[next_node]); + } + } + return next_ids; +} + +struct TarjanResult { + std::vector scc_id; // node -> SCC index + std::vector dfs_finish; // node -> DFS post-order index; ancestors finish AFTER descendants + std::vector> scc_nodes; // SCC -> list of node ids +}; + +// Iterative Tarjan SCC. Emits SCCs in reverse topological order (sources end up at the largest +// indices). Also records DFS post-order finish times, used downstream to split each cyclic SCC's +// intra-edges into forward vs back without a second edge-classification pass. +TarjanResult tarjan_scc(const std::vector> &next_ids) { + const int num_nodes = static_cast(next_ids.size()); + TarjanResult out; + out.scc_id.assign(num_nodes, -1); + out.dfs_finish.assign(num_nodes, -1); + + std::vector tarjan_index(num_nodes, -1); + std::vector tarjan_lowlink(num_nodes, 0); + std::vector on_stack(num_nodes, 0); + std::vector tarjan_stack; + tarjan_stack.reserve(num_nodes); + std::vector> dfs_stack; + dfs_stack.reserve(num_nodes); + int next_index = 0; + int next_finish = 0; + int next_scc = 0; + for (int origin = 0; origin < num_nodes; origin++) { + if (tarjan_index[origin] != -1) { + continue; + } + tarjan_index[origin] = next_index; + tarjan_lowlink[origin] = next_index; + next_index++; + tarjan_stack.push_back(origin); + on_stack[origin] = 1; + dfs_stack.emplace_back(origin, 0); + while (!dfs_stack.empty()) { + auto &frame = dfs_stack.back(); + const int u = frame.first; + const auto &nb = next_ids[u]; + if (frame.second < nb.size()) { + const int v = nb[frame.second++]; + if (tarjan_index[v] == -1) { + tarjan_index[v] = next_index; + tarjan_lowlink[v] = next_index; + next_index++; + tarjan_stack.push_back(v); + on_stack[v] = 1; + dfs_stack.emplace_back(v, 0); + } else if (on_stack[v]) { + if (tarjan_index[v] < tarjan_lowlink[u]) { + tarjan_lowlink[u] = tarjan_index[v]; + } + } + } else { + if (tarjan_lowlink[u] == tarjan_index[u]) { + std::vector component; + while (true) { + const int w = tarjan_stack.back(); + tarjan_stack.pop_back(); + on_stack[w] = 0; + out.scc_id[w] = next_scc; + component.push_back(w); + if (w == u) { + break; + } + } + out.scc_nodes.push_back(std::move(component)); + next_scc++; + } + out.dfs_finish[u] = next_finish++; + dfs_stack.pop_back(); + if (!dfs_stack.empty()) { + const int parent = dfs_stack.back().first; + if (tarjan_lowlink[u] < tarjan_lowlink[parent]) { + tarjan_lowlink[parent] = tarjan_lowlink[u]; + } + } + } + } + } + return out; +} + +struct SccEdgeSets { + // Each intra-SCC edge is either "forward" in the DFS spanning sense (target finishes before + // source, so source can be relaxed before target with a single topological pass) or "back" + // (target finishes at or after source, closing a cycle). The forward set drives the per-stack + // DAG dynamic programming; back-edges only need a single post-DP relaxation check to detect + // positive cycles. Inter-SCC edges always relax forward in topological order. + std::vector> next_ids_intra_fwd; + std::vector> next_ids_intra_back; + std::vector> next_ids_inter; +}; + +SccEdgeSets classify_scc_edges(const std::vector> &next_ids, + const std::vector &scc_id, + const std::vector &dfs_finish) { + const int num_nodes = static_cast(next_ids.size()); + SccEdgeSets out; + out.next_ids_intra_fwd.assign(num_nodes, {}); + out.next_ids_intra_back.assign(num_nodes, {}); + out.next_ids_inter.assign(num_nodes, {}); + for (int u = 0; u < num_nodes; u++) { + const int su = scc_id[u]; + for (int v : next_ids[u]) { + if (scc_id[v] == su) { + if (dfs_finish[v] < dfs_finish[u]) { + out.next_ids_intra_fwd[u].push_back(v); + } else { + out.next_ids_intra_back[u].push_back(v); + } + } else { + out.next_ids_inter[u].push_back(v); + } + } + } + return out; +} + +struct CyclicSccInfo { + std::vector scc_is_cyclic; // 1 iff SCC contains a cycle + std::vector> scc_topo; // for cyclic SCCs: nodes sorted by descending dfs_finish +}; + +// An SCC is cyclic iff |S| > 1 (any two nodes in a non-trivial SCC lie on a cycle) or |S| == 1 +// with a self-loop edge. For cyclic SCCs we precompute the topological ordering of their nodes +// so the per-stack DAG DP visits each node once with all forward predecessors already finalized. +CyclicSccInfo identify_cyclic_sccs_and_topo(const std::vector> &scc_nodes, + const std::vector> &next_ids_intra_back, + const std::vector &dfs_finish) { + const int num_sccs = static_cast(scc_nodes.size()); + CyclicSccInfo out; + out.scc_is_cyclic.assign(num_sccs, 0); + out.scc_topo.assign(num_sccs, {}); + for (int s = 0; s < num_sccs; s++) { + const auto &nodes_in_s = scc_nodes[s]; + if (nodes_in_s.size() > 1) { + out.scc_is_cyclic[s] = 1; + } else { + // Self-loops are classified as back-edges above, so a singleton SCC is cyclic iff it has a + // back-edge pointing at itself. + const int n = nodes_in_s[0]; + for (int v : next_ids_intra_back[n]) { + if (v == n) { + out.scc_is_cyclic[s] = 1; + break; + } + } + } + if (out.scc_is_cyclic[s]) { + struct DfsFinishGreater { + const std::vector &dfs_finish; + bool operator()(int a, int b) const { + return dfs_finish[a] > dfs_finish[b]; + } + }; + auto topo = nodes_in_s; + std::sort(topo.begin(), topo.end(), DfsFinishGreater{dfs_finish}); + out.scc_topo[s] = std::move(topo); + } + } + return out; +} + +struct FingerprintGroups { + std::vector stack_to_rep; // stack id -> representative stack id (whose DP run it shares) + std::vector rep_stack_ids; // representative stack ids that actually run the DP +}; + +// Group AD-stacks whose per-node (increased_size, max_increased_size) rows are bit-identical so +// the DP runs once per equivalence class instead of once per stack. In practice many AD-stacks in +// the same kernel share their push/pop schedule (one alloca per autodiff variable in the same +// loop body), and the DP on a large CFG dwarfs the dedup cost. We hash a sparse fingerprint (only +// nodes where the stack has activity) and group by it. Worst case (no duplicates): one DP run per +// stack, equivalent to running it directly per stack. +FingerprintGroups group_stacks_by_fingerprint(int num_nodes, + int num_stacks, + const std::vector &stack_active, + const std::vector> &increased_size, + const std::vector> &max_increased_size) { + using Fingerprint = std::vector>; // (node_id, is, mis), sorted by node_id + struct FingerprintHash { + std::size_t operator()(const Fingerprint &f) const noexcept { + // FNV-1a mix of three components per fingerprint entry. + constexpr std::size_t fnv_prime = 1099511628211ULL; + std::size_t h = 1469598103934665603ULL; + for (auto &[n, i, m] : f) { + h ^= static_cast(n); + h *= fnv_prime; + h ^= static_cast(static_cast(i)); + h *= fnv_prime; + h ^= static_cast(static_cast(m)); + h *= fnv_prime; + } + return h; + } + }; + std::unordered_map fp_to_rep; + FingerprintGroups out; + out.stack_to_rep.assign(num_stacks, -1); + for (int sid = 0; sid < num_stacks; sid++) { + if (!stack_active[sid]) { + continue; + } + Fingerprint fp; + const auto &is_row = increased_size[sid]; + const auto &mis_row = max_increased_size[sid]; + for (int n = 0; n < num_nodes; n++) { + if (is_row[n] != 0 || mis_row[n] != 0) { + fp.emplace_back(n, is_row[n], mis_row[n]); + } + } + auto [it, inserted] = fp_to_rep.emplace(std::move(fp), sid); + out.stack_to_rep[sid] = it->second; + if (inserted) { + out.rep_stack_ids.push_back(sid); + } + } + return out; +} + +enum class CyclicSccFastPath { + kPositiveCycle, // proven positive cycle exists for this stack inside this SCC + kZeroSpread, // no `is` contribution in the SCC; just spread max entry-side begin value + kFallback, // mixed-sign, run the full DP +}; + +// Sign-based fast paths sidestep the cyclic-SCC dynamic programming when the stack's `is` +// contribution inside this SCC is structurally trivial: +// 1. min_is >= 0 with max_is > 0: every node in the SCC lies on some cycle (SCC property), and +// a cycle through a strictly-positive node with all non-negative `is` along it sums to a +// positive value, so a positive cycle exists for this stack. This is the autodiff +// push-only-in-SCC pattern. +// 2. min_is == max_is == 0: no `is` contribution at all in this SCC; the DP would only spread +// the maximum entry-side `max_size_at_node_begin` value to every node in O(|S|). +CyclicSccFastPath classify_cyclic_scc_fast_path(const std::vector &nodes_in_s, + const std::vector &is_for_stack) { + int min_is = INT_MAX; + int max_is = INT_MIN; + for (int u : nodes_in_s) { + const int v = is_for_stack[u]; + if (v < min_is) { + min_is = v; + } + if (v > max_is) { + max_is = v; + } + } + if (min_is >= 0 && max_is > 0) { + return CyclicSccFastPath::kPositiveCycle; + } + if (min_is == 0 && max_is == 0) { + return CyclicSccFastPath::kZeroSpread; + } + return CyclicSccFastPath::kFallback; +} + +// Spread the maximum entry-side `max_size_at_node_begin` value across every node in the SCC. +// Equivalent to running the DP with all-zero `is` weights, in O(|S|). +void spread_max_begin_over_zero_scc(const std::vector &nodes_in_s, + std::vector &max_size_at_node_begin) { + int max_begin = -1; + for (int u : nodes_in_s) { + if (max_size_at_node_begin[u] > max_begin) { + max_begin = max_size_at_node_begin[u]; + } + } + if (max_begin < 0) { + return; + } + for (int u : nodes_in_s) { + if (max_size_at_node_begin[u] < max_begin) { + max_size_at_node_begin[u] = max_begin; + } + } +} + +// Mixed-sign case: single-pass dynamic programming on the SCC's forward edges (processed in +// descending DFS finish-time so every forward predecessor is finalized before its successor +// relaxes), followed by one relaxation check on the back-edges. Correctness: every walk inside +// the SCC decomposes into a forward path plus zero or more closed cycles (back-edge + forward +// path). For SCCs with no positive cycle, traversing a cycle adds a non-positive amount to the +// running size and so cannot improve max_size beyond what the forward DP already computed. The +// back-edge relaxation check after the DP detects the only failure mode (some back-edge would +// still improve a forward predecessor's value, which can only happen if a positive cycle exists +// for this stack). Returns true iff a positive cycle was detected. +bool dp_mixed_sign_cyclic_scc(const std::vector &topo, + const std::vector &nodes_in_s, + const std::vector &is_for_stack, + const std::vector> &next_ids_intra_fwd, + const std::vector> &next_ids_intra_back, + std::vector &max_size_at_node_begin) { + for (int u : topo) { + const int begin = max_size_at_node_begin[u]; + if (begin < 0) { + continue; + } + const int exit_val = begin + is_for_stack[u]; + for (int v : next_ids_intra_fwd[u]) { + if (exit_val > max_size_at_node_begin[v]) { + max_size_at_node_begin[v] = exit_val; + } + } + } + for (int u : nodes_in_s) { + const int begin = max_size_at_node_begin[u]; + if (begin < 0) { + continue; + } + const int exit_val = begin + is_for_stack[u]; + for (int v : next_ids_intra_back[u]) { + if (exit_val > max_size_at_node_begin[v]) { + return true; + } + } + } + return false; +} + +// SCC has converged for this stack. Update global `max_size` from each node's max-prefix +// contribution, then relax inter-SCC outgoing edges into successor SCCs (predecessors are already +// finalized when an SCC is entered, so each inter-SCC edge is touched exactly once). +void update_global_max_and_relax_inter_scc(const std::vector &nodes_in_s, + const std::vector &is_for_stack, + const std::vector &mis_for_stack, + const std::vector> &next_ids_inter, + std::vector &max_size_at_node_begin, + int &max_size) { + for (int u : nodes_in_s) { + const int begin = max_size_at_node_begin[u]; + if (begin < 0) { + continue; + } + const int prefix = begin + mis_for_stack[u]; + if (prefix > max_size) { + max_size = prefix; + } + const int exit_val = begin + is_for_stack[u]; + for (int v : next_ids_inter[u]) { + if (exit_val > max_size_at_node_begin[v]) { + max_size_at_node_begin[v] = exit_val; + } + } + } +} + +struct AdStackDPResult { + int max_size; + bool has_positive_loop; +}; + +// Bundle of "graph-shape" inputs to the per-stack DP: same for every stack, computed once before +// the loop. Named-field aggregate construction (designated initializers at the call site) makes +// it a compile error to swap, e.g. `next_ids_intra_fwd` and `next_ids_intra_back` -- the highest- +// risk swap surface in the old 11-positional-int-arg signature this struct replaces. +struct AdStackDPGraph { + int start_node; + int num_sccs; + const std::vector> &scc_nodes; + const std::vector &scc_is_cyclic; + const std::vector> &scc_topo; + const std::vector> &next_ids_intra_fwd; + const std::vector> &next_ids_intra_back; + const std::vector> &next_ids_inter; +}; + +// Bundle of per-stack per-node values for one DP representative. Same shape as the graph +// (`increased_size[i]` and `max_increased_size[i]` are the deltas for CFG node i). Bundling them +// together keeps `(increased_size, max_increased_size)` named at the call site -- before this +// struct, both were `const std::vector &` and trivial to swap. +struct AdStackPerNodeRow { + const std::vector &increased_size; // net (pushes - pops) inside each node + const std::vector &max_increased_size; // max running (pushes - pops) prefix inside each node +}; + +// Run the per-representative DP. Walks the SCC condensation in topological order (sources first, +// since Tarjan emits SCCs in reverse-topological order so source SCCs end up at the largest +// indices). `max_size_at_node_begin` is taken by reference as a scratch buffer reused across reps +// to avoid reallocating per iteration. +AdStackDPResult run_ad_stack_size_dp_for_representative(const AdStackDPGraph &graph, + const AdStackPerNodeRow &row, + std::vector &max_size_at_node_begin) { + std::fill(max_size_at_node_begin.begin(), max_size_at_node_begin.end(), -1); + max_size_at_node_begin[graph.start_node] = 0; + int max_size = 0; + for (int s = graph.num_sccs - 1; s >= 0; s--) { + const auto &nodes_in_s = graph.scc_nodes[s]; + if (graph.scc_is_cyclic[s]) { + switch (classify_cyclic_scc_fast_path(nodes_in_s, row.increased_size)) { + case CyclicSccFastPath::kPositiveCycle: + return {max_size, /*has_positive_loop=*/true}; + case CyclicSccFastPath::kZeroSpread: + spread_max_begin_over_zero_scc(nodes_in_s, max_size_at_node_begin); + break; + case CyclicSccFastPath::kFallback: + if (dp_mixed_sign_cyclic_scc(graph.scc_topo[s], nodes_in_s, row.increased_size, graph.next_ids_intra_fwd, + graph.next_ids_intra_back, max_size_at_node_begin)) { + return {max_size, /*has_positive_loop=*/true}; + } + break; + } + } + update_global_max_and_relax_inter_scc(nodes_in_s, row.increased_size, row.max_increased_size, + graph.next_ids_inter, max_size_at_node_begin, max_size); + } + return {max_size, /*has_positive_loop=*/false}; +} + +// Broadcast per-representative DP results to every active stack and apply the resolved +// `max_size`. Stacks with positive cycles are left at `max_size = 0` so the structural +// bounded-loop pre-pass in `irpass::determine_ad_stack_size` gets a chance to derive a symbolic +// bound; if it also cannot, the caller emits a hard compile error (there is no compile-time +// `default_ad_stack_size` fallback). +void apply_ad_stack_dp_results(const std::vector &stacks, + const std::vector &stack_active, + const std::vector &stack_to_rep, + const std::unordered_map &rep_results) { + const int num_stacks = static_cast(stacks.size()); + for (int sid = 0; sid < num_stacks; sid++) { + AdStackAllocaStmt *stack = stacks[sid]; + if (!stack_active[sid]) { + // No push/pop in the CFG: the DP would visit reachable nodes with all-zero edge weights and + // settle with `max_size = 0`, no positive loop. Reproduce that result directly. + QD_WARN("Unused autodiff stack {} should have been eliminated.", stack->name()); + continue; + } + const AdStackDPResult &res = rep_results.at(stack_to_rep[sid]); + if (res.has_positive_loop) { + // Leave `max_size = 0` so the symbolic-bound pre-pass can take over. + continue; + } + // Since we use |max_size| == 0 for adaptive sizes, we do not want stacks with maximum capacity + // indeed equal to 0. + QD_WARN_IF(res.max_size == 0, "Unused autodiff stack {} should have been eliminated.", stack->name()); + stack->max_size = res.max_size; + } +} + +} // namespace + +void ControlFlowGraph::determine_ad_stack_size() { + /** + * What problem this solves + * ------------------------ + * Reverse-mode autodiff records each primal write to a variable by pushing the value onto a + * per-variable stack (an "AD-stack"), then pops in reverse order during the backward pass. The + * stack's *capacity* must be known at allocation time. Some stacks are sized explicitly by the + * user (`max_size != 0`); the rest are "adaptive" (`max_size == 0`) and we have to figure out + * how many pushes can stack up on them at runtime by walking the CFG and bookkeeping pushes + * minus pops along every reachable path from kernel entry. + * + * For each adaptive AD-stack we want: the maximum value of (pushes - pops) over any walk from + * the entry node. Set that as `max_size`. If the kernel has a loop where pushes > pops on every + * iteration the maximum is unbounded; we leave `max_size = 0` and let the caller's structural + * bounded-loop pre-pass derive a symbolic `SizeExpr` from the loop ranges (and hard-error if + * even that fails -- there is no compile-time default fallback). + * + * High-level algorithm + * -------------------- + * 1. Index every adaptive AD-stack (`collect_adaptive_ad_stacks`). + * 2. For every (stack, CFG node) pair, precompute the net push/pop delta inside that node and + * the max running push-count prefix inside that node + * (`accumulate_per_stack_per_node_size_deltas`). After this point the actual stmts are + * irrelevant; everything below operates on int matrices. + * 3. Condense the CFG with Tarjan's SCC algorithm (`tarjan_scc`). Inside any cyclic SCC, the + * "max walk-sum from entry" question becomes "is there a positive cycle, and if not, what's + * the DAG longest path inside this SCC?". The condensation lets us solve those questions + * once per SCC, in topological order across SCCs. + * 4. For each *distinct* push/pop fingerprint across stacks (`group_stacks_by_fingerprint`), + * run the DP once: walk SCCs in topological order, per cyclic SCC try two cheap sign-based + * fast paths and fall back to a single-pass DAG DP + back-edge relaxation check for cycle + * detection. Broadcast the result to every stack that shares this fingerprint + * (`apply_ad_stack_dp_results`). In practice many autodiff kernels generate identical + * push/pop schedules across stacks (one alloca per AD variable in the same loop body), so + * the fingerprint dedup collapses N stacks to a handful of DP runs. + * + * The two cheap fast paths for a cyclic SCC: + * (a) min_is >= 0 and max_is > 0 => positive cycle exists for this stack (every node in an + * SCC is on a cycle; a cycle through a strictly-positive node with non-negative weights + * sums positive). Bail out. + * (b) min_is == 0 == max_is => the SCC adds nothing; just spread the maximum entry- + * side `max_size_at_node_begin` to every node in the SCC in O(|S|). + * The full mixed-sign DP runs only when neither shortcut applies. + * + * Performance summary + * ------------------- + * Per representative cost: O(V + E + sum_{cyclic SCC S} |S| * |E_S|), with the inner sum + * collapsing to O(|S| + |E_S|) for the common autodiff push/pop pattern (one DP sweep + one + * back-edge check). Overall: O(V + E + R * (V + E)), with R = number of distinct row-pair + * fingerprints (typically << number of stacks). Per-stack per-node tables are dense + * `vector>` indexed by contiguous int stack id, not `unordered_map`, to keep the + * hot inner loop branch-and-cache friendly. + */ + assert_structural_invariants(); + AdStackIndex idx = collect_adaptive_ad_stacks(nodes); + const int num_stacks = static_cast(idx.stacks.size()); + if (num_stacks == 0) { + return; + } + + const int num_nodes = size(); + AdStackPerNodeSizes sizes = accumulate_per_stack_per_node_size_deltas(nodes, idx.stack_id, num_stacks); + std::vector> next_ids = compute_outgoing_node_ids(nodes); + + TarjanResult tarjan = tarjan_scc(next_ids); + const int num_sccs = static_cast(tarjan.scc_nodes.size()); + SccEdgeSets edges = classify_scc_edges(next_ids, tarjan.scc_id, tarjan.dfs_finish); + CyclicSccInfo cyc = identify_cyclic_sccs_and_topo(tarjan.scc_nodes, edges.next_ids_intra_back, tarjan.dfs_finish); + + FingerprintGroups groups = group_stacks_by_fingerprint(num_nodes, num_stacks, sizes.stack_active, sizes.increased_size, + sizes.max_increased_size); + + std::vector max_size_at_node_begin(num_nodes); + std::unordered_map rep_results; + rep_results.reserve(groups.rep_stack_ids.size()); + const AdStackDPGraph graph{ + .start_node = start_node, + .num_sccs = num_sccs, + .scc_nodes = tarjan.scc_nodes, + .scc_is_cyclic = cyc.scc_is_cyclic, + .scc_topo = cyc.scc_topo, + .next_ids_intra_fwd = edges.next_ids_intra_fwd, + .next_ids_intra_back = edges.next_ids_intra_back, + .next_ids_inter = edges.next_ids_inter, + }; + for (int rep_sid : groups.rep_stack_ids) { + const AdStackPerNodeRow row{ + .increased_size = sizes.increased_size[rep_sid], + .max_increased_size = sizes.max_increased_size[rep_sid], + }; + rep_results[rep_sid] = run_ad_stack_size_dp_for_representative(graph, row, max_size_at_node_begin); + } + + apply_ad_stack_dp_results(idx.stacks, sizes.stack_active, groups.stack_to_rep, rep_results); +} + +} // namespace quadrants::lang diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 49c6de35fb..a78d8d046c 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -1633,8 +1633,13 @@ class InternalFuncStmt : public Stmt { */ class AdStackAllocaStmt : public Stmt { public: + // Sentinel value of `max_size` meaning "size has not been resolved yet; treat as adaptive". + // Used as a state marker in multiple passes (see `is_adaptive()`); do not redefine to a real + // capacity even if 0-sized adstacks were valid. + static constexpr std::size_t kAdaptiveSentinel = 0; + DataType dt; - std::size_t max_size{0}; // 0 = adaptive + std::size_t max_size{kAdaptiveSentinel}; // Compile-time captured symbolic expression for `max_size`, populated by // `determine_ad_stack_size` when the bound is derivable from constants and scalar field loads. // Host-evaluated pre-launch to size the adstack heap; null until the pre-pass runs. @@ -1648,6 +1653,14 @@ class AdStackAllocaStmt : public Stmt { QD_STMT_REG_FIELDS; } + // True iff this stack's capacity has not yet been resolved by `determine_ad_stack_size` (or any + // prior pass that may seed it). Prefer this over comparing `max_size` to a literal `0` -- it + // keeps the "0 means adaptive" sentinel encapsulated and lets a future change to the sentinel + // (e.g. to a typed enum) touch only this class. + bool is_adaptive() const { + return max_size == kAdaptiveSentinel; + } + std::size_t element_size_in_bytes() const { return data_type_size(ret_type); } diff --git a/quadrants/transforms/determine_ad_stack_size.cpp b/quadrants/transforms/determine_ad_stack_size.cpp index 203d72a48b..13536e5e84 100644 --- a/quadrants/transforms/determine_ad_stack_size.cpp +++ b/quadrants/transforms/determine_ad_stack_size.cpp @@ -1313,7 +1313,7 @@ bool size_expr_contains_inner_domain_enumeration(const SizeExpr *e) { bool determine_ad_stack_size(IRNode *root, const CompileConfig &config) { auto adaptive_allocas = irpass::analysis::gather_statements(root, [&](Stmt *s) { auto *ad_stack = s->cast(); - return ad_stack != nullptr && ad_stack->max_size == 0; + return ad_stack != nullptr && ad_stack->is_adaptive(); }); if (adaptive_allocas.empty()) { return false; @@ -1339,7 +1339,7 @@ bool determine_ad_stack_size(IRNode *root, const CompileConfig &config) { std::unordered_map alloca_cycle_detected; for (Stmt *s : adaptive_allocas) { auto *alloca = s->as(); - if (alloca->max_size != 0) { + if (!alloca->is_adaptive()) { continue; // Already resolved by Bellman-Ford in phase 1. } t_cycle_detected = false; @@ -1368,7 +1368,7 @@ bool determine_ad_stack_size(IRNode *root, const CompileConfig &config) { // uniform representation regardless of which phase resolved the bound. for (Stmt *s : adaptive_allocas) { auto *alloca = s->as(); - if (!alloca->size_expr && alloca->max_size != 0) { + if (!alloca->size_expr && !alloca->is_adaptive()) { alloca->size_expr = SizeExpr::make_const(static_cast(alloca->max_size)); } } @@ -1381,7 +1381,7 @@ bool determine_ad_stack_size(IRNode *root, const CompileConfig &config) { // `/ir_adstack_unresolved/unresolved_alloca_.ll` for offline inspection. for (Stmt *s : adaptive_allocas) { auto *alloca = s->as(); - if (alloca->max_size != 0 || alloca->size_expr) { + if (!alloca->is_adaptive() || alloca->size_expr) { continue; } std::string dump_hint;