Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7235,13 +7235,42 @@ struct BroadcastPad final
// Given a value and index idx, determine whether all values are the same along
// idx. If so, return said value
std::optional<Value> is_same_in_axis(OpBuilder &rewriter, ShapedType outTy,
Value v, size_t idx) {
Value v, size_t idx,
int64_t indexVectorDim) {
mlir::SplatElementsAttr splat;
if (matchPattern(v, m_Constant(&splat))) {
return stablehlo::ConstantOp::create(rewriter, v.getLoc(), outTy,
splat.resizeSplat(outTy));
}

if (auto iota = v.getDefiningOp<stablehlo::IotaOp>()) {
if (iota.getIotaDimension() == indexVectorDim) {
return stablehlo::ConstantOp::create(
rewriter, v.getLoc(), outTy,
cast<ElementsAttr>(makeAttr(outTy, (int64_t)idx)));
}
}

return {};
}

std::optional<std::pair<Value, int64_t>>
is_iota_in_axis(OpBuilder &rewriter, ShapedType outTy, Value v, size_t idx,
int64_t indexVectorDim) {
auto iota = detectIotaLikeTensor(v);
if (iota && iota->dimension != indexVectorDim) {
if (isOneAttr(iota->scale)) {
auto startVal = getDoubleFromAttr(iota->start);
if (startVal) {
return std::make_pair(
stablehlo::ConstantOp::create(
rewriter, v.getLoc(), outTy,
cast<ElementsAttr>(makeAttr(outTy, (int64_t)*startVal))),
iota->dimension);
}
}
}

return {};
}

Expand All @@ -7255,6 +7284,7 @@ struct ScatterToDynamicUpdateSlice final
Block &body = op.getUpdateComputation().front();
if (body.getOperations().size() != 1)
return failure();
}

Operation &innerOp = body.front();
if (!isa<stablehlo::ReturnOp>(&innerOp)) {
Expand Down Expand Up @@ -13140,8 +13170,7 @@ struct DUSSliceSimplify final
});

LLVM_DEBUG(
for (auto [idx, operandSize, updateSize]
: llvm::zip_equal(
for (auto [idx, operandSize, updateSize] : llvm::zip_equal(
newDusIndices,
cast<RankedTensorType>(preSliceOperand.getType()).getShape(),
cast<RankedTensorType>(preSliceUpdate.getType()).getShape())) {
Expand Down
Loading