Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions enzyme/Enzyme/DiffeGradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,35 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
return res;
}

bool DiffeGradientUtils::shouldUseAtomicShadowUpdate(Instruction *orig,
Value *origptr) const {
if (!AtomicAdd)
return false;

Value *base = getBaseObject(origptr);
auto arch = llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch();

// Stack allocations and local shadow objects cannot be raced by another CUDA
// work item, so a normal load/add/store is sufficient.
if (isa<AllocaInst>(base) &&
(arch == Triple::nvptx || arch == Triple::nvptx64 ||
arch == Triple::amdgcn))
return false;

// Backwards-only shadows are created in this function and do not escape. This
// assumes that all additional parallelism in this function is outlined.
if (backwardsOnlyShadows.find(base) != backwardsOnlyShadows.end())
return false;

// The elementwise-read contract states that each CUDA work item reads a
// distinct input element and therefore accumulates into a distinct shadow
// location. Avoiding atomics here is the root CUDA atomic-add optimization.
if (elementwiseReadForContext(orig, origptr))
return false;

return true;
}

AllocaInst *DiffeGradientUtils::getDifferential(Value *val) {
assert(mode != DerivativeMode::ForwardMode);
assert(mode != DerivativeMode::ForwardModeSplit);
Expand Down Expand Up @@ -994,27 +1023,10 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
dif = applyChainRule(addingType, BuilderM, rule, dif);
}

auto TmpOrig = getBaseObject(origptr);

// atomics
bool Atomic = AtomicAdd;
bool Atomic = shouldUseAtomicShadowUpdate(orig, origptr);
auto Arch = llvm::Triple(newFunc->getParent()->getTargetTriple()).getArch();

// No need to do atomic on local memory for CUDA since it can't be raced
// upon
if (isa<AllocaInst>(TmpOrig) &&
(Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
Arch == Triple::amdgcn)) {
Atomic = false;
}
// Moreover no need to do atomic on local shadows regardless since they are
// not captured/escaping and created in this function. This assumes that
// all additional parallelism in this function is outlined.
if (backwardsOnlyShadows.find(TmpOrig) != backwardsOnlyShadows.end())
Atomic = false;
if (Atomic && elementwiseReadForContext(orig, origptr))
Atomic = false;

if (Atomic) {
// For amdgcn constant AS is 4 and if the primal is in it we need to cast
// the derivative value to AS 1
Expand Down
3 changes: 3 additions & 0 deletions enzyme/Enzyme/DiffeGradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class DiffeGradientUtils final : public GradientUtils {
DerivativeMode mode, bool runtimeActivity, bool strongZero,
unsigned width, bool omp);

bool shouldUseAtomicShadowUpdate(llvm::Instruction *orig,
llvm::Value *origptr) const;

public:
/// Whether to free memory in reverse pass or split forward.
bool FreeMemory;
Expand Down
50 changes: 50 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/cuda-elementwise-atomic.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -preserve-nvvm -enzyme -enzyme-detect-readthrow=0 -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -enzyme-detect-readthrow=0 -passes="preserve-nvvm,enzyme" -S | FileCheck %s

; ModuleID = 'cuda-elementwise-atomic.ll'
source_filename = "cuda-elementwise-atomic.ll"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16-v32:32-v64:64:64-v128:128:128-n16:32:64-ni:10:11:12:13"
target triple = "nvptx64-nvidia-cuda"

@.str.enzyme_elementwise_read = private unnamed_addr constant [24 x i8] c"enzyme_elementwise_read\00", section "llvm.metadata"
@.str.file = private unnamed_addr constant [27 x i8] c"cuda-elementwise-atomic.ll\00", section "llvm.metadata"
@llvm.global.annotations = appending global [1 x { i8*, i8*, i8*, i32, i8* }] [{ i8*, i8*, i8*, i32, i8* } { i8* bitcast (float (float addrspace(1)*)* @vmul_elementwise to i8*), i8* getelementptr inbounds ([24 x i8], [24 x i8]* @.str.enzyme_elementwise_read, i32 0, i32 0), i8* getelementptr inbounds ([27 x i8], [27 x i8]* @.str.file, i32 0, i32 0), i32 1, i8* null }], section "llvm.metadata"

declare float @llvm.nvvm.ldg.global.f.f32.p1f32(float addrspace(1)* nocapture, i32)

define float @vmul_elementwise(float addrspace(1)* %inp) {
top:
%ld = call float @llvm.nvvm.ldg.global.f.f32.p1f32(float addrspace(1)* %inp, i32 4)
ret float %ld
}

define float @vmul_unknown(float addrspace(1)* %inp) {
top:
%ld = call float @llvm.nvvm.ldg.global.f.f32.p1f32(float addrspace(1)* %inp, i32 4)
ret float %ld
}

define float @test_elementwise(float addrspace(1)* %inp, float addrspace(1)* %dinp) {
entry:
%0 = tail call float (float (float addrspace(1)*)*, ...) @__enzyme_autodiff(float (float addrspace(1)*)* nonnull @vmul_elementwise, float addrspace(1)* %inp, float addrspace(1)* %dinp)
ret float %0
}

define float @test_unknown(float addrspace(1)* %inp, float addrspace(1)* %dinp) {
entry:
%0 = tail call float (float (float addrspace(1)*)*, ...) @__enzyme_autodiff(float (float addrspace(1)*)* nonnull @vmul_unknown, float addrspace(1)* %inp, float addrspace(1)* %dinp)
ret float %0
}

declare float @__enzyme_autodiff(float (float addrspace(1)*)*, ...)

; CHECK-LABEL: define internal void @diffevmul_elementwise(
; CHECK-NOT: atomicrmw
; CHECK: load float, {{.*}}%"inp'"
; CHECK: fadd fast float
; CHECK: store float {{.*}}%"inp'"
; CHECK: ret void

; CHECK-LABEL: define internal void @diffevmul_unknown(
; CHECK: atomicrmw fadd {{.*}}%"inp'", float %differeturn monotonic
; CHECK: ret void
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
#include <cuda_runtime.h>

#include <math.h>
#include <stdio.h>
#include <stdlib.h>

#define CUDA_CHECK(expr) \
do { \
cudaError_t err__ = (expr); \
if (err__ != cudaSuccess) { \
fprintf(stderr, "%s failed: %s\n", #expr, cudaGetErrorString(err__)); \
return 2; \
} \
} while (0)

__device__ __attribute__((noinline)) void square_atomic(float *x, float *y,
int i) {
y[i] = x[i] * x[i];
}

__device__ __attribute__((annotate("enzyme_elementwise_read")))
__attribute__((noinline)) void
square_elementwise(float *x, float *y, int i) {
y[i] = x[i] * x[i];
}

typedef void (*square_fn)(float *, float *, int);
extern __device__ void __enzyme_autodiff(square_fn, int, float *, float *, int,
float *, float *, int, int);
extern __device__ int enzyme_dup;
extern __device__ int enzyme_const;

__global__ void init_arrays(float *x, float *dx, float *y, float *dy, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n)
return;

x[i] = 1.0f + 0.001f * (float)(i & 1023);
dx[i] = 0.0f;
y[i] = 0.0f;
dy[i] = 1.0f;
}

__global__ void reset_gradients(float *dx, float *y, float *dy, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n)
return;

dx[i] = 0.0f;
y[i] = 0.0f;
dy[i] = 1.0f;
}

__global__ void square_grad_atomic(float *x, float *dx, float *y, float *dy,
int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n)
return;

__enzyme_autodiff(square_atomic, enzyme_dup, x, dx, enzyme_dup, y, dy,
enzyme_const, i);
}

__global__ void square_grad_elementwise(float *x, float *dx, float *y,
float *dy, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= n)
return;

__enzyme_autodiff(square_elementwise, enzyme_dup, x, dx, enzyme_dup, y, dy,
enzyme_const, i);
}

static int verify_result(const char *label, float *x, float *dx, float *y,
int n) {
float *hx = (float *)malloc((size_t)n * sizeof(float));
float *hdx = (float *)malloc((size_t)n * sizeof(float));
float *hy = (float *)malloc((size_t)n * sizeof(float));
if (!hx || !hdx || !hy) {
fprintf(stderr, "host allocation failed\n");
free(hx);
free(hdx);
free(hy);
return 3;
}

CUDA_CHECK(
cudaMemcpy(hx, x, (size_t)n * sizeof(float), cudaMemcpyDeviceToHost));
CUDA_CHECK(
cudaMemcpy(hdx, dx, (size_t)n * sizeof(float), cudaMemcpyDeviceToHost));
CUDA_CHECK(
cudaMemcpy(hy, y, (size_t)n * sizeof(float), cudaMemcpyDeviceToHost));

for (int i = 0; i < n; ++i) {
float expected_y = hx[i] * hx[i];
float expected_dx = 2.0f * hx[i];
if (fabsf(hy[i] - expected_y) > 2.0e-5f ||
fabsf(hdx[i] - expected_dx) > 2.0e-5f) {
fprintf(stderr,
"%s mismatch at %d: x=%g y=%g expected_y=%g dx=%g "
"expected_dx=%g\n",
label, i, hx[i], hy[i], expected_y, hdx[i], expected_dx);
free(hx);
free(hdx);
free(hy);
return 4;
}
}

free(hx);
free(hdx);
free(hy);
return 0;
}

static int time_atomic(float *x, float *dx, float *y, float *dy, int n,
int blocks, int threads, int reps, float *ms) {
cudaEvent_t start;
cudaEvent_t stop;
CUDA_CHECK(cudaEventCreate(&start));
CUDA_CHECK(cudaEventCreate(&stop));
reset_gradients<<<blocks, threads>>>(dx, y, dy, n);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(start));
for (int i = 0; i < reps; ++i)
square_grad_atomic<<<blocks, threads>>>(x, dx, y, dy, n);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(stop));
CUDA_CHECK(cudaEventSynchronize(stop));
CUDA_CHECK(cudaEventElapsedTime(ms, start, stop));
*ms /= (float)reps;
CUDA_CHECK(cudaEventDestroy(start));
CUDA_CHECK(cudaEventDestroy(stop));
return 0;
}

static int time_elementwise(float *x, float *dx, float *y, float *dy, int n,
int blocks, int threads, int reps, float *ms) {
cudaEvent_t start;
cudaEvent_t stop;
CUDA_CHECK(cudaEventCreate(&start));
CUDA_CHECK(cudaEventCreate(&stop));
reset_gradients<<<blocks, threads>>>(dx, y, dy, n);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(start));
for (int i = 0; i < reps; ++i)
square_grad_elementwise<<<blocks, threads>>>(x, dx, y, dy, n);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaEventRecord(stop));
CUDA_CHECK(cudaEventSynchronize(stop));
CUDA_CHECK(cudaEventElapsedTime(ms, start, stop));
*ms /= (float)reps;
CUDA_CHECK(cudaEventDestroy(start));
CUDA_CHECK(cudaEventDestroy(stop));
return 0;
}

int main(int argc, char **argv) {
int n = argc > 1 ? atoi(argv[1]) : (1 << 22);
int reps = argc > 2 ? atoi(argv[2]) : 50;
int threads = 256;
int blocks = (n + threads - 1) / threads;

float *x = nullptr;
float *dx = nullptr;
float *y = nullptr;
float *dy = nullptr;
CUDA_CHECK(cudaMalloc(&x, (size_t)n * sizeof(float)));
CUDA_CHECK(cudaMalloc(&dx, (size_t)n * sizeof(float)));
CUDA_CHECK(cudaMalloc(&y, (size_t)n * sizeof(float)));
CUDA_CHECK(cudaMalloc(&dy, (size_t)n * sizeof(float)));

init_arrays<<<blocks, threads>>>(x, dx, y, dy, n);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());

square_grad_atomic<<<blocks, threads>>>(x, dx, y, dy, n);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
int verify = verify_result("atomic", x, dx, y, n);
if (verify != 0)
return verify;

reset_gradients<<<blocks, threads>>>(dx, y, dy, n);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
square_grad_elementwise<<<blocks, threads>>>(x, dx, y, dy, n);
CUDA_CHECK(cudaGetLastError());
CUDA_CHECK(cudaDeviceSynchronize());
verify = verify_result("elementwise", x, dx, y, n);
if (verify != 0)
return verify;

float atomic_ms = 0.0f;
float elementwise_ms = 0.0f;
int timing = time_atomic(x, dx, y, dy, n, blocks, threads, reps, &atomic_ms);
if (timing != 0)
return timing;
timing =
time_elementwise(x, dx, y, dy, n, blocks, threads, reps, &elementwise_ms);
if (timing != 0)
return timing;

printf("n=%d reps=%d atomic_ms=%.6f elementwise_ms=%.6f speedup=%.3fx\n", n,
reps, atomic_ms, elementwise_ms, atomic_ms / elementwise_ms);
if (elementwise_ms >= atomic_ms) {
fprintf(stderr,
"expected elementwise path to be faster than atomic path\n");
return 5;
}

CUDA_CHECK(cudaFree(x));
CUDA_CHECK(cudaFree(dx));
CUDA_CHECK(cudaFree(y));
CUDA_CHECK(cudaFree(dy));
return 0;
}
Loading