diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index a795f9c1c8..ba7b66f1b7 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7235,13 +7235,42 @@ struct BroadcastPad final // Given a value and index idx, determine whether all values are the same along // idx. If so, return said value std::optional is_same_in_axis(OpBuilder &rewriter, ShapedType outTy, - Value v, size_t idx) { + Value v, size_t idx, + int64_t indexVectorDim) { mlir::SplatElementsAttr splat; if (matchPattern(v, m_Constant(&splat))) { return stablehlo::ConstantOp::create(rewriter, v.getLoc(), outTy, splat.resizeSplat(outTy)); } + if (auto iota = v.getDefiningOp()) { + if (iota.getIotaDimension() == indexVectorDim) { + return stablehlo::ConstantOp::create( + rewriter, v.getLoc(), outTy, + cast(makeAttr(outTy, (int64_t)idx))); + } + } + + return {}; +} + +std::optional> +is_iota_in_axis(OpBuilder &rewriter, ShapedType outTy, Value v, size_t idx, + int64_t indexVectorDim) { + auto iota = detectIotaLikeTensor(v); + if (iota && iota->dimension != indexVectorDim) { + if (isOneAttr(iota->scale)) { + auto startVal = getDoubleFromAttr(iota->start); + if (startVal) { + return std::make_pair( + stablehlo::ConstantOp::create( + rewriter, v.getLoc(), outTy, + cast(makeAttr(outTy, (int64_t)*startVal))), + iota->dimension); + } + } + } + return {}; } @@ -7255,6 +7284,7 @@ struct ScatterToDynamicUpdateSlice final Block &body = op.getUpdateComputation().front(); if (body.getOperations().size() != 1) return failure(); + } Operation &innerOp = body.front(); if (!isa(&innerOp)) { @@ -13140,8 +13170,7 @@ struct DUSSliceSimplify final }); LLVM_DEBUG( - for (auto [idx, operandSize, updateSize] - : llvm::zip_equal( + for (auto [idx, operandSize, updateSize] : llvm::zip_equal( newDusIndices, cast(preSliceOperand.getType()).getShape(), cast(preSliceUpdate.getType()).getShape())) {