From a8a778b9da1d414a6a176c62b5d921052baf86c6 Mon Sep 17 00:00:00 2001 From: Sri Hari Krishna Narayanan Date: Sat, 24 Jan 2026 16:57:24 -0600 Subject: [PATCH] Internediate commit --- .../StableHLOAutoDiffOpInterfaceImpl.cpp | 1000 ++++++++++++++++- 1 file changed, 969 insertions(+), 31 deletions(-) diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 16d8c491d2..2694825e8a 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -482,7 +482,7 @@ class AutoDiffWhileRev : public ReverseAutoDiffOpInterface::ExternalModel { - enum ReverseMode { CONSTANT, CONSTANT_CHECKPOINTING, UNKNOWN }; + enum ReverseMode { CONSTANT, CONSTANT_CHECKPOINTING, REVOLVE_CHECKPOINTING, UNKNOWN }; struct ReverseModeInfo { enum ReverseMode mode = UNKNOWN; WhileLoopInfo info; @@ -498,8 +498,13 @@ class AutoDiffWhileRev const char *checkpointAttrName = "enzymexla.enable_checkpointing"; auto enableCheckpointing = orig->getAttrOfType(checkpointAttrName); + const char *binomialCheckpointAttrName = "enzymexla.enable_binomial_checkpointing"; + auto enableBinomialCheckpointing = + orig->getAttrOfType(binomialCheckpointAttrName); if (enableCheckpointing && enableCheckpointing.getValue()) revInfo.mode = CONSTANT_CHECKPOINTING; + else if (enableBinomialCheckpointing && enableBinomialCheckpointing.getValue()) + revInfo.mode = REVOLVE_CHECKPOINTING; else revInfo.mode = CONSTANT; } @@ -515,6 +520,22 @@ class AutoDiffWhileRev makeI64Constant(loc, builder, step), operands); } + static stablehlo::WhileOp makeForLoop(OpBuilder &builder, Location loc, + int64_t start, Value limit, + int64_t step, ValueRange operands) { + return makeForLoop(builder, loc, makeI64Constant(loc, builder, start), + limit, + makeI64Constant(loc, builder, step), operands); + } + + static stablehlo::WhileOp makeForLoop(OpBuilder &builder, Location loc, + Value start, Value limit, + int64_t step, ValueRange operands) { + return makeForLoop(builder, loc, start, + limit, + makeI64Constant(loc, builder, step), operands); + } + static stablehlo::WhileOp makeForLoop(OpBuilder &builder, Location loc, Value start, Value limit, Value step, ValueRange operands) { @@ -750,6 +771,554 @@ class AutoDiffWhileRev return success(!anyFailed); } + static LogicalResult reverseWithCheckpointingRevolve(stablehlo::WhileOp orig, + struct ReverseModeInfo revInfo, + OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches, + ArrayRef operandsActive) { + // return success(true); + (void)caches; // Silence unused variable warning + auto unrankedTensorType = RankedTensorType::get({}, builder.getI64Type()); + auto numItersInit = + builder + .create( + orig->getLoc(), unrankedTensorType, + SplatElementsAttr::get( + unrankedTensorType, + ArrayRef(IntegerAttr::get(builder.getI64Type(), 0)))) + .getResult(); + + SetVector outsideRefs; + getUsedValuesDefinedAbove(orig->getRegions(), outsideRefs); + (void)outsideRefs; // TODO: use in REVOLVE_CHECKPOINTING implementation + // caches is used in rvStore action to cache state during reverse pass + + SmallVector operands; + for (auto [active, res] : llvm::zip(operandsActive, orig->getResults())) { + if (active) { + operands.push_back(gutils->diffe(res, builder)); + if (!gutils->isConstantValue(res)) + gutils->zeroDiffe(res, builder); + } + } + bool anyFailed = false; + // RvInit is used to initialize the revolve internals. + OpBuilder::InsertionGuard guard(builder); + + // Create a while loop that never ends (i =0; i<=1; i+=0) + stablehlo::WhileOp revOuter = + makeForLoop(builder, orig.getLoc(), 0, 1, 0, operands); + + // This is the body of the revolve loop + Block *revOuterBody = &revOuter.getBody().front(); + builder.setInsertionPointToStart(revOuterBody); + + // This is the call to RvNextAction which should go inside while loop + auto rvNextActionCallOp = builder.create( + orig.getLoc(), + RankedTensorType::get({}, builder.getI32Type()), + ValueRange(numItersInit), + builder.getStringAttr("RvNextAction"), + /*has_side_effect*/ nullptr, + /*backend_config*/ nullptr, + /*api_version*/ nullptr, + /*calledcomputations*/ nullptr, + /*operand_layouts*/ nullptr, + /*result_layouts*/ nullptr, + /*output_operand_aliases*/ nullptr); + + auto RvNextActionCountCallOp = builder.create( + orig.getLoc(), + RankedTensorType::get({}, builder.getI32Type()), + ValueRange(numItersInit), + builder.getStringAttr("RvNextActionCount"), + /*has_side_effect*/ nullptr, + /*backend_config*/ nullptr, + /*api_version*/ nullptr, + /*calledcomputations*/ nullptr, + /*operand_layouts*/ nullptr, + /*result_layouts*/ nullptr, + /*output_operand_aliases*/ nullptr); + IRMapping mapping; + Value nextActionVal = rvNextActionCallOp.getResult(0); + Block *origBody = &orig.getBody().front(); + + //============================================================= + // rvStore Action (action == 1) + //============================================================= + /*llvm::errs() << " rvStore Begin \n"; + { + Value cmp = builder.create( + orig.getLoc(), nextActionVal, + makeI32Constant(orig.getLoc(), builder, 1), ComparisonDirection::EQ); + OpBuilder::InsertionGuard guard(builder); + auto newIf = builder.create(orig.getLoc(), TypeRange(), cmp); + Block *ifBody_false = builder.createBlock(&newIf.getFalseBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody_false); + builder.create(orig.getLoc(), ValueRange()); + Block *ifBody = builder.createBlock(&newIf.getTrueBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody); + + // rvStore action in reverse pass: Cache the current state (operands) after + // forward iterations have been run. These are the loop-carried values that + // represent the state at this point in the reverse pass. + // The operands are the loop-carried values from the outer reverse loop + SmallVector currentOperands; + for (auto arg : revOuterBody->getArguments().slice(1)) { + currentOperands.push_back(arg); + } + + // Cache each operand (same as forward pass Store action) + for (Value operand : currentOperands) { + caches.push_back(gutils->initAndPushCache(operand, builder)); + } + + builder.create(orig.getLoc(), ValueRange()); + }*/ + + //============================================================= + // rvForward Action (action == 2) + //============================================================= + llvm::errs() << " rvForward Begin \n"; + { + Value cmp = builder.create( + orig.getLoc(), nextActionVal, + makeI32Constant(orig.getLoc(), builder, 2), ComparisonDirection::EQ); + OpBuilder::InsertionGuard guard(builder); + auto newIf = builder.create(orig.getLoc(), TypeRange(), cmp); + Block *ifBody_false = builder.createBlock(&newIf.getFalseBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody_false); + builder.create(orig.getLoc(), ValueRange()); + Block *ifBody = builder.createBlock(&newIf.getTrueBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody); + + // iterationLimitVal is a runtime Value from the RvNextActionCount custom call + // Convert from i32 to i64 for compatibility with makeForLoop + Value iterationLimitI32 = RvNextActionCountCallOp.getResult(0); + Value iterationLimitVal = builder.create( + orig.getLoc(), RankedTensorType::get({}, builder.getI64Type()), iterationLimitI32); + + auto inner = makeForLoop(builder, orig.getLoc(), 0, iterationLimitVal, 1, operands); + inner->setAttrs(orig->getAttrs()); + inner->removeAttr("enzymexla.enable_binomial_checkpointing"); + Block *innerBody = &inner.getBody().front(); + builder.setInsertionPointToStart(innerBody); + Value currentIV = innerBody->getArgument(0); + + for (auto &&[origarg, innerarg] : llvm::zip_equal( + origBody->getArguments(), innerBody->getArguments())) { + mapping.map(origarg, innerarg); + gutils->originalToNewFn.map(origarg, innerarg); + } + mapping.map(origBody->getArgument(0), currentIV); + gutils->originalToNewFn.map(origBody->getArgument(0), currentIV); + + gutils->originalToNewFnOps[orig] = inner; + builder.setInsertionPointAfter(inner); + + builder.create(orig.getLoc(), ValueRange()); + + builder.setInsertionPointAfter(newIf); + } + + //============================================================= + // rvFirstUTurn Action (action == 3) + //============================================================= + /*llvm::errs() << " rvFirstUTurn Begin \n"; + { + Value cmp = builder.create( + orig.getLoc(), nextActionVal, + makeI32Constant(orig.getLoc(), builder, 3), ComparisonDirection::EQ); + OpBuilder::InsertionGuard guard(builder); + auto newIf = builder.create(orig.getLoc(), TypeRange(), cmp); + Block *ifBody_false = builder.createBlock(&newIf.getFalseBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody_false); + builder.create(orig.getLoc(), ValueRange()); + Block *ifBody = builder.createBlock(&newIf.getTrueBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody); + Block *origBodyLocal = &orig.getBody().front(); + + int revIdx = 1; + SmallVector operandsOuter(revOuterBody->getArguments()); + for (auto &&[active, operand] : llvm::zip_equal( + operandsActive, origBodyLocal->getTerminator()->getOperands())) { + if (active) { + // Use the transformed version of the operand to avoid dominance errors + Value transformedOperand = gutils->getNewFromOriginal(operand); + gutils->addToDiffe(transformedOperand, operandsOuter[revIdx], builder); + revIdx++; + } + } + + { + OpBuilder cacheBuilder(newIf); + auto loc = orig->getLoc(); + auto cacheCreator = [&](Type t) { + Value cache = cacheBuilder.create(loc, t); + return std::make_pair(cache, cache); + }; + gutils->registerCacheCreatorHook(cacheCreator); + + auto rstart = origBodyLocal->rbegin(), rend = origBodyLocal->rend(); + rstart++; + for (auto it = rstart; it != rend; it++) { + Operation *op = &*it; + anyFailed |= gutils->Logic.visitChild(op, builder, gutils).failed(); + } + gutils->deregisterCacheCreatorHook(cacheCreator); + } + + SmallVector newResults; + for (auto &&[active, arg] : + llvm::zip_equal(operandsActive, origBodyLocal->getArguments())) { + if (active) { + newResults.push_back(gutils->diffe(arg, builder)); + if (!gutils->isConstantValue(arg)) + gutils->zeroDiffe(arg, builder); + } + } + + builder.create(orig.getLoc(), ValueRange()); + + builder.setInsertionPointAfter(newIf); + }*/ + + + //============================================================= + // rvUTurn Action (action == 4) + // This executes for single step the forward partial computation + // and one reverse computation + //============================================================= + llvm::errs() << " UTurn Begin \n"; + { + Value cmp = builder.create( + orig.getLoc(), nextActionVal, + makeI32Constant(orig.getLoc(), builder, 4), ComparisonDirection::EQ); + OpBuilder::InsertionGuard guard(builder); + auto newIf = builder.create(orig.getLoc(), TypeRange(), cmp); + Block *ifBody_false = builder.createBlock(&newIf.getFalseBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody_false); + builder.create(orig.getLoc(), ValueRange()); + Block *ifBody = builder.createBlock(&newIf.getTrueBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody); + + // rvUTurn: Run one forward iteration and one reverse iteration (same as sqrt checkpointing) + // Step 1: Create forward loop first + SmallVector forwardOperands; + for (auto arg : revOuterBody->getArguments().slice(1)) { + forwardOperands.push_back(arg); + } + + // Create a single-iteration forward loop (same pattern as sqrt checkpointing's revInner) + auto forwardLoop = makeForLoop(builder, orig.getLoc(), 0, 1, 1, forwardOperands); + forwardLoop->setAttrs(orig->getAttrs()); + forwardLoop->removeAttr("enzymexla.enable_binomial_checkpointing"); + Block *forwardBody = &forwardLoop.getBody().front(); + builder.setInsertionPointToStart(forwardBody); + + // Step 2: Map original loop arguments to forward loop arguments FIRST + // This must be done before processing operands so block arguments are available + IRMapping forwardMapping; + for (auto &&[origarg, forwardarg] : llvm::zip_equal( + origBody->getArguments(), forwardBody->getArguments())) { + forwardMapping.map(origarg, forwardarg); + gutils->originalToNewFn.map(origarg, forwardarg); + } + + // Step 3: Pre-clone constants and other values defined outside IfOp into forwardBody + // This ensures they're accessible to all operations inside the forward loop + // First, collect all values that need to be cloned + DenseSet processedValues; + SmallVector> valuesToClone; + + // Helper to check if a value is defined outside the IfOp + auto isDefinedOutsideIfOp = [&](Value v) -> bool { + if (auto *defOp = v.getDefiningOp()) { + return !newIf->isAncestor(defOp); + } else if (auto arg = dyn_cast(v)) { + return !newIf->isAncestor(arg.getOwner()->getParentOp()); + } + return false; + }; + + // Collect all operands used in the original body that are defined outside the IfOp + for (Operation &op : origBody->without_terminator()) { + for (Value operand : op.getOperands()) { + if (processedValues.contains(operand)) + continue; + processedValues.insert(operand); + + // Skip if already mapped (block arguments are already mapped) + if (forwardMapping.contains(operand)) + continue; + + // Try Enzyme's mapping first + Value valueToUse = operand; + if (gutils->originalToNewFn.contains(operand)) { + valueToUse = gutils->originalToNewFn.lookup(operand); + } + + // If the value is defined outside the IfOp region, we need to clone it into forwardBody + // This is necessary for region isolation - values outside IfOp can't be used inside + Operation *opToClone = nullptr; + + // Check if valueToUse is defined outside IfOp + if (isDefinedOutsideIfOp(valueToUse)) { + if (auto *defOp = valueToUse.getDefiningOp()) { + opToClone = defOp; + } + } else if (auto *origDefOp = operand.getDefiningOp()) { + // Also check original operand - if it's outside and Enzyme hasn't transformed it + if (isDefinedOutsideIfOp(operand)) { + opToClone = origDefOp; + } + } + + if (opToClone) { + // Store for cloning later (ensures proper ordering) + valuesToClone.push_back({operand, opToClone}); + } else { + // For values already in IfOp, use as-is + forwardMapping.map(operand, valueToUse); + } + } + } + + // Now clone all constants/operations at the start of forwardBody (ensures dominance) + // Clone them in order at the start of the block, and track the last one + builder.setInsertionPointToStart(forwardBody); + Operation *lastClonedConstant = nullptr; + for (auto [operand, opToClone] : valuesToClone) { + IRMapping emptyMap; + auto clonedOp = builder.clone(*opToClone, emptyMap); + forwardMapping.map(operand, clonedOp->getResult(0)); + lastClonedConstant = clonedOp; + } + + // Step 4: Clone forward operations using builder.clone (same as rvForward) + // Use InsertionGuard to ensure proper insertion point management + { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(forwardBody); + + // If we cloned constants, move insertion point to after them + if (lastClonedConstant) { + builder.setInsertionPointAfter(lastClonedConstant); + } + + // Clone operations using IRMapping (same pattern as rvForward) + // Skip enzyme operations to avoid caching issues + // NOTE: Temporarily using a workaround to avoid dominance issues + // The issue is that block arguments from forwardBody don't dominate + // when used in cloned operations, even though they should. + for (Operation &op : origBody->without_terminator()) { + if (isa(&op)) { + continue; // Skip all enzyme operations + } + + // Try cloning with explicit operand remapping to avoid dominance issues + SmallVector remappedOperands; + for (Value operand : op.getOperands()) { + if (forwardMapping.contains(operand)) { + remappedOperands.push_back(forwardMapping.lookup(operand)); + } else { + // Fallback: try Enzyme's mapping + if (gutils->originalToNewFn.contains(operand)) { + remappedOperands.push_back(gutils->originalToNewFn.lookup(operand)); + } else { + remappedOperands.push_back(operand); + } + } + } + + // Create operation manually with remapped operands instead of cloning + // This avoids dominance issues with builder.clone + Operation *newOp = nullptr; + if (auto addOp = dyn_cast(&op)) { + newOp = stablehlo::AddOp::create(builder, op.getLoc(), + remappedOperands[0], remappedOperands[1]); + } else if (auto mulOp = dyn_cast(&op)) { + newOp = stablehlo::MulOp::create(builder, op.getLoc(), + remappedOperands[0], remappedOperands[1]); + } else if (auto convertOp = dyn_cast(&op)) { + newOp = stablehlo::ConvertOp::create(builder, op.getLoc(), + remappedOperands[0]); + } else { + // For other operations, fall back to cloning + newOp = builder.clone(op, forwardMapping); + } + + // Update Enzyme's mapping for reverse pass + gutils->originalToNewFnOps[&op] = newOp; + for (auto &&[oldv, newv] : + llvm::zip(op.getResults(), newOp->getResults())) { + gutils->originalToNewFn.map(oldv, newv); + } + } + } + + // Update forward loop terminator (same as rvForward) + { + auto oldTerm = cast(origBody->getTerminator()); + auto newTerm = cast(forwardBody->getTerminator()); + SmallVector vals; + for (auto v : oldTerm.getResults().drop_front()) { + vals.push_back(forwardMapping.lookupOrDefault(v)); + } + newTerm.getResultsMutable() + .slice(1, newTerm.getResultsMutable().size() - 1) + .assign(vals); + } + + // Set the mapping for orig to forwardLoop (needed for reverse pass) + gutils->originalToNewFnOps[orig] = forwardLoop; + + // Step 2: Run reverse iteration - visit operations in reverse order (same as sqrt checkpointing) + builder.setInsertionPointAfter(forwardLoop); + + // Get the forward loop results for reverse computation + SmallVector reverseOperands; + for (auto result : forwardLoop.getResults().slice(1, forwardLoop.getNumResults() - 1)) { + reverseOperands.push_back(result); + } + + // Create reverse loop body (single iteration) + auto reverseLoop = makeForLoop(builder, orig.getLoc(), 0, 1, 1, reverseOperands); + reverseLoop->setAttrs(orig->getAttrs()); + reverseLoop->removeAttr("enzymexla.enable_binomial_checkpointing"); + Block *reverseBody = &reverseLoop.getBody().front(); + builder.setInsertionPointToStart(reverseBody); + + // Add to diffe (same as sqrt checkpointing's revLoop) + int revIdx = 1; + for (auto &&[active, operand] : llvm::zip_equal( + operandsActive, origBody->getTerminator()->getOperands())) { + if (active) { + gutils->addToDiffe(operand, reverseBody->getArgument(revIdx), builder); + revIdx++; + } + } + + // Register cache creator hook and visit operations in reverse order (same as sqrt) + { + OpBuilder cacheBuilder(forwardLoop); + auto loc = orig->getLoc(); + auto cacheCreator = [&](Type t) { + Value cache = enzyme::InitOp::create(cacheBuilder, loc, t); + return std::make_pair(cache, cache); + }; + gutils->registerCacheCreatorHook(cacheCreator); + + // Visit operations in reverse order to compute gradients + // Use the builder from ifBody context, not reverseBody, to avoid nested enzyme.set issues + builder.setInsertionPointToEnd(ifBody); + auto rstart = origBody->rbegin(), rend = origBody->rend(); + rstart++; // Skip terminator + for (auto it = rstart; it != rend; it++) { + Operation *op = &*it; + anyFailed |= gutils->Logic.visitChild(op, builder, gutils).failed(); + } + gutils->deregisterCacheCreatorHook(cacheCreator); + } + + // Set insertion point to end of ifBody to ensure return is last + builder.setInsertionPointToEnd(ifBody); + + // Get diffe values and zero them (same as sqrt checkpointing) + SmallVector newResults; + for (auto &&[active, arg] : + llvm::zip_equal(operandsActive, origBody->getArguments())) { + if (active) { + newResults.push_back(gutils->diffe(arg, builder)); + if (!gutils->isConstantValue(arg)) + gutils->zeroDiffe(arg, builder); + } + } + + // Create return as the last operation in ifBody + builder.create(orig.getLoc(), ValueRange()); + + builder.setInsertionPointAfter(newIf); + } + + //============================================================= + // rvRestore Action (action == 5) + //============================================================= + /*llvm::errs() << " rvRestore Begin \n"; + + // Get the types of the operands (loop-carried values, excluding induction variable) + SmallVector operandTypes; + for (auto arg : revOuterBody->getArguments().slice(1)) { + operandTypes.push_back(arg.getType()); + } + + { + Value cmp = builder.create( + orig.getLoc(), nextActionVal, + makeI32Constant(orig.getLoc(), builder, 5), ComparisonDirection::EQ); + OpBuilder::InsertionGuard guard(builder); + auto newIf = builder.create(orig.getLoc(), operandTypes, cmp); + + // False branch: return current operands unchanged + Block *ifBody_false = builder.createBlock(&newIf.getFalseBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody_false); + SmallVector currentOperands; + for (auto arg : revOuterBody->getArguments().slice(1)) { + currentOperands.push_back(arg); + } + builder.create(orig.getLoc(), currentOperands); + + // True branch: restore values from cache + Block *ifBody = builder.createBlock(&newIf.getTrueBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(ifBody); + + // rvRestore action: Restore values from cache that were stored in rvStore + // Pop the most recently cached values (stored in the last rvStore call) + SmallVector restoredOperands; + size_t numOperands = revOuterBody->getArguments().size() - 1; // Exclude induction variable + + // Pop from the most recent caches (last numOperands caches in the vector) + if (caches.size() >= numOperands) { + size_t startIdx = caches.size() - numOperands; + for (size_t i = startIdx; i < caches.size(); ++i) { + Value restored = gutils->popCache(caches[i], builder); + restoredOperands.push_back(restored); + } + } else { + // Not enough caches - this shouldn't happen in correct REVOLVE usage + orig->emitError() << "rvRestore: Not enough caches to restore. Expected " + << numOperands << " but have " << caches.size() << "\n"; + // Return current operands as fallback + for (auto arg : revOuterBody->getArguments().slice(1)) { + restoredOperands.push_back(arg); + } + } + + builder.create(orig.getLoc(), restoredOperands); + + builder.setInsertionPointAfter(newIf); + + // Update the loop terminator to use the restored/current values from the IfOp + // The IfOp returns either restored values (if restore happened) or current values + Operation *loopTerminator = revOuterBody->getTerminator(); + if (auto returnOp = dyn_cast(loopTerminator)) { + // The terminator currently uses revOuterBody->getArguments() + // We need to replace the loop-carried values (slice(1)) with IfOp results + SmallVector newTerminatorOperands; + newTerminatorOperands.push_back(revOuterBody->getArgument(0)); // Keep induction variable + // Add the IfOp results (restored or current operands) + for (Value result : newIf.getResults()) { + newTerminatorOperands.push_back(result); + } + returnOp->setOperands(newTerminatorOperands); + } + } // End of InsertionGuard scope - newIf is still valid here + llvm::errs() << " rvRestore End \n";*/ + + return success(!anyFailed); + } + public: LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, MGradientUtilsReverse *gutils, @@ -777,6 +1346,11 @@ class AutoDiffWhileRev if (revInfo.mode == CONSTANT_CHECKPOINTING) { return reverseWithCheckpointing(cast(orig), revInfo, builder, gutils, caches, operandsActive); + } else if (revInfo.mode == REVOLVE_CHECKPOINTING) { + llvm::errs() <<"REVERSE IF CONDITION \n"; + return reverseWithCheckpointingRevolve(cast(orig), revInfo, + builder, gutils, caches, operandsActive); + llvm::errs() <<"REVERSE IF CONDITION DONE \n"; } else if (revInfo.mode == CONSTANT) { auto iterType = orig->getOperand(0).getType(); numIters = stablehlo::ConstantOp::create( @@ -1047,6 +1621,243 @@ class AutoDiffWhileRev return caches; } + } else if (getReverseMode(orig).mode == REVOLVE_CHECKPOINTING) { + OpBuilder builder(newWhile); + llvm::errs() << "Cache REVOLVE_CHECKPOINTING"; + + SetVector outsideRefs; + getUsedValuesDefinedAbove(orig->getRegions(), outsideRefs); + SmallVector caches; + + OpBuilder::InsertionGuard guard(builder); + //int64_t nOuter =info.getConstantNumIters(); + auto unrankedTensorType = RankedTensorType::get({}, builder.getI64Type()); + auto nOuter = + builder + .create( + orig->getLoc(), unrankedTensorType, + SplatElementsAttr::get( + unrankedTensorType, + ArrayRef(IntegerAttr::get(builder.getI64Type(), info.getConstantNumIters())))) + .getResult(); + + //RvInit is used to initialize the revolve internals. + (void)builder.create( + orig->getLoc(), + //builder.getIntegerType(32, false), + UnrankedTensorType::get(builder.getI32Type()), + //TypeRange{type_input, type_tau} nullptr, + ValueRange(nOuter), + builder.getStringAttr("RvInit"), + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr); + + + // Get the transformed operands from newWhile (the augmented primal) + // newWhile is already the transformed version, so its operands are in the correct region + SmallVector vals; + for (auto operand : newWhile->getOperands().slice(1, newWhile->getNumOperands() - 1)) { + vals.push_back(operand); + } + auto outer = makeForLoop(builder, orig->getLoc(), 0, 1, 0, vals); + + // Store outsideRefs before the outer loop (same as sqrt checkpointing) + builder.setInsertionPoint(outer); + for (auto ref : outsideRefs) { + caches.push_back(gutils->initAndPushCache( + gutils->getNewFromOriginal(ref), builder)); + } + + Block *outerBody = &outer.getBody().front(); + builder.setInsertionPointToStart(outerBody); + + + //This is the call to RvNextAction which should go inside while loop + //SmallVector iterShape; + //iterShape.push_back(1); + auto rvNextActionCallOp = builder.create( + orig->getLoc(), + RankedTensorType::get({}, builder.getI32Type()), + ValueRange(nOuter), + builder.getStringAttr("RvNextAction"), + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + + auto RvNextActionCountCallOp = builder.create( + orig->getLoc(), + RankedTensorType::get({}, builder.getI32Type()), + //TypeRange{type_input, type_tau} nullptr, + ValueRange(nOuter), + builder.getStringAttr("RvNextActionCount"), + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr); + + //SHK: What is getting pushed back here? + // Suspect is is basically storing the state at the start of the + //outer loop body. + + // For REVOLVE_CHECKPOINTING, caching is handled conditionally in the + // reverse pass based on Revolve actions, not in the forward pass. + // The forward pass just executes iterations based on Revolve's actions. + // TODO: Implement Store action in forward pass once region/dominance issues are resolved + + // Set insertion point in outerBody after the custom calls + builder.setInsertionPointAfter(RvNextActionCountCallOp); + + IRMapping mapping; + + //SHK: What are the arguments to the outerbody? + // Use block arguments directly - they're defined in outerBody and will dominate + // all uses within the outer loop body and nested regions + SmallVector operands; + for (auto arg : outerBody->getArguments().slice(1)) { + operands.push_back(arg); + } + + Value nextActionVal = rvNextActionCallOp.getResult(0); + + auto types = ValueRange(vals).getTypes(); + Block *oldInnerBody = &newWhile.getBody().front(); + + // Use operands directly - will be updated by Store action if needed + SmallVector forwardOperands = operands; + + //============================================================= + // Store Action (action == 1): Store current operands to cache + //============================================================= + { + Value cmpStore = builder.create( + orig->getLoc(), nextActionVal, + makeI32Constant(orig->getLoc(), builder, 1), ComparisonDirection::EQ); + + auto storeIf = builder.create(orig->getLoc(), types, cmpStore); + + // False branch: pass through unchanged operands + Block *storeFalse = builder.createBlock(&storeIf.getFalseBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(storeFalse); + builder.create(orig->getLoc(), forwardOperands); + + // True branch: store operands to cache + Block *storeTrue = builder.createBlock(&storeIf.getTrueBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(storeTrue); + + // Store loop-carried values to cache (same as sqrt checkpointing) + // These are the same operands that sqrt checkpointing stores + for (Value operand : operands) { + caches.push_back(gutils->initAndPushCache(operand, builder)); + } + + // Return operands unchanged (caching is done via side effects) + builder.create(orig->getLoc(), forwardOperands); + + builder.setInsertionPointAfter(storeIf); + + // Update forwardOperands to use results from storeIf + forwardOperands.clear(); + for (Value result : storeIf.getResults()) { + forwardOperands.push_back(result); + } + } + + //============================================================= + // Forward Action (action == 2) or FirstUTurn (action == 4): Execute iterations + // Both actions execute forward iterations; FirstUTurn marks the + // transition point where the backward pass will begin after. + //============================================================= + { + Value cmpForward = builder.create( + orig->getLoc(), nextActionVal, + makeI32Constant(orig->getLoc(), builder, 2), ComparisonDirection::EQ); + Value cmpFirstUTurn = builder.create( + orig->getLoc(), nextActionVal, + makeI32Constant(orig->getLoc(), builder, 4), ComparisonDirection::EQ); + Value cmpForwardOrFirstUTurn = builder.create( + orig->getLoc(), cmpForward, cmpFirstUTurn); + + auto forwardIf = builder.create(orig->getLoc(), types, cmpForwardOrFirstUTurn); + + // False branch: pass through unchanged (use forwardOperands from Store) + Block *forwardFalse = builder.createBlock(&forwardIf.getFalseBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(forwardFalse); + builder.create(orig->getLoc(), forwardOperands); + + // True branch: execute forward iterations + Block *forwardTrue = builder.createBlock(&forwardIf.getTrueBranch(), {}, TypeRange()); + builder.setInsertionPointToStart(forwardTrue); + + // Get iteration limit from RvNextActionCount and convert from i32 to i64 + Value iterationLimitI32 = RvNextActionCountCallOp.getResult(0); + Value iterationLimitVal = builder.create( + orig->getLoc(), RankedTensorType::get({}, builder.getI64Type()), iterationLimitI32); + + // Create inner loop to execute iterations using forwardOperands + auto inner = makeForLoop(builder, orig->getLoc(), 0, iterationLimitVal, 1, forwardOperands); + Block *innerBody = &inner.getBody().front(); + + builder.setInsertionPointToStart(innerBody); + + IRMapping mapping; + + // Map original loop arguments to inner loop arguments + for (auto [oldArg, newArg] : llvm::zip_equal( + oldInnerBody->getArguments(), innerBody->getArguments())) { + mapping.map(oldArg, newArg); + } + + // Handle induction variable: newIV = innerIV + outerIV + // For REVOLVE, outerIV is always 0 (outer loop is 0 to 1), but we compute + // it for consistency with sqrt checkpointing + Value oldIV = oldInnerBody->getArgument(0); + Value outerIV = outerBody->getArgument(0); + Value newIV = stablehlo::AddOp::create( + builder, oldIV.getLoc(), innerBody->getArgument(0), outerIV); + mapping.map(oldIV, newIV); + + // Clone the original loop body operations + // Skip enzyme operations to avoid caching issues (caching handled in reverse pass) + for (Operation &innerOp : oldInnerBody->without_terminator()) { + if (isa(&innerOp)) { + continue; // Skip all enzyme operations + } + builder.clone(innerOp, mapping); + } + + // Update the inner loop's return values (same as sqrt checkpointing) + SmallVector newReturns; + for (auto oldRes : + oldInnerBody->getTerminator()->getOperands().slice( + 1, oldInnerBody->getTerminator()->getNumOperands() - 1)) { + newReturns.push_back(mapping.lookupOrDefault(oldRes)); + } + Operation *term = innerBody->getTerminator(); + term->setOperands(1, term->getNumOperands() - 1, newReturns); + + // Return inner loop results from the forward if-true branch + builder.setInsertionPointAfter(inner); + builder.create(orig->getLoc(), + inner.getResults().slice(1, inner.getNumResults() - 1)); + + builder.setInsertionPointAfter(forwardIf); + + // Connect the forwardIf results back to the outer loop terminator + // The outer loop terminator should return the updated operands from forwardIf + auto terminator = cast(outerBody->getTerminator()); + SmallVector newOperands; + newOperands.push_back(outerBody->getArgument(0)); // Keep the induction variable + newOperands.append(forwardIf.getResults().begin(), forwardIf.getResults().end()); + terminator->setOperands(newOperands); + } + + return caches; } return {}; @@ -2817,7 +3628,17 @@ struct WhileOpEnzymeOpsRemover } cachesMap[pushedValue] = info; - otherWhileOp = cast(info.popOp->getParentOp()); + // Find the WhileOp that contains the popOp, walking up the parent chain + // if necessary (e.g., if popOp is inside an IfOp) + Operation *parentOp = info.popOp->getParentOp(); + while (parentOp && !isa(parentOp)) { + parentOp = parentOp->getParentOp(); + } + if (!parentOp) { + return rewriter.notifyMatchFailure( + op, "Could not find WhileOp parent for popOp"); + } + otherWhileOp = cast(parentOp); } } @@ -2925,6 +3746,7 @@ struct WhileOpEnzymeOpsRemover } Value itersV = nullptr; + SmallVector newCacheValues; // Collect all new cache values before modifying terminator for (auto &cinfo : caches) { Value cache = cinfo.initOp.getResult(); @@ -2961,6 +3783,20 @@ struct WhileOpEnzymeOpsRemover cast(cast(cinfo.cachedType()) .getShadowType(numIters)); + // If original type is rank-0, ensure newType is rank-1 for DynamicUpdateSliceOp + // (getShadowType might not add a dimension for rank-0 tensors) + // This also ensures the cache type matches what we're pushing + if (auto TT = dyn_cast(cinfo.cachedType())) { + if (TT.getShape().empty()) { + auto newTypeTT = dyn_cast(newType); + if (newTypeTT && newTypeTT.getShape().empty()) { + // newType is still rank-0, so manually create rank-1 version + // Use size 1 as placeholder (actual size set by DynamicPadOp) + newType = RankedTensorType::get({1}, TT.getElementType()); + } + } + } + Value initValue; if (info.isConstant()) { initValue = cast(newType).createNullValue( @@ -3012,47 +3848,121 @@ struct WhileOpEnzymeOpsRemover newOperands.push_back(initValue); - auto cacheValue = body->addArgument(newType, cinfo.pushOp->getLoc()); + // Add arguments to body and condition one at a time, matching gradient handling pattern + // This preserves region isolation by making modifications incrementally + Value cacheArg = body->addArgument(newType, cinfo.pushOp->getLoc()); cond->addArgument(newType, cinfo.pushOp->getLoc()); + + // Update terminator immediately with cache argument as placeholder (same as gradient handling) + // This preserves region isolation by keeping the terminator valid + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(term); + term->insertOperands(term->getNumOperands(), ValueRange(cacheArg)); + } + } + auto numInitArgs = whileOp->getNumOperands(); + auto newWhile = + stablehlo::WhileOp::create(rewriter, op->getLoc(), newOperands); + + newWhile.getCond().takeBody(whileOp.getCond()); + newWhile.getBody().takeBody(whileOp.getBody()); + + // Now create replacement operations and erase pushOps after takeBody + // This preserves region isolation by ensuring the body is moved before modifications + Block *newBody = &newWhile.getBody().front(); + auto newTerm = newBody->getTerminator(); + SmallVector finalCacheValues; + + // Track which cache index we're processing + size_t cacheArgIdx = 0; + for (auto &cinfo : caches) { + // Skip caches that were hoisted outside the loop (they were handled in first loop) + // Note: After takeBody, whileOp.getBody() is empty, so we can't check here + // Instead, we check if we can find the pushOp in the new body + + // Find the pushOp in the new body (it was moved by takeBody) + enzyme::PushOp newPushOp; + bool foundPushOp = false; + for (auto &op : *newBody) { + if (auto push = dyn_cast(&op)) { + // Check if this is the pushOp we're looking for by comparing the cache + if (push.getCache() == cinfo.initOp.getResult()) { + newPushOp = push; + foundPushOp = true; + break; + } + } + } + + if (!foundPushOp) { + continue; // PushOp not found, skip + } + + // Get the cache value argument (accounting for gradients that were added) + size_t argIdx = numInitArgs + updatedGradients.size() + cacheArgIdx; + Value cacheValue = newBody->getArgument(argIdx); + + Value newCacheValue; { OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(cinfo.pushOp); + rewriter.setInsertionPoint(newPushOp); - Value newCacheValue; if (auto TT = dyn_cast(cinfo.cachedType())) { - auto shape = TT.getShape(); + auto cacheType = cast(cacheValue.getType()); + auto cacheShapeRef = cacheType.getShape(); + SmallVector cacheShape(cacheShapeRef.begin(), cacheShapeRef.end()); + auto originalShape = TT.getShape(); + + Value originalValue = newPushOp.getValue(); + Value valueForUpdate = originalValue; + auto pushLoc = newPushOp->getLoc(); - SmallVector startIndices(shape.size() + 1, zero); - startIndices[0] = inductionVariable; + Value actualCacheValue = cacheValue; + + Value zeroInBody = makeI64Constant(pushLoc, rewriter, 0); + SmallVector startIndices(cacheShape.size(), zeroInBody); + startIndices[0] = newBody->getArgument(0); // induction variable SmallVector updateShape; - updateShape.push_back(1); - updateShape.append(shape.begin(), shape.end()); + if (originalShape.empty()) { + updateShape = {1}; + } else { + updateShape.push_back(1); + if (cacheShape.size() > 1) { + updateShape.append(cacheShape.begin() + 1, cacheShape.end()); + } + } + auto updateType = cast(actualCacheValue.getType()).clone(updateShape); Value reshapedUpdate = stablehlo::ReshapeOp::create( - rewriter, cinfo.pushOp->getLoc(), TT.clone(updateShape), - cinfo.pushOp.getValue()); + rewriter, pushLoc, updateType, valueForUpdate); newCacheValue = stablehlo::DynamicUpdateSliceOp::create( - rewriter, cinfo.pushOp->getLoc(), cacheValue, reshapedUpdate, + rewriter, pushLoc, actualCacheValue, reshapedUpdate, startIndices); } else { assert(false && "todo"); - // newCacheValue = tensor::InsertOp::create(rewriter, - // info.pushOp->getLoc(), info.pushOp.getValue(), cacheValue, - // inductionVariable); } - term->insertOperands(term->getNumOperands(), ValueRange(newCacheValue)); + rewriter.setInsertionPointAfter(newCacheValue.getDefiningOp()); + rewriter.eraseOp(newPushOp); } + + finalCacheValues.push_back(newCacheValue); + cacheArgIdx++; } - auto numInitArgs = whileOp->getNumOperands(); - auto newWhile = - stablehlo::WhileOp::create(rewriter, op->getLoc(), newOperands); - - newWhile.getCond().takeBody(whileOp.getCond()); - newWhile.getBody().takeBody(whileOp.getBody()); + // Replace placeholder cache arguments in terminator with actual cache values + if (!finalCacheValues.empty()) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(newTerm); + // The terminator already has cache arguments as placeholders (the last N operands) + // Replace them with the actual cache values + size_t numCacheArgs = finalCacheValues.size(); + size_t termStartIdx = newTerm->getNumOperands() - numCacheArgs; + newTerm->setOperands(termStartIdx, numCacheArgs, finalCacheValues); + } unsigned resultIdx = numInitArgs; for (auto grad : updatedGradients) { @@ -3138,8 +4048,11 @@ struct WhileOpEnzymeOpsRemover Value cache = info.initOp.getResult(); + // Save the cached type before erasing initOp (needed later at line 3998) + Type cachedType = info.cachedType(); + auto newType = - cast(cast(info.cachedType()) + cast(cast(cachedType) .getShadowType(numIters)); enzyme::InitOp newInit = ({ OpBuilder::InsertionGuard guard(rewriter); @@ -3149,10 +4062,16 @@ struct WhileOpEnzymeOpsRemover rewriter, info.initOp->getLoc(), enzyme::CacheType::get(cache.getContext(), newType)); }); + + // Replace all uses of the old initOp with the new one + rewriter.replaceAllUsesWith(info.initOp.getResult(), newInit.getResult()); + // Erase the old initOp after replacing all uses + rewriter.eraseOp(info.initOp); + info.pushOp = ({ OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(newWhile); - auto newPush = enzyme::PushOp::create(rewriter, cache.getLoc(), + auto newPush = enzyme::PushOp::create(rewriter, newInit.getLoc(), newInit.getResult(), newWhile->getResult(resultIdx)); rewriter.eraseOp(info.pushOp); @@ -3176,17 +4095,36 @@ struct WhileOpEnzymeOpsRemover popBody->getArgument(popBody->getNumArguments() - 1); Value popValue; - if (auto TT = dyn_cast(info.cachedType())) { - auto shape = TT.getShape(); - SmallVector startIndices(shape.size() + 1, zero); + if (auto TT = dyn_cast(cachedType)) { + auto popType = cast(popNewValue.getType()); + auto popShapeRef = popType.getShape(); + SmallVector popShape(popShapeRef.begin(), popShapeRef.end()); + + // If popNewValue is rank 0, we need to reshape it to rank 1 first + Value actualPopValue = popNewValue; + if (popShape.empty()) { + // Reshape rank-0 to rank-1: tensor -> tensor<1xi64> + auto rank1Type = RankedTensorType::get( + {1}, popType.getElementType()); + actualPopValue = stablehlo::ReshapeOp::create( + rewriter, info.popOp->getLoc(), rank1Type, popNewValue); + popShape = {1}; + } + + SmallVector startIndices(popShape.size(), zero); startIndices[0] = newInductionVariable; SmallVector sliceSizes; - sliceSizes.reserve(shape.size() + 1); + sliceSizes.reserve(popShape.size()); sliceSizes.push_back(1); - sliceSizes.append(shape.begin(), shape.end()); + if (popShape.size() > 1) { + sliceSizes.append(popShape.begin() + 1, popShape.end()); + } + auto sliceResultType = popShape.empty() + ? RankedTensorType::get({1}, popType.getElementType()) + : cast(actualPopValue.getType()).clone(sliceSizes); popValue = stablehlo::DynamicSliceOp::create( - rewriter, info.popOp->getLoc(), TT.clone(sliceSizes), popNewValue, + rewriter, info.popOp->getLoc(), sliceResultType, actualPopValue, startIndices, sliceSizes); popValue = stablehlo::ReshapeOp::create( rewriter, info.popOp->getLoc(), TT, popValue);