Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ LocalPreferences.toml
CLAUDE.md
AGENTS.md
CODEX.md
.gemini
.gemini

# random scripts for debugging
/scripts
/dev
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ version = "0.0.2"
ADTypes = "1.21.0"
AbstractMCMC = "5.10.0"
CUDA = "5.11.0"
CUDA_Runtime_jll = "0.21"
DifferentiationInterface = "0.7.13"
DynamicPPL = "0.40.6, 0.41"
Enzyme = "0.13.131"
Enzyme = "0.13.146"
ForwardDiff = "1"
LinearAlgebra = "1"
LogDensityProblems = "2"
MCMCChains = "7.7.0"
Mooncake = "0.5.26"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're interested in second-order Mooncake features through DI, you wanna track JuliaDiff/DifferentiationInterface.jl#990.

See also JuliaDiff/DifferentiationInterface.jl#986

Random = "1"
Statistics = "1"
julia = "1.10"
Expand All @@ -22,19 +25,22 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
DynamicPPLExt = ["DynamicPPL", "LogDensityProblems"]
DynamicPPLExt = ["DynamicPPL", "ForwardDiff", "LogDensityProblems"]
EnzymeExt = "Enzyme"
LogDensityProblemsExt = "LogDensityProblems"

[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"

[weakdeps]
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
7 changes: 4 additions & 3 deletions benchmarks/ParallelMCMCBenchmarks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ParallelMCMC = "1a970f40-4406-51c9-a967-cb3143c111e8"
Expand All @@ -14,8 +15,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[sources]
ParallelMCMC = {path = "../.."}

[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"

[sources.ParallelMCMC]
path = "../.."
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ using Statistics

using AbstractMCMC: sample
using ParallelMCMC
using ADTypes
using Enzyme
using ParallelMCMCBenchmarks

const BayesLogReg = ParallelMCMCBenchmarks.BayesLogReg
Expand Down Expand Up @@ -135,6 +137,7 @@ if _cuda_ok
tol_rel=tol_rel,
damping=damping,
probes=probes,
backend=ADTypes.AutoEnzyme(),
)

println(" T=$T")
Expand Down
10 changes: 9 additions & 1 deletion benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ using Statistics

using AbstractMCMC: sample
using ParallelMCMC
using ADTypes
using Enzyme
using ParallelMCMCBenchmarks
const BayesLinReg = ParallelMCMCBenchmarks.BayesLinReg
const MALARunner = ParallelMCMCBenchmarks.MALARunner
Expand Down Expand Up @@ -46,7 +48,13 @@ x_warm, ϵ_tuned = MALARunner.tune_stepsize_mala(
mala_sampler = AdaptiveMALASampler(ϵ_tuned; n_warmup=500)

deer_sampler = ParallelMALASampler(
ϵ_tuned; T=64, maxiter=200, tol_abs=1e-6, tol_rel=1e-5, damping=0.5
ϵ_tuned;
T=64,
maxiter=200,
tol_abs=1e-6,
tol_rel=1e-5,
damping=0.5,
backend=ADTypes.AutoEnzyme(),
)

# Benchmark helper
Expand Down
13 changes: 10 additions & 3 deletions benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,10 @@ function write_markdown(path, rows; warn_ratio, fail_ratio)
"`.",
)
println(io)
println(io, "| Benchmark | Base median | PR median | Ratio | Allocs | Memory delta | Status |")
println(
io,
"| Benchmark | Base median | PR median | Ratio | Allocs | Memory delta | Status |",
)
println(io, "|---|---:|---:|---:|---:|---:|---|")
for row in rows
println(
Expand Down Expand Up @@ -167,8 +170,12 @@ function main(args=ARGS)
base === nothing && error("--base is required")
head === nothing && error("--head is required")

warn_ratio = parse(Float64, _option(args, "--warn-ratio", get(ENV, "PMCMC_BENCH_WARN_RATIO", "1.25")))
fail_ratio = parse(Float64, _option(args, "--fail-ratio", get(ENV, "PMCMC_BENCH_FAIL_RATIO", "1.75")))
warn_ratio = parse(
Float64, _option(args, "--warn-ratio", get(ENV, "PMCMC_BENCH_WARN_RATIO", "1.25"))
)
fail_ratio = parse(
Float64, _option(args, "--fail-ratio", get(ENV, "PMCMC_BENCH_FAIL_RATIO", "1.75"))
)
markdown = _option(args, "--markdown", "")

warn_ratio > 1 || error("--warn-ratio must be greater than 1")
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using Printf
using Statistics

using ParallelMCMC
using ADTypes
using Enzyme
using ParallelMCMCBenchmarks
using CUDA

Expand Down Expand Up @@ -79,7 +81,7 @@ function build_raw_deer_problem(
damping::Float32,
probes::Int,
cholM=nothing,
backend=DEER.DEFAULT_BACKEND,
backend=ADTypes.AutoEnzyme(),
)
FP = typeof(epsilon)
D = model.dim
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ function main(args=ARGS)
return 0
end

seconds = parse(Float64, _option(args, "--seconds", get(ENV, "PMCMC_BENCH_SECONDS", "0.5")))
seconds = parse(
Float64, _option(args, "--seconds", get(ENV, "PMCMC_BENCH_SECONDS", "0.5"))
)
samples = parse(Int, _option(args, "--samples", get(ENV, "PMCMC_BENCH_SAMPLES", "8")))
output = _option(args, "--output", "")
markdown = _option(args, "--markdown", "")
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ or:
using Random
using CUDA
using ParallelMCMC
using ADTypes
using Enzyme
using ParallelMCMCBenchmarks

const BayesLogReg = ParallelMCMCBenchmarks.BayesLogReg
Expand All @@ -52,7 +54,7 @@ function build_raw_deer_problem(
epsilon::Float32,
T::Int,
cholM=nothing,
backend=DEER.DEFAULT_BACKEND,
backend=ADTypes.AutoEnzyme(),
)
FP = typeof(epsilon)
D = model.dim
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using Statistics

using AbstractMCMC: sample
using ParallelMCMC
using ADTypes
using Enzyme
using ParallelMCMCBenchmarks
using CUDA

Expand Down Expand Up @@ -67,7 +69,9 @@ function build_problem()
end

function build_rec(model_gpu, x0_gpu, tape)
return ParallelMCMC._build_mala_deer_rec(model_gpu, epsilon, tape, x0_gpu;)
return ParallelMCMC._build_mala_deer_rec(
model_gpu, epsilon, tape, x0_gpu; backend=ADTypes.AutoEnzyme()
)
end

function solve_prebuilt(rec, x0_gpu, ws; seed=42, return_info=false)
Expand Down Expand Up @@ -131,6 +135,7 @@ deer_gpu = ParallelMALASampler(
tol_rel=tol_rel,
damping=damping,
probes=probes,
backend=ADTypes.AutoEnzyme(),
)

println("=" ^ 96)
Expand Down
8 changes: 5 additions & 3 deletions benchmarks/ParallelMCMCBenchmarks/src/models/bayes_linreg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ function make_problem(X::AbstractMatrix, y::AbstractVector; σ::Real=1.0, τ::Re
return g
end

# Analytic Gaussian posterior:
# Σ_post = (XᵀX/σ² + I/τ²)^{-1}
# μ_post = Σ_post * (Xᵀy/σ²)
#=
Analytic Gaussian posterior:
Σ_post = (XᵀX/σ² + I/τ²)^{-1}
μ_post = Σ_post * (Xᵀy/σ²)
=#
A = (X' * X) ./ σ2
A += Diagonal(fill(1.0 / τ2, p))
Σ_post = inv(Symmetric(A))
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ function make_problem(X::AbstractMatrix, y::AbstractVector)

function logp(β::AbstractVector)
mul!(logits, X, β)
# ll = sum(@. y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits))))
# We perform this in-place to avoid allocations
# Note: log1p(exp(x)) is the softplus function
#=
ll = sum(@. y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits))))
We perform this in-place to avoid allocations
Note: log1p(exp(x)) is the softplus function
=#
@. p = y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits)))
return sum(p) - 0.5 * sum(abs2, β)
end
Expand All @@ -47,8 +49,10 @@ function make_problem(X::AbstractMatrix, y::AbstractVector)
mul!(logits, X, β)
@. p = 1 / (1 + exp(-logits))
@. resid = y - p
# grad = X' * (y - p) - β
# We reuse the output allocation here by using mul!
#=
grad = X' * (y - p) - β
We reuse the output allocation here by using mul!
=#
grad = (X' * resid) .- β
return grad
end
Expand Down
34 changes: 11 additions & 23 deletions benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ using Statistics
using TOML

using ParallelMCMC
using ADTypes
using Enzyme

const MALA = ParallelMCMC.MALA
const DEER = ParallelMCMC.DEER
Expand Down Expand Up @@ -49,8 +51,9 @@ function _make_stdnormal_rec(rng::AbstractRNG, dim::Int, steps::Int, epsilon::Fl
tape = [(noise=randn(rng, dim), u=rand(rng)) for _ in 1:steps]

step_fwd =
(x, te) ->
MALA.mala_step_taped(_stdnormal_logp, _stdnormal_grad, x, epsilon, te.noise, te.u)
(x, te) -> MALA.mala_step_taped(
_stdnormal_logp, _stdnormal_grad, x, epsilon, te.noise, te.u
)
jvp =
(x, te, v) -> MALA.mala_step_surrogate_sigmoid_jvp(
_stdnormal_logp,
Expand Down Expand Up @@ -88,14 +91,7 @@ function _mala_step_case()
epsilon = 0.04

bench = @benchmarkable MALA.mala_step_with_logα!(
$x_next,
$workspace,
$_stdnormal_logp,
$_stdnormal_grad,
$x,
$epsilon,
$noise,
$u,
$x_next, $workspace, $_stdnormal_logp, $_stdnormal_grad, $x, $epsilon, $noise, $u
) evals = 1

return BenchmarkCase(
Expand Down Expand Up @@ -144,7 +140,7 @@ function _deer_solve_case()
jacobian=:diag,
damping=0.5,
rng=rng,
workspace=$workspace,
workspace=($workspace),
copy_result=false,
) setup = (rng = MersenneTwister(42)) evals = 1

Expand All @@ -170,16 +166,12 @@ function _parallel_mala_sample_case()
tol_rel=1e-5,
jacobian=:diag,
damping=0.5,
backend=ADTypes.AutoEnzyme(),
)
initial_params = zeros(dim)

bench = @benchmarkable sample(
rng,
$model,
$sampler,
64;
initial_params=$initial_params,
progress=false,
rng, $model, $sampler, 64; initial_params=($initial_params), progress=false
) setup = (rng = MersenneTwister(42)) evals = 1

return BenchmarkCase(
Expand All @@ -201,16 +193,12 @@ function _parallel_mala_sample_no_hvp_case()
tol_rel=1e-5,
jacobian=:diag,
damping=0.5,
backend=ADTypes.AutoEnzyme(),
)
initial_params = zeros(dim)

bench = @benchmarkable sample(
rng,
$model,
$sampler,
64;
initial_params=$initial_params,
progress=false,
rng, $model, $sampler, 64; initial_params=($initial_params), progress=false
) setup = (rng = MersenneTwister(42)) evals = 1

return BenchmarkCase(
Expand Down
25 changes: 24 additions & 1 deletion docs/src/10-getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ model = DensityModel(logp, grad_logp, 2; param_names=[:x1, :x2])

---

## ParallelMALASampler — the primary algorithm
## ParallelMALASampler

[`ParallelMALASampler`](@ref) reformulates a trajectory of `T` MALA steps as a fixed-point problem and solves it via Newton iterations, each of which costs $O(\log T)$ parallel work via an associative prefix scan. Wall-clock time per sample is therefore sublinear in chain length on multi-core CPUs and GPUs.

Expand Down Expand Up @@ -84,6 +84,29 @@ chain = sample(model, sampler, 500;
chain_type=MCMCChains.Chains)
```

#### AD-HVP fallback

DEER needs a Hessian-vector product (HVP) at every Newton step. If you
supply `hvp`/`hvp_batch` analytically, those run as plain kernels and
nothing else is required. If you only supply `grad_logdensity` /
`grad_logdensity_batch`, the sampler computes the HVP via **Mooncake**
reverse-mode applied to `x -> dot(grad_logdensity(x), v)`.

```julia
using ParallelMCMC, CUDA

const X = CUDA.CuMatrix(randn(Float32, N, D))

# Plain Julia / CUDA operations.
grad_logp(β) = -β .- (transpose(X) * (X * β)) ./ Float32(N)
grad_logp_batch(B) = -B .- (transpose(X) * (X * B)) ./ Float32(N)

model = DensityModel(logp, grad_logp, D;
logdensity_batch=logp_batch,
grad_logdensity_batch=grad_logp_batch)
sampler = ParallelMALASampler(0.05f0; T=64)
```

---

## Turing.jl integration
Expand Down
Loading
Loading