Skip to content

GPU fixes#32

Draft
rsenne wants to merge 9 commits into
mainfrom
GPU_Fixes_
Draft

GPU fixes#32
rsenne wants to merge 9 commits into
mainfrom
GPU_Fixes_

Conversation

@rsenne
Copy link
Copy Markdown
Owner

@rsenne rsenne commented May 9, 2026

This branch lands the GPU + AD-backend overhaul on top of main. The big themes:

Modular AD via DifferentiationInterface

  1. Make DifferentiationInterface the unified entry point into AD; Enzyme, Mooncake, and ForwardDiff are now [weakdeps] with extensions, not hard deps.
  2. Remove the old hard-coded Enzyme machinery (DEER.DEFAULT_BACKEND etc.) — backend is now a required kwarg on ParallelMALASampler, with the user choosing the AD package they want loaded.
  3. New DEER._hvp_strategy(backend) dispatch picks forward_on_grad (Enzyme/ForwardDiff) vs reverse_on_grad (Mooncake/Zygote) for the AD-HVP fallback. Forward-on-grad is the only AD-HVP path that's reliable on GPU.

EnzymeExt: GPU-safe Enzyme rules

  1. New ext/EnzymeExt.jl (~280 lines) defines native Enzyme forward / augmented_primal / reverse rules for pmcmc_matmul, pmcmc_dot, pmcmc_dotsum. These wrappers are treated as opaque by Enzyme's IR rewriter, sidestepping the gc-transition aborts Enzyme hits when lowering cuBLAS / cuMemcpyDtoHAsync_v2 bundles inside *(::CuArray, ::CuArray).
  2. DEER._hvp_forward_backend / _hvp_closure_backend overloads pin Enzyme.Forward + set_runtime_activity + function_annotation=Const for plain AutoEnzyme(). Runtime activity is load-bearing for composed pmcmc_matmul(transpose(X), pmcmc_matmul(X, β)) calls.
  3. Forward HVP now routes through forward-mode Enzyme with runtime activity (instead of forward-over-reverse, which crashes on MvNormal / Dirichlet log-pdfs).

Tests

  1. New test/test-GPU-AD-HVP.jl and test/test-GPU-Performance.jl covering the GPU AD-HVP path.
  2. New test/test-Owned-Matmul.jl exercises the custom Enzyme rules on pmcmc_matmul / pmcmc_dot / pmcmc_dotsum.
  3. Existing tests updated to pass an explicit backend=ADTypes.AutoEnzyme() (now required).

Needed Changes Prior to Merging

  1. Expand documentation to include a limitations section
  2. Expand documentation to explain new AD choices
  3. Add worked GPU examples as indicated in Add more worked examples #28
  4. Update changelog

Resolves #29 and provides a workaround to #25

@codecov
Copy link
Copy Markdown

codecov Bot commented May 9, 2026

Codecov Report

❌ Patch coverage is 22.56098% with 127 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.29%. Comparing base (9d3b516) to head (90a1ddd).

Files with missing lines Patch % Lines
ext/EnzymeExt.jl 3.09% 94 Missing ⚠️
src/DEER/DEER.jl 38.88% 33 Missing ⚠️

❌ Your patch check has failed because the patch coverage (22.56%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (80.29%) is below the target coverage (90.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #32      +/-   ##
==========================================
- Coverage   88.65%   80.29%   -8.37%     
==========================================
  Files           6        8       +2     
  Lines        1040     1162     +122     
==========================================
+ Hits          922      933      +11     
- Misses        118      229     +111     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gdalle
Copy link
Copy Markdown
Collaborator

gdalle commented May 10, 2026

Hey @rsenne, need a review here?

@rsenne
Copy link
Copy Markdown
Owner Author

rsenne commented May 10, 2026

Hi @gdalle yes I would love that. This is my first pass on this and it would be much appreciated!

Comment thread ext/EnzymeExt.jl
@@ -0,0 +1,283 @@
module EnzymeExt
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wsmoses is the best person to review this one

Comment thread src/DEER/DEER.jl
Comment on lines +95 to +97
return DI.prepare_pushforward(
f, _hvp_forward_backend(backend), x_template, (v_template,); strict=Val(false)
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use DI.hvp directly here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is because you need a batched gradient, you may be interested in JuliaDiff/DifferentiationInterface.jl#991

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that you can already batch a small number of tangents by passing a tuple though

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two reasons:

  1. The lack of batching (happy to help tackle this!)
  2. Perhaps, more importantly, every second order method I've tried breaks

Here are two MRWE if interested

#=
Direct `DI.hvp(logp, ...)` fails for every Mooncake-based config on GPU.
A single Mooncake reverse pass over the user's `gradlogp` succeeds.

Target (same shape as `test/test-GPU-AD-HVP.jl`):

  logp(β)     = -0.5 * (||β||^2 + ||Xβ||^2 / N)
  gradlogp(β) = -(β + Xᵀ X β / N)
  Hv         = -v - Xᵀ X v / N

Run from repo root:

  julia --project=test dev/di_hvp_gpu_mwe.jl
=#

push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))

using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Mooncake
import ForwardDiff
import CUDA
import LinearAlgebra: dot
using Random

CUDA.functional() || error("requires functional CUDA device")

function _logp_single(β, X)
    Xβ = pmcmc_matmul(X, β)
    N = oftype(zero(eltype(β)), size(X, 1))
    -oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end

function _gradlogp_single(β, X)
    Xβ = pmcmc_matmul(X, β)
    N  = oftype(zero(eltype(β)), size(X, 1))
    Y  = pmcmc_matmul(transpose(X), Xβ)
    Y  = Y ./ N
    Y  = Y .+ β
    return -Y
end

const D    = 20
const Ndat = 64

rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)

β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)

logp(β)     = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)

ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])

println()
println("== reverse-on-grad (Mooncake gradient of dot ∘ gradlogp) ==")
try
    closure = β -> dot(gradlogp(β), v_gpu)
    g     = DI.gradient(closure, AutoMooncake(; config=nothing), β_gpu)
    g_vec = g isa Tuple ? first(g) : g
    Hv    = Array(g_vec)
    println("  Hv[1:3] = ", Hv[1:3])
    println("  matches: ", isapprox(Hv, ref; atol=1f-3, rtol=1f-2))
catch e
    println("  FAILED: ", first(sprint(showerror, e), 600))
end

println()
configs = [
    "AutoMooncake (DI default SecondOrder)"      => AutoMooncake(; config=nothing),
    "SecondOrder(AutoForwardDiff, AutoMooncake)" => SecondOrder(AutoForwardDiff(), AutoMooncake(; config=nothing)),
    "SecondOrder(AutoMooncake, AutoForwardDiff)" => SecondOrder(AutoMooncake(; config=nothing), AutoForwardDiff()),
]

for (lbl, b) in configs
    print("== $lbl ==\n  ")
    try
        prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
        Hv   = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
        vec_ = Hv isa Tuple ? first(Hv) : Hv
        host = Array(vec_)
        ok   = isapprox(host, ref; atol=1f-3, rtol=1f-2)
        println("Hv[1:3]: ", host[1:3], "  matches: ", ok)
    catch e
        println("FAILED: ", first(sprint(showerror, e), 600))
    end
end
#=
Direct `DI.hvp(logp, AutoEnzyme(...))` fails for every config on GPU.
A forward-mode Enzyme pushforward of the user's `gradlogp` succeeds.

Same target as `dev/di_hvp_gpu_mwe.jl` / `test/test-GPU-AD-HVP.jl`:

  logp(β)     = -0.5 * (||β||^2 + ||Xβ||^2 / N)
  gradlogp(β) = -(β + Xᵀ X β / N)
  Hv         = -v - Xᵀ X v / N

Five Enzyme variants tested. Three hard-abort the Julia process during
Enzyme compilation and therefore cannot share a script with the rest:

  AutoEnzyme()                                 — hard abort during compile
  AutoEnzyme(mode=Reverse)                     — hard abort during compile
  SecondOrder(AutoEnzyme(Forward),
              AutoEnzyme(Reverse))             — hard abort during compile

To reproduce the aborts run one variant at a time. The two below throw
catchable exceptions and run cleanly in the same process.

Run from repo root:

  julia --project=test dev/di_hvp_gpu_enzyme_mwe.jl
=#

push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))

using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Enzyme
import CUDA
using Random

CUDA.functional() || error("requires functional CUDA device")

function _logp_single(β, X)
    Xβ = pmcmc_matmul(X, β)
    N = oftype(zero(eltype(β)), size(X, 1))
    -oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end

function _gradlogp_single(β, X)
    Xβ = pmcmc_matmul(X, β)
    N  = oftype(zero(eltype(β)), size(X, 1))
    Y  = pmcmc_matmul(transpose(X), Xβ)
    Y  = Y ./ N
    Y  = Y .+ β
    return -Y
end

const D    = 20
const Ndat = 64

rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)

β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)

logp(β)     = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)

ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])

println()
println("== forward-on-grad (Enzyme.Forward pushforward of gradlogp) ==")
try
    be    = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)
    prep  = DI.prepare_pushforward(gradlogp, be, β_gpu, (v_gpu,); strict=Val(false))
    Hv    = DI.pushforward(gradlogp, prep, be, β_gpu, (v_gpu,))
    vec_  = Hv isa Tuple ? first(Hv) : Hv
    host  = Array(vec_)
    println("  Hv[1:3] = ", host[1:3])
    println("  matches: ", isapprox(host, ref; atol=1f-3, rtol=1f-2))
catch e
    println("  FAILED: ", first(sprint(showerror, e), 600))
end

println()
configs = [
    "AutoEnzyme(mode=Forward)" =>
        AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
    "SecondOrder(AutoEnzyme(Reverse), AutoEnzyme(Forward))" =>
        SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Const),
                    AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)),
]

for (lbl, b) in configs
    print("== $lbl ==\n  ")
    try
        prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
        Hv   = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
        vec_ = Hv isa Tuple ? first(Hv) : Hv
        host = Array(vec_)
        ok   = isapprox(host, ref; atol=1f-3, rtol=1f-2)
        println("Hv[1:3]: ", host[1:3], "  matches: ", ok)
    catch e
        println("FAILED: ", first(sprint(showerror, e), 800))
    end
end

Totally and completely possible I am doing something fundamentally wrong--in which case please correct me

Comment thread src/DEER/DEER.jl
(c::_BatchHvpReverseClosure)(X, V) = pmcmc_dotsum(c.grad_batch(X), V)

#=
Pick the AD-HVP fallback strategy from the user's backend.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole logic exists in DI so it would be great if we added batching there and that alleviated your troubles here

Comment thread src/DEER/DEER.jl
Comment on lines +199 to +200
We bundle the closure with the prep so `prepare_gradient` and `gradient`
see the same function instance (DI keys preparations on function identity).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a constant pain in my behind, you can check out the hvp.jl file in DI and the ForwardDiff extension to see how much I struggle

Comment thread src/DEER/DEER.jl Outdated
Comment thread src/interface.jl
Comment thread src/ParallelMCMC.jl
Comment on lines +11 to +22
Owned wrappers: identical semantics to their Base counterparts, but provide
stable function identities for backend-specific AD rules in `ext/EnzymeExt.jl`
without committing type piracy on `Base.*` / `Base.dot` / `Base.sum`. User
model code that wants those rules to fire (notably on GPU, where the default
rules trip on cuBLAS gc-transition bundles) should call these instead.

Why both a matmul and reductions: the GPU AD-HVP path runs Enzyme reverse mode
through `gradlogp` *and* through a scalar reduction wrapping it (see DEER's
`_HvpReverseClosure`). The reduction is what actually emits the
`cuMemcpyDtoHAsync_v2` that crashes Enzyme. Owning both the matmul AND the
reduction lets Enzyme treat each as opaque, so neither bundle ever enters
its IR.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really surprised by Enzyme reacting badly to LinearAlgebra, I suspect you might be holding it wrong. @wsmoses send help

Comment thread Project.toml
LinearAlgebra = "1"
LogDensityProblems = "2"
MCMCChains = "7.7.0"
Mooncake = "0.5.26"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're interested in second-order Mooncake features through DI, you wanna track JuliaDiff/DifferentiationInterface.jl#990.

See also JuliaDiff/DifferentiationInterface.jl#986

@rsenne rsenne requested a review from gdalle May 18, 2026 13:41
@rsenne
Copy link
Copy Markdown
Owner Author

rsenne commented May 18, 2026

hey @gdalle I addressed the two addressable point (e.g., type instability and nixing symbols) thoughts?

Also, included 2 MWRE above. If I'm not crazy--I can open issues for these on the respective repos though I'm not confident yet till someone who knows better than i says so. Also, happy to help tackle the linked DI issue for batching--it seems reasonably approachable?

Let me know what you think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Remove default backend (Enzyme) and make DI the main UI

2 participants