diff --git a/.gitignore b/.gitignore index 94aa3f1..367f7a7 100644 --- a/.gitignore +++ b/.gitignore @@ -21,4 +21,8 @@ LocalPreferences.toml CLAUDE.md AGENTS.md CODEX.md -.gemini \ No newline at end of file +.gemini + +# random scripts for debugging +/scripts +/dev \ No newline at end of file diff --git a/Project.toml b/Project.toml index 04cc852..dc6021b 100644 --- a/Project.toml +++ b/Project.toml @@ -7,12 +7,15 @@ version = "0.0.2" ADTypes = "1.21.0" AbstractMCMC = "5.10.0" CUDA = "5.11.0" +CUDA_Runtime_jll = "0.21" DifferentiationInterface = "0.7.13" DynamicPPL = "0.40.6, 0.41" -Enzyme = "0.13.131" +Enzyme = "0.13.146" +ForwardDiff = "1" LinearAlgebra = "1" LogDensityProblems = "2" MCMCChains = "7.7.0" +Mooncake = "0.5.26" Random = "1" Statistics = "1" julia = "1.10" @@ -22,14 +25,14 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [extensions] -DynamicPPLExt = ["DynamicPPL", "LogDensityProblems"] +DynamicPPLExt = ["DynamicPPL", "ForwardDiff", "LogDensityProblems"] +EnzymeExt = "Enzyme" LogDensityProblemsExt = "LogDensityProblems" [extras] @@ -37,4 +40,7 @@ CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" [weakdeps] DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" diff --git a/benchmarks/ParallelMCMCBenchmarks/Project.toml b/benchmarks/ParallelMCMCBenchmarks/Project.toml index 0111b50..7d36f10 100644 --- a/benchmarks/ParallelMCMCBenchmarks/Project.toml +++ b/benchmarks/ParallelMCMCBenchmarks/Project.toml @@ -6,6 +6,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ParallelMCMC = "1a970f40-4406-51c9-a967-cb3143c111e8" @@ -14,8 +15,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +[sources] +ParallelMCMC = {path = "../.."} + [extras] CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" - -[sources.ParallelMCMC] -path = "../.." diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl index 4e578a3..831ed63 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl @@ -23,6 +23,8 @@ using Statistics using AbstractMCMC: sample using ParallelMCMC +using ADTypes +using Enzyme using ParallelMCMCBenchmarks const BayesLogReg = ParallelMCMCBenchmarks.BayesLogReg @@ -135,6 +137,7 @@ if _cuda_ok tol_rel=tol_rel, damping=damping, probes=probes, + backend=ADTypes.AutoEnzyme(), ) println(" T=$T") diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl index 5fd7f16..466e8d9 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl @@ -17,6 +17,8 @@ using Statistics using AbstractMCMC: sample using ParallelMCMC +using ADTypes +using Enzyme using ParallelMCMCBenchmarks const BayesLinReg = ParallelMCMCBenchmarks.BayesLinReg const MALARunner = ParallelMCMCBenchmarks.MALARunner @@ -46,7 +48,13 @@ x_warm, ϵ_tuned = MALARunner.tune_stepsize_mala( mala_sampler = AdaptiveMALASampler(ϵ_tuned; n_warmup=500) deer_sampler = ParallelMALASampler( - ϵ_tuned; T=64, maxiter=200, tol_abs=1e-6, tol_rel=1e-5, damping=0.5 + ϵ_tuned; + T=64, + maxiter=200, + tol_abs=1e-6, + tol_rel=1e-5, + damping=0.5, + backend=ADTypes.AutoEnzyme(), ) # Benchmark helper diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl index d811efa..b2d3fda 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl @@ -108,7 +108,10 @@ function write_markdown(path, rows; warn_ratio, fail_ratio) "`.", ) println(io) - println(io, "| Benchmark | Base median | PR median | Ratio | Allocs | Memory delta | Status |") + println( + io, + "| Benchmark | Base median | PR median | Ratio | Allocs | Memory delta | Status |", + ) println(io, "|---|---:|---:|---:|---:|---:|---|") for row in rows println( @@ -167,8 +170,12 @@ function main(args=ARGS) base === nothing && error("--base is required") head === nothing && error("--head is required") - warn_ratio = parse(Float64, _option(args, "--warn-ratio", get(ENV, "PMCMC_BENCH_WARN_RATIO", "1.25"))) - fail_ratio = parse(Float64, _option(args, "--fail-ratio", get(ENV, "PMCMC_BENCH_FAIL_RATIO", "1.75"))) + warn_ratio = parse( + Float64, _option(args, "--warn-ratio", get(ENV, "PMCMC_BENCH_WARN_RATIO", "1.25")) + ) + fail_ratio = parse( + Float64, _option(args, "--fail-ratio", get(ENV, "PMCMC_BENCH_FAIL_RATIO", "1.75")) + ) markdown = _option(args, "--markdown", "") warn_ratio > 1 || error("--warn-ratio must be greater than 1") diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl index 7fbf2b6..f4ba655 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl @@ -18,6 +18,8 @@ using Printf using Statistics using ParallelMCMC +using ADTypes +using Enzyme using ParallelMCMCBenchmarks using CUDA @@ -79,7 +81,7 @@ function build_raw_deer_problem( damping::Float32, probes::Int, cholM=nothing, - backend=DEER.DEFAULT_BACKEND, + backend=ADTypes.AutoEnzyme(), ) FP = typeof(epsilon) D = model.dim diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl index acc2e84..dbc06b8 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl @@ -35,7 +35,9 @@ function main(args=ARGS) return 0 end - seconds = parse(Float64, _option(args, "--seconds", get(ENV, "PMCMC_BENCH_SECONDS", "0.5"))) + seconds = parse( + Float64, _option(args, "--seconds", get(ENV, "PMCMC_BENCH_SECONDS", "0.5")) + ) samples = parse(Int, _option(args, "--samples", get(ENV, "PMCMC_BENCH_SAMPLES", "8"))) output = _option(args, "--output", "") markdown = _option(args, "--markdown", "") diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl index 98fc4fa..b31176e 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl @@ -27,6 +27,8 @@ or: using Random using CUDA using ParallelMCMC +using ADTypes +using Enzyme using ParallelMCMCBenchmarks const BayesLogReg = ParallelMCMCBenchmarks.BayesLogReg @@ -52,7 +54,7 @@ function build_raw_deer_problem( epsilon::Float32, T::Int, cholM=nothing, - backend=DEER.DEFAULT_BACKEND, + backend=ADTypes.AutoEnzyme(), ) FP = typeof(epsilon) D = model.dim diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl index 46336fd..e76e923 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl @@ -9,6 +9,8 @@ using Statistics using AbstractMCMC: sample using ParallelMCMC +using ADTypes +using Enzyme using ParallelMCMCBenchmarks using CUDA @@ -67,7 +69,9 @@ function build_problem() end function build_rec(model_gpu, x0_gpu, tape) - return ParallelMCMC._build_mala_deer_rec(model_gpu, epsilon, tape, x0_gpu;) + return ParallelMCMC._build_mala_deer_rec( + model_gpu, epsilon, tape, x0_gpu; backend=ADTypes.AutoEnzyme() + ) end function solve_prebuilt(rec, x0_gpu, ws; seed=42, return_info=false) @@ -131,6 +135,7 @@ deer_gpu = ParallelMALASampler( tol_rel=tol_rel, damping=damping, probes=probes, + backend=ADTypes.AutoEnzyme(), ) println("=" ^ 96) diff --git a/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_linreg.jl b/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_linreg.jl index 02e3f55..c2e1d67 100644 --- a/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_linreg.jl +++ b/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_linreg.jl @@ -35,9 +35,11 @@ function make_problem(X::AbstractMatrix, y::AbstractVector; σ::Real=1.0, τ::Re return g end - # Analytic Gaussian posterior: - # Σ_post = (XᵀX/σ² + I/τ²)^{-1} - # μ_post = Σ_post * (Xᵀy/σ²) + #= + Analytic Gaussian posterior: + Σ_post = (XᵀX/σ² + I/τ²)^{-1} + μ_post = Σ_post * (Xᵀy/σ²) + =# A = (X' * X) ./ σ2 A += Diagonal(fill(1.0 / τ2, p)) Σ_post = inv(Symmetric(A)) diff --git a/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl b/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl index 0af562a..e47b0f0 100644 --- a/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl +++ b/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl @@ -36,9 +36,11 @@ function make_problem(X::AbstractMatrix, y::AbstractVector) function logp(β::AbstractVector) mul!(logits, X, β) - # ll = sum(@. y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits)))) - # We perform this in-place to avoid allocations - # Note: log1p(exp(x)) is the softplus function + #= + ll = sum(@. y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits)))) + We perform this in-place to avoid allocations + Note: log1p(exp(x)) is the softplus function + =# @. p = y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits))) return sum(p) - 0.5 * sum(abs2, β) end @@ -47,8 +49,10 @@ function make_problem(X::AbstractMatrix, y::AbstractVector) mul!(logits, X, β) @. p = 1 / (1 + exp(-logits)) @. resid = y - p - # grad = X' * (y - p) - β - # We reuse the output allocation here by using mul! + #= + grad = X' * (y - p) - β + We reuse the output allocation here by using mul! + =# grad = (X' * resid) .- β return grad end diff --git a/benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl b/benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl index c281d42..df4622f 100644 --- a/benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl +++ b/benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl @@ -9,6 +9,8 @@ using Statistics using TOML using ParallelMCMC +using ADTypes +using Enzyme const MALA = ParallelMCMC.MALA const DEER = ParallelMCMC.DEER @@ -49,8 +51,9 @@ function _make_stdnormal_rec(rng::AbstractRNG, dim::Int, steps::Int, epsilon::Fl tape = [(noise=randn(rng, dim), u=rand(rng)) for _ in 1:steps] step_fwd = - (x, te) -> - MALA.mala_step_taped(_stdnormal_logp, _stdnormal_grad, x, epsilon, te.noise, te.u) + (x, te) -> MALA.mala_step_taped( + _stdnormal_logp, _stdnormal_grad, x, epsilon, te.noise, te.u + ) jvp = (x, te, v) -> MALA.mala_step_surrogate_sigmoid_jvp( _stdnormal_logp, @@ -88,14 +91,7 @@ function _mala_step_case() epsilon = 0.04 bench = @benchmarkable MALA.mala_step_with_logα!( - $x_next, - $workspace, - $_stdnormal_logp, - $_stdnormal_grad, - $x, - $epsilon, - $noise, - $u, + $x_next, $workspace, $_stdnormal_logp, $_stdnormal_grad, $x, $epsilon, $noise, $u ) evals = 1 return BenchmarkCase( @@ -144,7 +140,7 @@ function _deer_solve_case() jacobian=:diag, damping=0.5, rng=rng, - workspace=$workspace, + workspace=($workspace), copy_result=false, ) setup = (rng = MersenneTwister(42)) evals = 1 @@ -170,16 +166,12 @@ function _parallel_mala_sample_case() tol_rel=1e-5, jacobian=:diag, damping=0.5, + backend=ADTypes.AutoEnzyme(), ) initial_params = zeros(dim) bench = @benchmarkable sample( - rng, - $model, - $sampler, - 64; - initial_params=$initial_params, - progress=false, + rng, $model, $sampler, 64; initial_params=($initial_params), progress=false ) setup = (rng = MersenneTwister(42)) evals = 1 return BenchmarkCase( @@ -201,16 +193,12 @@ function _parallel_mala_sample_no_hvp_case() tol_rel=1e-5, jacobian=:diag, damping=0.5, + backend=ADTypes.AutoEnzyme(), ) initial_params = zeros(dim) bench = @benchmarkable sample( - rng, - $model, - $sampler, - 64; - initial_params=$initial_params, - progress=false, + rng, $model, $sampler, 64; initial_params=($initial_params), progress=false ) setup = (rng = MersenneTwister(42)) evals = 1 return BenchmarkCase( diff --git a/docs/src/10-getting-started.md b/docs/src/10-getting-started.md index 159afa0..d5ca16e 100644 --- a/docs/src/10-getting-started.md +++ b/docs/src/10-getting-started.md @@ -30,7 +30,7 @@ model = DensityModel(logp, grad_logp, 2; param_names=[:x1, :x2]) --- -## ParallelMALASampler — the primary algorithm +## ParallelMALASampler [`ParallelMALASampler`](@ref) reformulates a trajectory of `T` MALA steps as a fixed-point problem and solves it via Newton iterations, each of which costs $O(\log T)$ parallel work via an associative prefix scan. Wall-clock time per sample is therefore sublinear in chain length on multi-core CPUs and GPUs. @@ -84,6 +84,29 @@ chain = sample(model, sampler, 500; chain_type=MCMCChains.Chains) ``` +#### AD-HVP fallback + +DEER needs a Hessian-vector product (HVP) at every Newton step. If you +supply `hvp`/`hvp_batch` analytically, those run as plain kernels and +nothing else is required. If you only supply `grad_logdensity` / +`grad_logdensity_batch`, the sampler computes the HVP via **Mooncake** +reverse-mode applied to `x -> dot(grad_logdensity(x), v)`. + +```julia +using ParallelMCMC, CUDA + +const X = CUDA.CuMatrix(randn(Float32, N, D)) + +# Plain Julia / CUDA operations. +grad_logp(β) = -β .- (transpose(X) * (X * β)) ./ Float32(N) +grad_logp_batch(B) = -B .- (transpose(X) * (X * B)) ./ Float32(N) + +model = DensityModel(logp, grad_logp, D; + logdensity_batch=logp_batch, + grad_logdensity_batch=grad_logp_batch) +sampler = ParallelMALASampler(0.05f0; T=64) +``` + --- ## Turing.jl integration diff --git a/docs/src/assets/make_julia_deer_gif.jl b/docs/src/assets/make_julia_deer_gif.jl index 2f61130..92617d7 100644 --- a/docs/src/assets/make_julia_deer_gif.jl +++ b/docs/src/assets/make_julia_deer_gif.jl @@ -27,12 +27,7 @@ const CENTERS = [ ] const SIGMAS = [0.38, 0.38, 0.38, 0.25] const LOG_WEIGHTS = log.([1.0, 1.0, 1.0, 0.50]) -const LOGO_COLORS = [ - (56, 152, 38), - (203, 60, 51), - (149, 88, 178), - (64, 99, 216), -] +const LOGO_COLORS = [(56, 152, 38), (203, 60, 51), (149, 88, 178), (64, 99, 216)] struct TapeStep xi::Vector{Float64} @@ -120,9 +115,7 @@ function make_recursion(tape, epsilon) return (; step_fwd, jvp, tape) end -function deer_diag_update!( - output, A, B, scan_ws, rec, s0, current; damping=0.55 -) +function deer_diag_update!(output, A, B, scan_ws, rec, s0, current; damping=0.55) dim, steps = size(current) basis = zeros(dim) @@ -346,21 +339,24 @@ function main() epsilon = 0.095 damping = 0.55 x0 = [ - rand(rng) * (X_RANGE[2] - X_RANGE[1]) + X_RANGE[1], - rand(rng) * (Y_RANGE[2] - Y_RANGE[1]) + Y_RANGE[1], + rand(rng) * (X_RANGE[2] - X_RANGE[1]) + X_RANGE[1], + rand(rng) * (Y_RANGE[2] - Y_RANGE[1]) + Y_RANGE[1], ] - tape = make_tape(rng, 2, steps) rec = make_recursion(tape, epsilon) - iterates, metrics = record_iterates(rec, x0; steps=steps, maxiter=maxiter, damping=damping) + iterates, metrics = record_iterates( + rec, x0; steps=steps, maxiter=maxiter, damping=damping + ) selected = unique(round.(Int, range(0, maxiter; length=256))) selected_iterates = [iterates[i + 1] for i in selected] noise = [step.xi for step in tape] uniforms = [step.u for step in tape] - sequential = MALA.run_mala_sequential_taped(logposterior, gradposterior, x0, epsilon, noise, uniforms) + sequential = MALA.run_mala_sequential_taped( + logposterior, gradposterior, x0, epsilon, noise, uniforms + ) final_trajectory = reduce(hcat, sequential[2:end]) output = isempty(ARGS) ? joinpath(@__DIR__, "julia_deer_posterior.gif") : ARGS[1] diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index ae0c6cb..e87c90e 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -4,12 +4,11 @@ using ParallelMCMC using ADTypes: ADTypes using DynamicPPL: DynamicPPL using AbstractMCMC: AbstractMCMC -using Enzyme: Enzyme using MCMCChains: MCMCChains using LogDensityProblems: LogDensityProblems """ - DensityModel(turing_model::DynamicPPL.Model; ad_backend=ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Duplicated), hvp=nothing) + DensityModel(turing_model::DynamicPPL.Model; ad_backend=ADTypes.AutoForwardDiff(), hvp=nothing) Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a `DensityModel`, automatically extracting parameter names and wiring up gradient @@ -26,18 +25,15 @@ using Turing, ParallelMCMC, MCMCChains y ~ Normal(μ, 0.5) end +# AutoForwardDiff is the default. For larger models pass an explicit backend +# and `using` the corresponding package (Enzyme, Mooncake). model = DensityModel(mymodel(1.5)) chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; chain_type=MCMCChains.Chains, discard_warmup=true, progress=true) ``` """ function ParallelMCMC.DensityModel( - turing_model::DynamicPPL.Model; - ad_backend=ADTypes.AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Duplicated, - ), - hvp=nothing, + turing_model::DynamicPPL.Model; ad_backend=ADTypes.AutoForwardDiff(), hvp=nothing ) # Sample in linked/unconstrained space and let DynamicPPL provide the gradient. ld = DynamicPPL.LogDensityFunction( diff --git a/ext/EnzymeExt.jl b/ext/EnzymeExt.jl new file mode 100644 index 0000000..061f7a4 --- /dev/null +++ b/ext/EnzymeExt.jl @@ -0,0 +1,288 @@ +module EnzymeExt + +using ParallelMCMC: pmcmc_matmul, pmcmc_dot, pmcmc_dotsum +using ParallelMCMC.DEER: DEER +using ADTypes: ADTypes +using LinearAlgebra: dot +using Enzyme +using Enzyme.EnzymeCore +using Enzyme.EnzymeCore.EnzymeRules: + EnzymeRules, + FwdConfig, + RevConfig, + Annotation, + AugmentedReturn, + needs_primal, + needs_shadow, + overwritten, + width + +#= +Tell DEER's forward-on-grad HVP path how to normalize a plain `AutoEnzyme()`: +pin `mode=Enzyme.Forward` and `function_annotation=Enzyme.Const`. Pinning +Forward is load-bearing on GPU — without it, DI defaults to reverse mode +and Enzyme aborts on the cuBLAS / `cuPointerGetAttribute` gc-transition +bundle inside the user's GPU `gradlogp`. If the user already specified +either field, respect their choice. +=# +function DEER._hvp_forward_backend(backend::ADTypes.AutoEnzyme{M,A}) where {M,A} + A === Nothing || return backend # user already specified function_annotation + #= + `set_runtime_activity` is load-bearing for composed `pmcmc_matmul` calls + (e.g. `pmcmc_matmul(transpose(X), pmcmc_matmul(X, β))`). Static activity + analysis can't prove the outer call's `transpose(X)` shadow is safe to + reuse, and Enzyme aborts with `EnzymeRuntimeActivityError`. With runtime + activity, the shadow is tracked dynamically. + =# + mode = backend.mode === nothing ? Enzyme.set_runtime_activity(Enzyme.Forward) : + backend.mode + return ADTypes.AutoEnzyme(; mode=mode, function_annotation=Enzyme.Const) +end + +#= +Tell DEER's reverse-on-grad HVP path how to normalize a plain `AutoEnzyme()`: +fill in `function_annotation=Enzyme.Const` so Enzyme doesn't throw +`EnzymeMutabilityException` on the read-only `_HvpReverseClosure` / +`_BatchHvpReverseClosure` wrappers. If the user already specified +`function_annotation`, respect their choice. +=# +function DEER._hvp_closure_backend(backend::ADTypes.AutoEnzyme{M,A}) where {M,A} + A === Nothing || return backend # user already specified function_annotation + mode = backend.mode === nothing ? Enzyme.set_runtime_activity(Enzyme.Reverse) : + backend.mode + return ADTypes.AutoEnzyme(; mode=mode, function_annotation=Enzyme.Const) +end + +#= +Native Enzyme rules for the owned wrappers `pmcmc_matmul`, `pmcmc_dot`, +`pmcmc_dotsum`. These exist so that on GPU paths Enzyme treats each as +opaque: primal and cotangents are computed by plain `*`, `'`, `dot`, `sum`, +and broadcast — *outside* Enzyme's IR rewriter. That sidesteps the +"unsupported tag gc-transition" abort Enzyme hits when it tries to lower +`cuMemcpyDtoHAsync_v2` (emitted by every CuArray scalar reduction) or the +cuBLAS bundles inside `*(::CuArray, ::CuArray)`. + +For each function we define + - `EnzymeRules.forward` (forward-mode JVP) + - `EnzymeRules.augmented_primal` (forward sweep of reverse mode) + - `EnzymeRules.reverse` (reverse sweep) + +Width=1 only — we don't need batch-mode AD here, and DI uses width=1. +=# + +#= +pmcmc_matmul(A, B) = A * B — matrix output (Duplicated) + +JVP: dY = dA * B + A * dB +Pullback: dA += dY * B' + dB += A' * dY +=# + +function EnzymeRules.forward( + config::FwdConfig, + ::Const{typeof(pmcmc_matmul)}, + RT::Type{<:Annotation}, + A::Annotation, + B::Annotation, +) + primal = needs_primal(config) ? pmcmc_matmul(A.val, B.val) : nothing + shadow = if A isa Const && B isa Const + #= + Both args Const → output tangent is structurally zero. We still need + to return an output-shaped array (Enzyme's shadow-type check rejects + `nothing` when the caller asked for `Duplicated`/`DuplicatedNoNeed`). + =# + zero(primal === nothing ? pmcmc_matmul(A.val, B.val) : primal) + elseif A isa Const + pmcmc_matmul(A.val, B.dval) + elseif B isa Const + pmcmc_matmul(A.dval, B.val) + else + pmcmc_matmul(A.dval, B.val) .+ pmcmc_matmul(A.val, B.dval) + end + if RT <: Const + return needs_primal(config) ? primal : nothing + elseif RT <: DuplicatedNoNeed + return shadow + elseif RT <: Duplicated + return Duplicated(primal === nothing ? pmcmc_matmul(A.val, B.val) : primal, shadow) + else + return needs_primal(config) ? primal : shadow + end +end + +function EnzymeRules.augmented_primal( + config::RevConfig, + ::Const{typeof(pmcmc_matmul)}, + RT::Type{<:Annotation}, + A::Annotation, + B::Annotation, +) + Y = pmcmc_matmul(A.val, B.val) + cache_A = overwritten(config)[2] ? copy(A.val) : A.val + cache_B = overwritten(config)[3] ? copy(B.val) : B.val + dY = if RT <: Duplicated || RT <: DuplicatedNoNeed + zero(Y) + else + nothing + end + primal = needs_primal(config) ? Y : nothing + return AugmentedReturn(primal, dY, (dY, cache_A, cache_B)) +end + +function EnzymeRules.reverse( + config::RevConfig, + ::Const{typeof(pmcmc_matmul)}, + ::Type{<:Annotation}, + tape, + A::Annotation, + B::Annotation, +) + dY, cache_A, cache_B = tape + if dY !== nothing + if !(A isa Const) + A.dval .+= pmcmc_matmul(dY, transpose(cache_B)) + end + if !(B isa Const) + B.dval .+= pmcmc_matmul(transpose(cache_A), dY) + end + fill!(dY, zero(eltype(dY))) + end + return (nothing, nothing) +end + +#= +pmcmc_dot(a, b) = dot(a, b) — scalar output (Active) + +JVP: dr = dot(da, b) + dot(a, db) +Pullback: da += dr * b + db += dr * a +=# + +function EnzymeRules.forward( + config::FwdConfig, + ::Const{typeof(pmcmc_dot)}, + RT::Type{<:Annotation}, + a::Annotation, + b::Annotation, +) + primal = needs_primal(config) ? pmcmc_dot(a.val, b.val) : nothing + if RT <: Const + return needs_primal(config) ? primal : nothing + end + da_term = a isa Const ? zero(eltype(a.val)) : pmcmc_dot(a.dval, b.val) + db_term = b isa Const ? zero(eltype(b.val)) : pmcmc_dot(a.val, b.dval) + tangent = da_term + db_term + if RT <: DuplicatedNoNeed + return tangent + elseif RT <: Duplicated + return Duplicated(primal === nothing ? pmcmc_dot(a.val, b.val) : primal, tangent) + else + return needs_primal(config) ? primal : tangent + end +end + +function EnzymeRules.augmented_primal( + config::RevConfig, + ::Const{typeof(pmcmc_dot)}, + ::Type{<:Annotation}, + a::Annotation, + b::Annotation, +) + primal = needs_primal(config) ? pmcmc_dot(a.val, b.val) : nothing + cache_a = (!(b isa Const) && overwritten(config)[2]) ? copy(a.val) : nothing + cache_b = (!(a isa Const) && overwritten(config)[3]) ? copy(b.val) : nothing + return AugmentedReturn(primal, nothing, (cache_a, cache_b)) +end + +function EnzymeRules.reverse( + config::RevConfig, + ::Const{typeof(pmcmc_dot)}, + dret, + tape, + a::Annotation, + b::Annotation, +) + cache_a, cache_b = tape + if !(dret isa Const) + dr = dret.val + if !(a isa Const) + bv = cache_b !== nothing ? cache_b : b.val + a.dval .+= dr .* bv + end + if !(b isa Const) + av = cache_a !== nothing ? cache_a : a.val + b.dval .+= dr .* av + end + end + return (nothing, nothing) +end + +#= +pmcmc_dotsum(A, B) = sum(A .* B) — scalar output (Active) + +Same algebra as `pmcmc_dot`, just with matrix args. +=# + +function EnzymeRules.forward( + config::FwdConfig, + ::Const{typeof(pmcmc_dotsum)}, + RT::Type{<:Annotation}, + A::Annotation, + B::Annotation, +) + primal = needs_primal(config) ? pmcmc_dotsum(A.val, B.val) : nothing + if RT <: Const + return needs_primal(config) ? primal : nothing + end + dA_term = A isa Const ? zero(eltype(A.val)) : pmcmc_dotsum(A.dval, B.val) + dB_term = B isa Const ? zero(eltype(B.val)) : pmcmc_dotsum(A.val, B.dval) + tangent = dA_term + dB_term + if RT <: DuplicatedNoNeed + return tangent + elseif RT <: Duplicated + return Duplicated( + primal === nothing ? pmcmc_dotsum(A.val, B.val) : primal, tangent + ) + else + return needs_primal(config) ? primal : tangent + end +end + +function EnzymeRules.augmented_primal( + config::RevConfig, + ::Const{typeof(pmcmc_dotsum)}, + ::Type{<:Annotation}, + A::Annotation, + B::Annotation, +) + primal = needs_primal(config) ? pmcmc_dotsum(A.val, B.val) : nothing + cache_A = (!(B isa Const) && overwritten(config)[2]) ? copy(A.val) : nothing + cache_B = (!(A isa Const) && overwritten(config)[3]) ? copy(B.val) : nothing + return AugmentedReturn(primal, nothing, (cache_A, cache_B)) +end + +function EnzymeRules.reverse( + config::RevConfig, + ::Const{typeof(pmcmc_dotsum)}, + dret, + tape, + A::Annotation, + B::Annotation, +) + cache_A, cache_B = tape + if !(dret isa Const) + dr = dret.val + if !(A isa Const) + Bv = cache_B !== nothing ? cache_B : B.val + A.dval .+= dr .* Bv + end + if !(B isa Const) + Av = cache_A !== nothing ? cache_A : A.val + B.dval .+= dr .* Av + end + end + return (nothing, nothing) +end + +end # module diff --git a/src/DEER/DEER.jl b/src/DEER/DEER.jl index 1dcb15b..01471e4 100644 --- a/src/DEER/DEER.jl +++ b/src/DEER/DEER.jl @@ -3,7 +3,6 @@ module DEER using LinearAlgebra using DifferentiationInterface using ADTypes: ADTypes, AbstractADType -import Enzyme: Enzyme using Random using CUDA: CUDA @@ -11,29 +10,9 @@ include("DEERScan.jl") using .DEERScan const DI = DifferentiationInterface -const DEFAULT_BACKEND = ADTypes.AutoEnzyme(; - mode=Enzyme.Forward, function_annotation=Enzyme.Duplicated -) -const DEFAULT_HVP_BACKEND = DI.SecondOrder( - DEFAULT_BACKEND, - # The log-density function itself is non-active in the reverse pass. - # Marking it duplicated makes DynamicPPL closure fields look active to Enzyme. - ADTypes.AutoEnzyme(; - mode=Enzyme.set_runtime_activity(Enzyme.Reverse), - function_annotation=Enzyme.Const, - ), -) export TapedRecursion, - DEERWorkspace, - deer_update!, - deer_update, - solve, - _prepare_hvp, - _hvp_prepared, - _hvp_nopre, - DEFAULT_BACKEND, - DEFAULT_HVP_BACKEND + DEERWorkspace, deer_update!, deer_update, solve, _prepare_hvp, _hvp_prepared, _hvp_nopre """ Deterministic recursion driven by a pre-generated tape. @@ -91,7 +70,11 @@ function DEERWorkspace(S_template::AbstractMatrix, s0_template::AbstractVector) S_tmp = similar(S_template) diff_buf = similar(S_template) zbuf = similar(s0_template) - zhost = zbuf isa CUDA.CuArray ? Vector{eltype(s0_template)}(undef, length(s0_template)) : nothing + zhost = if zbuf isa CUDA.CuArray + Vector{eltype(s0_template)}(undef, length(s0_template)) + else + nothing + end jt_buf = similar(s0_template) xbar_buf = similar(s0_template) scan = DEERScan.AffineScanWorkspace(S_template) @@ -109,7 +92,9 @@ end function _prepare_hvp(f, backend::AbstractADType, x_template::AbstractVector) v_template = similar(x_template) fill!(v_template, zero(eltype(x_template))) - return DI.prepare_pushforward(f, backend, x_template, (v_template,); strict=Val(false)) + return DI.prepare_pushforward( + f, _hvp_forward_backend(backend), x_template, (v_template,); strict=Val(false) + ) end function _materialize_ad_array(x::AbstractArray) @@ -131,69 +116,199 @@ function _hvp_prepared( ) x_exec = _materialize_ad_vector(x) v_exec = _tangent_like(x_exec, v) - res = DI.pushforward(f, prep, backend, x_exec, (v_exec,)) + res = DI.pushforward(f, prep, _hvp_forward_backend(backend), x_exec, (v_exec,)) return res isa Tuple ? first(res) : res end function _hvp_nopre(f, backend::AbstractADType, x::AbstractVector, v::AbstractVector) x_exec = _materialize_ad_vector(x) v_exec = _tangent_like(x_exec, v) - prep = DI.prepare_pushforward(f, backend, x_exec, (v_exec,); strict=Val(false)) - res = DI.pushforward(f, prep, backend, x_exec, (v_exec,)) + eff_backend = _hvp_forward_backend(backend) + prep = DI.prepare_pushforward(f, eff_backend, x_exec, (v_exec,); strict=Val(false)) + res = DI.pushforward(f, prep, eff_backend, x_exec, (v_exec,)) return res isa Tuple ? first(res) : res end -_hvp_backend(backend::AbstractADType) = backend - -_hvp_second_order_backend(backend::AbstractADType) = DI.SecondOrder(backend, backend) -_hvp_second_order_backend(backend::DI.SecondOrder) = backend +function _prepare_batch_hvp_from_grad( + grad_batch, backend::AbstractADType, X_template::AbstractMatrix +) + V_template = similar(X_template) + fill!(V_template, zero(eltype(X_template))) + return DI.prepare_pushforward( + grad_batch, _hvp_forward_backend(backend), X_template, (V_template,); + strict=Val(false), + ) +end -_hvp_backend(::ADTypes.AutoEnzyme) = DEFAULT_BACKEND +function _batch_hvp_from_grad_prepared( + grad_batch, prep, backend::AbstractADType, X::AbstractMatrix, V::AbstractMatrix +) + X_exec = _materialize_ad_matrix(X) + V_exec = _tangent_like(X_exec, V) + res = DI.pushforward( + grad_batch, prep, _hvp_forward_backend(backend), X_exec, (V_exec,) + ) + return res isa Tuple ? first(res) : res +end -_hvp_second_order_backend(::ADTypes.AutoEnzyme) = DEFAULT_HVP_BACKEND +#= +--------------------------------------------------------------------------- +Reverse-on-grad HVP. Computes Hv as the gradient of +`x -> pmcmc_dot(gradlogp(x), v)`, with `v` carried as a DI Constant context +so a single `prepare_gradient` covers all subsequent Newton steps. The +batched variant uses `B -> pmcmc_dotsum(grad_batch(B), V)` and exploits the +column-independence of the user's batched gradient — its gradient w.r.t. B +is the columnwise HVP. + +The reductions go through `pmcmc_dot`/`pmcmc_dotsum` (rather than `dot`/`sum`) +so that on GPU the EnzymeExt reverse rules intercept them. Without that, +Enzyme reverse-mode tries to invert `cuMemcpyDtoHAsync_v2` (the device→host +copy emitted by every CuArray scalar reduction) and crashes on its +gc-transition bundle. + +We bundle the closure with the prep so `prepare_gradient` and `gradient` +see the same function instance (DI keys preparations on function identity). +--------------------------------------------------------------------------- +=# +import ..ParallelMCMC: pmcmc_dot, pmcmc_dotsum + +struct _HvpReverseClosure{F} + grad::F +end +(c::_HvpReverseClosure)(x, v) = pmcmc_dot(c.grad(x), v) -function _prepare_logdensity_hvp(f, backend::AbstractADType, x_template::AbstractVector) - v_template = similar(x_template) - fill!(v_template, zero(eltype(x_template))) - return DI.prepare_hvp(f, backend, x_template, (v_template,); strict=Val(false)) +struct _BatchHvpReverseClosure{F} + grad_batch::F +end +(c::_BatchHvpReverseClosure)(X, V) = pmcmc_dotsum(c.grad_batch(X), V) + +#= +Pick the AD-HVP fallback strategy from the user's backend. + + ForwardOnGrad() — `pushforward(gradlogp, x, v)`. Routes through the + `pmcmc_matmul` frule, works on GPU. Requires the + backend to support forward mode. + ReverseOnGrad() — `gradient(x -> pmcmc_dot(gradlogp(x), v))`. Routes + through both the matmul and dot/sum rrules. Works on + CPU with reverse-only backends (Mooncake, Zygote); + not GPU-safe because Enzyme reverse trips on CUDA.jl + internals beyond what we wrap. + +These are singleton types rather than symbols so the choice dispatches +statically — `_make_hvp_fn(_hvp_strategy(backend), ...)` resolves to one +concrete method (and one concrete return type) at compile time, without +relying on constant propagation through `===`. + +Default is `ForwardOnGrad()` since most DI backends support forward mode. +We override to `ReverseOnGrad()` for reverse-only backends. AutoEnzyme stays +on forward because Enzyme.Forward is robust on CuArrays once the matmul is +wrapped. +=# +abstract type HVPStrategy end +struct ForwardOnGrad <: HVPStrategy end +struct ReverseOnGrad <: HVPStrategy end + +_hvp_strategy(::AbstractADType) = ForwardOnGrad() +_hvp_strategy(::ADTypes.AutoMooncake) = ReverseOnGrad() +if isdefined(ADTypes, :AutoZygote) + _hvp_strategy(::ADTypes.AutoZygote) = ReverseOnGrad() +end +if isdefined(ADTypes, :AutoReverseDiff) + _hvp_strategy(::ADTypes.AutoReverseDiff) = ReverseOnGrad() +end +if isdefined(ADTypes, :AutoTracker) + _hvp_strategy(::ADTypes.AutoTracker) = ReverseOnGrad() end -function _logdensity_hvp_prepared( - f, prep, backend::AbstractADType, x::AbstractVector, v::AbstractVector +#= +Hooks for backend-specific normalization of the user's `backend`. + +`_hvp_forward_backend` is for the forward-on-grad pushforward path +(differentiates the user's `gradlogp` directly). EnzymeExt specializes it +to pin `mode=Enzyme.Forward` and `function_annotation=Enzyme.Const` when +the user passed plain `AutoEnzyme()` — without pinning Forward, DI lowers +through reverse mode and Enzyme aborts on the cuBLAS / cuPointerGetAttribute +gc-transition bundle on GPU. + +`_hvp_closure_backend` is for the reverse-on-grad gradient path on the +read-only `_HvpReverseClosure` / `_BatchHvpReverseClosure` wrappers. +EnzymeExt specializes it to set `function_annotation=Enzyme.Const` so +Enzyme doesn't throw `EnzymeMutabilityException` on a closure that captures +`gradlogp`. + +Default for both is identity. +=# +_hvp_forward_backend(backend::AbstractADType) = backend +_hvp_closure_backend(backend::AbstractADType) = backend + +function _prepare_hvp_via_grad_reverse( + gradlogp, backend::AbstractADType, x_template::AbstractVector ) - x_exec = _materialize_ad_vector(x) - v_exec = _tangent_like(x_exec, v) - res = DI.hvp(f, prep, backend, x_exec, (v_exec,)) - return res isa Tuple ? first(res) : res + v_template = similar(x_template) + fill!(v_template, zero(eltype(x_template))) + f = _HvpReverseClosure(gradlogp) + eff_backend = _hvp_closure_backend(backend) + prep = DI.prepare_gradient(f, eff_backend, x_template, DI.Constant(v_template)) + return (f, prep, eff_backend) end -function _logdensity_hvp_nopre( - f, backend::AbstractADType, x::AbstractVector, v::AbstractVector +function _hvp_via_grad_reverse_prepared( + prep_pair, x::AbstractVector, v::AbstractVector ) - x_exec = _materialize_ad_vector(x) - v_exec = _tangent_like(x_exec, v) - prep = DI.prepare_hvp(f, backend, x_exec, (v_exec,); strict=Val(false)) - res = DI.hvp(f, prep, backend, x_exec, (v_exec,)) - return res isa Tuple ? first(res) : res + f, prep, eff_backend = prep_pair + return DI.gradient(f, prep, eff_backend, x, DI.Constant(v)) end -function _prepare_batch_hvp_from_grad( +function _prepare_batch_hvp_via_grad_reverse( grad_batch, backend::AbstractADType, X_template::AbstractMatrix ) V_template = similar(X_template) fill!(V_template, zero(eltype(X_template))) - return DI.prepare_pushforward( - grad_batch, backend, X_template, (V_template,); strict=Val(false) - ) + f = _BatchHvpReverseClosure(grad_batch) + eff_backend = _hvp_closure_backend(backend) + prep = DI.prepare_gradient(f, eff_backend, X_template, DI.Constant(V_template)) + return (f, prep, eff_backend) end -function _batch_hvp_from_grad_prepared( - grad_batch, prep, backend::AbstractADType, X::AbstractMatrix, V::AbstractMatrix +function _batch_hvp_via_grad_reverse_prepared( + prep_pair, X::AbstractMatrix, V::AbstractMatrix ) - X_exec = _materialize_ad_matrix(X) - V_exec = _tangent_like(X_exec, V) - res = DI.pushforward(grad_batch, prep, backend, X_exec, (V_exec,)) - return res isa Tuple ? first(res) : res + f, prep, eff_backend = prep_pair + return DI.gradient(f, prep, eff_backend, X, DI.Constant(V)) +end + +#= +Strategy-dispatched factories. Each method returns a closure with one +concrete type, so the call site `_make_hvp_fn(_hvp_strategy(backend), ...)` +is type-stable: dispatch on the singleton `HVPStrategy` picks the method +at compile time, and the returned closure type is statically known. +=# +function _make_hvp_fn( + ::ForwardOnGrad, gradlogp, backend::AbstractADType, x_template::AbstractVector +) + prep = _prepare_hvp(gradlogp, backend, x_template) + return (pt, dir) -> _hvp_prepared(gradlogp, prep, backend, pt, dir) +end + +function _make_hvp_fn( + ::ReverseOnGrad, gradlogp, backend::AbstractADType, x_template::AbstractVector +) + prep = _prepare_hvp_via_grad_reverse(gradlogp, backend, x_template) + return (pt, dir) -> _hvp_via_grad_reverse_prepared(prep, pt, dir) +end + +function _make_hvp_batch_fn( + ::ForwardOnGrad, grad_batch, backend::AbstractADType, X_template::AbstractMatrix +) + prep = _prepare_batch_hvp_from_grad(grad_batch, backend, X_template) + return (X, V) -> _batch_hvp_from_grad_prepared(grad_batch, prep, backend, X, V) +end + +function _make_hvp_batch_fn( + ::ReverseOnGrad, grad_batch, backend::AbstractADType, X_template::AbstractMatrix +) + prep = _prepare_batch_hvp_via_grad_reverse(grad_batch, backend, X_template) + return (X, V) -> _batch_hvp_via_grad_reverse_prepared(prep, X, V) end @inline function _rademacher!(z::AbstractArray{T}, rng::AbstractRNG) where {T} @@ -221,8 +336,9 @@ end @inline _rademacher!(z::AbstractArray, rng::AbstractRNG, ::Nothing) = _rademacher!(z, rng) @inline _rademacher_matrix!(Z::AbstractMatrix, rng::AbstractRNG) = _rademacher!(Z, rng) -@inline _rademacher_matrix!(Z::AbstractMatrix, rng::AbstractRNG, host) = - _rademacher!(Z, rng, host) +@inline _rademacher_matrix!(Z::AbstractMatrix, rng::AbstractRNG, host) = _rademacher!( + Z, rng, host +) function jac_diag_via_jvps(rec::TapedRecursion, x::AbstractVector, t::Int) D = length(x) diff --git a/src/DEER/DEERScan.jl b/src/DEER/DEERScan.jl index 601af90..abdfad4 100644 --- a/src/DEER/DEERScan.jl +++ b/src/DEER/DEERScan.jl @@ -140,8 +140,10 @@ function solve_affine_scan_diag!( last_level = (offset << 1) >= T @views begin if !last_level - # The destination already has the older prefix up to offset ÷ 2; - # only the newly exposed unchanged segment needs refreshing. + #= + The destination already has the older prefix up to offset ÷ 2; + only the newly exposed unchanged segment needs refreshing. + =# copy_start = offset == 1 ? 1 : (offset >> 1) + 1 alpha_new[:, copy_start:offset] .= alpha[:, copy_start:offset] beta_new[:, copy_start:offset] .= beta[:, copy_start:offset] diff --git a/src/MALA/MALA.jl b/src/MALA/MALA.jl index abf1016..465434f 100644 --- a/src/MALA/MALA.jl +++ b/src/MALA/MALA.jl @@ -37,10 +37,10 @@ end _apply_L!(out, ξ, ::Nothing) = (out .= ξ) _apply_L!(out, ξ, cholM::Cholesky) = mul!(out, cholM.L, ξ) -_quad_Minv!(tmp, r, ::Nothing) = dot(r, r) +_quad_Minv!(tmp, r, ::Nothing) = sum(abs2, r) function _quad_Minv!(tmp, r, cholM::Cholesky) ldiv!(tmp, cholM.L, r) - return dot(tmp, tmp) + return sum(abs2, tmp) end _logdet_M(::Nothing) = false # Bool promotes to any numeric type without widening @@ -715,7 +715,9 @@ function mala_step_surrogate_sigmoid_jvp( Minv_r = ws.solve_buf end - dlogα = dot(ws.g_y, ws.w) - dot(ws.g_x, v) - inv(2 * ε) * dot(Minv_r, ws.dr) + inv_2ε = inv(2 * ε) + @. ws.Hv_y = ws.g_y * ws.w - ws.g_x * v - inv_2ε * Minv_r * ws.dr + dlogα = sum(ws.Hv_y) dg = g * (one(g) - g) * dlogα @. ws.jvp_out = a * ws.w + (ws.y - x) * dg + (one(a) - a) * v @@ -785,7 +787,9 @@ function mala_step_taped_and_jvp!( Minv_r = ws.solve_buf end - dlogα = dot(ws.g_y, ws.w) - dot(ws.g_x, v) - inv(2 * ε) * dot(Minv_r, ws.dr) + inv_2ε = inv(2 * ε) + @. ws.Hv_y = ws.g_y * ws.w - ws.g_x * v - inv_2ε * Minv_r * ws.dr + dlogα = sum(ws.Hv_y) dg = g * (one(g) - g) * dlogα @. jvp_out = a * ws.w + (ws.y - x) * dg + (one(a) - a) * v diff --git a/src/ParallelMCMC.jl b/src/ParallelMCMC.jl index 5c68f28..f6ffd0c 100644 --- a/src/ParallelMCMC.jl +++ b/src/ParallelMCMC.jl @@ -2,12 +2,29 @@ module ParallelMCMC using AbstractMCMC using CUDA -using Enzyme using MCMCChains using LinearAlgebra using Random using Statistics +#= +Owned wrappers: identical semantics to their Base counterparts, but provide +stable function identities for backend-specific AD rules in `ext/EnzymeExt.jl` +without committing type piracy on `Base.*` / `Base.dot` / `Base.sum`. User +model code that wants those rules to fire (notably on GPU, where the default +rules trip on cuBLAS gc-transition bundles) should call these instead. + +Why both a matmul and reductions: the GPU AD-HVP path runs Enzyme reverse mode +through `gradlogp` *and* through a scalar reduction wrapping it (see DEER's +`_HvpReverseClosure`). The reduction is what actually emits the +`cuMemcpyDtoHAsync_v2` that crashes Enzyme. Owning both the matmul AND the +reduction lets Enzyme treat each as opaque, so neither bundle ever enters +its IR. +=# +pmcmc_matmul(A::AbstractVecOrMat, B::AbstractVecOrMat) = A * B +pmcmc_dot(a::AbstractVector, b::AbstractVector) = dot(a, b) +pmcmc_dotsum(A::AbstractVecOrMat, B::AbstractVecOrMat) = sum(A .* B) + include("MALA/MALA.jl") include("DEER/DEERScan.jl") include("DEER/DEER.jl") @@ -18,5 +35,6 @@ export MALASampler, MALATransition, MALAState export AdaptiveMALASampler, AdaptiveMALATransition, AdaptiveMALAState export ParallelMALASampler, ParallelMALATransition, ParallelMALAState export MALA, DEER +export pmcmc_matmul, pmcmc_dot, pmcmc_dotsum end diff --git a/src/interface.jl b/src/interface.jl index 4ad6fd4..0c7df70 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -273,7 +273,7 @@ function ParallelMALASampler( damping::Real=0.5, probes::Int=1, cholM=nothing, - backend=DEER.DEFAULT_BACKEND, + backend, ) epsilon > 0 || throw(ArgumentError("epsilon must be > 0, got $epsilon")) (jacobian === :stoch_diag || jacobian === :diag) || @@ -368,21 +368,32 @@ function _build_mala_deer_rec( tape::Vector{<:MALATapeElement}, x0_like::AbstractVector; cholM=nothing, - backend=DEER.DEFAULT_BACKEND, + backend, tape_noise=nothing, tape_uniforms=nothing, ) logp = model.logdensity gradlogp = model.grad_logdensity - # Use a model-provided HVP when available. Otherwise differentiate the - # scalar logdensity directly + #= + Use a model-provided HVP when available. Otherwise compute Hv via AD, + picking the path from the user's `backend` (see `DEER._hvp_strategy`): + + ForwardOnGrad() — `pushforward(gradlogp, x, v)`. Used for forward- + capable backends (AutoEnzyme, AutoForwardDiff). + Routes through the `pmcmc_matmul` frule and is the + only AD-HVP mode that's reliable on GPU. + ReverseOnGrad() — `gradient(x -> pmcmc_dot(gradlogp(x), v))`. Used + for reverse-only backends (AutoMooncake, AutoZygote + et al.). + + Either path can be bypassed by providing an analytical `hvp` / + `hvp_batch` on the `DensityModel` (the recommended GPU production path). + =# hvp_fn = if model.hvp !== nothing model.hvp else - hvp_backend = DEER._hvp_second_order_backend(backend) - prep_hvp = DEER._prepare_logdensity_hvp(logp, hvp_backend, x0_like) - (pt, dir) -> DEER._logdensity_hvp_prepared(logp, prep_hvp, hvp_backend, pt, dir) + DEER._make_hvp_fn(DEER._hvp_strategy(backend), gradlogp, backend, x0_like) end # Exact forward step. @@ -436,16 +447,11 @@ function _build_mala_deer_rec( hvp_batch = if model.hvp_batch !== nothing model.hvp_batch else - hvp_batch_backend = DEER._hvp_backend(backend) - prep_hvp_batch = DEER._prepare_batch_hvp_from_grad( - model.grad_logdensity_batch, hvp_batch_backend, X_template - ) - (X, V) -> DEER._batch_hvp_from_grad_prepared( + DEER._make_hvp_batch_fn( + DEER._hvp_strategy(backend), model.grad_logdensity_batch, - prep_hvp_batch, - hvp_batch_backend, - X, - V, + backend, + X_template, ) end diff --git a/test/test-Adaptive-MALA.jl b/test/test-Adaptive-MALA.jl index fa6a987..b10bb3e 100644 --- a/test/test-Adaptive-MALA.jl +++ b/test/test-Adaptive-MALA.jl @@ -156,8 +156,10 @@ end _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) end - # After warmup, ε̄ should yield acceptance near target. - # Run 500 post-warmup steps and measure actual rate. + #= + After warmup, ε̄ should yield acceptance near target. + Run 500 post-warmup steps and measure actual rate. + =# n_accept = 0 for _ in 1:500 t, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler, state) diff --git a/test/test-Code-Quality.jl b/test/test-Code-Quality.jl index 4504761..20d326b 100644 --- a/test/test-Code-Quality.jl +++ b/test/test-Code-Quality.jl @@ -1,11 +1,11 @@ -# using Aqua, JET +using Test +using Aqua, JET +using ParallelMCMC -# @testset "Code Quality" begin -# @testset "Aqua Tests" begin -# Aqua.test_all(ParallelMCMC) -# end +@testset "Aqua" begin + Aqua.test_all(ParallelMCMC) +end -# @testset "Code Linting (JET)" begin -# JET.test_package(ParallelMCMC; target_modules=(ParallelMCMC,)) -# end -# end +@testset "JET" begin + JET.test_package(ParallelMCMC; target_modules=(ParallelMCMC,)) +end diff --git a/test/test-DEER-Interface.jl b/test/test-DEER-Interface.jl index 5cfee5d..64bfe52 100644 --- a/test/test-DEER-Interface.jl +++ b/test/test-DEER-Interface.jl @@ -5,6 +5,10 @@ using Statistics using MCMCChains using ParallelMCMC +using ADTypes +using Enzyme + +const _AD = ADTypes.AutoEnzyme() logp_deer(x) = -0.5 * dot(x, x) gradlogp_deer(x) = -x @@ -16,7 +20,7 @@ gradlogp_quartic_deer(x) = @. -x - 0.4 * x^3 hvp_quartic_deer(x, v) = @. (-1 - 1.2 * x^2) * v @testset "ParallelMALASampler construction" begin - s = ParallelMALASampler(0.05) + s = ParallelMALASampler(0.05; backend=_AD) @test s isa ParallelMCMC.AbstractMCMC.AbstractSampler @test s.epsilon == 0.05 @test s.T == 64 @@ -24,17 +28,17 @@ hvp_quartic_deer(x, v) = @. (-1 - 1.2 * x^2) * v @test s.jacobian === :stoch_diag # keyword overrides - s2 = ParallelMALASampler(0.1; T=32, jacobian=:stoch_diag, damping=0.8) + s2 = ParallelMALASampler(0.1; T=32, jacobian=:stoch_diag, damping=0.8, backend=_AD) @test s2.T == 32 @test s2.jacobian === :stoch_diag @test s2.damping == 0.8 - @test_throws ArgumentError ParallelMALASampler(0.1; jacobian=:full) + @test_throws ArgumentError ParallelMALASampler(0.1; jacobian=:full, backend=_AD) end @testset "ParallelMALASampler initial step" begin rng = MersenneTwister(42) model = DensityModel(logp_deer, gradlogp_deer, 3) - sampler = ParallelMALASampler(0.05; T=16) + sampler = ParallelMALASampler(0.05; T=16, backend=_AD) trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) @@ -53,7 +57,7 @@ end @testset "ParallelMALASampler initial step respects initial_params" begin rng = MersenneTwister(42) model = DensityModel(logp_deer, gradlogp_deer, 3) - sampler = ParallelMALASampler(0.05; T=8) + sampler = ParallelMALASampler(0.05; T=8, backend=_AD) x0 = [1.0, 2.0, 3.0] x0_copy = copy(x0) @@ -74,23 +78,14 @@ end tape = [ParallelMCMC.MALATapeElement(randn(rng, D), rand(rng)) for _ in 1:T] model = DensityModel(logp_quartic_deer, gradlogp_quartic_deer, D) - rec_ad = ParallelMCMC._build_mala_deer_rec( - model, ε, tape, zeros(D); backend=ParallelMCMC.DEER.DEFAULT_BACKEND - ) + rec_ad = ParallelMCMC._build_mala_deer_rec(model, ε, tape, zeros(D); backend=_AD) x = randn(rng, D) v = randn(rng, D) te = tape[1] f_ad, jvp_ad = rec_ad.fwd_and_jvp(x, te, v) f_ref, jvp_ref = ParallelMCMC.MALA.mala_step_taped_and_jvp( - logp_quartic_deer, - gradlogp_quartic_deer, - x, - ε, - te.ξ, - te.u, - v, - hvp_quartic_deer, + logp_quartic_deer, gradlogp_quartic_deer, x, ε, te.ξ, te.u, v, hvp_quartic_deer ) @test f_ad ≈ f_ref atol=1e-10 rtol=1e-10 @@ -111,7 +106,7 @@ end grad_logdensity_batch=gradlogp_batch_deer, ) - rec = ParallelMCMC._build_mala_deer_rec(model, ε, tape, zeros(D)) + rec = ParallelMCMC._build_mala_deer_rec(model, ε, tape, zeros(D); backend=_AD) @test rec.fwd_and_jvp_batch !== nothing X = randn(rng, D, T) @@ -132,7 +127,7 @@ end @testset "ParallelMALASampler sequential steps advance trajectory index" begin rng = MersenneTwister(7) model = DensityModel(logp_deer, gradlogp_deer, 2) - sampler = ParallelMALASampler(0.05; T=8) + sampler = ParallelMALASampler(0.05; T=8, backend=_AD) trans, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) @test state.t == 1 @@ -167,7 +162,7 @@ end grad_logdensity_batch=gradlogp_batch_deer, hvp_batch=(X, V) -> -V, ) - sampler = ParallelMALASampler(0.05; T=8) + sampler = ParallelMALASampler(0.05; T=8, backend=_AD) _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=zeros(2)) @test batch_calls[] ≥ 1 @@ -186,7 +181,7 @@ end rng = MersenneTwister(99) model = DensityModel(logp_deer, gradlogp_deer, 2) T = 4 - sampler = ParallelMALASampler(0.05; T=T) + sampler = ParallelMALASampler(0.05; T=T, backend=_AD) _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler) @@ -206,7 +201,7 @@ end @testset "ParallelMALASampler sample() end-to-end" begin model = DensityModel(logp_deer, gradlogp_deer, 2) - sampler = ParallelMALASampler(0.05; T=16) + sampler = ParallelMALASampler(0.05; T=16, backend=_AD) samples = sample(MersenneTwister(1), model, sampler, 50; progress=false) @test length(samples) == 50 @@ -214,7 +209,7 @@ end @testset "ParallelMALASampler sample() with chain_type=Chains" begin model = DensityModel(logp_deer, gradlogp_deer, 2) - sampler = ParallelMALASampler(0.05; T=16) + sampler = ParallelMALASampler(0.05; T=16, backend=_AD) chain = sample( MersenneTwister(1), @@ -236,7 +231,7 @@ end @testset "ParallelMALASampler sample() with custom param_names" begin model = DensityModel(logp_deer, gradlogp_deer, 2) - sampler = ParallelMALASampler(0.05; T=16) + sampler = ParallelMALASampler(0.05; T=16, backend=_AD) chain = sample( MersenneTwister(2), @@ -255,7 +250,7 @@ end @testset "ParallelMALASampler stationary distribution" begin D = 3 model = DensityModel(logp_deer, gradlogp_deer, D) - sampler = ParallelMALASampler(0.1; T=32, damping=0.5) + sampler = ParallelMALASampler(0.1; T=32, damping=0.5, backend=_AD) chain = sample( MersenneTwister(2025), @@ -278,7 +273,7 @@ end @testset "ParallelMALASampler parallel chains via MCMCThreads" begin model = DensityModel(logp_deer, gradlogp_deer, 2) - sampler = ParallelMALASampler(0.05; T=8) + sampler = ParallelMALASampler(0.05; T=8, backend=_AD) chains = sample( MersenneTwister(42), diff --git a/test/test-DEER-Turing-Logistic.jl b/test/test-DEER-Turing-Logistic.jl index e6e5ff3..b025d54 100644 --- a/test/test-DEER-Turing-Logistic.jl +++ b/test/test-DEER-Turing-Logistic.jl @@ -9,6 +9,8 @@ using ParallelMCMC using DynamicPPL using LogDensityProblems using ADTypes +using Enzyme +using ForwardDiff using Distributions: MvNormal, Bernoulli #= @@ -17,13 +19,6 @@ Bayesian logistic regression: y_i | β ~ Bernoulli(sigmoid(X_i β)) Synthetic data with known β_true lets us verify posterior means. - -The tests are split into two groups: - -Turing integration: uses the DynamicPPL convenience constructor. These are -CPU-only because DynamicPPL model evaluation does not support GPU arrays. -ParallelMALASampler works fine with Turing on CPU; for GPU you supply a manual -logp/gradlogp that is array-type-agnostic (see below). =# const _LR_D = 2 @@ -39,8 +34,10 @@ end const _LR_X, _LR_y = _lr_data(MersenneTwister(1234), _LR_N, _LR_D, _LR_β_true) -# Array-type-agnostic logp and gradlogp — identical math as the Turing model, -# but expressed as plain broadcasts so they run on CPU or GPU without change. +#= +Array-type-agnostic logp and gradlogp — identical math as the Turing model, +but expressed as plain broadcasts so they run on CPU or GPU without change. +=# function _logp_lr(β, X, y) logits = X * β ll = sum(@. y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits)))) @@ -89,7 +86,14 @@ end @testset "ParallelMALASampler Turing logistic: chains output well-formed" begin model = _deer_logistic_turing_density_model() sampler = ParallelMALASampler( - 0.05; T=16, maxiter=80, tol_abs=1e-4, tol_rel=1e-3, jacobian=:diag, damping=0.5 + 0.05; + T=16, + maxiter=80, + tol_abs=1e-4, + tol_rel=1e-3, + jacobian=:diag, + damping=0.5, + backend=ADTypes.AutoEnzyme(), ) chain = sample( @@ -112,7 +116,14 @@ end @testset "ParallelMALASampler Turing logistic: posterior sign correct" begin model = _deer_logistic_turing_density_model() sampler = ParallelMALASampler( - 0.05; T=16, maxiter=80, tol_abs=1e-4, tol_rel=1e-3, jacobian=:diag, damping=0.5 + 0.05; + T=16, + maxiter=80, + tol_abs=1e-4, + tol_rel=1e-3, + jacobian=:diag, + damping=0.5, + backend=ADTypes.AutoEnzyme(), ) chain = sample( @@ -249,7 +260,7 @@ else @test size(S_gpu) == (_LR_D, T) @test all(isfinite, Array(S_gpu)) S_ref = reduce(hcat, xs_seq[2:end]) - @test Array(S_gpu) ≈ S_ref rtol=1e-4 atol=1e-5 + @test Array(S_gpu) ≈ S_ref rtol=1e-3 atol=1e-4 end @testset "ParallelMALASampler GPU logistic: posterior mean matches CPU" begin @@ -257,11 +268,11 @@ else y_gpu = CUDA.CuVector(_y_f32) sampler = ParallelMALASampler( - 0.1f0; + 0.05f0; T=16, - maxiter=50, - tol_abs=1e-4f0, - tol_rel=1e-3f0, + maxiter=200, + tol_abs=1.0f-4, + tol_rel=1.0f-3, damping=0.5f0, backend=ADTypes.AutoEnzyme(), ) diff --git a/test/test-Deer-vs-MALA.jl b/test/test-Deer-vs-MALA.jl index 75460d0..09281cb 100644 --- a/test/test-Deer-vs-MALA.jl +++ b/test/test-Deer-vs-MALA.jl @@ -1,10 +1,13 @@ using Test using Random using LinearAlgebra +using ADTypes +using Enzyme using ParallelMCMC const MALA = ParallelMCMC.MALA const DEER = ParallelMCMC.DEER +const _AD = ADTypes.AutoEnzyme() # Standard normal target in R^d logp_stdnormal(x) = -0.5 * dot(x, x) @@ -32,7 +35,7 @@ function _make_rec(tape, ε) step_fwd = (x, tt) -> MALA.mala_step_taped(logp_stdnormal, gradlogp_stdnormal, x, ε, tt.ξ, tt.u) - hvp_fn = (pt, dir) -> DEER._hvp_nopre(gradlogp_stdnormal, DEER.DEFAULT_BACKEND, pt, dir) + hvp_fn = (pt, dir) -> DEER._hvp_nopre(gradlogp_stdnormal, _AD, pt, dir) jvp = (x, tt, v) -> MALA.mala_step_surrogate_sigmoid_jvp( logp_stdnormal, gradlogp_stdnormal, x, ε, tt.ξ, tt.u, v, hvp_fn diff --git a/test/test-GPU-AD-HVP.jl b/test/test-GPU-AD-HVP.jl new file mode 100644 index 0000000..d28e167 --- /dev/null +++ b/test/test-GPU-AD-HVP.jl @@ -0,0 +1,228 @@ +using Test +using Random +using LinearAlgebra +using Statistics +using MCMCChains + +using ParallelMCMC +using ADTypes: ADTypes +using Enzyme +using Mooncake +using CUDA: CUDA + +#= +Coverage for the GPU AD-HVP path: when the user supplies a non-Turing model +with `gradlogp` / `grad_logdensity_batch` but no analytical `hvp`/`hvp_batch`, +the sampler should fall back to AD on the user's gradient and run on GPU +without crashing. +=# + +const _ADHVP_GPU_AVAILABLE = try + CUDA.functional() && (CUDA.CuArray([1.0f0]); true) +catch + false +end + +if !_ADHVP_GPU_AVAILABLE + @info "GPU AD-HVP test: CUDA not functional — skipping" +else + + #= + Multivariate Gaussian target with X'X/N perturbation: + logp(β) = -0.5 (||β||^2 + ||Xβ||^2 / N) + True mean = 0; we'll check the posterior mean is near zero. + + Calls go through `pmcmc_matmul` so Enzyme dispatches to the rule in + `ext/EnzymeExt.jl` and avoids cuBLAS gc-transition bundles, which + Enzyme's default `*` rule trips on. + + Gradients are written as a sequence of single-op broadcasts rather + than one fused expression like `-β .- pmcmc_matmul(...) ./ N`. Julia + fuses the second form into one CUDA broadcast kernel whose gc-transition + bundle (cuPointerGetAttribute on each input) Enzyme cannot lower — + splitting it into stages gives Enzyme one operation at a time and + keeps each kernel small enough to differentiate. + =# + function _logp_single(β, X) + Xβ = pmcmc_matmul(X, β) + N = oftype(zero(eltype(β)), size(X, 1)) + -oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N) + end + + function _gradlogp_single(β, X) + Xβ = pmcmc_matmul(X, β) + N = oftype(zero(eltype(β)), size(X, 1)) + Y = pmcmc_matmul(transpose(X), Xβ) + Y = Y ./ N + Y = Y .+ β + return -Y + end + + function _logp_batch(B, X) + XB = pmcmc_matmul(X, B) + N = oftype(zero(eltype(B)), size(X, 1)) + -oftype(zero(eltype(B)), 0.5) .* + (vec(sum(abs2, B; dims=1)) .+ vec(sum(abs2, XB; dims=1)) ./ N) + end + + function _gradlogp_batch(B, X) + XB = pmcmc_matmul(X, B) + N = oftype(zero(eltype(B)), size(X, 1)) + Y = pmcmc_matmul(transpose(X), XB) + Y = Y ./ N + Y = Y .+ B + return -Y + end + + @testset "GPU AD-HVP: ParallelMALASampler runs without analytical HVP" begin + D = 20 + N_data = 64 + rng = MersenneTwister(20251231) + X_cpu = randn(rng, Float32, N_data, D) + X_gpu = CUDA.CuMatrix(X_cpu) + + model = DensityModel( + β -> _logp_single(β, X_gpu), + β -> _gradlogp_single(β, X_gpu), + D; + logdensity_batch=B -> _logp_batch(B, X_gpu), + grad_logdensity_batch=B -> _gradlogp_batch(B, X_gpu), + ) + + sampler = ParallelMALASampler( + 0.05f0; + T=16, + maxiter=200, + tol_abs=1.0f-3, + tol_rel=1.0f-2, + damping=0.5f0, + backend=ADTypes.AutoEnzyme(), + ) + + n_samples, n_burn = 1600, 400 + x0 = CUDA.zeros(Float32, D) + + raw = sample( + MersenneTwister(42), + model, + sampler, + n_samples; + initial_params=x0, + progress=false, + ) + β_post = vec( + mean(reduce(hcat, [Array(s.x) for s in raw[(n_burn + 1):end]]); dims=2) + ) + + @test all(isfinite, β_post) + #= + Gaussian target → posterior mean is zero. Loose tolerance accounts for MC + variance with N_eff ~ 100-200; an actual divergence would blow past this. + =# + @test maximum(abs, β_post) < 0.4 + end + + @testset "GPU AD-HVP: matches CPU AD-HVP" begin + D = 12 + N_data = 48 + rng = MersenneTwister(20251231) + X_cpu = randn(rng, Float32, N_data, D) + X_gpu = CUDA.CuMatrix(X_cpu) + + model_cpu = DensityModel( + β -> _logp_single(β, X_cpu), + β -> _gradlogp_single(β, X_cpu), + D; + logdensity_batch=B -> _logp_batch(B, X_cpu), + grad_logdensity_batch=B -> _gradlogp_batch(B, X_cpu), + ) + model_gpu = DensityModel( + β -> _logp_single(β, X_gpu), + β -> _gradlogp_single(β, X_gpu), + D; + logdensity_batch=B -> _logp_batch(B, X_gpu), + grad_logdensity_batch=B -> _gradlogp_batch(B, X_gpu), + ) + + sampler = ParallelMALASampler( + 0.05f0; + T=16, + maxiter=200, + tol_abs=1.0f-3, + tol_rel=1.0f-2, + damping=0.5f0, + backend=ADTypes.AutoEnzyme(), + ) + + n_samples, n_burn = 4000, 1000 + raw_cpu = sample(MersenneTwister(7), model_cpu, sampler, n_samples; progress=false) + β_cpu = vec(mean(reduce(hcat, [s.x for s in raw_cpu[(n_burn + 1):end]]); dims=2)) + + raw_gpu = sample( + MersenneTwister(7), + model_gpu, + sampler, + n_samples; + initial_params=CUDA.zeros(Float32, D), + progress=false, + ) + β_gpu = vec( + mean(reduce(hcat, [Array(s.x) for s in raw_gpu[(n_burn + 1):end]]); dims=2) + ) + + @test maximum(abs, β_cpu .- β_gpu) < 0.25 + end + + #= + Mooncake on GPU. The whole point of routing through DI is that swapping + the backend should keep the same `ParallelMALASampler` API working. + `_hvp_strategy(::AutoMooncake) = ReverseOnGrad()`, so this exercises the + reverse-on-grad path through Mooncake's CUDA extension (no custom rules + of ours involved — Mooncake's defaults handle `*`, broadcast, `sum`, + `dot` on `CuArray` directly). + =# + @testset "GPU AD-HVP: ParallelMALASampler runs without analytical HVP (Mooncake)" begin + D = 20 + N_data = 64 + rng = MersenneTwister(20251231) + X_cpu = randn(rng, Float32, N_data, D) + X_gpu = CUDA.CuMatrix(X_cpu) + + model = DensityModel( + β -> _logp_single(β, X_gpu), + β -> _gradlogp_single(β, X_gpu), + D; + logdensity_batch=B -> _logp_batch(B, X_gpu), + grad_logdensity_batch=B -> _gradlogp_batch(B, X_gpu), + ) + + sampler = ParallelMALASampler( + 0.05f0; + T=16, + maxiter=200, + tol_abs=1.0f-3, + tol_rel=1.0f-2, + damping=0.5f0, + backend=ADTypes.AutoMooncake(; config=nothing), + ) + + n_samples, n_burn = 1600, 400 + x0 = CUDA.zeros(Float32, D) + + raw = sample( + MersenneTwister(42), + model, + sampler, + n_samples; + initial_params=x0, + progress=false, + ) + β_post = vec( + mean(reduce(hcat, [Array(s.x) for s in raw[(n_burn + 1):end]]); dims=2) + ) + + @test all(isfinite, β_post) + # Same Gaussian target as the AutoEnzyme test; same loose tolerance. + @test maximum(abs, β_post) < 0.4 + end +end # _ADHVP_GPU_AVAILABLE diff --git a/test/test-GPU-MALA.jl b/test/test-GPU-MALA.jl index f6d7179..8aee951 100644 --- a/test/test-GPU-MALA.jl +++ b/test/test-GPU-MALA.jl @@ -8,8 +8,10 @@ const MALA = ParallelMCMC.MALA using CUDA: CUDA -# Check if a real GPU is accessible by attempting a small allocation. -# CUDA.functional() only checks that the library loads, not that a device exists. +#= +Check if a real GPU is accessible by attempting a small allocation. +CUDA.functional() only checks that the library loads, not that a device exists. +=# const CUDA_AVAILABLE = try CUDA.CuArray([1.0f0]) true @@ -25,8 +27,10 @@ else logp_batch(X) = vec(-0.5f0 .* sum(abs2, X; dims=1)) gradlogp_batch(X) = -X - # Scaled normal: dimension i has variance σᵢ² = i (so std = sqrt(i)) - # logp = -0.5 sum_i x_i²/i, grad_i = -x_i/i + #= + Scaled normal: dimension i has variance σᵢ² = i (so std = sqrt(i)) + logp = -0.5 sum_i x_i²/i, grad_i = -x_i/i + =# function logp_scaled(X) D = size(X, 1) scales = CUDA.CuArray(Float32.(1:D)) # D-vector on GPU @@ -119,9 +123,11 @@ else D, N = 4, 64 X = CUDA.randn(Float32, D, N) Ξ = CUDA.randn(Float32, D, N) - # u very close to 1 ⟹ log(u) ≈ 0, forces rejection whenever logα ≤ 0. - # Some chains may still be accepted (logα > 0 when proposal lands at higher density), - # but for every rejected chain X_next must exactly equal X. + #= + u very close to 1 ⟹ log(u) ≈ 0, forces rejection whenever logα ≤ 0. + Some chains may still be accepted (logα > 0 when proposal lands at higher density), + but for every rejected chain X_next must exactly equal X. + =# u = CUDA.fill(1.0f0 - 1.0f-6, N) X_next, accepted = MALA.mala_step_batched( @@ -152,8 +158,10 @@ else end @testset "GPU acceptance rate in reasonable range" begin - # With ε=0.1 and standard normal, empirical acceptance rate should be - # well above 0 and below 1. + #= + With ε=0.1 and standard normal, empirical acceptance rate should be + well above 0 and below 1. + =# D, N, T = 5, 512, 200 n_accepted = 0 @@ -170,6 +178,7 @@ else end @testset "GPU stationary distribution (standard normal)" begin + CUDA.seed!(2025) D, N, T = 3, 512, 1_000 X = CUDA.randn(Float32, D, N) @@ -180,7 +189,7 @@ else end X_cpu = Array(X) - @test maximum(abs.(vec(mean(X_cpu; dims=2)))) < 0.15 + @test maximum(abs.(vec(mean(X_cpu; dims=2)))) < 0.2 @test maximum(abs.(vec(var(X_cpu; dims=2)) .- 1.0f0)) < 0.25 end diff --git a/test/test-GPU-Performance.jl b/test/test-GPU-Performance.jl new file mode 100644 index 0000000..2ef1217 --- /dev/null +++ b/test/test-GPU-Performance.jl @@ -0,0 +1,179 @@ +using Test +using Random +using LinearAlgebra +using Statistics + +using ParallelMCMC +using ADTypes: ADTypes +using CUDA: CUDA +using MCMCChains + +#= +On a problem large enough to amortize CUDA kernel-launch overhead, +ParallelMALASampler on GPU with the batched DEER path is meaningfully faster +than sequential `MALASampler` on CPU. + +This is a regression sanity check. + + 1. GPU samples/s ≥ 2× CPU samples/s + 2. GPU samples/s ≥ 1500 + +If GPU is unavailable, the test set is skipped. +=# + +const _PERF_GPU_AVAILABLE = try + CUDA.functional() && (CUDA.CuArray([1.0f0]); true) +catch + false +end + +if !_PERF_GPU_AVAILABLE + @info "GPU performance test: CUDA not functional — skipping" +else + + #= + Multivariate Gaussian target — well-conditioned, optimal MALA acceptance from + any start. Lets ε be set analytically so the chain actually moves and DEER + does real work. + =# + function _perf_logp(β, X, _) + Xβ = X * β + N = oftype(zero(eltype(β)), size(X, 1)) + return -oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N) + end + + function _perf_gradlogp(β, X, _) + Xβ = X * β + N = oftype(zero(eltype(β)), size(X, 1)) + return -β .- (X' * Xβ) ./ N + end + + function _perf_hvp(β, v, X, _) + Xv = X * v + N = oftype(zero(eltype(β)), size(X, 1)) + return -v .- (X' * Xv) ./ N + end + + function _perf_logp_batch(B, X, _) + XB = X * B + N = oftype(zero(eltype(B)), size(X, 1)) + return -oftype(zero(eltype(B)), 0.5) .* + (vec(sum(abs2, B; dims=1)) .+ vec(sum(abs2, XB; dims=1)) ./ N) + end + + function _perf_gradlogp_batch(B, X, _) + XB = X * B + N = oftype(zero(eltype(B)), size(X, 1)) + return -B .- (X' * XB) ./ N + end + + function _perf_hvp_batch(B, V, X, _) + XV = X * V + N = oftype(zero(eltype(B)), size(X, 1)) + return -V .- (X' * XV) ./ N + end + + function _bench(model, sampler, N; x0, reps=2) + # Warmup + sample( + MersenneTwister(0), + model, + sampler, + sampler isa ParallelMALASampler ? sampler.T : 100; + progress=false, + initial_params=x0, + ) + GC.gc() + if x0 isa CUDA.CuArray + CUDA.synchronize() + end + + ts = Float64[] + for _ in 1:reps + GC.gc() + if x0 isa CUDA.CuArray + CUDA.synchronize() + end + t0 = time_ns() + sample( + MersenneTwister(42), model, sampler, N; progress=false, initial_params=x0 + ) + if x0 isa CUDA.CuArray + CUDA.synchronize() + end + push!(ts, (time_ns() - t0) / 1e9) + end + return median(ts) + end + + @testset "GPU performance vs sequential CPU MALA" begin + #= + Sized so each per-step Sgemm is large enough to amortize CUDA kernel- + launch overhead. D<200 is launch-bound and CPU wins; this is the regime + users care about for the "ParallelMCMC speeds things up" claim. + =# + D = 300 + N_data = 4_000 + rng = MersenneTwister(20251231) + X_cpu = randn(rng, Float32, N_data, D) + y_cpu = randn(rng, Float32, N_data) # unused for Gaussian target + + ε = 0.07f0 + T = 1024 + maxiter = 200 + N_samples = 2_000 + + # CPU baseline + logp_cpu = β -> _perf_logp(β, X_cpu, y_cpu) + gradlogp_cpu = β -> _perf_gradlogp(β, X_cpu, y_cpu) + cpu_model = DensityModel(logp_cpu, gradlogp_cpu, D) + cpu_sampler = MALASampler(ε) + x0_cpu = zeros(Float32, D) + + cpu_time = _bench(cpu_model, cpu_sampler, N_samples; x0=x0_cpu) + cpu_sps = N_samples / cpu_time + @info "CPU baseline" cpu_sps cpu_time + + # GPU run (analytical-HVP batched path) + X_gpu = CUDA.CuMatrix(X_cpu) + y_gpu = CUDA.CuVector(y_cpu) + logp_gpu = β -> _perf_logp(β, X_gpu, y_gpu) + gradlogp_gpu = β -> _perf_gradlogp(β, X_gpu, y_gpu) + hvp_gpu = (β, v) -> _perf_hvp(β, v, X_gpu, y_gpu) + logp_batch_gpu = B -> _perf_logp_batch(B, X_gpu, y_gpu) + gradlogp_batch_gpu = B -> _perf_gradlogp_batch(B, X_gpu, y_gpu) + hvp_batch_gpu = (B, V) -> _perf_hvp_batch(B, V, X_gpu, y_gpu) + + gpu_model = DensityModel( + logp_gpu, + gradlogp_gpu, + D; + hvp=hvp_gpu, + logdensity_batch=logp_batch_gpu, + grad_logdensity_batch=gradlogp_batch_gpu, + hvp_batch=hvp_batch_gpu, + ) + gpu_sampler = ParallelMALASampler( + ε; + T=T, + maxiter=maxiter, + tol_abs=1.0f-3, + tol_rel=1.0f-2, + damping=0.5f0, + backend=ADTypes.AutoEnzyme(), + ) + x0_gpu = CUDA.CuArray(x0_cpu) + + gpu_time = _bench(gpu_model, gpu_sampler, N_samples; x0=x0_gpu) + gpu_sps = N_samples / gpu_time + speedup = gpu_sps / cpu_sps + @info "GPU run" gpu_sps gpu_time speedup + + #= + The thresholds are deliberately loose so flaky shared-cluster GPUs don't + red the build, but tight enough to catch a genuine perf regression. + =# + @test speedup ≥ 2.0 + @test gpu_sps ≥ 1_500.0 + end +end diff --git a/test/test-Jacobian-Estimator.jl b/test/test-Jacobian-Estimator.jl index a978b31..08679a7 100644 --- a/test/test-Jacobian-Estimator.jl +++ b/test/test-Jacobian-Estimator.jl @@ -2,10 +2,13 @@ using Test using Random using LinearAlgebra using StatsBase +using ADTypes +using Enzyme using ParallelMCMC const DEER = ParallelMCMC.DEER const MALA = ParallelMCMC.MALA +const _AD = ADTypes.AutoEnzyme() logp_stdnormal_B(x) = -0.5 * dot(x, x) gradlogp_stdnormal_B(x) = -x @@ -31,8 +34,10 @@ end Base.similar(x::TaggedVector, n::Int) = TaggedVector(Vector{eltype(x)}(undef, n)) Base.copy(x::TaggedVector) = TaggedVector(copy(x.data)) -# Reference log q(y|x) for MALA: -# y ~ Normal( x + ϵ∇logp(x), 2ϵ I ) +#= +Reference log q(y|x) for MALA: +y ~ Normal( x + ϵ∇logp(x), 2ϵ I ) +=# function logq_mala_ref_B( y::AbstractVector, x::AbstractVector, gradlogp_x::AbstractVector, ϵ::Real ) @@ -142,18 +147,10 @@ make_affine_tape(rng::AbstractRNG, D::Int, T::Int) = [randn(rng, D) for _ in 1:T s0 = randn(rng, D) ws = DEER.DEERWorkspace(S, s0) - @test_throws ArgumentError DEER.deer_update!( - ws, S_out, rec, s0, S; jacobian=:full - ) - @test_throws ArgumentError DEER.deer_update!( - ws, S_out, rec, s0, S; damping=0.0 - ) - @test_throws ArgumentError DEER.deer_update!( - ws, S_out, rec, s0, S; probes=0 - ) - @test_throws ArgumentError DEER.deer_update!( - ws, randn(rng, D, T + 1), rec, s0, S - ) + @test_throws ArgumentError DEER.deer_update!(ws, S_out, rec, s0, S; jacobian=:full) + @test_throws ArgumentError DEER.deer_update!(ws, S_out, rec, s0, S; damping=0.0) + @test_throws ArgumentError DEER.deer_update!(ws, S_out, rec, s0, S; probes=0) + @test_throws ArgumentError DEER.deer_update!(ws, randn(rng, D, T + 1), rec, s0, S) @test_throws ArgumentError DEER.solve(rec, s0; maxiter=0) @test_throws ArgumentError DEER.solve(rec, s0; tol_abs=-1.0) @@ -201,9 +198,7 @@ make_affine_tape(rng::AbstractRNG, D::Int, T::Int) = [randn(rng, D) for _ in 1:T (x, tt) -> MALA.mala_step_taped( logp_stdnormal_B, gradlogp_stdnormal_B, x, ϵ, tt.ξ, tt.u ) - hvp_fn = - (pt, dir) -> - DEER._hvp_nopre(gradlogp_stdnormal_B, DEER.DEFAULT_BACKEND, pt, dir) + hvp_fn = (pt, dir) -> DEER._hvp_nopre(gradlogp_stdnormal_B, _AD, pt, dir) jvp = (x, tt, v) -> MALA.mala_step_surrogate_sigmoid_jvp( logp_stdnormal_B, gradlogp_stdnormal_B, x, ϵ, tt.ξ, tt.u, v, hvp_fn diff --git a/test/test-MALA-Kernel.jl b/test/test-MALA-Kernel.jl index 4abc877..3da55ed 100644 --- a/test/test-MALA-Kernel.jl +++ b/test/test-MALA-Kernel.jl @@ -9,8 +9,10 @@ const MALA = ParallelMCMC.MALA logp_stdnormal_kernel(x) = -0.5 * dot(x, x) gradlogp_stdnormal_kernel(x) = -x -# Reference implementation of log q(y|x) for MALA proposal: -# y ~ Normal( x + ϵ∇logp(x), 2ϵ I ) +#= +Reference implementation of log q(y|x) for MALA proposal: +y ~ Normal( x + ϵ∇logp(x), 2ϵ I ) +=# function logq_mala_ref( y::AbstractVector, x::AbstractVector, gradlogp_x::AbstractVector, ϵ::Real ) @@ -30,8 +32,7 @@ function logq_mala_mass_ref( μ = x .+ ϵ .* (cholM.L * (adjoint(cholM.L) * gradlogp_x)) r = y .- μ d = length(x) - return -0.5 * dot(cholM \ r, r) / (2ϵ) - - (d / 2) * log(4π * ϵ) - 0.5 * logdet(cholM) + return -0.5 * dot(cholM \ r, r) / (2ϵ) - (d / 2) * log(4π * ϵ) - 0.5 * logdet(cholM) end # Build a tape (ξs, us) deterministically @@ -207,8 +208,9 @@ end y = MALA.mala_proposal( logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x, ϵ, ξ; cholM=cholM ) - y_ref = x .+ ϵ .* (cholM.L * (adjoint(cholM.L) * gradlogp_stdnormal_kernel(x))) .+ - sqrt(2ϵ) .* (cholM.L * ξ) + y_ref = + x .+ ϵ .* (cholM.L * (adjoint(cholM.L) * gradlogp_stdnormal_kernel(x))) .+ + sqrt(2ϵ) .* (cholM.L * ξ) @test y ≈ y_ref atol=1e-12 rtol=1e-12 logq_impl = MALA.logq_mala(y, x, gradlogp_stdnormal_kernel(x), ϵ; cholM=cholM) @@ -217,15 +219,9 @@ end X = randn(rng, D, N) Ξ = randn(rng, D, N) - u = range(0.1, 0.9; length=N) |> collect + u = collect(range(0.1, 0.9; length=N)) X_next, accepted = MALA.mala_step_batched( - X -> vec(-0.5 .* sum(abs2, X; dims=1)), - X -> -X, - X, - ϵ, - Ξ, - u; - cholM=cholM, + X -> vec(-0.5 .* sum(abs2, X; dims=1)), X -> -X, X, ϵ, Ξ, u; cholM=cholM ) @test length(accepted) == N @@ -250,7 +246,11 @@ end logq_batch = MALA.logq_mala_batched(Y, X, G, ϵ; cholM=cholM) logq_cols = [ MALA.logq_mala( - copy(view(Y, :, j)), copy(view(X, :, j)), copy(view(G, :, j)), ϵ; cholM=cholM + copy(view(Y, :, j)), + copy(view(X, :, j)), + copy(view(G, :, j)), + ϵ; + cholM=cholM, ) for j in 1:N ] @test logq_batch ≈ logq_cols atol=1e-12 rtol=1e-12 @@ -273,7 +273,9 @@ end ) @test actual ≈ expected atol=1e-12 rtol=1e-12 - @test all(xi -> min(xi[1], xi[2]) - 1e-12 ≤ xi[3] ≤ max(xi[1], xi[2]) + 1e-12, - zip(x, y, actual)) + @test all( + xi -> min(xi[1], xi[2]) - 1e-12 ≤ xi[3] ≤ max(xi[1], xi[2]) + 1e-12, + zip(x, y, actual), + ) end end diff --git a/test/test-Owned-Matmul.jl b/test/test-Owned-Matmul.jl new file mode 100644 index 0000000..527fae1 --- /dev/null +++ b/test/test-Owned-Matmul.jl @@ -0,0 +1,325 @@ +using Test +using Random +using LinearAlgebra + +using ParallelMCMC +using ADTypes +using Enzyme +import ParallelMCMC.DEER as DEER +const DI = DEER.DI + +#= +Tests for the owned wrappers `pmcmc_matmul`, `pmcmc_dot`, `pmcmc_dotsum` and +the native Enzyme rules in `ext/EnzymeExt.jl`. We check: + 1. function values match the Base counterparts + 2. forward-mode JVPs match the analytical formulas + 3. reverse-mode pullbacks (gradients) match the analytical formulas + +The reverse cases are critical: they're what makes the GPU AD-HVP path work +without crashing on `cuMemcpyDtoHAsync_v2` gc-transition bundles. +=# + +const _AD_FWD = ADTypes.AutoEnzyme(; mode=Enzyme.Forward) +const _AD_REV = ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Const +) + +# ============================================================ +# pmcmc_matmul +# ============================================================ + +@testset "pmcmc_matmul value matches Base.*" begin + rng = MersenneTwister(0) + A = randn(rng, 4, 3) + B = randn(rng, 3, 5) + b = randn(rng, 3) + + @test pmcmc_matmul(A, B) ≈ A * B + @test pmcmc_matmul(A, b) ≈ A * b +end + +@testset "Enzyme forward JVP on pmcmc_matmul matches analytical" begin + rng = MersenneTwister(1) + M, K, N = 5, 4, 3 + A = randn(rng, M, K) + B = randn(rng, K, N) + dA = randn(rng, M, K) + dB = randn(rng, K, N) + + dY_expected = dA * B + A * dB + + f1 = ((A_, B_),) -> pmcmc_matmul(A_, B_) + prep = DI.prepare_pushforward(f1, _AD_FWD, (A, B), ((dA, dB),); strict=Val(false)) + dY = first(DI.pushforward(f1, prep, _AD_FWD, (A, B), ((dA, dB),))) + + @test dY ≈ dY_expected rtol=1e-12 atol=1e-12 +end + +@testset "Enzyme forward JVP on pmcmc_matmul: matrix-vector" begin + rng = MersenneTwister(2) + M, K = 6, 4 + A = randn(rng, M, K) + b = randn(rng, K) + dA = randn(rng, M, K) + db = randn(rng, K) + + dy_expected = dA * b + A * db + + f1 = ((A_, b_),) -> pmcmc_matmul(A_, b_) + prep = DI.prepare_pushforward(f1, _AD_FWD, (A, b), ((db, db),); strict=Val(false)) # eltype-stable + prep = DI.prepare_pushforward(f1, _AD_FWD, (A, b), ((dA, db),); strict=Val(false)) + dy = first(DI.pushforward(f1, prep, _AD_FWD, (A, b), ((dA, db),))) + + @test dy ≈ dy_expected rtol=1e-12 atol=1e-12 +end + +@testset "Enzyme reverse pullback on pmcmc_matmul (via grad of dot)" begin + # f(A, B) = dot(pmcmc_matmul(A, B), w) is scalar; dA = w * B', dB = A' * w + rng = MersenneTwister(3) + M, K = 5, 4 + A = randn(rng, M, K) + b = randn(rng, K) + w = randn(rng, M) + + f = ((A_, b_),) -> pmcmc_dot(pmcmc_matmul(A_, b_), w) + prep = DI.prepare_gradient(f, _AD_REV, (A, b); strict=Val(false)) + g = DI.gradient(f, prep, _AD_REV, (A, b)) + + dA_expected = w * b' + db_expected = A' * w + @test g[1] ≈ dA_expected rtol=1e-12 atol=1e-12 + @test g[2] ≈ db_expected rtol=1e-12 atol=1e-12 +end + +# ============================================================ +# pmcmc_dot +# ============================================================ + +@testset "pmcmc_dot value matches LinearAlgebra.dot" begin + rng = MersenneTwister(10) + a = randn(rng, 7) + b = randn(rng, 7) + @test pmcmc_dot(a, b) ≈ dot(a, b) +end + +@testset "Enzyme reverse pullback on pmcmc_dot" begin + # f(a, b) = pmcmc_dot(a, b); da = b, db = a + rng = MersenneTwister(11) + a = randn(rng, 6) + b = randn(rng, 6) + + f = ((a_, b_),) -> pmcmc_dot(a_, b_) + prep = DI.prepare_gradient(f, _AD_REV, (a, b); strict=Val(false)) + g = DI.gradient(f, prep, _AD_REV, (a, b)) + + @test g[1] ≈ b rtol=1e-12 atol=1e-12 + @test g[2] ≈ a rtol=1e-12 atol=1e-12 +end + +@testset "Enzyme forward JVP on pmcmc_dot matches analytical" begin + # d/dt pmcmc_dot(a + t*da, b + t*db)|t=0 = dot(da, b) + dot(a, db) + rng = MersenneTwister(12) + a = randn(rng, 6) + b = randn(rng, 6) + da = randn(rng, 6) + db = randn(rng, 6) + + dval_expected = dot(da, b) + dot(a, db) + + f1 = ((a_, b_),) -> pmcmc_dot(a_, b_) + prep = DI.prepare_pushforward(f1, _AD_FWD, (a, b), ((da, db),); strict=Val(false)) + dval = first(DI.pushforward(f1, prep, _AD_FWD, (a, b), ((da, db),))) + + @test dval ≈ dval_expected rtol=1e-12 atol=1e-12 +end + +# ============================================================ +# pmcmc_dotsum +# ============================================================ + +@testset "pmcmc_dotsum value matches sum(A .* B)" begin + rng = MersenneTwister(20) + A = randn(rng, 3, 4) + B = randn(rng, 3, 4) + @test pmcmc_dotsum(A, B) ≈ sum(A .* B) +end + +@testset "Enzyme reverse pullback on pmcmc_dotsum" begin + # f(A, B) = pmcmc_dotsum(A, B); dA = B, dB = A + rng = MersenneTwister(21) + A = randn(rng, 4, 3) + B = randn(rng, 4, 3) + + f = ((A_, B_),) -> pmcmc_dotsum(A_, B_) + prep = DI.prepare_gradient(f, _AD_REV, (A, B); strict=Val(false)) + g = DI.gradient(f, prep, _AD_REV, (A, B)) + + @test g[1] ≈ B rtol=1e-12 atol=1e-12 + @test g[2] ≈ A rtol=1e-12 atol=1e-12 +end + +@testset "Enzyme forward JVP on pmcmc_dotsum matches analytical" begin + # d/dt pmcmc_dotsum(A + t*dA, B + t*dB)|t=0 = sum(dA .* B) + sum(A .* dB) + rng = MersenneTwister(22) + A = randn(rng, 4, 3) + B = randn(rng, 4, 3) + dA = randn(rng, 4, 3) + dB = randn(rng, 4, 3) + + dval_expected = sum(dA .* B) + sum(A .* dB) + + f1 = ((A_, B_),) -> pmcmc_dotsum(A_, B_) + prep = DI.prepare_pushforward(f1, _AD_FWD, (A, B), ((dA, dB),); strict=Val(false)) + dval = first(DI.pushforward(f1, prep, _AD_FWD, (A, B), ((dA, dB),))) + + @test dval ≈ dval_expected rtol=1e-12 atol=1e-12 +end + +# ============================================================ +# Enzyme Const-annotation paths +# +# DI doesn't expose Enzyme `Const` directly on arguments — it activates the +# whole gradient subject. The rule bodies all special-case `a isa Const` / +# `B isa Const`; these tests hit those branches via `Enzyme.autodiff`. +# ============================================================ + +@testset "Enzyme Const-arg forward JVP — pmcmc_matmul" begin + rng = MersenneTwister(100) + M, K, N = 5, 4, 3 + A = randn(rng, M, K); B = randn(rng, K, N) + dA = randn(rng, M, K); dB = randn(rng, K, N) + + # Const(A), Duplicated(B): dY = A * dB + (dY1,) = Enzyme.autodiff( + Enzyme.Forward, pmcmc_matmul, Enzyme.Duplicated, + Enzyme.Const(A), Enzyme.Duplicated(B, dB), + ) + @test dY1 ≈ A * dB rtol=1e-12 atol=1e-12 + + # Duplicated(A), Const(B): dY = dA * B + (dY2,) = Enzyme.autodiff( + Enzyme.Forward, pmcmc_matmul, Enzyme.Duplicated, + Enzyme.Duplicated(A, dA), Enzyme.Const(B), + ) + @test dY2 ≈ dA * B rtol=1e-12 atol=1e-12 + + # Const(A), Const(B): tangent is zero + (dY0,) = Enzyme.autodiff( + Enzyme.Forward, pmcmc_matmul, Enzyme.Duplicated, + Enzyme.Const(A), Enzyme.Const(B), + ) + @test all(iszero, dY0) +end + +@testset "Enzyme Const-arg reverse pullback — pmcmc_matmul" begin + # f(A, B) = pmcmc_dot(pmcmc_matmul(A, B), w); dA = w * B', dB = A' * w + rng = MersenneTwister(101) + M, K = 5, 4 + A = randn(rng, M, K); b = randn(rng, K); w = randn(rng, M) + + # Const(A): only B accumulates; expect db = A' * w + db_buf = zero(b) + Enzyme.autodiff( + Enzyme.Reverse, + (A_, B_, w_) -> pmcmc_dot(pmcmc_matmul(A_, B_), w_), + Enzyme.Active, + Enzyme.Const(A), Enzyme.Duplicated(b, db_buf), Enzyme.Const(w), + ) + @test db_buf ≈ A' * w rtol=1e-12 atol=1e-12 + + # Const(b): only A accumulates; expect dA = w * b' + dA_buf = zero(A) + Enzyme.autodiff( + Enzyme.Reverse, + (A_, B_, w_) -> pmcmc_dot(pmcmc_matmul(A_, B_), w_), + Enzyme.Active, + Enzyme.Duplicated(A, dA_buf), Enzyme.Const(b), Enzyme.Const(w), + ) + @test dA_buf ≈ w * b' rtol=1e-12 atol=1e-12 +end + +@testset "Enzyme Const-arg forward JVP — pmcmc_dot" begin + rng = MersenneTwister(110) + a = randn(rng, 6); b = randn(rng, 6) + da = randn(rng, 6); db = randn(rng, 6) + + (dv1,) = Enzyme.autodiff( + Enzyme.Forward, pmcmc_dot, Enzyme.Duplicated, + Enzyme.Const(a), Enzyme.Duplicated(b, db), + ) + @test dv1 ≈ dot(a, db) rtol=1e-12 atol=1e-12 + + (dv2,) = Enzyme.autodiff( + Enzyme.Forward, pmcmc_dot, Enzyme.Duplicated, + Enzyme.Duplicated(a, da), Enzyme.Const(b), + ) + @test dv2 ≈ dot(da, b) rtol=1e-12 atol=1e-12 + + (dv0,) = Enzyme.autodiff( + Enzyme.Forward, pmcmc_dot, Enzyme.Duplicated, + Enzyme.Const(a), Enzyme.Const(b), + ) + @test dv0 == 0 +end + +@testset "Enzyme Const-arg reverse pullback — pmcmc_dot" begin + rng = MersenneTwister(111) + a = randn(rng, 6); b = randn(rng, 6) + + da_buf = zero(a) + Enzyme.autodiff( + Enzyme.Reverse, pmcmc_dot, Enzyme.Active, + Enzyme.Duplicated(a, da_buf), Enzyme.Const(b), + ) + @test da_buf ≈ b rtol=1e-12 atol=1e-12 + + db_buf = zero(b) + Enzyme.autodiff( + Enzyme.Reverse, pmcmc_dot, Enzyme.Active, + Enzyme.Const(a), Enzyme.Duplicated(b, db_buf), + ) + @test db_buf ≈ a rtol=1e-12 atol=1e-12 +end + +@testset "Enzyme Const-arg forward JVP — pmcmc_dotsum" begin + rng = MersenneTwister(120) + A = randn(rng, 4, 3); B = randn(rng, 4, 3) + dA = randn(rng, 4, 3); dB = randn(rng, 4, 3) + + (dv1,) = Enzyme.autodiff( + Enzyme.Forward, pmcmc_dotsum, Enzyme.Duplicated, + Enzyme.Const(A), Enzyme.Duplicated(B, dB), + ) + @test dv1 ≈ sum(A .* dB) rtol=1e-12 atol=1e-12 + + (dv2,) = Enzyme.autodiff( + Enzyme.Forward, pmcmc_dotsum, Enzyme.Duplicated, + Enzyme.Duplicated(A, dA), Enzyme.Const(B), + ) + @test dv2 ≈ sum(dA .* B) rtol=1e-12 atol=1e-12 + + (dv0,) = Enzyme.autodiff( + Enzyme.Forward, pmcmc_dotsum, Enzyme.Duplicated, + Enzyme.Const(A), Enzyme.Const(B), + ) + @test dv0 == 0 +end + +@testset "Enzyme Const-arg reverse pullback — pmcmc_dotsum" begin + rng = MersenneTwister(121) + A = randn(rng, 4, 3); B = randn(rng, 4, 3) + + dA_buf = zero(A) + Enzyme.autodiff( + Enzyme.Reverse, pmcmc_dotsum, Enzyme.Active, + Enzyme.Duplicated(A, dA_buf), Enzyme.Const(B), + ) + @test dA_buf ≈ B rtol=1e-12 atol=1e-12 + + dB_buf = zero(B) + Enzyme.autodiff( + Enzyme.Reverse, pmcmc_dotsum, Enzyme.Active, + Enzyme.Const(A), Enzyme.Duplicated(B, dB_buf), + ) + @test dB_buf ≈ A rtol=1e-12 atol=1e-12 +end diff --git a/test/test-Turing-Integration.jl b/test/test-Turing-Integration.jl index d812dfd..e79af9d 100644 --- a/test/test-Turing-Integration.jl +++ b/test/test-Turing-Integration.jl @@ -9,12 +9,16 @@ using ParallelMCMC using DynamicPPL using LogDensityProblems using ADTypes -using Distributions: Beta, Normal, MvNormal, product_distribution, Dirichlet - -# A simple 1-D normal likelihood: μ ~ N(0,1), y | μ ~ N(μ, 0.5) -# Posterior: μ | y=1.5 is N(μ_post, σ_post²) -# σ_post² = 1 / (1/1² + 1/0.5²) = 1 / (1 + 4) = 0.2 -# μ_post = σ_post² * (y / 0.5²) = 0.2 * (1.5 / 0.25) = 0.2 * 6 = 1.2 +using Enzyme +using ForwardDiff +using Distributions: Beta, Dirichlet, Normal, MvNormal, product_distribution + +#= +A simple 1-D normal likelihood: μ ~ N(0,1), y | μ ~ N(μ, 0.5) +Posterior: μ | y=1.5 is N(μ_post, σ_post²) +σ_post² = 1 / (1/1² + 1/0.5²) = 1 / (1 + 4) = 0.2 +μ_post = σ_post² * (y / 0.5²) = 0.2 * (1.5 / 0.25) = 0.2 * 6 = 1.2 +=# const TRUE_OBS = 1.5 const TRUE_MU_POST = 1.2 const TRUE_VAR_POST = 0.2 @@ -34,6 +38,14 @@ end x ~ Beta(2, 2) end +@model function mvnormal_2d_model() + x ~ MvNormal(zeros(2), I) +end + +@model function dirichlet_3_model() + x ~ Dirichlet(ones(3)) +end + @testset "directly passing LogDensityFunction" begin ld = DynamicPPL.LogDensityFunction( normal_model(TRUE_OBS), @@ -97,6 +109,7 @@ end tol_rel=1e-4, jacobian=jacobian, damping=0.5, + backend=ADTypes.AutoEnzyme(), ) trans, state = ParallelMCMC.AbstractMCMC.step( @@ -110,6 +123,58 @@ end end end +@testset "DynamicPPLExt: MvNormal(zeros(2), I) runs with ParallelMALA" begin + model = DensityModel(mvnormal_2d_model()) + + @test model.dim == 2 + @test isfinite(model.logdensity(zeros(2))) + @test all(isfinite, model.grad_logdensity(zeros(2))) + + sampler = ParallelMALASampler( + 0.2; T=8, maxiter=80, tol_abs=1e-4, tol_rel=1e-3, backend=ADTypes.AutoEnzyme() + ) + chain = sample( + MersenneTwister(3), + model, + sampler, + 800; + initial_params=zeros(2), + chain_type=MCMCChains.Chains, + progress=false, + ) + samples = Array(chain) + @test all(isfinite, samples) + # Standard normal in 2-D: posterior mean should be near zero. + @test maximum(abs, vec(mean(samples; dims=1))) < 0.25 +end + +@testset "DynamicPPLExt: Dirichlet(ones(3)) runs with ParallelMALA (linked space)" begin + #= + Dirichlet(ones(3)) lives on a 2-simplex, so its unconstrained + representation has dim 2. Bijectors handles the link/unlink. + =# + model = DensityModel(dirichlet_3_model()) + + @test model.dim == 2 + @test isfinite(model.logdensity(zeros(2))) + @test all(isfinite, model.grad_logdensity(zeros(2))) + + sampler = ParallelMALASampler( + 0.2; T=8, maxiter=80, tol_abs=1e-4, tol_rel=1e-3, backend=ADTypes.AutoEnzyme() + ) + chain = sample( + MersenneTwister(4), + model, + sampler, + 800; + initial_params=zeros(2), + chain_type=MCMCChains.Chains, + progress=false, + ) + @test chain isa MCMCChains.Chains + @test all(isfinite, Array(chain)) +end + @testset "DynamicPPLExt: named columns in Chains output" begin model = DensityModel(normal_model(TRUE_OBS))