Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -297,6 +299,129 @@ struct InsertValueOpInterfaceReverse
MGradientUtilsReverse *gutils) const {}
};

struct MemcpyOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<MemcpyOpInterfaceReverse,
LLVM::MemcpyOp> {

static Type inferElemType(LLVM::MemcpyOp cp) {
if (auto t = cp->getAttrOfType<TypeAttr>("enzyme.elem_type"))
return t.getValue();

auto fromDef = [](Value p) -> Type {
if (auto alloca = p.getDefiningOp<LLVM::AllocaOp>())
return alloca.getElemType();
if (auto gep = p.getDefiningOp<LLVM::GEPOp>())
return gep.getElemType();
return nullptr;
};
if (Type t = fromDef(cp.getDst()))
return t;
if (Type t = fromDef(cp.getSrc()))
return t;

auto walk = [](Value p) -> Type {
for (Operation *user : p.getUsers()) {
if (auto ld = dyn_cast<LLVM::LoadOp>(user))
if (isa<AutoDiffTypeInterface>(ld.getType()))
return ld.getType();
if (auto st = dyn_cast<LLVM::StoreOp>(user))
if (isa<AutoDiffTypeInterface>(st.getValue().getType()))
return st.getValue().getType();
}
return nullptr;
};
if (Type t = walk(cp.getDst()))
return t;
if (Type t = walk(cp.getSrc()))
return t;

return Type();
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
auto cp = cast<LLVM::MemcpyOp>(op);
if (gutils->isConstantValue(cp.getDst()))
return {};
bool srcActive = !gutils->isConstantValue(cp.getSrc());
OpBuilder cb(gutils->getNewFromOriginal(op));
SmallVector<Value> caches;
caches.push_back(
gutils->initAndPushCache(gutils->invertPointerM(cp.getDst(), cb), cb));
caches.push_back(gutils->initAndPushCache(
srcActive ? gutils->invertPointerM(cp.getSrc(), cb)
: gutils->getNewFromOriginal(cp.getSrc()),
cb));
caches.push_back(
gutils->initAndPushCache(gutils->getNewFromOriginal(cp.getLen()), cb));
return caches;
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}

LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto cp = cast<LLVM::MemcpyOp>(op);
if (gutils->isConstantValue(cp.getDst()))
return success();
bool srcActive = !gutils->isConstantValue(cp.getSrc());

Value dDst = gutils->popCache(caches[0], builder);
Value dSrc = gutils->popCache(caches[1], builder);
Value len = gutils->popCache(caches[2], builder);

Type elemTy = inferElemType(cp);
if (!elemTy)
return op->emitError()
<< "memcpy reverse: cannot infer element type "
"(annotate enzyme.elem_type or lower to scalar stores)";

auto adt = dyn_cast<AutoDiffTypeInterface>(elemTy);
if (!adt || !elemTy.isIntOrFloat())
return op->emitError() << "memcpy reverse: element type " << elemTy
<< " is not a supported scalar";

Location loc = op->getLoc();
unsigned bytes = (elemTy.getIntOrFloatBitWidth() + 7) / 8;

// n_elements = len / sizeof(elemTy)
Value byteSz =
LLVM::ConstantOp::create(builder, loc, len.getType(),
builder.getIntegerAttr(len.getType(), bytes));
Value nInt = LLVM::SDivOp::create(builder, loc, len, byteSz);
Value n =
arith::IndexCastOp::create(builder, loc, builder.getIndexType(), nInt);

Value c0 = arith::ConstantIndexOp::create(builder, loc, 0);
Value c1 = arith::ConstantIndexOp::create(builder, loc, 1);
Value zeroElem = adt.createNullValue(builder, loc);
Type ptrTy = cp.getDst().getType();

auto forOp = scf::ForOp::create(builder, loc, c0, n, c1);

OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(forOp.getBody()->getTerminator());
Value ivIdx = forOp.getInductionVar();
Value iv = arith::IndexCastOp::create(builder, loc, len.getType(), ivIdx);

Value gDst = LLVM::GEPOp::create(builder, loc, ptrTy, elemTy, dDst,
ArrayRef<LLVM::GEPArg>{iv});
Value vDst = LLVM::LoadOp::create(builder, loc, elemTy, gDst);
if (srcActive) {
Value gSrc = LLVM::GEPOp::create(builder, loc, ptrTy, elemTy, dSrc,
ArrayRef<LLVM::GEPArg>{iv});
Value vSrc = LLVM::LoadOp::create(builder, loc, elemTy, gSrc);
Value sum = adt.createAddOp(builder, loc, vSrc, vDst);
LLVM::StoreOp::create(builder, loc, sum, gSrc);
}
LLVM::StoreOp::create(builder, loc, zeroElem, gDst);

return success();
}
};

std::optional<Value> findPtrSize(Value ptr) {
if (auto allocOp = ptr.getDefiningOp<llvm_ext::AllocOp>())
return allocOp.getSize();
Expand Down Expand Up @@ -467,6 +592,7 @@ void mlir::enzyme::registerLLVMDialectAutoDiffInterface(
*context);
LLVM::InsertValueOp::attachInterface<InsertValueOpInterfaceReverse>(
*context);
LLVM::MemcpyOp::attachInterface<MemcpyOpInterfaceReverse>(*context);
LLVM::UnreachableOp::template attachInterface<
detail::NoopRevAutoDiffInterface<LLVM::UnreachableOp>>(*context);
LLVM::LLVMFuncOp::attachInterface<AutoDiffLLVMFuncOpFunctionInterface>(
Expand Down
4 changes: 3 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -53,7 +54,8 @@ struct DifferentiatePass

registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect,
mlir::linalg::LinalgDialect, mlir::enzyme::EnzymeDialect>();
mlir::scf::SCFDialect, mlir::linalg::LinalgDialect,
mlir::enzyme::EnzymeDialect>();
}

static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Interfaces/FunctionInterfaces.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"

#define DEBUG_TYPE "enzyme"

Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def DifferentiatePass : Pass<"enzyme"> {
"complex::ComplexDialect",
"cf::ControlFlowDialect",
"tensor::TensorDialect",
"scf::SCFDialect",
"enzyme::EnzymeDialect",
];
let options = [
Expand Down Expand Up @@ -85,6 +86,7 @@ def DifferentiateWrapperPass : Pass<"enzyme-wrap"> {
"arith::ArithDialect",
"complex::ComplexDialect",
"cf::ControlFlowDialect",
"scf::SCFDialect",
"enzyme::EnzymeDialect"
];
let options = [
Expand Down
70 changes: 70 additions & 0 deletions enzyme/test/MLIR/ReverseMode/memcpy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: %eopt --enzyme %s | FileCheck %s

func.func @copy1(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) {
%c0 = llvm.mlir.constant(0 : i64) : i64
%dst_p = llvm.getelementptr %dst[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64
%src_p = llvm.getelementptr %src[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64
"llvm.intr.memcpy"(%dst_p, %src_p, %n)
<{arg_attrs = [{llvm.align = 8 : i64}], isVolatile = false}>
: (!llvm.ptr, !llvm.ptr, i64) -> ()
return
}

func.func @dcopy1(%dst: !llvm.ptr, %ddst: !llvm.ptr,
%src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) {
enzyme.autodiff @copy1(%dst, %ddst, %src, %dsrc, %n) {
activity = [#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_const>],
ret_activity = []
} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> ()
return
}

func.func @copy2(%dst: !llvm.ptr, %src: !llvm.ptr, %n: i64) {
%c0 = llvm.mlir.constant(0 : i64) : i64
%dst_p = llvm.getelementptr %dst[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64
%src_p = llvm.getelementptr %src[%c0] : (!llvm.ptr, i64) -> !llvm.ptr, f64
"llvm.intr.memcpy"(%dst_p, %src_p, %n)
<{arg_attrs = [{llvm.align = 8 : i64}, {llvm.align = 8 : i64}, {}],
isVolatile = false}>
: (!llvm.ptr, !llvm.ptr, i64) -> ()
return
}

func.func @dcopy2(%dst: !llvm.ptr, %ddst: !llvm.ptr,
%src: !llvm.ptr, %dsrc: !llvm.ptr, %n: i64) {
enzyme.autodiff @copy2(%dst, %ddst, %src, %dsrc, %n) {
activity = [#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_dup>,
#enzyme<activity enzyme_const>],
ret_activity = []
} : (!llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr, i64) -> ()
return
}

// CHECK-LABEL: func.func private @diffecopy1(
// Forward: the primal memcpy is preserved.
// CHECK: "llvm.intr.memcpy"
// Reverse: n / sizeof(f64) element-wise loop, d_src[i] += d_dst[i]; d_dst[i]=0.
// CHECK: %[[BYTES:.+]] = llvm.mlir.constant(8 : i64) : i64
// CHECK: llvm.sdiv %{{.+}}, %[[BYTES]] : i64
// CHECK: arith.index_cast
// CHECK: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: scf.for
// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64
// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64
// CHECK: llvm.getelementptr %{{.+}}[%{{.+}}] : (!llvm.ptr, i64) -> !llvm.ptr, f64
// CHECK: llvm.load %{{.+}} : !llvm.ptr -> f64
// CHECK: %[[SUM:.+]] = arith.addf
// CHECK: llvm.store %[[SUM]], %{{.+}} : f64, !llvm.ptr
// CHECK: llvm.store %[[ZERO]], %{{.+}} : f64, !llvm.ptr

// CHECK-LABEL: func.func private @diffecopy2(
// CHECK: "llvm.intr.memcpy"
// CHECK: llvm.mlir.constant(8 : i64) : i64
// CHECK: %[[ZERO2:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: scf.for
// CHECK: %[[SUM2:.+]] = arith.addf
// CHECK: llvm.store %[[SUM2]], %{{.+}} : f64, !llvm.ptr
// CHECK: llvm.store %[[ZERO2]], %{{.+}} : f64, !llvm.ptr
Loading