diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 472eccddd49..52547f841ed 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -17,7 +17,9 @@ #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" #include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -297,6 +299,129 @@ struct InsertValueOpInterfaceReverse MGradientUtilsReverse *gutils) const {} }; +struct MemcpyOpInterfaceReverse + : public ReverseAutoDiffOpInterface::ExternalModel { + + static Type inferElemType(LLVM::MemcpyOp cp) { + if (auto t = cp->getAttrOfType("enzyme.elem_type")) + return t.getValue(); + + auto fromDef = [](Value p) -> Type { + if (auto alloca = p.getDefiningOp()) + return alloca.getElemType(); + if (auto gep = p.getDefiningOp()) + return gep.getElemType(); + return nullptr; + }; + if (Type t = fromDef(cp.getDst())) + return t; + if (Type t = fromDef(cp.getSrc())) + return t; + + auto walk = [](Value p) -> Type { + for (Operation *user : p.getUsers()) { + if (auto ld = dyn_cast(user)) + if (isa(ld.getType())) + return ld.getType(); + if (auto st = dyn_cast(user)) + if (isa(st.getValue().getType())) + return st.getValue().getType(); + } + return nullptr; + }; + if (Type t = walk(cp.getDst())) + return t; + if (Type t = walk(cp.getSrc())) + return t; + + return Type(); + } + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + auto cp = cast(op); + if (gutils->isConstantValue(cp.getDst())) + return {}; + bool srcActive = !gutils->isConstantValue(cp.getSrc()); + OpBuilder cb(gutils->getNewFromOriginal(op)); + SmallVector caches; + caches.push_back( + gutils->initAndPushCache(gutils->invertPointerM(cp.getDst(), cb), cb)); + caches.push_back(gutils->initAndPushCache( + srcActive ? gutils->invertPointerM(cp.getSrc(), cb) + : gutils->getNewFromOriginal(cp.getSrc()), + cb)); + caches.push_back( + gutils->initAndPushCache(gutils->getNewFromOriginal(cp.getLen()), cb)); + return caches; + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} + + LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + auto cp = cast(op); + if (gutils->isConstantValue(cp.getDst())) + return success(); + bool srcActive = !gutils->isConstantValue(cp.getSrc()); + + Value dDst = gutils->popCache(caches[0], builder); + Value dSrc = gutils->popCache(caches[1], builder); + Value len = gutils->popCache(caches[2], builder); + + Type elemTy = inferElemType(cp); + if (!elemTy) + return op->emitError() + << "memcpy reverse: cannot infer element type " + "(annotate enzyme.elem_type or lower to scalar stores)"; + + auto adt = dyn_cast(elemTy); + if (!adt || !elemTy.isIntOrFloat()) + return op->emitError() << "memcpy reverse: element type " << elemTy + << " is not a supported scalar"; + + Location loc = op->getLoc(); + unsigned bytes = (elemTy.getIntOrFloatBitWidth() + 7) / 8; + + // n_elements = len / sizeof(elemTy) + Value byteSz = + LLVM::ConstantOp::create(builder, loc, len.getType(), + builder.getIntegerAttr(len.getType(), bytes)); + Value nInt = LLVM::SDivOp::create(builder, loc, len, byteSz); + Value n = + arith::IndexCastOp::create(builder, loc, builder.getIndexType(), nInt); + + Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); + Value c1 = arith::ConstantIndexOp::create(builder, loc, 1); + Value zeroElem = adt.createNullValue(builder, loc); + Type ptrTy = cp.getDst().getType(); + + auto forOp = scf::ForOp::create(builder, loc, c0, n, c1); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(forOp.getBody()->getTerminator()); + Value ivIdx = forOp.getInductionVar(); + Value iv = arith::IndexCastOp::create(builder, loc, len.getType(), ivIdx); + + Value gDst = LLVM::GEPOp::create(builder, loc, ptrTy, elemTy, dDst, + ArrayRef{iv}); + Value vDst = LLVM::LoadOp::create(builder, loc, elemTy, gDst); + if (srcActive) { + Value gSrc = LLVM::GEPOp::create(builder, loc, ptrTy, elemTy, dSrc, + ArrayRef{iv}); + Value vSrc = LLVM::LoadOp::create(builder, loc, elemTy, gSrc); + Value sum = adt.createAddOp(builder, loc, vSrc, vDst); + LLVM::StoreOp::create(builder, loc, sum, gSrc); + } + LLVM::StoreOp::create(builder, loc, zeroElem, gDst); + + return success(); + } +}; + std::optional findPtrSize(Value ptr) { if (auto allocOp = ptr.getDefiningOp()) return allocOp.getSize(); @@ -467,6 +592,7 @@ void mlir::enzyme::registerLLVMDialectAutoDiffInterface( *context); LLVM::InsertValueOp::attachInterface( *context); + LLVM::MemcpyOp::attachInterface(*context); LLVM::UnreachableOp::template attachInterface< detail::NoopRevAutoDiffInterface>(*context); LLVM::LLVMFuncOp::attachInterface( diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 685bdae38bb..1a563833f98 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/PassManager.h" @@ -53,8 +54,8 @@ struct DifferentiatePass registry.insert(); + mlir::scf::SCFDialect, mlir::memref::MemRefDialect, + mlir::linalg::LinalgDialect, mlir::enzyme::EnzymeDialect>(); } static std::vector mode_from_fn(FunctionOpInterface fn, diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 4a3697cb74c..e52d2c6f9a1 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -21,6 +21,7 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #define DEBUG_TYPE "enzyme" diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index f835b800218..c846371a8ea 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> { "complex::ComplexDialect", "cf::ControlFlowDialect", "tensor::TensorDialect", + "scf::SCFDialect", "enzyme::EnzymeDialect", "linalg::LinalgDialect", "memref::MemRefDialect" @@ -87,6 +88,7 @@ def DifferentiateWrapperPass : Pass<"enzyme-wrap"> { "arith::ArithDialect", "complex::ComplexDialect", "cf::ControlFlowDialect", + "scf::SCFDialect", "enzyme::EnzymeDialect", "linalg::LinalgDialect", "memref::MemRefDialect" diff --git a/enzyme/test/MLIR/ReverseMode/memcpy.mlir b/enzyme/test/MLIR/ReverseMode/memcpy.mlir new file mode 100644 index 00000000000..fbafc8756e7 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/memcpy.mlir @@ -0,0 +1,70 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +func.func @copy1(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) { + %c0 = llvm.mlir.constant(0 : i64) : i64 + %dst_p = llvm.getelementptr %dst[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64 + %src_p = llvm.getelementptr %src[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64 + "llvm.intr.memcpy"(%dst_p, %src_p, %n) + <{arg_attrs = [{llvm.align = 8 : i64}], isVolatile = false}> + : (!llvm.ptr, !llvm.ptr, i64) -> () + return +} + +func.func @dcopy1(%dst: !llvm.ptr, %ddst: !llvm.ptr, + %src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) { + enzyme.autodiff @copy1(%dst, %ddst, %src, %dsrc, %n) { + activity = [#enzyme, + #enzyme, + #enzyme], + ret_activity = [] + } : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> () + return +} + +func.func @copy2(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) { + %c0 = llvm.mlir.constant(0 : i64) : i64 + %dst_p = llvm.getelementptr %dst[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64 + %src_p = llvm.getelementptr %src[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64 + "llvm.intr.memcpy"(%dst_p, %src_p, %n) + <{arg_attrs = [{llvm.align = 8 : i64}, {llvm.align = 8 : i64}, {}], + isVolatile = false}> + : (!llvm.ptr, !llvm.ptr, i64) -> () + return +} + +func.func @dcopy2(%dst: !llvm.ptr, %ddst: !llvm.ptr, + %src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) { + enzyme.autodiff @copy2(%dst, %ddst, %src, %dsrc, %n) { + activity = [#enzyme, + #enzyme, + #enzyme], + ret_activity = [] + } : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> () + return +} + +// CHECK-LABEL: func.func private @diffecopy1( +// Forward: the primal memcpy is preserved. +// CHECK: "llvm.intr.memcpy" +// Reverse: n / sizeof(f64) element-wise loop, d_src[i] += d_dst[i]; d_dst[i]=0. +// CHECK: %[[BYTES:.+]] = llvm.mlir.constant(8 : i64) : i64 +// CHECK: llvm.sdiv %{{.+}}, %[[BYTES]] : i64 +// CHECK: arith.index_cast +// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: scf.for +// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64 +// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64 +// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64 +// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64 +// CHECK: %[[SUM:.+]] = arith.addf +// CHECK: llvm.store %[[SUM]], %{{.+}} : f64, !llvm.ptr +// CHECK: llvm.store %[[ZERO]], %{{.+}} : f64, !llvm.ptr + +// CHECK-LABEL: func.func private @diffecopy2( +// CHECK: "llvm.intr.memcpy" +// CHECK: llvm.mlir.constant(8 : i64) : i64 +// CHECK: %[[ZERO2:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: scf.for +// CHECK: %[[SUM2:.+]] = arith.addf +// CHECK: llvm.store %[[SUM2]], %{{.+}} : f64, !llvm.ptr +// CHECK: llvm.store %[[ZERO2]], %{{.+}} : f64, !llvm.ptr