From 5ee13f0eddf4b3da735717f6507e940e04979aff Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 28 May 2026 03:43:15 -0500 Subject: [PATCH 1/7] add reverse AD for llvm.intr.memcpy --- .../LLVMAutoDiffOpInterfaceImpl.cpp | 106 ++++++++++++++++++ enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 3 +- enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp | 1 + enzyme/Enzyme/MLIR/Passes/Passes.td | 2 + 4 files changed, 111 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 472eccddd496..f85c85f25170 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -17,7 +17,9 @@ #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -297,6 +299,109 @@ struct InsertValueOpInterfaceReverse MGradientUtilsReverse *gutils) const {} }; +struct MemcpyOpInterfaceReverse + : public ReverseAutoDiffOpInterface::ExternalModel { + + static Type inferElemType(LLVM::MemcpyOp cp) { + if (auto t = cp->getAttrOfType("enzyme.elem_type")) + return t.getValue(); + auto walk = [](Value p) -> Type { + for (Operation *user : p.getUsers()) { + if (auto ld = dyn_cast(user)) + if (isa(ld.getType())) + return ld.getType(); + if (auto st = dyn_cast(user)) + if (isa(st.getValue().getType())) + return st.getValue().getType(); + } + return nullptr; + }; + if (Type t = walk(cp.getDst())) + return t; + if (Type t = walk(cp.getSrc())) + return t; + return Float64Type::get(cp.getContext()); + } + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + auto cp = cast(op); + if (gutils->isConstantValue(cp.getDst())) + return {}; + bool srcActive = !gutils->isConstantValue(cp.getSrc()); + OpBuilder cb(gutils->getNewFromOriginal(op)); + SmallVector caches; + caches.push_back( + gutils->initAndPushCache(gutils->invertPointerM(cp.getDst(), cb), cb)); + caches.push_back(gutils->initAndPushCache( + srcActive ? gutils->invertPointerM(cp.getSrc(), cb) + : gutils->getNewFromOriginal(cp.getSrc()), + cb)); + caches.push_back( + gutils->initAndPushCache(gutils->getNewFromOriginal(cp.getLen()), cb)); + return caches; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} + + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto cp = cast(op); + if (gutils->isConstantValue(cp.getDst())) + return success(); + bool srcActive = !gutils->isConstantValue(cp.getSrc()); + + Value dDst = gutils->popCache(caches[0], builder); + Value dSrc = gutils->popCache(caches[1], builder); + Value len = gutils->popCache(caches[2], builder); + + Type elemTy = inferElemType(cp); + auto adt = dyn_cast(elemTy); + if (!adt || !elemTy.isIntOrFloat()) + return op->emitError() + << "memcpy reverse: unsupported element type " << elemTy + << " (annotate enzyme.elem_type or lower to scalar stores)"; + + Location loc = op->getLoc(); + unsigned bytes = (elemTy.getIntOrFloatBitWidth() + 7) / 8; + + // n_elements = len / sizeof(elemTy) + Value byteSz = LLVM::ConstantOp::create( + builder, loc, len.getType(), + builder.getIntegerAttr(len.getType(), bytes)); + Value nInt = LLVM::SDivOp::create(builder, loc, len, byteSz); + Value n = + arith::IndexCastOp::create(builder, loc, builder.getIndexType(), nInt); + + Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); + Value c1 = arith::ConstantIndexOp::create(builder, loc, 1); + Value zeroElem = adt.createNullValue(builder, loc); + Type ptrTy = cp.getDst().getType(); + + auto forOp = scf::ForOp::create(builder, loc, c0, n, c1); + OpBuilder body(forOp.getBody()->getTerminator()); + Value ivIdx = forOp.getInductionVar(); + Value iv = arith::IndexCastOp::create(body, loc, len.getType(), ivIdx); + + Value gDst = LLVM::GEPOp::create(body, loc, ptrTy, elemTy, dDst, + ArrayRef{iv}); + Value vDst = LLVM::LoadOp::create(body, loc, elemTy, gDst); + if (srcActive) { + Value gSrc = LLVM::GEPOp::create(body, loc, ptrTy, elemTy, dSrc, + ArrayRef{iv}); + Value vSrc = LLVM::LoadOp::create(body, loc, elemTy, gSrc); + Value sum = adt.createAddOp(body, loc, vSrc, vDst); + LLVM::StoreOp::create(body, loc, sum, gSrc); + } + LLVM::StoreOp::create(body, loc, zeroElem, gDst); + + return success(); + } +}; + std::optional findPtrSize(Value ptr) { if (auto allocOp = ptr.getDefiningOp()) return allocOp.getSize(); @@ -467,6 +572,7 @@ void mlir::enzyme::registerLLVMDialectAutoDiffInterface( *context); LLVM::InsertValueOp::attachInterface( *context); + LLVM::MemcpyOp::attachInterface(*context); LLVM::UnreachableOp::template attachInterface< detail::NoopRevAutoDiffInterface>(*context); LLVM::LLVMFuncOp::attachInterface( diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 5985a5eae92c..4ae2bf9e1a5b 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/PassManager.h" @@ -52,7 +53,7 @@ struct DifferentiatePass registry.insert(); + mlir::scf::SCFDialect, mlir::enzyme::EnzymeDialect>(); } static std::vector mode_from_fn(FunctionOpInterface fn, diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 4a3697cb74c1..e52d2c6f9a11 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -21,6 +21,7 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #define DEBUG_TYPE "enzyme" diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index 2cbc20453c56..9505541d344f 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> { "complex::ComplexDialect", "cf::ControlFlowDialect", "tensor::TensorDialect", + "scf::SCFDialect", "enzyme::EnzymeDialect", ]; let options = [ @@ -85,6 +86,7 @@ def DifferentiateWrapperPass : Pass<"enzyme-wrap"> { "arith::ArithDialect", "complex::ComplexDialect", "cf::ControlFlowDialect", + "scf::SCFDialect", "enzyme::EnzymeDialect" ]; let options = [ From 2348bf94b6c802a500dcde29bcd8e7353480693e Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 28 May 2026 03:54:36 -0500 Subject: [PATCH 2/7] add lit tests --- .../LLVMAutoDiffOpInterfaceImpl.cpp | 6 +- enzyme/test/MLIR/ReverseMode/memcpy.mlir | 64 +++++++++++++++++++ 2 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 enzyme/test/MLIR/ReverseMode/memcpy.mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index f85c85f25170..a30695e1f57d 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -369,9 +369,9 @@ struct MemcpyOpInterfaceReverse unsigned bytes = (elemTy.getIntOrFloatBitWidth() + 7) / 8; // n_elements = len / sizeof(elemTy) - Value byteSz = LLVM::ConstantOp::create( - builder, loc, len.getType(), - builder.getIntegerAttr(len.getType(), bytes)); + Value byteSz = + LLVM::ConstantOp::create(builder, loc, len.getType(), + builder.getIntegerAttr(len.getType(), bytes)); Value nInt = LLVM::SDivOp::create(builder, loc, len, byteSz); Value n = arith::IndexCastOp::create(builder, loc, builder.getIndexType(), nInt); diff --git a/enzyme/test/MLIR/ReverseMode/memcpy.mlir b/enzyme/test/MLIR/ReverseMode/memcpy.mlir new file mode 100644 index 000000000000..4b90efe66d88 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/memcpy.mlir @@ -0,0 +1,64 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +func.func @copy1(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) { + "llvm.intr.memcpy"(%dst, %src, %n) + <{arg_attrs = [{llvm.align = 8 : i64}], isVolatile = false}> + : (!llvm.ptr, !llvm.ptr, i64) -> () + return +} + +func.func @dcopy1(%dst: !llvm.ptr, %ddst: !llvm.ptr, + %src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) { + enzyme.autodiff @copy1(%dst, %ddst, %src, %dsrc, %n) { + activity = [#enzyme, + #enzyme, + #enzyme], + ret_activity = [] + } : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> () + return +} + +func.func @copy2(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) { + "llvm.intr.memcpy"(%dst, %src, %n) + <{arg_attrs = [{llvm.align = 8 : i64}, {llvm.align = 8 : i64}, {}], + isVolatile = false}> + : (!llvm.ptr, !llvm.ptr, i64) -> () + return +} + +func.func @dcopy2(%dst: !llvm.ptr, %ddst: !llvm.ptr, + %src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) { + enzyme.autodiff @copy2(%dst, %ddst, %src, %dsrc, %n) { + activity = [#enzyme, + #enzyme, + #enzyme], + ret_activity = [] + } : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> () + return +} + +// CHECK-LABEL: func.func private @diffecopy1( +// Forward: the primal memcpy is preserved. +// CHECK: "llvm.intr.memcpy" +// Reverse: n / sizeof(f64) element-wise loop, d_src[i] += d_dst[i]; d_dst[i]=0. +// CHECK: %[[BYTES:.+]] = llvm.mlir.constant(8 : i64) : i64 +// CHECK: llvm.sdiv %{{.+}}, %[[BYTES]] : i64 +// CHECK: arith.index_cast +// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: scf.for +// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64 +// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64 +// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64 +// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64 +// CHECK: %[[SUM:.+]] = arith.addf +// CHECK: llvm.store %[[SUM]], %{{.+}} : f64, !llvm.ptr +// CHECK: llvm.store %[[ZERO]], %{{.+}} : f64, !llvm.ptr + +// CHECK-LABEL: func.func private @diffecopy2( +// CHECK: "llvm.intr.memcpy" +// CHECK: llvm.mlir.constant(8 : i64) : i64 +// CHECK: %[[ZERO2:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: scf.for +// CHECK: %[[SUM2:.+]] = arith.addf +// CHECK: llvm.store %[[SUM2]], %{{.+}} : f64, !llvm.ptr +// CHECK: llvm.store %[[ZERO2]], %{{.+}} : f64, !llvm.ptr From ad9fd9a2759710dadd218d33b9ee34b617500a1b Mon Sep 17 00:00:00 2001 From: xys-syx Date: Thu, 28 May 2026 04:01:42 -0500 Subject: [PATCH 3/7] fmt --- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 979dc23b2f07..106bc2a4caf9 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -17,8 +17,8 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/PassManager.h" @@ -54,7 +54,8 @@ struct DifferentiatePass registry.insert(); + mlir::scf::SCFDialect, mlir::linalg::LinalgDialect, + mlir::enzyme::EnzymeDialect>(); } static std::vector mode_from_fn(FunctionOpInterface fn, From 205b86ec88cfa85058494e62a3244803aea36177 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 4 Jun 2026 04:13:07 -0500 Subject: [PATCH 4/7] InsertionGuard + setInsertionPoint --- .../LLVMAutoDiffOpInterfaceImpl.cpp | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index a30695e1f57d..b02f88cdf6da 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -382,21 +382,23 @@ struct MemcpyOpInterfaceReverse Type ptrTy = cp.getDst().getType(); auto forOp = scf::ForOp::create(builder, loc, c0, n, c1); - OpBuilder body(forOp.getBody()->getTerminator()); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(forOp.getBody()->getTerminator()); Value ivIdx = forOp.getInductionVar(); - Value iv = arith::IndexCastOp::create(body, loc, len.getType(), ivIdx); + Value iv = arith::IndexCastOp::create(builder, loc, len.getType(), ivIdx); - Value gDst = LLVM::GEPOp::create(body, loc, ptrTy, elemTy, dDst, + Value gDst = LLVM::GEPOp::create(builder, loc, ptrTy, elemTy, dDst, ArrayRef{iv}); - Value vDst = LLVM::LoadOp::create(body, loc, elemTy, gDst); + Value vDst = LLVM::LoadOp::create(builder, loc, elemTy, gDst); if (srcActive) { - Value gSrc = LLVM::GEPOp::create(body, loc, ptrTy, elemTy, dSrc, + Value gSrc = LLVM::GEPOp::create(builder, loc, ptrTy, elemTy, dSrc, ArrayRef{iv}); - Value vSrc = LLVM::LoadOp::create(body, loc, elemTy, gSrc); - Value sum = adt.createAddOp(body, loc, vSrc, vDst); - LLVM::StoreOp::create(body, loc, sum, gSrc); + Value vSrc = LLVM::LoadOp::create(builder, loc, elemTy, gSrc); + Value sum = adt.createAddOp(builder, loc, vSrc, vDst); + LLVM::StoreOp::create(builder, loc, sum, gSrc); } - LLVM::StoreOp::create(body, loc, zeroElem, gDst); + LLVM::StoreOp::create(builder, loc, zeroElem, gDst); return success(); } From e4f465899527005965c790efad6c91511c615f94 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Mon, 8 Jun 2026 00:09:45 -0500 Subject: [PATCH 5/7] add type analyis using upstream --- .../LLVMAutoDiffOpInterfaceImpl.cpp | 33 +++++++++++++++++-- enzyme/test/MLIR/ReverseMode/memcpy.mlir | 10 ++++-- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index b02f88cdf6da..fb71f1f7fd16 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -303,9 +303,30 @@ struct MemcpyOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { + // Light-weight "type analysis" for memcpy's opaque pointer operands. The + // priority is: + // 1. an explicit `enzyme.elem_type` TypeAttr on the memcpy op, + // 2. the producing `LLVM::AllocaOp` / `LLVM::GEPOp` of dst/src — both ops + // carry an explicit `elemType` in the opaque-pointer LLVM dialect, so + // they are the most reliable in-IR source of the element type, + // 3. a neighboring typed `LLVM::LoadOp` / `LLVM::StoreOp` on dst/src, + // 4. give up (return a null Type — the caller emits a diagnostic). static Type inferElemType(LLVM::MemcpyOp cp) { if (auto t = cp->getAttrOfType("enzyme.elem_type")) return t.getValue(); + + auto fromDef = [](Value p) -> Type { + if (auto alloca = p.getDefiningOp()) + return alloca.getElemType(); + if (auto gep = p.getDefiningOp()) + return gep.getElemType(); + return nullptr; + }; + if (Type t = fromDef(cp.getDst())) + return t; + if (Type t = fromDef(cp.getSrc())) + return t; + auto walk = [](Value p) -> Type { for (Operation *user : p.getUsers()) { if (auto ld = dyn_cast(user)) @@ -321,7 +342,8 @@ struct MemcpyOpInterfaceReverse return t; if (Type t = walk(cp.getSrc())) return t; - return Float64Type::get(cp.getContext()); + + return Type(); } SmallVector cacheValues(Operation *op, @@ -359,11 +381,16 @@ struct MemcpyOpInterfaceReverse Value len = gutils->popCache(caches[2], builder); Type elemTy = inferElemType(cp); + if (!elemTy) + return op->emitError() + << "memcpy reverse: cannot infer element type " + "(annotate enzyme.elem_type or lower to scalar stores)"; + auto adt = dyn_cast(elemTy); if (!adt || !elemTy.isIntOrFloat()) return op->emitError() - << "memcpy reverse: unsupported element type " << elemTy - << " (annotate enzyme.elem_type or lower to scalar stores)"; + << "memcpy reverse: element type " << elemTy + << " is not a supported scalar"; Location loc = op->getLoc(); unsigned bytes = (elemTy.getIntOrFloatBitWidth() + 7) / 8; diff --git a/enzyme/test/MLIR/ReverseMode/memcpy.mlir b/enzyme/test/MLIR/ReverseMode/memcpy.mlir index 4b90efe66d88..fbafc8756e72 100644 --- a/enzyme/test/MLIR/ReverseMode/memcpy.mlir +++ b/enzyme/test/MLIR/ReverseMode/memcpy.mlir @@ -1,7 +1,10 @@ // RUN: %eopt --enzyme %s | FileCheck %s func.func @copy1(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) { - "llvm.intr.memcpy"(%dst, %src, %n) + %c0 = llvm.mlir.constant(0 : i64) : i64 + %dst_p = llvm.getelementptr %dst[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64 + %src_p = llvm.getelementptr %src[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64 + "llvm.intr.memcpy"(%dst_p, %src_p, %n) <{arg_attrs = [{llvm.align = 8 : i64}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> () return @@ -19,7 +22,10 @@ func.func @dcopy1(%dst: !llvm.ptr, %ddst: !llvm.ptr, } func.func @copy2(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) { - "llvm.intr.memcpy"(%dst, %src, %n) + %c0 = llvm.mlir.constant(0 : i64) : i64 + %dst_p = llvm.getelementptr %dst[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64 + %src_p = llvm.getelementptr %src[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64 + "llvm.intr.memcpy"(%dst_p, %src_p, %n) <{arg_attrs = [{llvm.align = 8 : i64}, {llvm.align = 8 : i64}, {}], isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i64) -> () From 3f1cba4557c9aa08525e7d3ac777e61611f48eda Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Mon, 8 Jun 2026 00:21:50 -0500 Subject: [PATCH 6/7] fmt --- .../MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index fb71f1f7fd16..f55e6dba3f13 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -385,12 +385,11 @@ struct MemcpyOpInterfaceReverse return op->emitError() << "memcpy reverse: cannot infer element type " "(annotate enzyme.elem_type or lower to scalar stores)"; - + auto adt = dyn_cast(elemTy); if (!adt || !elemTy.isIntOrFloat()) - return op->emitError() - << "memcpy reverse: element type " << elemTy - << " is not a supported scalar"; + return op->emitError() << "memcpy reverse: element type " << elemTy + << " is not a supported scalar"; Location loc = op->getLoc(); unsigned bytes = (elemTy.getIntOrFloatBitWidth() + 7) / 8; From 05e12deb5d3c440d9a22b42821fc35462ad5507c Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Mon, 8 Jun 2026 22:49:05 -0500 Subject: [PATCH 7/7] fmt --- .../MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index f55e6dba3f13..52547f841edc 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -303,14 +303,6 @@ struct MemcpyOpInterfaceReverse : public ReverseAutoDiffOpInterface::ExternalModel { - // Light-weight "type analysis" for memcpy's opaque pointer operands. The - // priority is: - // 1. an explicit `enzyme.elem_type` TypeAttr on the memcpy op, - // 2. the producing `LLVM::AllocaOp` / `LLVM::GEPOp` of dst/src — both ops - // carry an explicit `elemType` in the opaque-pointer LLVM dialect, so - // they are the most reliable in-IR source of the element type, - // 3. a neighboring typed `LLVM::LoadOp` / `LLVM::StoreOp` on dst/src, - // 4. give up (return a null Type — the caller emits a diagnostic). static Type inferElemType(LLVM::MemcpyOp cp) { if (auto t = cp->getAttrOfType("enzyme.elem_type")) return t.getValue();