From cc3feb0ce21fab695031761514cb24a737bbe57f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 20 Apr 2026 22:47:17 -0500 Subject: [PATCH 1/2] refactor: update ScatterToDynamicUpdateSlice to detect special scatter update detection --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 99 +++++++++-------------- 1 file changed, 39 insertions(+), 60 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index a795f9c1c8..2672d18e4c 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7252,80 +7252,59 @@ struct ScatterToDynamicUpdateSlice final LogicalResult matchAndRewriteImpl(stablehlo::ScatterOp op, PatternRewriter &rewriter) const { - Block &body = op.getUpdateComputation().front(); - if (body.getOperations().size() != 1) - return failure(); - - Operation &innerOp = body.front(); - if (!isa(&innerOp)) { - return failure(); - } - if (innerOp.getNumOperands() != 1) { + if (op.getInputs().size() != 1) { return failure(); } - if (op.getInputs().size() != 1) + CheckCommonScatterOp scatterCheck(op); + if (scatterCheck.kind != ScatterOpKind::Setindex && + scatterCheck.kind != ScatterOpKind::ConstantSetindex) { return failure(); - - // For us to proceed, either we are returning the last block argument or we - // are returning a constant - Value update = nullptr; - DenseElementsAttr splatAttr; - - auto retop = dyn_cast(innerOp.getOperand(0)); - if (retop) { - if (retop.getOwner() != &body) - return failure(); - if (retop.getArgNumber() != 1) - return failure(); - update = op.getUpdates()[0]; - } else { - DenseElementsAttr attr; - if (matchPattern(innerOp.getOperand(0), m_Constant(&attr))) { - splatAttr = DenseElementsAttr::get( - cast(op.getUpdates()[0].getType()), - attr.getSplatValue()); - } else { - return failure(); - } } auto dims = op.getScatterDimensionNumbers(); auto input = op.getInputs()[0]; auto scatter = op.getScatterIndices(); - auto updateShape = - cast(op.getUpdates()[0].getType()).getShape(); + auto updatesTy = cast(op.getUpdates()[0].getType()); + auto updateRank = updatesTy.getRank(); - if (dims.getInsertedWindowDims().size() == 0 && - dims.getUpdateWindowDims().size() == updateShape.size()) { - - if (update == nullptr) { - update = stablehlo::ConstantOp::create( - rewriter, op.getLoc(), op.getUpdates()[0].getType(), splatAttr); - } + if (dims.getInsertedWindowDims().size() != 0 || + dims.getUpdateWindowDims().size() != updateRank) { + return failure(); + } - auto ity = RankedTensorType::get( - {}, cast(scatter.getType()).getElementType()); - SmallVector start(updateShape.size(), 0); - for (auto en : llvm::enumerate(dims.getScatterDimsToOperandDims())) { - auto startval = is_same_in_axis(rewriter, ity, scatter, en.index()); - if (!startval) - return failure(); - start[en.value()] = *startval; - } - for (auto &v : start) { - if (v != nullptr) - continue; - v = stablehlo::ConstantOp::create(rewriter, op.getLoc(), ity, - cast(makeAttr(ity, 0))); - } - rewriter.replaceOpWithNewOp( - op, op.getResult(0).getType(), input, update, start); - return success(); + Value update; + if (scatterCheck.kind == ScatterOpKind::Setindex) { + update = op.getUpdates()[0]; + } else if (scatterCheck.kind == ScatterOpKind::ConstantSetindex) { + auto splatAttr = DenseElementsAttr::get( + cast(op.getUpdates()[0].getType()), + scatterCheck.constant.getSplatValue()); + update = stablehlo::ConstantOp::create(rewriter, op.getLoc(), updatesTy, + splatAttr); + } else { + return failure(); } - return failure(); + auto ity = RankedTensorType::get( + {}, cast(scatter.getType()).getElementType()); + SmallVector start(updateRank, 0); + for (auto en : llvm::enumerate(dims.getScatterDimsToOperandDims())) { + auto startval = is_same_in_axis(rewriter, ity, scatter, en.index()); + if (!startval) + return failure(); + start[en.value()] = *startval; + } + for (auto &v : start) { + if (v != nullptr) + continue; + v = stablehlo::ConstantOp::create(rewriter, op.getLoc(), ity, + cast(makeAttr(ity, 0))); + } + rewriter.replaceOpWithNewOp( + op, op.getResult(0).getType(), input, update, start); + return success(); } }; From bbaaecbc7e73ac14c96ea6e5eb75b6ce0aef32fd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 18 Apr 2026 20:35:05 -0500 Subject: [PATCH 2/2] feat: extend scatter to dynamic update slice simplification --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 134 +++++++++++++++------- 1 file changed, 92 insertions(+), 42 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 2672d18e4c..ba7b66f1b7 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7235,13 +7235,42 @@ struct BroadcastPad final // Given a value and index idx, determine whether all values are the same along // idx. If so, return said value std::optional is_same_in_axis(OpBuilder &rewriter, ShapedType outTy, - Value v, size_t idx) { + Value v, size_t idx, + int64_t indexVectorDim) { mlir::SplatElementsAttr splat; if (matchPattern(v, m_Constant(&splat))) { return stablehlo::ConstantOp::create(rewriter, v.getLoc(), outTy, splat.resizeSplat(outTy)); } + if (auto iota = v.getDefiningOp()) { + if (iota.getIotaDimension() == indexVectorDim) { + return stablehlo::ConstantOp::create( + rewriter, v.getLoc(), outTy, + cast(makeAttr(outTy, (int64_t)idx))); + } + } + + return {}; +} + +std::optional> +is_iota_in_axis(OpBuilder &rewriter, ShapedType outTy, Value v, size_t idx, + int64_t indexVectorDim) { + auto iota = detectIotaLikeTensor(v); + if (iota && iota->dimension != indexVectorDim) { + if (isOneAttr(iota->scale)) { + auto startVal = getDoubleFromAttr(iota->start); + if (startVal) { + return std::make_pair( + stablehlo::ConstantOp::create( + rewriter, v.getLoc(), outTy, + cast(makeAttr(outTy, (int64_t)*startVal))), + iota->dimension); + } + } + } + return {}; } @@ -7252,59 +7281,81 @@ struct ScatterToDynamicUpdateSlice final LogicalResult matchAndRewriteImpl(stablehlo::ScatterOp op, PatternRewriter &rewriter) const { - if (op.getInputs().size() != 1) { + Block &body = op.getUpdateComputation().front(); + if (body.getOperations().size() != 1) return failure(); } - CheckCommonScatterOp scatterCheck(op); - if (scatterCheck.kind != ScatterOpKind::Setindex && - scatterCheck.kind != ScatterOpKind::ConstantSetindex) { + Operation &innerOp = body.front(); + if (!isa(&innerOp)) { return failure(); } + if (innerOp.getNumOperands() != 1) { + return failure(); + } + + if (op.getInputs().size() != 1) + return failure(); + + // For us to proceed, either we are returning the last block argument or we + // are returning a constant + Value update = nullptr; + DenseElementsAttr splatAttr; + + auto retop = dyn_cast(innerOp.getOperand(0)); + if (retop) { + if (retop.getOwner() != &body) + return failure(); + if (retop.getArgNumber() != 1) + return failure(); + update = op.getUpdates()[0]; + } else { + DenseElementsAttr attr; + if (matchPattern(innerOp.getOperand(0), m_Constant(&attr))) { + splatAttr = DenseElementsAttr::get( + cast(op.getUpdates()[0].getType()), + attr.getSplatValue()); + } else { + return failure(); + } + } auto dims = op.getScatterDimensionNumbers(); auto input = op.getInputs()[0]; auto scatter = op.getScatterIndices(); - auto updatesTy = cast(op.getUpdates()[0].getType()); - auto updateRank = updatesTy.getRank(); + auto updateShape = + cast(op.getUpdates()[0].getType()).getShape(); - if (dims.getInsertedWindowDims().size() != 0 || - dims.getUpdateWindowDims().size() != updateRank) { - return failure(); - } + if (dims.getInsertedWindowDims().size() == 0 && + dims.getUpdateWindowDims().size() == updateShape.size()) { - Value update; - if (scatterCheck.kind == ScatterOpKind::Setindex) { - update = op.getUpdates()[0]; - } else if (scatterCheck.kind == ScatterOpKind::ConstantSetindex) { - auto splatAttr = DenseElementsAttr::get( - cast(op.getUpdates()[0].getType()), - scatterCheck.constant.getSplatValue()); - update = stablehlo::ConstantOp::create(rewriter, op.getLoc(), updatesTy, - splatAttr); - } else { - return failure(); - } + if (update == nullptr) { + update = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), op.getUpdates()[0].getType(), splatAttr); + } - auto ity = RankedTensorType::get( - {}, cast(scatter.getType()).getElementType()); - SmallVector start(updateRank, 0); - for (auto en : llvm::enumerate(dims.getScatterDimsToOperandDims())) { - auto startval = is_same_in_axis(rewriter, ity, scatter, en.index()); - if (!startval) - return failure(); - start[en.value()] = *startval; - } - for (auto &v : start) { - if (v != nullptr) - continue; - v = stablehlo::ConstantOp::create(rewriter, op.getLoc(), ity, - cast(makeAttr(ity, 0))); + auto ity = RankedTensorType::get( + {}, cast(scatter.getType()).getElementType()); + SmallVector start(updateShape.size(), 0); + for (auto en : llvm::enumerate(dims.getScatterDimsToOperandDims())) { + auto startval = is_same_in_axis(rewriter, ity, scatter, en.index()); + if (!startval) + return failure(); + start[en.value()] = *startval; + } + for (auto &v : start) { + if (v != nullptr) + continue; + v = stablehlo::ConstantOp::create(rewriter, op.getLoc(), ity, + cast(makeAttr(ity, 0))); + } + rewriter.replaceOpWithNewOp( + op, op.getResult(0).getType(), input, update, start); + return success(); } - rewriter.replaceOpWithNewOp( - op, op.getResult(0).getType(), input, update, start); - return success(); + + return failure(); } }; @@ -13119,8 +13170,7 @@ struct DUSSliceSimplify final }); LLVM_DEBUG( - for (auto [idx, operandSize, updateSize] - : llvm::zip_equal( + for (auto [idx, operandSize, updateSize] : llvm::zip_equal( newDusIndices, cast(preSliceOperand.getType()).getShape(), cast(preSliceUpdate.getType()).getShape())) {