GPU fixes#32
Conversation
…ault enzyme machinery.
# Conflicts: # Project.toml # ext/DynamicPPLExt.jl # ext/LogDensityProblemsExt.jl # test/test-DEER-Turing-Logistic.jl # test/test-Turing-Integration.jl
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (22.56%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #32 +/- ##
==========================================
- Coverage 88.65% 80.29% -8.37%
==========================================
Files 6 8 +2
Lines 1040 1162 +122
==========================================
+ Hits 922 933 +11
- Misses 118 229 +111 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hey @rsenne, need a review here? |
|
Hi @gdalle yes I would love that. This is my first pass on this and it would be much appreciated! |
| @@ -0,0 +1,283 @@ | |||
| module EnzymeExt | |||
There was a problem hiding this comment.
@wsmoses is the best person to review this one
| return DI.prepare_pushforward( | ||
| f, _hvp_forward_backend(backend), x_template, (v_template,); strict=Val(false) | ||
| ) |
There was a problem hiding this comment.
Why not use DI.hvp directly here?
There was a problem hiding this comment.
If it is because you need a batched gradient, you may be interested in JuliaDiff/DifferentiationInterface.jl#991
There was a problem hiding this comment.
Note that you can already batch a small number of tangents by passing a tuple though
There was a problem hiding this comment.
Two reasons:
- The lack of batching (happy to help tackle this!)
- Perhaps, more importantly, every second order method I've tried breaks
Here are two MRWE if interested
#=
Direct `DI.hvp(logp, ...)` fails for every Mooncake-based config on GPU.
A single Mooncake reverse pass over the user's `gradlogp` succeeds.
Target (same shape as `test/test-GPU-AD-HVP.jl`):
logp(β) = -0.5 * (||β||^2 + ||Xβ||^2 / N)
gradlogp(β) = -(β + Xᵀ X β / N)
Hv = -v - Xᵀ X v / N
Run from repo root:
julia --project=test dev/di_hvp_gpu_mwe.jl
=#
push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))
using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Mooncake
import ForwardDiff
import CUDA
import LinearAlgebra: dot
using Random
CUDA.functional() || error("requires functional CUDA device")
function _logp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
-oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end
function _gradlogp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
Y = pmcmc_matmul(transpose(X), Xβ)
Y = Y ./ N
Y = Y .+ β
return -Y
end
const D = 20
const Ndat = 64
rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)
β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)
logp(β) = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)
ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])
println()
println("== reverse-on-grad (Mooncake gradient of dot ∘ gradlogp) ==")
try
closure = β -> dot(gradlogp(β), v_gpu)
g = DI.gradient(closure, AutoMooncake(; config=nothing), β_gpu)
g_vec = g isa Tuple ? first(g) : g
Hv = Array(g_vec)
println(" Hv[1:3] = ", Hv[1:3])
println(" matches: ", isapprox(Hv, ref; atol=1f-3, rtol=1f-2))
catch e
println(" FAILED: ", first(sprint(showerror, e), 600))
end
println()
configs = [
"AutoMooncake (DI default SecondOrder)" => AutoMooncake(; config=nothing),
"SecondOrder(AutoForwardDiff, AutoMooncake)" => SecondOrder(AutoForwardDiff(), AutoMooncake(; config=nothing)),
"SecondOrder(AutoMooncake, AutoForwardDiff)" => SecondOrder(AutoMooncake(; config=nothing), AutoForwardDiff()),
]
for (lbl, b) in configs
print("== $lbl ==\n ")
try
prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
Hv = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
vec_ = Hv isa Tuple ? first(Hv) : Hv
host = Array(vec_)
ok = isapprox(host, ref; atol=1f-3, rtol=1f-2)
println("Hv[1:3]: ", host[1:3], " matches: ", ok)
catch e
println("FAILED: ", first(sprint(showerror, e), 600))
end
end#=
Direct `DI.hvp(logp, AutoEnzyme(...))` fails for every config on GPU.
A forward-mode Enzyme pushforward of the user's `gradlogp` succeeds.
Same target as `dev/di_hvp_gpu_mwe.jl` / `test/test-GPU-AD-HVP.jl`:
logp(β) = -0.5 * (||β||^2 + ||Xβ||^2 / N)
gradlogp(β) = -(β + Xᵀ X β / N)
Hv = -v - Xᵀ X v / N
Five Enzyme variants tested. Three hard-abort the Julia process during
Enzyme compilation and therefore cannot share a script with the rest:
AutoEnzyme() — hard abort during compile
AutoEnzyme(mode=Reverse) — hard abort during compile
SecondOrder(AutoEnzyme(Forward),
AutoEnzyme(Reverse)) — hard abort during compile
To reproduce the aborts run one variant at a time. The two below throw
catchable exceptions and run cleanly in the same process.
Run from repo root:
julia --project=test dev/di_hvp_gpu_enzyme_mwe.jl
=#
push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))
using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Enzyme
import CUDA
using Random
CUDA.functional() || error("requires functional CUDA device")
function _logp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
-oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end
function _gradlogp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
Y = pmcmc_matmul(transpose(X), Xβ)
Y = Y ./ N
Y = Y .+ β
return -Y
end
const D = 20
const Ndat = 64
rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)
β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)
logp(β) = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)
ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])
println()
println("== forward-on-grad (Enzyme.Forward pushforward of gradlogp) ==")
try
be = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)
prep = DI.prepare_pushforward(gradlogp, be, β_gpu, (v_gpu,); strict=Val(false))
Hv = DI.pushforward(gradlogp, prep, be, β_gpu, (v_gpu,))
vec_ = Hv isa Tuple ? first(Hv) : Hv
host = Array(vec_)
println(" Hv[1:3] = ", host[1:3])
println(" matches: ", isapprox(host, ref; atol=1f-3, rtol=1f-2))
catch e
println(" FAILED: ", first(sprint(showerror, e), 600))
end
println()
configs = [
"AutoEnzyme(mode=Forward)" =>
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
"SecondOrder(AutoEnzyme(Reverse), AutoEnzyme(Forward))" =>
SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Const),
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)),
]
for (lbl, b) in configs
print("== $lbl ==\n ")
try
prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
Hv = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
vec_ = Hv isa Tuple ? first(Hv) : Hv
host = Array(vec_)
ok = isapprox(host, ref; atol=1f-3, rtol=1f-2)
println("Hv[1:3]: ", host[1:3], " matches: ", ok)
catch e
println("FAILED: ", first(sprint(showerror, e), 800))
end
endTotally and completely possible I am doing something fundamentally wrong--in which case please correct me
| (c::_BatchHvpReverseClosure)(X, V) = pmcmc_dotsum(c.grad_batch(X), V) | ||
|
|
||
| #= | ||
| Pick the AD-HVP fallback strategy from the user's backend. |
There was a problem hiding this comment.
This whole logic exists in DI so it would be great if we added batching there and that alleviated your troubles here
| We bundle the closure with the prep so `prepare_gradient` and `gradient` | ||
| see the same function instance (DI keys preparations on function identity). |
There was a problem hiding this comment.
That's a constant pain in my behind, you can check out the hvp.jl file in DI and the ForwardDiff extension to see how much I struggle
| Owned wrappers: identical semantics to their Base counterparts, but provide | ||
| stable function identities for backend-specific AD rules in `ext/EnzymeExt.jl` | ||
| without committing type piracy on `Base.*` / `Base.dot` / `Base.sum`. User | ||
| model code that wants those rules to fire (notably on GPU, where the default | ||
| rules trip on cuBLAS gc-transition bundles) should call these instead. | ||
|
|
||
| Why both a matmul and reductions: the GPU AD-HVP path runs Enzyme reverse mode | ||
| through `gradlogp` *and* through a scalar reduction wrapping it (see DEER's | ||
| `_HvpReverseClosure`). The reduction is what actually emits the | ||
| `cuMemcpyDtoHAsync_v2` that crashes Enzyme. Owning both the matmul AND the | ||
| reduction lets Enzyme treat each as opaque, so neither bundle ever enters | ||
| its IR. |
There was a problem hiding this comment.
I'm really surprised by Enzyme reacting badly to LinearAlgebra, I suspect you might be holding it wrong. @wsmoses send help
| LinearAlgebra = "1" | ||
| LogDensityProblems = "2" | ||
| MCMCChains = "7.7.0" | ||
| Mooncake = "0.5.26" |
There was a problem hiding this comment.
If you're interested in second-order Mooncake features through DI, you wanna track JuliaDiff/DifferentiationInterface.jl#990.
|
hey @gdalle I addressed the two addressable point (e.g., type instability and nixing symbols) thoughts? Also, included 2 MWRE above. If I'm not crazy--I can open issues for these on the respective repos though I'm not confident yet till someone who knows better than i says so. Also, happy to help tackle the linked DI issue for batching--it seems reasonably approachable? Let me know what you think! |
This branch lands the GPU + AD-backend overhaul on top of main. The big themes:
Modular AD via DifferentiationInterface
EnzymeExt: GPU-safe Enzyme rules
Tests
Needed Changes Prior to Merging
Resolves #29 and provides a workaround to #25