Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,35 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}
}

// v6 NULL-shadow guard (restored): Cromwell's
// WallFcnMomentumWall<12> reverse-mode adjoint loads a shadow
// pointer that may be NULL at runtime when the upstream may-
// aliased shadow alloca was skipped, then segfaults on the
// following load/setPtrDiffe/addToDiffe sequence. The
// SubTransferHelper srcConstant memcpy fix alone is not
// sufficient -- restore the runtime IsNotNull branch around
// the slice-store body so a NULL shadow short-circuits to the
// existing merge epilogue. Reuses the surrounding loop's
// `merge` variable so the existing `if (merge)` close at the
// end of the storeSize loop emits the join block.
if (!merge) {
auto shadow_ptr_nc = lookup(
gutils->invertPointerM(orig_ptr, Builder2), Builder2);
Value *shadow_ptr_v = shadow_ptr_nc;
if (gutils->getWidth() != 1) {
shadow_ptr_v =
gutils->extractMeta(Builder2, shadow_ptr_v, 0);
}
Value *notnull = Builder2.CreateIsNotNull(shadow_ptr_v);
BasicBlock *current = Builder2.GetInsertBlock();
BasicBlock *conditional = gutils->addReverseBlock(
current, current->getName() + "_nnactive");
merge = gutils->addReverseBlock(
conditional, current->getName() + "_nnmerge");
Builder2.CreateCondBr(notnull, conditional, merge);
Builder2.SetInsertPoint(conditional);
}

if (constantval) {
gutils->setPtrDiffe(
&I, orig_ptr,
Expand Down
58 changes: 38 additions & 20 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8973,27 +8973,45 @@ void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode,
// Don't zero in forward mode.
if (mode != DerivativeMode::ForwardModeSplit) {

Value *args[] = {
shadowsLookedUp ? shadow_dst
: gutils->lookupM(shadow_dst, Builder2),
ConstantInt::get(Type::getInt8Ty(MTI->getContext()), 0),
gutils->lookupM(length, Builder2),
ConstantInt::getFalse(MTI->getContext())};

if (args[0]->getType()->isIntegerTy())
args[0] = Builder2.CreateIntToPtr(args[0],
getInt8PtrTy(MTI->getContext()));

Type *tys[] = {args[0]->getType(), args[2]->getType()};
auto memsetIntr = getIntrinsicDeclaration(
MTI->getParent()->getParent()->getParent(), Intrinsic::memset,
tys);
auto cal = Builder2.CreateCall(memsetIntr, args);
cal->setCallingConv(memsetIntr->getCallingConv());
if (dstalign != 0) {
cal->addParamAttr(0, Attribute::getWithAlignment(MTI->getContext(),
Align(dstalign)));
// Option X root-cause fix: in srcConstant mode the shadow_dst
// alloca may legitimately hold pointer-typed entries (placed by
// visitCommonStore as scratch shadow pointers, or by other
// shadow-allocation paths). Upstream Enzyme memsets the byte range
// to zero — which clobbers those pointer entries to NULL and
// causes downstream reverse-pass loads to dereference NULL.
// Instead, copy the primal source bytes into the shadow
// (identity-shadow for inactive sources). Aliasing primal data
// into the shadow is safe here because the source is srcConstant,
// so no adjoint contribution should ever be accumulated through
// it; the runtime-activity / NULL-shadow guards in
// visitStoreInst suppress any stray write that would otherwise
// mutate primal memory.
Value *raw_shadow_dst = shadowsLookedUp
? shadow_dst
: gutils->lookupM(shadow_dst, Builder2);
// shadow_src in srcConstant mode is the primal source (set by the
// caller). Copy primal bytes verbatim into the shadow alloca.
Value *raw_shadow_src = shadowsLookedUp
? shadow_src
: gutils->lookupM(shadow_src, Builder2);
if (raw_shadow_dst->getType()->isIntegerTy())
raw_shadow_dst = Builder2.CreateIntToPtr(
raw_shadow_dst, getInt8PtrTy(MTI->getContext()));
if (raw_shadow_src->getType()->isIntegerTy())
raw_shadow_src = Builder2.CreateIntToPtr(
raw_shadow_src, getInt8PtrTy(MTI->getContext()));
Value *dstp = raw_shadow_dst;
Value *srcp = raw_shadow_src;
if (offset != 0) {
dstp = Builder2.CreateConstInBoundsGEP1_64(
Type::getInt8Ty(MTI->getContext()), dstp, offset);
srcp = Builder2.CreateConstInBoundsGEP1_64(
Type::getInt8Ty(MTI->getContext()), srcp, offset);
}
MaybeAlign dalign = dstalign ? MaybeAlign(dstalign) : MaybeAlign();
MaybeAlign salign = srcalign ? MaybeAlign(srcalign) : MaybeAlign();
Builder2.CreateMemCpy(dstp, dalign, srcp, salign,
gutils->lookupM(length, Builder2));
}

} else {
Expand Down
Loading