diff --git a/.github/workflows/Benchmarks.yml b/.github/workflows/Benchmarks.yml new file mode 100644 index 0000000..982dbb1 --- /dev/null +++ b/.github/workflows/Benchmarks.yml @@ -0,0 +1,109 @@ +name: Benchmarks + +on: + pull_request: + branches: + - main + paths: + - ".github/workflows/Benchmarks.yml" + - "Project.toml" + - "src/**" + - "ext/**" + - "benchmarks/**" + types: [opened, synchronize, reopened, ready_for_review] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +permissions: + contents: read + +jobs: + benchmark: + name: PR benchmark comparison + runs-on: ubuntu-latest + timeout-minutes: 60 + + env: + PMCMC_BENCH_SECONDS: "0.5" + PMCMC_BENCH_SAMPLES: "8" + PMCMC_BENCH_WARN_RATIO: "1.25" + PMCMC_BENCH_FAIL_RATIO: "1.75" + JULIA_NUM_PRECOMPILE_TASKS: "1" + + steps: + - name: Checkout base + uses: actions/checkout@v4 + with: + repository: ${{ github.event.pull_request.base.repo.full_name }} + ref: ${{ github.event.pull_request.base.sha }} + path: base + + - name: Checkout PR head + uses: actions/checkout@v4 + with: + repository: ${{ github.event.pull_request.head.repo.full_name }} + ref: ${{ github.event.pull_request.head.sha }} + path: head + + - uses: julia-actions/setup-julia@v2 + with: + version: "1" + arch: x64 + + - name: Use Julia cache + uses: julia-actions/cache@v3 + + - name: Instantiate base package in benchmark environment + working-directory: head/benchmarks/ParallelMCMCBenchmarks + run: | + julia --project=. -e 'using Pkg; Pkg.develop(PackageSpec(path=joinpath(ENV["GITHUB_WORKSPACE"], "base"))); Pkg.instantiate()' + + - name: Run base benchmarks + run: | + julia --project="$GITHUB_WORKSPACE/head/benchmarks/ParallelMCMCBenchmarks" \ + "$GITHUB_WORKSPACE/head/benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl" \ + --seconds "$PMCMC_BENCH_SECONDS" \ + --samples "$PMCMC_BENCH_SAMPLES" \ + --output "$GITHUB_WORKSPACE/base-benchmarks.toml" \ + --markdown "$GITHUB_WORKSPACE/base-benchmarks.md" + + - name: Instantiate PR package in benchmark environment + working-directory: head/benchmarks/ParallelMCMCBenchmarks + run: | + julia --project=. -e 'using Pkg; Pkg.develop(PackageSpec(path=joinpath(ENV["GITHUB_WORKSPACE"], "head"))); Pkg.instantiate()' + + - name: Run PR benchmarks + run: | + julia --project="$GITHUB_WORKSPACE/head/benchmarks/ParallelMCMCBenchmarks" \ + "$GITHUB_WORKSPACE/head/benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl" \ + --seconds "$PMCMC_BENCH_SECONDS" \ + --samples "$PMCMC_BENCH_SAMPLES" \ + --output "$GITHUB_WORKSPACE/head-benchmarks.toml" \ + --markdown "$GITHUB_WORKSPACE/head-benchmarks.md" + + - name: Compare benchmarks + run: | + set +e + julia "$GITHUB_WORKSPACE/head/benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl" \ + --base "$GITHUB_WORKSPACE/base-benchmarks.toml" \ + --head "$GITHUB_WORKSPACE/head-benchmarks.toml" \ + --warn-ratio "$PMCMC_BENCH_WARN_RATIO" \ + --fail-ratio "$PMCMC_BENCH_FAIL_RATIO" \ + --markdown "$GITHUB_WORKSPACE/benchmark-comparison.md" + status=$? + cat "$GITHUB_WORKSPACE/benchmark-comparison.md" >> "$GITHUB_STEP_SUMMARY" + exit "$status" + + - name: Upload benchmark results + uses: actions/upload-artifact@v4 + if: always() + with: + name: pr-benchmark-results + path: | + base-benchmarks.toml + base-benchmarks.md + head-benchmarks.toml + head-benchmarks.md + benchmark-comparison.md diff --git a/.gitignore b/.gitignore index 2b0c91d..94aa3f1 100644 --- a/.gitignore +++ b/.gitignore @@ -6,8 +6,19 @@ .benchmarkci Manifest.toml benchmark/*.json +benchmarks/ParallelMCMCBenchmarks/*-benchmarks.toml +benchmarks/ParallelMCMCBenchmarks/*-benchmarks.md +benchmarks/ParallelMCMCBenchmarks/benchmark-comparison.md coverage docs/build/ env node_modules LocalPreferences.toml + +# ignore any AI files a contributor may have +.codex +.claude +CLAUDE.md +AGENTS.md +CODEX.md +.gemini \ No newline at end of file diff --git a/Project.toml b/Project.toml index 00f0e24..5e9e141 100644 --- a/Project.toml +++ b/Project.toml @@ -10,12 +10,12 @@ CUDA = "5.11.0" DifferentiationInterface = "0.7.13" DynamicPPL = "0.40" Enzyme = "0.13.131" -LinearAlgebra = "1.12.0" +LinearAlgebra = "1" LogDensityProblems = "2" LogDensityProblemsAD = "1" MCMCChains = "7.7.0" -Random = "1.11.0" -Statistics = "1.11.1" +Random = "1" +Statistics = "1" julia = "1.10" [deps] diff --git a/README.md b/README.md index 3d0f11c..1ad6eed 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,78 @@ # ParallelMCMC +

+ ParallelMCMC logo +

+ [![Stable Documentation](https://img.shields.io/badge/docs-stable-blue.svg)](https://rsenne.github.io/ParallelMCMC.jl/stable) [![Development documentation](https://img.shields.io/badge/docs-dev-blue.svg)](https://rsenne.github.io/ParallelMCMC.jl/dev) [![Test workflow status](https://github.com/rsenne/ParallelMCMC.jl/actions/workflows/Test.yml/badge.svg?branch=main)](https://github.com/rsenne/ParallelMCMC.jl/actions/workflows/Test.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/rsenne/ParallelMCMC.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/rsenne/ParallelMCMC.jl) [![Docs workflow Status](https://github.com/rsenne/ParallelMCMC.jl/actions/workflows/Docs.yml/badge.svg?branch=main)](https://github.com/rsenne/ParallelMCMC.jl/actions/workflows/Docs.yml?query=branch%3Amain) -[![DOI](https://zenodo.org/badge/DOI/FIXME)](https://doi.org/FIXME) [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](CODE_OF_CONDUCT.md) [![All Contributors](https://img.shields.io/github/all-contributors/rsenne/ParallelMCMC.jl?labelColor=5e1ec7&color=c0ffee&style=flat-square)](#contributors) [![BestieTemplate](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/JuliaBesties/BestieTemplate.jl/main/docs/src/assets/badge.json)](https://github.com/JuliaBesties/BestieTemplate.jl) +

+ DEER trajectory estimates improving on a Julia-logo-shaped posterior +

+ +

+ DEER iterates on a synthetic Julia-logo-shaped posterior: orange trajectory estimates move toward the taped MALA path over repeated trajectory solves. +

+ +## What this package does + +**ParallelMCMC.jl** implements *parallel-across-the-sequence* MCMC in Julia: instead of generating samples one at a time, an entire trajectory of $T$ correlated steps is solved *simultaneously*. This makes wall-clock time per sample sublinear in chain length on multi-core CPUs and GPUs, where conventional sequential MCMC scales linearly. + +The flagship algorithm is **DEER** (Lim et al. 2024; Gonzalez et al. 2024), which reformulates a chain of $T$ MALA steps as a fixed-point problem and solves it with Newton iterations. Each iteration linearizes the per-step transition around the current trajectory guess and resolves the resulting linear recursion in $O(\log T)$ parallel work via an associative prefix scan. With shared input randomness, DEER converges to the exact sequential MALA trace up to a numerical tolerance — typically in tens of iterations even for chains of tens of thousands of samples. + +The approach and its scaling tricks (stochastic Hutchinson Jacobian estimators, damping, sliding windows) are described in: + +> Zoltowski, D. M., Wu, S., Gonzalez, X., Kozachkov, L., & Linderman, S. W. (2025). +> **Parallelizing MCMC Across the Sequence Length.** *NeurIPS 2025.* +> [arXiv:2508.18413](https://arxiv.org/abs/2508.18413) + +### Samplers + +| Sampler | Role | +|---|---| +| [`ParallelMALASampler`](src/interface.jl) | **Primary** — parallel-across-sequence MALA via DEER; $O(\log T)$ per solve | +| [`MALASampler`](src/interface.jl) | Baseline — sequential MALA with a fixed step size | +| [`AdaptiveMALASampler`](src/interface.jl) | Baseline — sequential MALA with dual-averaging step-size adaptation | + +All samplers implement the [AbstractMCMC](https://github.com/TuringLang/AbstractMCMC.jl) interface and return [`MCMCChains.Chains`](https://github.com/TuringLang/MCMCChains.jl) objects, so they slot into existing Turing.jl / AbstractMCMC workflows. + + +### Quick start + +Install the package from GitHub with: + +```julia-repl +pkg> add https://github.com/rsenne/ParallelMCMC.jl +``` + +```julia +using ParallelMCMC, MCMCChains + +logp(x) = -0.5 * sum(abs2, x) # 2-D standard normal +grad_logp(x) = -x + +model = DensityModel(logp, grad_logp, 2; param_names=[:x1, :x2]) +sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag) + +chain = sample(model, sampler, 500; chain_type=MCMCChains.Chains) +``` + +See the [Getting Started guide](docs/src/10-getting-started.md) for worked examples, GPU usage, Turing.jl integration, and step-size tuning. + ## How to Cite If you use ParallelMCMC.jl in your work, please cite using the reference given in [CITATION.cff](https://github.com/rsenne/ParallelMCMC.jl/blob/main/CITATION.cff). ## Contributing -If you want to make contributions of any kind, please first that a look into our [contributing guide directly on GitHub](docs/src/90-contributing.md) or the [contributing page on the website](https://rsenne.github.io/ParallelMCMC.jl/dev/90-contributing/) +If you want to contribute, start with the [contributing guide on GitHub](docs/src/90-contributing.md) or the [documentation site](https://rsenne.github.io/ParallelMCMC.jl/dev/90-contributing/). --- diff --git a/benchmarks/ParallelMCMCBenchmarks/Project.toml b/benchmarks/ParallelMCMCBenchmarks/Project.toml index 75f662d..0111b50 100644 --- a/benchmarks/ParallelMCMCBenchmarks/Project.toml +++ b/benchmarks/ParallelMCMCBenchmarks/Project.toml @@ -12,6 +12,7 @@ ParallelMCMC = "1a970f40-4406-51c9-a967-cb3143c111e8" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [extras] CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" diff --git a/benchmarks/ParallelMCMCBenchmarks/README.md b/benchmarks/ParallelMCMCBenchmarks/README.md new file mode 100644 index 0000000..5f2be8c --- /dev/null +++ b/benchmarks/ParallelMCMCBenchmarks/README.md @@ -0,0 +1,41 @@ +# ParallelMCMC Benchmarks + +This package holds reproducible benchmark workloads for `ParallelMCMC.jl`. + +## PR Regression Suite + +The pull request workflow runs a small CPU-only suite against both the PR head +and the PR base commit, then compares median runtimes. It covers: + +- one allocation-light MALA transition, +- the diagonal affine scan used by quasi-DEER, +- a full DEER block solve on a taped Gaussian MALA trajectory, +- the public `ParallelMALASampler` path on a small Bayesian logistic regression. + +Run the same suite locally from the repository root with: + +```bash +julia --project=benchmarks/ParallelMCMCBenchmarks \ + benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl \ + --output pr-benchmarks.toml \ + --markdown pr-benchmarks.md +``` + +To compare two result files: + +```bash +julia benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl \ + --base base-benchmarks.toml \ + --head head-benchmarks.toml \ + --markdown benchmark-comparison.md +``` + +CI marks a benchmark as `watch` above a 1.25x median-time ratio and fails above +1.75x. The thresholds can be adjusted with `PMCMC_BENCH_WARN_RATIO` and +`PMCMC_BENCH_FAIL_RATIO`. + +## Manual GPU Sweeps + +The existing scripts in `scripts/` are still intended for deeper manual +throughput and GPU investigations. They are not part of the PR gate because +standard GitHub-hosted runners do not provide CUDA hardware. diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl index eb4f5f1..4e578a3 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl @@ -29,7 +29,7 @@ const BayesLogReg = ParallelMCMCBenchmarks.BayesLogReg function _parse_t_vals() raw = get(ENV, "PMCMC_T_VALS", "") - isempty(strip(raw)) && return [128, 256, 512, 1024, 2048] + isempty(strip(raw)) && return [512, 1024, 2048, 4096, 8192] return parse.(Int, strip.(split(raw, ","))) end @@ -109,8 +109,9 @@ if _cuda_ok y_gpu = CUDA.CuVector(y_f32) logp_gpu, gradlogp_gpu, hvp_gpu = BayesLogReg.make_problem_with_hvp(X_gpu, y_gpu) - logp_gpu_batch, gradlogp_gpu_batch, hvp_gpu_batch = - BayesLogReg.make_problem_batched_with_hvp(X_gpu, y_gpu) + logp_gpu_batch, gradlogp_gpu_batch, hvp_gpu_batch = BayesLogReg.make_problem_batched_with_hvp( + X_gpu, y_gpu + ) model_gpu = DensityModel( logp_gpu, diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl index 371acef..5fd7f16 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl @@ -19,16 +19,16 @@ using AbstractMCMC: sample using ParallelMCMC using ParallelMCMCBenchmarks const BayesLinReg = ParallelMCMCBenchmarks.BayesLinReg -const MALARunner = ParallelMCMCBenchmarks.MALARunner +const MALARunner = ParallelMCMCBenchmarks.MALARunner # Problem setup rng = MersenneTwister(20251231) n, p = 200, 16 -X = randn(rng, n, p) +X = randn(rng, n, p) β_true = randn(rng, p) -σ = 1.0 -y = X * β_true .+ σ .* randn(rng, n) +σ = 1.0 +y = X * β_true .+ σ .* randn(rng, n) logpost, gradlogpost, μ_post, _ = BayesLinReg.make_problem(X, y; σ=σ, τ=10.0) model = DensityModel(logpost, gradlogpost, p) @@ -46,12 +46,7 @@ x_warm, ϵ_tuned = MALARunner.tune_stepsize_mala( mala_sampler = AdaptiveMALASampler(ϵ_tuned; n_warmup=500) deer_sampler = ParallelMALASampler( - ϵ_tuned; - T=64, - maxiter=200, - tol_abs=1e-6, - tol_rel=1e-5, - damping=0.5, + ϵ_tuned; T=64, maxiter=200, tol_abs=1e-6, tol_rel=1e-5, damping=0.5 ) # Benchmark helper @@ -79,7 +74,9 @@ println("AdaptiveMALASampler (n_warmup=500, Float64)") println("Model: Bayesian linear regression n=$n p=$p") println("=" ^ 60, "\n") for (n_samples, reps, label) in configs - results[("MALA", label)] = run_bench(model, mala_sampler, n_samples; reps, label, sampler_name="MALA") + results[("MALA", label)] = run_bench( + model, mala_sampler, n_samples; reps, label, sampler_name="MALA" + ) end # ParallelMALA (DEER) @@ -89,7 +86,9 @@ println("ParallelMALASampler (T=64, AutoEnzyme, Float64)") println("Model: Bayesian linear regression n=$n p=$p") println("=" ^ 60, "\n") for (n_samples, reps, label) in configs - results[("DEER", label)] = run_bench(model, deer_sampler, n_samples; reps, label, sampler_name="DEER") + results[("DEER", label)] = run_bench( + model, deer_sampler, n_samples; reps, label, sampler_name="DEER" + ) end # Summary table diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl new file mode 100644 index 0000000..d811efa --- /dev/null +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl @@ -0,0 +1,188 @@ +#!/usr/bin/env julia + +using Printf +using TOML + +function _option(args, name, default=nothing) + index = findfirst(==(name), args) + index === nothing && return default + index < length(args) || error("missing value for $name") + return args[index + 1] +end + +function _has_flag(args, name) + return any(==(name), args) +end + +function _usage() + return """ + Usage: + julia compare_pr_benchmarks.jl --base base.toml --head head.toml [options] + + Options: + --base PATH Base-branch benchmark TOML. + --head PATH PR-head benchmark TOML. + --markdown PATH Write a Markdown comparison table. + --fail-ratio N Fail when head/base median time exceeds N. Default: 1.75 + --warn-ratio N Mark warning when head/base median time exceeds N. Default: 1.25 + --help Show this help message. + """ +end + +function _format_time(ns::Real) + ns < 1e3 && return @sprintf("%.0f ns", ns) + ns < 1e6 && return @sprintf("%.2f us", ns / 1e3) + ns < 1e9 && return @sprintf("%.2f ms", ns / 1e6) + return @sprintf("%.2f s", ns / 1e9) +end + +function _format_ratio(ratio::Real) + return @sprintf("%.2fx", ratio) +end + +function _format_delta(bytes::Real) + sign = bytes >= 0 ? "+" : "-" + return sign * string(abs(round(Int, bytes))) * " B" +end + +function _status(ratio::Real, warn_ratio::Real, fail_ratio::Real) + ratio >= fail_ratio && return "regression" + ratio >= warn_ratio && return "watch" + ratio <= inv(warn_ratio) && return "faster" + return "ok" +end + +function compare_results(base_path, head_path; warn_ratio=1.25, fail_ratio=1.75) + base = TOML.parsefile(base_path) + head = TOML.parsefile(head_path) + base_benchmarks = base["benchmarks"] + head_benchmarks = head["benchmarks"] + + names = sort(collect(intersect(keys(base_benchmarks), keys(head_benchmarks)))) + isempty(names) && error("no common benchmark names found") + + rows = Vector{Dict{String,Any}}() + failed = false + + for name in names + base_result = base_benchmarks[name] + head_result = head_benchmarks[name] + base_ns = Float64(base_result["median_ns"]) + head_ns = Float64(head_result["median_ns"]) + ratio = head_ns / base_ns + status = _status(ratio, warn_ratio, fail_ratio) + failed |= status == "regression" + + push!( + rows, + Dict{String,Any}( + "name" => name, + "base_ns" => base_ns, + "head_ns" => head_ns, + "ratio" => ratio, + "base_allocs" => Int(base_result["median_allocs"]), + "head_allocs" => Int(head_result["median_allocs"]), + "memory_delta" => + Float64(head_result["median_memory_bytes"]) - + Float64(base_result["median_memory_bytes"]), + "status" => status, + ), + ) + end + + return rows, failed +end + +function write_markdown(path, rows; warn_ratio, fail_ratio) + mkpath(dirname(path)) + open(path, "w") do io + println(io, "# ParallelMCMC Benchmark Comparison") + println(io) + println( + io, + "Median runtime is compared against the pull request base commit. ", + "A benchmark fails the job when `head / base > ", + fail_ratio, + "` and is marked `watch` above `", + warn_ratio, + "`.", + ) + println(io) + println(io, "| Benchmark | Base median | PR median | Ratio | Allocs | Memory delta | Status |") + println(io, "|---|---:|---:|---:|---:|---:|---|") + for row in rows + println( + io, + "| `", + row["name"], + "` | ", + _format_time(row["base_ns"]), + " | ", + _format_time(row["head_ns"]), + " | ", + _format_ratio(row["ratio"]), + " | ", + row["base_allocs"], + " -> ", + row["head_allocs"], + " | ", + _format_delta(row["memory_delta"]), + " | ", + row["status"], + " |", + ) + end + println(io) + end + return path +end + +function print_summary(rows) + println("ParallelMCMC benchmark comparison") + for row in rows + println( + " ", + row["name"], + ": ", + _format_time(row["base_ns"]), + " -> ", + _format_time(row["head_ns"]), + " (", + _format_ratio(row["ratio"]), + ", ", + row["status"], + ")", + ) + end +end + +function main(args=ARGS) + if _has_flag(args, "--help") + print(_usage()) + return 0 + end + + base = _option(args, "--base") + head = _option(args, "--head") + base === nothing && error("--base is required") + head === nothing && error("--head is required") + + warn_ratio = parse(Float64, _option(args, "--warn-ratio", get(ENV, "PMCMC_BENCH_WARN_RATIO", "1.25"))) + fail_ratio = parse(Float64, _option(args, "--fail-ratio", get(ENV, "PMCMC_BENCH_FAIL_RATIO", "1.75"))) + markdown = _option(args, "--markdown", "") + + warn_ratio > 1 || error("--warn-ratio must be greater than 1") + fail_ratio > warn_ratio || error("--fail-ratio must be greater than --warn-ratio") + + rows, failed = compare_results(base, head; warn_ratio=warn_ratio, fail_ratio=fail_ratio) + print_summary(rows) + + if !isempty(markdown) + write_markdown(markdown, rows; warn_ratio=warn_ratio, fail_ratio=fail_ratio) + println("Wrote Markdown comparison to ", markdown) + end + + return failed ? 1 : 0 +end + +exit(main()) diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl index 9c66151..7fbf2b6 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl @@ -41,8 +41,9 @@ y_gpu = CUDA.CuVector(y_f32) logp_cpu, gradlogp_cpu = BayesLogReg.make_problem(X_cpu, y_cpu) logp_gpu, gradlogp_gpu, hvp_gpu = BayesLogReg.make_problem_with_hvp(X_gpu, y_gpu) -logp_gpu_batch, gradlogp_gpu_batch, hvp_gpu_batch = - BayesLogReg.make_problem_batched_with_hvp(X_gpu, y_gpu) +logp_gpu_batch, gradlogp_gpu_batch, hvp_gpu_batch = BayesLogReg.make_problem_batched_with_hvp( + X_gpu, y_gpu +) model_gpu = DensityModel( logp_gpu, @@ -89,12 +90,7 @@ function build_raw_deer_problem( end rec = ParallelMCMC._build_mala_deer_rec( - model, - epsilon, - tape, - x0; - cholM=cholM, - backend=backend, + model, epsilon, tape, x0; cholM=cholM, backend=backend ) return rec @@ -126,6 +122,7 @@ function solve_raw_deer_block_prebuilt( probes=probes, rng=rng, workspace=ws, + copy_result=false, ) end @@ -143,13 +140,7 @@ function solve_seq_mala_block( ξs = [randn(rng, FP, D) for _ in 1:T] us = rand(rng, FP, T) return ParallelMCMC.MALA.run_mala_sequential_taped( - logp, - gradlogp, - x0, - epsilon, - ξs, - us; - cholM=cholM, + logp, gradlogp, x0, epsilon, ξs, us; cholM=cholM ) end @@ -175,13 +166,13 @@ function bench_raw_deer_build(model, x0; T, reps) MersenneTwister(42), $model, $x0; - epsilon=$epsilon, - T=$T, - maxiter=$maxiter, - tol_abs=$tol_abs, - tol_rel=$tol_rel, - damping=$damping, - probes=$probes, + epsilon=($epsilon), + T=($T), + maxiter=($maxiter), + tol_abs=($tol_abs), + tol_rel=($tol_rel), + damping=($damping), + probes=($probes), ) CUDA.synchronize() rec @@ -231,11 +222,11 @@ function bench_raw_deer_solve_only(model, x0; T, reps) $rec, $x0, $ws; - tol_abs=$tol_abs, - tol_rel=$tol_rel, - maxiter=$maxiter, - damping=$damping, - probes=$probes, + tol_abs=($tol_abs), + tol_rel=($tol_rel), + maxiter=($maxiter), + damping=($damping), + probes=($probes), ) CUDA.synchronize() S @@ -256,22 +247,12 @@ function bench_seq_mala_cpu(logp, gradlogp, x0; T, reps) println(" [CPU] sequential taped MALA, T=$T") xs_warm = solve_seq_mala_block( - MersenneTwister(42), - logp, - gradlogp, - x0; - epsilon=epsilon, - T=T, + MersenneTwister(42), logp, gradlogp, x0; epsilon=epsilon, T=T ) b = @benchmark begin xs = solve_seq_mala_block( - MersenneTwister(42), - $logp, - $gradlogp, - $x0; - epsilon=$epsilon, - T=$T, + MersenneTwister(42), $logp, $gradlogp, $x0; epsilon=($epsilon), T=($T) ) xs end samples=reps evals=1 @@ -337,7 +318,13 @@ seq_results = Dict{Int,BenchmarkTools.Trial}() logp_results = Dict{Int,BenchmarkTools.Trial}() for T in T_vals - reps = T <= 16 ? 8 : T <= 32 ? 6 : 4 + reps = if T <= 16 + 8 + elseif T <= 32 + 6 + else + 4 + end build_results[T] = bench_raw_deer_build(model_gpu, x0_gpu; T=T, reps=reps) solve_results[T], rec, ws = bench_raw_deer_solve_only(model_gpu, x0_gpu; T=T, reps=reps) seq_results[T] = bench_seq_mala_cpu(logp_cpu, gradlogp_cpu, x0_cpu; T=T, reps=reps) diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl new file mode 100644 index 0000000..acc2e84 --- /dev/null +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl @@ -0,0 +1,61 @@ +#!/usr/bin/env julia + +using ParallelMCMCBenchmarks + +const PRSuite = ParallelMCMCBenchmarks.ParallelMCMCPrBenchmarks + +function _option(args, name, default=nothing) + index = findfirst(==(name), args) + index === nothing && return default + index < length(args) || error("missing value for $name") + return args[index + 1] +end + +function _has_flag(args, name) + return any(==(name), args) +end + +function _usage() + return """ + Usage: + julia --project benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl [options] + + Options: + --output PATH Write machine-readable TOML benchmark results. + --markdown PATH Write a Markdown summary table. + --seconds N Target seconds per benchmark case. Default: 0.5 + --samples N Minimum samples per benchmark case. Default: 8 + --help Show this help message. + """ +end + +function main(args=ARGS) + if _has_flag(args, "--help") + print(_usage()) + return 0 + end + + seconds = parse(Float64, _option(args, "--seconds", get(ENV, "PMCMC_BENCH_SECONDS", "0.5"))) + samples = parse(Int, _option(args, "--samples", get(ENV, "PMCMC_BENCH_SAMPLES", "8"))) + output = _option(args, "--output", "") + markdown = _option(args, "--markdown", "") + + seconds > 0 || error("--seconds must be positive") + samples > 0 || error("--samples must be positive") + + results = PRSuite.run_pr_benchmarks(; seconds=seconds, samples=samples) + + if !isempty(output) + PRSuite.write_results(output, results) + println("Wrote TOML results to ", output) + end + + if !isempty(markdown) + PRSuite.write_markdown(markdown, results) + println("Wrote Markdown summary to ", markdown) + end + + return 0 +end + +exit(main()) diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl index 0c55910..98fc4fa 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl @@ -63,12 +63,7 @@ function build_raw_deer_problem( end rec = ParallelMCMC._build_mala_deer_rec( - model, - epsilon, - tape, - x0; - cholM=cholM, - backend=backend, + model, epsilon, tape, x0; cholM=cholM, backend=backend ) return rec end @@ -84,8 +79,9 @@ function build_problem(; seed=20251231) y_gpu = CUDA.CuVector(y_f32) logp_gpu, gradlogp_gpu, hvp_gpu = BayesLogReg.make_problem_with_hvp(X_gpu, y_gpu) - logp_gpu_batch, gradlogp_gpu_batch, hvp_gpu_batch = - BayesLogReg.make_problem_batched_with_hvp(X_gpu, y_gpu) + logp_gpu_batch, gradlogp_gpu_batch, hvp_gpu_batch = BayesLogReg.make_problem_batched_with_hvp( + X_gpu, y_gpu + ) model_gpu = DensityModel( logp_gpu, @@ -100,11 +96,7 @@ function build_problem(; seed=20251231) x0_gpu = CUDA.zeros(Float32, D) rec = build_raw_deer_problem( - MersenneTwister(42), - model_gpu, - x0_gpu; - epsilon=epsilon, - T=T, + MersenneTwister(42), model_gpu, x0_gpu; epsilon=epsilon, T=T ) ws = DEER.DEERWorkspace(x0_gpu, T) @@ -127,6 +119,7 @@ function solve_once(; seed=42) probes=probes, rng=rng, workspace=PROF_STATE.ws, + copy_result=false, ) CUDA.synchronize() return S diff --git a/benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl b/benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl index e4f18b3..46336fd 100644 --- a/benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl +++ b/benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl @@ -47,8 +47,9 @@ function build_problem() y_gpu = CUDA.CuVector(y_f32) logp_gpu, gradlogp_gpu, hvp_gpu = BayesLogReg.make_problem_with_hvp(X_gpu, y_gpu) - logp_gpu_batch, gradlogp_gpu_batch, hvp_gpu_batch = - BayesLogReg.make_problem_batched_with_hvp(X_gpu, y_gpu) + logp_gpu_batch, gradlogp_gpu_batch, hvp_gpu_batch = BayesLogReg.make_problem_batched_with_hvp( + X_gpu, y_gpu + ) model_gpu = DensityModel( logp_gpu, @@ -66,12 +67,7 @@ function build_problem() end function build_rec(model_gpu, x0_gpu, tape) - return ParallelMCMC._build_mala_deer_rec( - model_gpu, - epsilon, - tape, - x0_gpu; - ) + return ParallelMCMC._build_mala_deer_rec(model_gpu, epsilon, tape, x0_gpu;) end function solve_prebuilt(rec, x0_gpu, ws; seed=42, return_info=false) @@ -87,6 +83,7 @@ function solve_prebuilt(rec, x0_gpu, ws; seed=42, return_info=false) rng=MersenneTwister(seed), workspace=ws, return_info=return_info, + copy_result=false, ) CUDA.synchronize() return out @@ -110,7 +107,8 @@ end function print_trial(name, trial) t_ms = median(trial).time / 1e6 - @printf "%-32s %12.3f %12d %12.3f\n" name t_ms median(trial).allocs median(trial).memory / 1024^2 + @printf "%-32s %12.3f %12d %12.3f\n" name t_ms median(trial).allocs median(trial).memory / + 1024^2 return t_ms end @@ -140,7 +138,9 @@ println("GPU ParallelMALA component profile") println("Model: Bayesian logistic regression D=$D N_data=$N_data") println("T=$T N_samples=$N_samples blocks=$(cld(N_samples, T))") println("epsilon=$epsilon maxiter=$maxiter tol_abs=$tol_abs tol_rel=$tol_rel") -println("Warmup solve converged=$(info.converged), iters=$(info.iters), metric=$(info.metric)") +println( + "Warmup solve converged=$(info.converged), iters=$(info.iters), metric=$(info.metric)" +) println("=" ^ 96) println() @@ -191,7 +191,7 @@ b_sample = @benchmark begin $deer_gpu, $N_samples; progress=false, - initial_params=$x0_gpu, + initial_params=($x0_gpu), ) CUDA.synchronize() xs diff --git a/benchmarks/ParallelMCMCBenchmarks/src/ParallelMCMCBenchmarks.jl b/benchmarks/ParallelMCMCBenchmarks/src/ParallelMCMCBenchmarks.jl index a77b179..b43d01e 100644 --- a/benchmarks/ParallelMCMCBenchmarks/src/ParallelMCMCBenchmarks.jl +++ b/benchmarks/ParallelMCMCBenchmarks/src/ParallelMCMCBenchmarks.jl @@ -1,7 +1,10 @@ module ParallelMCMCBenchmarks +export ParallelMCMCPrBenchmarks + include("models/bayes_linreg.jl") include("models/bayes_logreg.jl") include("runners/mala_runner.jl") +include("pr_suite.jl") end # module diff --git a/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl b/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl index 3b8bc56..0af562a 100644 --- a/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl +++ b/benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl @@ -3,15 +3,16 @@ module BayesLogReg using Random using LinearAlgebra +function _sum_columns!(out::AbstractVector, row::AbstractMatrix, A::AbstractMatrix) + sum!(row, A) + copyto!(out, row) + return out +end + """ make_data(rng, N, D; β_scale=1.0) -> (X, y, β_true) -Generate synthetic Bayesian logistic regression data: - β_true ~ N(0, β_scale² I_D) - X_i ~ N(0, I_D) - y_i ~ Bernoulli(sigmoid(X_i β_true)) - -Returns Float64 arrays. +Generate synthetic Bayesian logistic regression data. """ function make_data(rng::AbstractRNG, N::Int, D::Int; β_scale::Real=1.0) β_true = randn(rng, D) .* float(β_scale) @@ -24,24 +25,32 @@ end """ make_problem(X, y) -> (logp, gradlogp) -Return the log-posterior and its gradient for: - β ~ N(0, I_D) - y_i | β ~ Bernoulli(sigmoid(Xᵢβ)) - -Both closures accept an AbstractVector and work with any eltype, so they -are compatible with ForwardDiff dual numbers and GPU arrays alike. +Optimized to reuse memory buffers for GPU performance. """ function make_problem(X::AbstractMatrix, y::AbstractVector) + # Pre-allocate workspaces in the closure scope + N = size(X, 1) + logits = similar(X, N) + p = similar(X, N) + resid = similar(X, N) + function logp(β::AbstractVector) - logits = X * β - ll = sum(@. y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits)))) - return ll - oftype(ll, 0.5) * sum(abs2, β) + mul!(logits, X, β) + # ll = sum(@. y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits)))) + # We perform this in-place to avoid allocations + # Note: log1p(exp(x)) is the softplus function + @. p = y * (-log1p(exp(-logits))) + (1 - y) * (-log1p(exp(logits))) + return sum(p) - 0.5 * sum(abs2, β) end function gradlogp(β::AbstractVector) - logits = X * β - p = @. 1 / (1 + exp(-logits)) - return X' * (y .- p) .- β + mul!(logits, X, β) + @. p = 1 / (1 + exp(-logits)) + @. resid = y - p + # grad = X' * (y - p) - β + # We reuse the output allocation here by using mul! + grad = (X' * resid) .- β + return grad end return logp, gradlogp @@ -49,23 +58,24 @@ end """ make_problem_with_hvp(X, y) -> (logp, gradlogp, hvp) - -Return the log-posterior, its gradient, and an analytic Hessian-vector product -for Bayesian logistic regression. """ function make_problem_with_hvp(X::AbstractMatrix, y::AbstractVector) logp, gradlogp = make_problem(X, y) + # Workspace buffers + N = size(X, 1) + logits = similar(X, N) + p = similar(X, N) + w = similar(X, N) + Xv = similar(X, N) + function hvp(β::AbstractVector, v::AbstractVector) - logits = X * β - T = eltype(logits) - oneT = one(T) - p = oneT ./ (oneT .+ exp.(-logits)) - w = p .* (oneT .- p) - Xv = X * v - tmp = w .* Xv - Xt_tmp = X' * tmp - return Xt_tmp .* (-oneT) .- v + mul!(logits, X, β) + @. p = 1 / (1 + exp(-logits)) + @. w = p * (1 - p) + mul!(Xv, X, v) + @. Xv = w * Xv + return -(X' * Xv) .- v end return logp, gradlogp, hvp @@ -73,33 +83,70 @@ end """ make_problem_batched(X, y) -> (logp_batch, gradlogp_batch) - -Batched counterparts of `make_problem` that accept D×M matrices and return M-vectors -or D×M matrices. """ function make_problem_batched(X::AbstractMatrix, y::AbstractVector) + N, D = size(X) + y_col = reshape(y, N, 1) + + logits = Ref(similar(X, N, 1)) + res = Ref(similar(X, N, 1)) + sq = Ref(similar(X, D, 1)) + grad = Ref(similar(X, D, 1)) + out = Ref(similar(X, eltype(X), 1)) + quad = Ref(similar(X, eltype(X), 1)) + out_row = Ref(similar(X, eltype(X), 1, 1)) + quad_row = Ref(similar(X, eltype(X), 1, 1)) + + function ensure_workspace!(M::Int) + if size(logits[], 2) != M + logits[] = similar(X, N, M) + res[] = similar(X, N, M) + sq[] = similar(X, D, M) + grad[] = similar(X, D, M) + out[] = similar(X, eltype(X), M) + quad[] = similar(X, eltype(X), M) + out_row[] = similar(X, eltype(X), 1, M) + quad_row[] = similar(X, eltype(X), 1, M) + end + return nothing + end + function logp_batch(B::AbstractMatrix) - logits = X * B - T = eltype(logits) - oneT = one(T) - halfT = T(0.5) - - neg_logits = logits .* (-oneT) - ll_yes = y .* (-log1p.(exp.(neg_logits))) - ll_no = (oneT .- y) .* (-log1p.(exp.(logits))) - ll_mat = ll_yes .+ ll_no - ll = vec(sum(ll_mat; dims=1)) - quad = vec(sum(abs2, B; dims=1)) - return ll .- halfT .* quad + M = size(B, 2) + size(B, 1) == D || throw(DimensionMismatch("B must have size (D, M)")) + ensure_workspace!(M) + + logits_buf = logits[] + res_buf = res[] + sq_buf = sq[] + out_buf = out[] + quad_buf = quad[] + + mul!(logits_buf, X, B) + @. res_buf = + y_col * (-log1p(exp(-logits_buf))) + (1 - y_col) * (-log1p(exp(logits_buf))) + @. sq_buf = abs2(B) + _sum_columns!(out_buf, out_row[], res_buf) + _sum_columns!(quad_buf, quad_row[], sq_buf) + @. out_buf = out_buf - 0.5 * quad_buf + return out_buf end function gradlogp_batch(B::AbstractMatrix) - logits = X * B - T = eltype(logits) - oneT = one(T) - p = oneT ./ (oneT .+ exp.(logits .* (-oneT))) - resid = y .- p - return X' * resid .- B + M = size(B, 2) + size(B, 1) == D || throw(DimensionMismatch("B must have size (D, M)")) + ensure_workspace!(M) + + logits_buf = logits[] + res_buf = res[] + grad_buf = grad[] + + mul!(logits_buf, X, B) + @. logits_buf = 1 / (1 + exp(-logits_buf)) + @. res_buf = y_col - logits_buf + mul!(grad_buf, adjoint(X), res_buf) + @. grad_buf = grad_buf - B + return grad_buf end return logp_batch, gradlogp_batch @@ -107,23 +154,48 @@ end """ make_problem_batched_with_hvp(X, y) -> (logp_batch, gradlogp_batch, hvp_batch) - -Batched log-posterior, gradient, and analytic Hessian-vector product for -Bayesian logistic regression. Inputs `B` and `V` are both D×M matrices. """ function make_problem_batched_with_hvp(X::AbstractMatrix, y::AbstractVector) logp_batch, gradlogp_batch = make_problem_batched(X, y) + N, D = size(X) + + logits = Ref(similar(X, N, 1)) + w = Ref(similar(X, N, 1)) + Xv = Ref(similar(X, N, 1)) + out = Ref(similar(X, D, 1)) + + function ensure_hvp_workspace!(M::Int) + if size(logits[], 2) != M + logits[] = similar(X, N, M) + w[] = similar(X, N, M) + Xv[] = similar(X, N, M) + out[] = similar(X, D, M) + end + return nothing + end function hvp_batch(B::AbstractMatrix, V::AbstractMatrix) size(B) == size(V) || throw(DimensionMismatch("B and V must have the same size")) - logits = X * B - T = eltype(logits) - oneT = one(T) - p = oneT ./ (oneT .+ exp.(logits .* (-oneT))) - w = p .* (oneT .- p) - Xv = X * V - tmp = w .* Xv - return -(X' * tmp) .- V + M = size(B, 2) + size(B, 1) == D || throw(DimensionMismatch("B and V must have size (D, M)")) + ensure_hvp_workspace!(M) + + logits_buf = logits[] + w_buf = w[] + Xv_buf = Xv[] + out_buf = out[] + + mul!(logits_buf, X, B) + + @. logits_buf = 1 / (1 + exp(-logits_buf)) + @. w_buf = logits_buf * (1 - logits_buf) + + mul!(Xv_buf, X, V) + @. Xv_buf = w_buf * Xv_buf + + mul!(out_buf, adjoint(X), Xv_buf) + @. out_buf = -out_buf - V + return out_buf end return logp_batch, gradlogp_batch, hvp_batch diff --git a/benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl b/benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl new file mode 100644 index 0000000..c281d42 --- /dev/null +++ b/benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl @@ -0,0 +1,332 @@ +module ParallelMCMCPrBenchmarks + +using AbstractMCMC: sample +using BenchmarkTools +using LinearAlgebra +using Printf +using Random +using Statistics +using TOML + +using ParallelMCMC + +const MALA = ParallelMCMC.MALA +const DEER = ParallelMCMC.DEER +const DEERScan = ParallelMCMC.DEERScan +if isdefined(parentmodule(@__MODULE__), :BayesLogReg) + const BayesLogReg = getfield(parentmodule(@__MODULE__), :BayesLogReg) +else + using ParallelMCMCBenchmarks: BayesLogReg +end + +struct BenchmarkCase + name::String + description::String + benchmark::Any +end + +function _stdnormal_logp(x) + return -0.5 * dot(x, x) +end + +function _stdnormal_grad(x) + return -x +end + +function _stdnormal_hvp(x, v) + return -v +end + +function _quartic_logp(x) + return -0.5 * dot(x, x) - 0.05 * sum(abs2.(x) .^ 2) +end + +function _quartic_grad(x) + return @. -x - 0.2 * x^3 +end + +function _make_stdnormal_rec(rng::AbstractRNG, dim::Int, steps::Int, epsilon::Float64) + tape = [(noise=randn(rng, dim), u=rand(rng)) for _ in 1:steps] + + step_fwd = + (x, te) -> + MALA.mala_step_taped(_stdnormal_logp, _stdnormal_grad, x, epsilon, te.noise, te.u) + jvp = + (x, te, v) -> MALA.mala_step_surrogate_sigmoid_jvp( + _stdnormal_logp, + _stdnormal_grad, + x, + epsilon, + te.noise, + te.u, + v, + _stdnormal_hvp, + ) + fwd_and_jvp = + (x, te, v) -> MALA.mala_step_taped_and_jvp( + _stdnormal_logp, + _stdnormal_grad, + x, + epsilon, + te.noise, + te.u, + v, + _stdnormal_hvp, + ) + + return DEER.TapedRecursion(step_fwd, jvp, tape; fwd_and_jvp=fwd_and_jvp) +end + +function _mala_step_case() + rng = MersenneTwister(9001) + dim = 16 + x = randn(rng, dim) + x_next = similar(x) + noise = randn(rng, dim) + u = rand(rng) + workspace = MALA.MALAWorkspace(x) + epsilon = 0.04 + + bench = @benchmarkable MALA.mala_step_with_logα!( + $x_next, + $workspace, + $_stdnormal_logp, + $_stdnormal_grad, + $x, + $epsilon, + $noise, + $u, + ) evals = 1 + + return BenchmarkCase( + "mala.step_stdnormal_16d", + "One allocation-light MALA step with log-acceptance on a 16-D Gaussian target.", + bench, + ) +end + +function _affine_scan_case() + rng = MersenneTwister(9002) + dim = 32 + steps = 256 + A = 0.92 .+ 0.02 .* randn(rng, dim, steps) + B = 0.10 .* randn(rng, dim, steps) + s0 = randn(rng, dim) + output = similar(A) + workspace = DEERScan.AffineScanWorkspace(A) + + bench = @benchmarkable DEERScan.solve_affine_scan_diag!( + $output, $A, $B, $s0, $workspace + ) evals = 1 + + return BenchmarkCase( + "deer.affine_scan_32x256", + "Diagonal affine prefix scan used by quasi-DEER, D=32 and T=256.", + bench, + ) +end + +function _deer_solve_case() + rng = MersenneTwister(9003) + dim = 8 + steps = 64 + epsilon = 0.06 + s0 = randn(rng, dim) + rec = _make_stdnormal_rec(rng, dim, steps, epsilon) + workspace = DEER.DEERWorkspace(s0, steps) + + bench = @benchmarkable DEER.solve( + $rec, + $s0; + tol_abs=1e-8, + tol_rel=1e-6, + maxiter=80, + jacobian=:diag, + damping=0.5, + rng=rng, + workspace=$workspace, + copy_result=false, + ) setup = (rng = MersenneTwister(42)) evals = 1 + + return BenchmarkCase( + "deer.solve_stdnormal_8x64", + "A full DEER block solve against a taped 8-D Gaussian MALA trajectory, T=64.", + bench, + ) +end + +function _parallel_mala_sample_case() + rng = MersenneTwister(9004) + dim = 4 + n_obs = 64 + X, y, _ = BayesLogReg.make_data(rng, n_obs, dim) + logp, gradlogp, hvp = BayesLogReg.make_problem_with_hvp(X, y) + model = ParallelMCMC.DensityModel(logp, gradlogp, dim; hvp=hvp) + sampler = ParallelMCMC.ParallelMALASampler( + 0.045; + T=16, + maxiter=60, + tol_abs=1e-6, + tol_rel=1e-5, + jacobian=:diag, + damping=0.5, + ) + initial_params = zeros(dim) + + bench = @benchmarkable sample( + rng, + $model, + $sampler, + 64; + initial_params=$initial_params, + progress=false, + ) setup = (rng = MersenneTwister(42)) evals = 1 + + return BenchmarkCase( + "sampler.parallel_mala_logreg_4d_64samples", + "Public ParallelMALASampler path on a small Bayesian logistic regression problem.", + bench, + ) +end + +function _parallel_mala_sample_no_hvp_case() + rng = MersenneTwister(9005) + dim = 4 + model = ParallelMCMC.DensityModel(_quartic_logp, _quartic_grad, dim) + sampler = ParallelMCMC.ParallelMALASampler( + 0.04; + T=16, + maxiter=60, + tol_abs=1e-6, + tol_rel=1e-5, + jacobian=:diag, + damping=0.5, + ) + initial_params = zeros(dim) + + bench = @benchmarkable sample( + rng, + $model, + $sampler, + 64; + initial_params=$initial_params, + progress=false, + ) setup = (rng = MersenneTwister(42)) evals = 1 + + return BenchmarkCase( + "sampler.parallel_mala_quartic_no_hvp_4d_64samples", + "Public ParallelMALASampler path on a non-Gaussian target with no model-provided HVP.", + bench, + ) +end + +function make_pr_benchmarks() + return [ + _mala_step_case(), + _affine_scan_case(), + _deer_solve_case(), + _parallel_mala_sample_case(), + _parallel_mala_sample_no_hvp_case(), + ] +end + +function _trial_summary(trial::BenchmarkTools.Trial) + med = median(trial) + min_est = minimum(trial) + return Dict{String,Any}( + "median_ns" => med.time, + "minimum_ns" => min_est.time, + "median_memory_bytes" => med.memory, + "median_allocs" => med.allocs, + "samples" => length(trial.times), + ) +end + +function _format_time(ns::Real) + ns < 1e3 && return @sprintf("%.0f ns", ns) + ns < 1e6 && return @sprintf("%.2f us", ns / 1e3) + ns < 1e9 && return @sprintf("%.2f ms", ns / 1e6) + return @sprintf("%.2f s", ns / 1e9) +end + +function run_pr_benchmarks(; seconds::Real=0.5, samples::Int=8) + cases = make_pr_benchmarks() + results = Dict{String,Any}() + + println("Running ParallelMCMC PR benchmarks") + println(" Julia: ", VERSION) + println(" seconds per case: ", seconds) + println(" minimum samples per case: ", samples) + println() + + for case in cases + println("==> ", case.name) + println(" ", case.description) + trial = run(case.benchmark; seconds=seconds, samples=samples) + summary = _trial_summary(trial) + summary["description"] = case.description + results[case.name] = summary + + println( + " median: ", + _format_time(summary["median_ns"]), + " min: ", + _format_time(summary["minimum_ns"]), + " allocs: ", + summary["median_allocs"], + " memory: ", + summary["median_memory_bytes"], + " bytes", + ) + println() + end + + return Dict{String,Any}( + "metadata" => Dict{String,Any}( + "julia_version" => string(VERSION), + "seconds" => Float64(seconds), + "samples" => samples, + ), + "benchmarks" => results, + ) +end + +function write_results(path::AbstractString, results::Dict{String,Any}) + mkpath(dirname(path)) + open(path, "w") do io + TOML.print(io, results; sorted=true) + end + return path +end + +function write_markdown(path::AbstractString, results::Dict{String,Any}) + benchmarks = results["benchmarks"] + mkpath(dirname(path)) + open(path, "w") do io + println(io, "# ParallelMCMC PR Benchmarks") + println(io) + println(io, "| Benchmark | Median | Minimum | Allocs | Memory |") + println(io, "|---|---:|---:|---:|---:|") + for name in sort(collect(keys(benchmarks))) + result = benchmarks[name] + println( + io, + "| `", + name, + "` | ", + _format_time(result["median_ns"]), + " | ", + _format_time(result["minimum_ns"]), + " | ", + result["median_allocs"], + " | ", + result["median_memory_bytes"], + " B |", + ) + end + println(io) + end + return path +end + +end # module diff --git a/benchmarks/ParallelMCMCBenchmarks/src/runners/mala_runner.jl b/benchmarks/ParallelMCMCBenchmarks/src/runners/mala_runner.jl index 3e9dcc3..ab99273 100644 --- a/benchmarks/ParallelMCMCBenchmarks/src/runners/mala_runner.jl +++ b/benchmarks/ParallelMCMCBenchmarks/src/runners/mala_runner.jl @@ -39,6 +39,8 @@ function tune_stepsize_mala( ξs, us = make_tape(rng, D, Twarm) x = copy(x0) + x_next = similar(x) + ws = MALA.MALAWorkspace(x) logϵ = log(ϵ0) # Step size schedule; keep conservative to avoid oscillation @@ -47,10 +49,13 @@ function tune_stepsize_mala( for t in 1:Twarm ϵ = exp(logϵ) # mala_step_with_logα shares primal computation with the step: 2 gradlogp evals total. - x, accepted, _ = MALA.mala_step_with_logα(logp, gradlogp, x, ϵ, ξs[t], us[t]) + x_next, accepted, _ = MALA.mala_step_with_logα!( + x_next, ws, logp, gradlogp, x, ϵ, ξs[t], us[t] + ) # Robbins–Monro update on log ϵ η = η0 / sqrt(t) logϵ += η * (Float64(accepted) - target_accept) + x, x_next = x_next, x end return x, exp(logϵ) @@ -78,12 +83,17 @@ function run_taped_mala_with_accepts( accepts = Vector{Float64}(undef, T) xs[1] = copy(x0) - x = x0 + x = copy(x0) + x_next = similar(x) + ws = MALA.MALAWorkspace(x) for t in 1:T # mala_step_full shares primal computation: 2 gradlogp evals instead of 5. - x, accepted = MALA.mala_step_full(logp, gradlogp, x, ϵ, ξs[t], us[t]) + x_next, accepted, _ = MALA.mala_step_with_logα!( + x_next, ws, logp, gradlogp, x, ϵ, ξs[t], us[t] + ) accepts[t] = Float64(accepted) - xs[t + 1] = x + xs[t + 1] = copy(x_next) + x, x_next = x_next, x end return xs, accepts diff --git a/docs/make.jl b/docs/make.jl index 0f4b508..79c38fb 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,6 @@ using ParallelMCMC using Documenter +using Documenter.Remotes: GitHub DocMeta.setdocmeta!(ParallelMCMC, :DocTestSetup, :(using ParallelMCMC); recursive=true) @@ -61,10 +62,15 @@ end makedocs(; modules=[ParallelMCMC], authors="Ryan Senne ", - repo="https://github.com/rsenne/ParallelMCMC.jl/blob/{commit}{path}#{line}", + repo=GitHub("rsenne", "ParallelMCMC.jl"), sitename="ParallelMCMC.jl", - format=Documenter.HTML(; canonical="https://rsenne.github.io/ParallelMCMC.jl"), + checkdocs=:none, + format=Documenter.HTML(; + canonical="https://rsenne.github.io/ParallelMCMC.jl", + repolink="https://github.com/rsenne/ParallelMCMC.jl", + edit_link="main", + ), pages=list_pages(), ) -deploydocs(; repo="github.com/rsenne/ParallelMCMC.jl") +deploydocs(; repo="github.com/rsenne/ParallelMCMC.jl", devbranch="main") diff --git a/docs/src/10-getting-started.md b/docs/src/10-getting-started.md index 80fdc4e..5a5826f 100644 --- a/docs/src/10-getting-started.md +++ b/docs/src/10-getting-started.md @@ -35,7 +35,7 @@ model = DensityModel(logp, grad_logp, 2; param_names=[:x1, :x2]) [`ParallelMALASampler`](@ref) reformulates a trajectory of `T` MALA steps as a fixed-point problem and solves it via Newton iterations, each of which costs $O(\log T)$ parallel work via an associative prefix scan. Wall-clock time per sample is therefore sublinear in chain length on multi-core CPUs and GPUs. ```julia -sampler = ParallelMALASampler(0.1; T=64, jacobian=:diag, damping=0.5) +sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag, damping=0.5) chain = sample(model, sampler, 500; chain_type=MCMCChains.Chains, progress=true) @@ -53,9 +53,8 @@ The `jacobian` keyword controls how the per-step Jacobian is approximated during | Mode | Cost per step | Notes | |---|---|---| -| `:diag` (default) | 1 full Jacobian | Exact diagonal; good default | -| `:stoch_diag` | `probes` JVPs | Hutchinson estimator; better in high dimensions | -| `:full` | 1 full Jacobian | Full matrix step; rarely needed | +| `:stoch_diag` (default) | `probes` JVPs | Hutchinson estimator; scalable default for high dimensions | +| `:diag` | `D` JVPs | Exact diagonal; useful for low-dimensional checks | For high-dimensional targets, `:stoch_diag` with a small number of `probes` is a good trade-off: @@ -118,7 +117,7 @@ using Turing, LogDensityProblems, LogDensityProblemsAD, ADTypes using ParallelMCMC, MCMCChains ld = DynamicPPL.LogDensityFunction(normal_model(1.5)) -ldg = LogDensityProblemsAD.ADgradient(ADTypes.AutoMooncake(; config=nothing), ld) +ldg = LogDensityProblemsAD.ADgradient(ADTypes.AutoEnzyme(), ld) model = DensityModel(ldg; param_names=[:μ]) ``` diff --git a/docs/src/20-algorithms.md b/docs/src/20-algorithms.md index 21e27ec..df9ab9a 100644 --- a/docs/src/20-algorithms.md +++ b/docs/src/20-algorithms.md @@ -70,7 +70,7 @@ The associative operator for combining two adjacent segments $(\alpha_1, \beta_1 (\alpha_2, \beta_2) \circ (\alpha_1, \beta_1) = (\alpha_2 \odot \alpha_1,\; \alpha_2 \odot \beta_1 + \beta_2). ``` -This associativity means the recurrence can be solved by an inclusive parallel-prefix scan in $O(\log T)$ levels, each level consisting of a single broadcast over all $T$ columns — no per-timestep loops. The implementation ([`DEER.solve_affine_scan_diag`](@ref DEER.solve_affine_scan_diag)) is array-type-agnostic and runs identically on CPU `Matrix` and GPU `CuMatrix`. Though, the CPU approach is included only for correctness testing — you will almost certainly be slower than regular sequential MCMC algorithms without a GPU. +This associativity means the recurrence can be solved by an inclusive parallel-prefix scan in $O(\log T)$ levels, each level consisting of a single broadcast over all $T$ columns and no per-timestep loops. The implementation in `ParallelMCMC.DEERScan.solve_affine_scan_diag!` is array-type-agnostic and runs identically on CPU `Matrix` and GPU `CuMatrix`. The CPU path is mainly useful for correctness checks; without a GPU or substantial parallel hardware, sequential MCMC is usually faster wall-clock. --- @@ -80,9 +80,11 @@ This associativity means the recurrence can be solved by an inclusive parallel-p Use the full $D \times D$ Jacobian $J_t$ at each timestep. The linear recursion becomes a sequence of dense matrix multiplications, solved sequentially (no scan shortcut for the general case). Cost per iteration: $O(TD^3)$. Memory: $O(TD^2)$. Accurate but impractical for large $D$. +This package currently exposes the diagonal scan variants below as `jacobian` modes; full-matrix DEER is discussed here for context but is not available through the `jacobian` keyword. + ### Quasi-DEER (`:diag`) -Replace $J_t$ with $\mathrm{diag}(J_t)$, retaining only the diagonal. The recursion reduces to the scalar affine scan described above, solved in $O(TD \log T)$ total work. The exact diagonal is computed via automatic differentiation (one full Jacobian, then `diag`). This is the default mode. +Replace $J_t$ with $\mathrm{diag}(J_t)$, retaining only the diagonal. The recursion reduces to the scalar affine scan described above, solved in $O(TD \log T)$ total work. The exact diagonal is computed with `D` Jacobian-vector products. This mode is useful for low-dimensional checks and reference runs. ### Stochastic quasi-DEER (`:stoch_diag`) @@ -93,7 +95,7 @@ Computing the exact diagonal of $J_t$ requires a full Jacobian (or $D$ JVPs), wh \quad z^{(k)}_i \overset{\text{iid}}{\sim} \mathrm{Rademacher}(\pm 1). ``` -Each probe $z^{(k)}$ costs a single Jacobian-vector product (one forward-mode or reverse-mode pass), so the total cost is $O(KTD)$ — linear in $D$ and $T$. In the limit $K \to \infty$ the estimate converges to the true diagonal. See Zoltowski et al. (2025) [^1] for convergence analysis. +Each probe $z^{(k)}$ costs a single Jacobian-vector product (one forward-mode or reverse-mode pass), so the total cost is $O(KTD)$ — linear in $D$ and $T$. In the limit $K \to \infty$ the estimate converges to the true diagonal. This is the default mode. See Zoltowski et al. (2025) [^1] for convergence analysis. In practice $K=1$ or $K=2$ probes works well; controlled by the `probes` argument of [`ParallelMALASampler`](@ref). @@ -126,7 +128,7 @@ x_t^\text{surrogate} = \hat{g}_t \, \tilde{x}_t + (1 - \hat{g}_t)\, x_{t-1}, where $\sigma$ is the logistic function. The stop-gradient term makes $\hat{g}_t$ equal to the exact indicator in the forward pass while routing gradients through $\sigma$ during the backward pass. This gives a well-defined, smooth Jacobian whose value at the operating point equals that of the relaxed step. -In the implementation, the accept indicator $g_t$ is pre-computed from the previous-iterate state and passed as a **frozen constant** to the surrogate ([`MALA.mala_step_surrogate`](@ref MALA.mala_step_surrogate)), so differentiation never touches the discontinuity. +In the implementation, the accept indicator $g_t$ is pre-computed from the previous-iterate state and passed as a **frozen constant** to the surrogate ([`MALA.mala_step_surrogate_sigmoid`](@ref MALA.mala_step_surrogate_sigmoid)), so differentiation never touches the discontinuity. ### Summary of the DEER–MALA loop @@ -136,8 +138,10 @@ In the implementation, the accept indicator $g_t$ is pre-computed from the previ 4. **Convergence check.** If the change is below tolerance, return $S^{(i+1)}$; otherwise go to step 2. 5. **Sample delivery.** Return the $T$ columns of the converged trajectory as individual MCMC samples. +For low-level callers using `DEER.solve(...; workspace=ws)`, the returned trajectory is copied by default so it remains valid after later solves reuse `ws`. Set `copy_result=false` only when you are intentionally accepting workspace-owned output that may be overwritten by a later call. + --- ## References -[^1]: Zoltowski, D., Wu, Y., Gonzalez, D., Kozachkov, L., & Linderman, S. (2025). *Parallelizing MCMC Across the Sequence Length*. NeurIPS 2025. [arXiv:2508.18413](https://arxiv.org/abs/2508.18413) +[^1]: Zoltowski, D. M., Wu, S., Gonzalez, X., Kozachkov, L., & Linderman, S. W. (2025). *Parallelizing MCMC Across the Sequence Length*. NeurIPS 2025. [arXiv:2508.18413](https://arxiv.org/abs/2508.18413) diff --git a/docs/src/90-contributing.md b/docs/src/90-contributing.md index 7557faf..d73ce8b 100644 --- a/docs/src/90-contributing.md +++ b/docs/src/90-contributing.md @@ -1,25 +1,25 @@ # [Contributing guidelines](@id contributing) -First of all, thanks for the interest! +Thanks for your interest in improving ParallelMCMC.jl. -We welcome all kinds of contribution, including, but not limited to code, documentation, examples, configuration, issue creating, etc. +We welcome contributions of all kinds: code, documentation, examples, benchmarks, issue reports, and review feedback. -Be polite and respectful, and follow the code of conduct. +Please be respectful and follow the [Code of Conduct](https://github.com/rsenne/ParallelMCMC.jl/blob/main/CODE_OF_CONDUCT.md). ## Bug reports and discussions If you think you found a bug, feel free to open an [issue](https://github.com/rsenne/ParallelMCMC.jl/issues). -Focused suggestions and requests can also be opened as issues. -Before opening a pull request, start an issue or a discussion on the topic, please. +Focused suggestions and feature requests are also welcome there. +For larger changes, please start with an issue or discussion before opening a pull request. ## Working on an issue -If you found an issue that interests you, comment on that issue what your plans are. +If you found an issue that interests you, leave a comment describing what you plan to work on. If the solution to the issue is clear, you can immediately create a pull request (see below). Otherwise, say what your proposed solution is and wait for a discussion around it. !!! tip Feel free to ping us after a few days if there are no responses. -If your solution involves code (or something that requires running the package locally), check the [developer documentation](91-developer.md). +If your solution involves code, tests, benchmarks, or documentation builds, check the [developer documentation](91-developer.md). Otherwise, you can use the GitHub interface directly to create your pull request. diff --git a/docs/src/91-developer.md b/docs/src/91-developer.md index 12b71e9..9f71f6a 100644 --- a/docs/src/91-developer.md +++ b/docs/src/91-developer.md @@ -134,38 +134,57 @@ We try to keep a linear history in this repo, so it is important to keep your br ## Building and viewing the documentation locally -Following the latest suggestions, we recommend using `LiveServer` to build the documentation. -Here is how you do it: +The CI workflow builds docs from the `docs/` project after developing the current checkout into that environment. To match CI locally: -1. Run `julia --project=docs` to open Julia in the environment of the docs. -1. If this is the first time building the docs - 1. Press `]` to enter `pkg` mode - 1. Run `pkg> dev .` to use the development version of your package - 1. Press backspace to leave `pkg` mode -1. Run `julia> using LiveServer` -1. Run `julia> servedocs()` +1. Instantiate the docs environment and develop the package into it: + + ```bash + julia --project=docs -e ' + using Pkg + Pkg.develop(Pkg.PackageSpec(path=pwd())) + Pkg.instantiate()' + ``` + +2. Build the docs once: + + ```bash + julia --project=docs docs/make.jl + ``` + +3. For live preview while editing, start a docs Julia session and serve it: + + ```bash + julia --project=docs + ``` + + ```julia + using LiveServer + servedocs() + ``` + +4. If you want the extra CI parity check, run the doctests too: + + ```bash + julia --project=docs -e ' + using Documenter: DocMeta, doctest + using ParallelMCMC + DocMeta.setdocmeta!(ParallelMCMC, :DocTestSetup, :(using ParallelMCMC); recursive=true) + doctest(ParallelMCMC)' + ``` + +If you update the landing-page animation, regenerate it with: + +```bash +julia docs/src/assets/make_julia_deer_gif.jl +``` ## Making a new release -To create a new release, you can follow these simple steps: - -- Create a branch `release-x.y.z` -- Update `version` in `Project.toml` -- Update the `CHANGELOG.md`: - - Rename the section "Unreleased" to "[x.y.z] - yyyy-mm-dd" (i.e., version under brackets, dash, and date in ISO format) - - Add a new section on top of it named "Unreleased" - - Add a new link in the bottom for version "x.y.z" - - Change the "[unreleased]" link to use the latest version - end of line, `vx.y.z ... HEAD`. -- Create a commit "Release vx.y.z", push, create a PR, wait for it to pass, merge the PR. -- Go back to main screen and click on the latest commit (link: ) -- At the bottom, write `@JuliaRegistrator register` - -After that, you only need to wait and verify: - -- Wait for the bot to comment (should take < 1m) with a link to a PR to the registry -- Follow the link and wait for a comment on the auto-merge -- The comment should said all is well and auto-merge should occur shortly -- After the merge happens, TagBot will trigger and create a new GitHub tag. Check on -- After the release is create, a "docs" GitHub action will start for the tag. -- After it passes, a deploy action will run. -- After that runs, the [stable docs](https://rsenne.github.io/ParallelMCMC.jl/stable) should be updated. Check them and look for the version number. +To create a new release: + +1. Create a release branch such as `release-x.y.z`. +2. Update `version` in `Project.toml`. +3. Update any release-facing docs you want to ship with the tag. +4. Open and merge the release PR after CI passes. +5. Comment `@JuliaRegistrator register` on the merge commit or release PR, then wait for the registry PR and auto-merge. +6. After registration, verify that TagBot creates the GitHub tag and that the docs workflow updates the [stable docs](https://rsenne.github.io/ParallelMCMC.jl/stable). diff --git a/docs/src/95-reference.md b/docs/src/95-reference.md index 2ebceac..9624633 100644 --- a/docs/src/95-reference.md +++ b/docs/src/95-reference.md @@ -1,7 +1,20 @@ # [API Reference](@id reference) +```@meta +CurrentModule = ParallelMCMC +``` + This page documents all public types and functions exported by ParallelMCMC.jl. +## Extension constructors + +`DensityModel` also has extension constructors for common probabilistic-programming interfaces: + +- `DensityModel(ld; param_names=nothing)` for `LogDensityProblems` models with gradients +- `DensityModel(turing_model)` for `DynamicPPL` / Turing models when the relevant extension packages are loaded + +See [Getting Started](10-getting-started.md) for end-to-end examples of both. + ## Model ```@docs @@ -30,6 +43,21 @@ ParallelMALAState ParallelMALATransition ``` +## Low-level namespaces + +These lower-level building blocks power the public samplers and are useful if you +want to work with taped recursions or the diagonal affine scan directly. + +```@docs +DEER.TapedRecursion +DEER.DEERWorkspace +DEER.solve +ParallelMCMC.DEERScan.AffineScanWorkspace +MALA.mala_step_surrogate_sigmoid +``` + +The diagonal scan implementation itself lives in `ParallelMCMC.DEERScan.solve_affine_scan_diag!`. + ## Index ```@index diff --git a/docs/src/assets/julia_deer_posterior.gif b/docs/src/assets/julia_deer_posterior.gif new file mode 100644 index 0000000..b57b5b9 Binary files /dev/null and b/docs/src/assets/julia_deer_posterior.gif differ diff --git a/docs/src/assets/logo.png b/docs/src/assets/logo.png new file mode 100644 index 0000000..fbeb237 Binary files /dev/null and b/docs/src/assets/logo.png differ diff --git a/docs/src/assets/make_julia_deer_gif.jl b/docs/src/assets/make_julia_deer_gif.jl new file mode 100644 index 0000000..2f61130 --- /dev/null +++ b/docs/src/assets/make_julia_deer_gif.jl @@ -0,0 +1,386 @@ +#!/usr/bin/env julia + +# Regenerate `julia_deer_posterior.gif` for the README and docs landing page. +# +# This script intentionally includes the local MALA and affine-scan primitives +# instead of `using ParallelMCMC`, so the docs asset can be generated without +# downloading GPU/AD artifacts. + +using LinearAlgebra +using Printf +using Random + +const REPO_ROOT = normpath(joinpath(@__DIR__, "..", "..", "..")) +include(joinpath(REPO_ROOT, "src", "MALA", "MALA.jl")) +include(joinpath(REPO_ROOT, "src", "DEER", "DEERScan.jl")) + +const WIDTH = 560 +const HEIGHT = 420 +const X_RANGE = (-2.25, 2.25) +const Y_RANGE = (-1.70, 1.78) + +const CENTERS = [ + (-1.05, -0.55), # green + (1.05, -0.55), # red + (0.0, 1.02), # purple + (0.0, -0.16), # blue +] +const SIGMAS = [0.38, 0.38, 0.38, 0.25] +const LOG_WEIGHTS = log.([1.0, 1.0, 1.0, 0.50]) +const LOGO_COLORS = [ + (56, 152, 38), + (203, 60, 51), + (149, 88, 178), + (64, 99, 216), +] + +struct TapeStep + xi::Vector{Float64} + u::Float64 +end + +function log_components(x) + vals = Vector{Float64}(undef, length(CENTERS)) + for k in eachindex(CENTERS) + cx, cy = CENTERS[k] + sig2 = SIGMAS[k]^2 + dx = x[1] - cx + dy = x[2] - cy + vals[k] = LOG_WEIGHTS[k] - log(2 * pi * sig2) - 0.5 * (dx^2 + dy^2) / sig2 + end + return vals +end + +function responsibilities_from_logs(vals) + offset = maximum(vals) + weights = exp.(vals .- offset) + total = sum(weights) + return weights ./ total, offset + log(total) +end + +function logposterior(x) + _, lp = responsibilities_from_logs(log_components(x)) + return lp +end + +function gradposterior(x) + resp, _ = responsibilities_from_logs(log_components(x)) + grad = zeros(2) + for k in eachindex(CENTERS) + cx, cy = CENTERS[k] + sig2 = SIGMAS[k]^2 + grad[1] += resp[k] * (cx - x[1]) / sig2 + grad[2] += resp[k] * (cy - x[2]) / sig2 + end + return grad +end + +function hvp_posterior(x, v) + resp, _ = responsibilities_from_logs(log_components(x)) + component_grads = Matrix{Float64}(undef, 2, length(CENTERS)) + grad = zeros(2) + hv = zeros(2) + + for k in eachindex(CENTERS) + cx, cy = CENTERS[k] + sig2 = SIGMAS[k]^2 + g1 = (cx - x[1]) / sig2 + g2 = (cy - x[2]) / sig2 + component_grads[1, k] = g1 + component_grads[2, k] = g2 + grad[1] += resp[k] * g1 + grad[2] += resp[k] * g2 + hv[1] -= resp[k] * v[1] / sig2 + hv[2] -= resp[k] * v[2] / sig2 + end + + for k in eachindex(CENTERS) + dg1 = component_grads[1, k] - grad[1] + dg2 = component_grads[2, k] - grad[2] + projection = dg1 * v[1] + dg2 * v[2] + hv[1] += resp[k] * dg1 * projection + hv[2] += resp[k] * dg2 * projection + end + + return hv +end + +function make_tape(rng, dim, steps) + return [TapeStep(randn(rng, dim), rand(rng)) for _ in 1:steps] +end + +function make_recursion(tape, epsilon) + step_fwd = + (x, step) -> + MALA.mala_step_taped(logposterior, gradposterior, x, epsilon, step.xi, step.u) + jvp = + (x, step, v) -> MALA.mala_step_surrogate_sigmoid_jvp( + logposterior, gradposterior, x, epsilon, step.xi, step.u, v, hvp_posterior + ) + return (; step_fwd, jvp, tape) +end + +function deer_diag_update!( + output, A, B, scan_ws, rec, s0, current; damping=0.55 +) + dim, steps = size(current) + basis = zeros(dim) + + for t in 1:steps + xbar = t == 1 ? s0 : view(current, :, t - 1) + ft = rec.step_fwd(xbar, rec.tape[t]) + + for j in 1:dim + fill!(basis, 0.0) + basis[j] = 1.0 + jv = rec.jvp(xbar, rec.tape[t], basis) + A[j, t] = jv[j] + end + + @views B[:, t] .= ft .- A[:, t] .* xbar + end + + DEERScan.solve_affine_scan_diag!(output, A, B, s0, scan_ws) + @. output = (1 - damping) * current + damping * output + return output +end + +function record_iterates(rec, s0; steps, maxiter=24, damping=0.55) + dim = length(s0) + current = repeat(reshape(s0, dim, 1), 1, steps) + next_state = similar(current) + A = similar(current) + B = similar(current) + scan_ws = DEERScan.AffineScanWorkspace(A) + iterates = [copy(current)] + metrics = Float64[] + + for _ in 1:maxiter + deer_diag_update!(next_state, A, B, scan_ws, rec, s0, current; damping=damping) + delta = maximum(abs.(next_state .- current)) + scale = 1e-5 + 1e-4 * maximum(abs.(next_state)) + push!(metrics, delta / scale) + push!(iterates, copy(next_state)) + current, next_state = next_state, current + end + + return iterates, metrics +end + +clamp_u8(x) = UInt8(clamp(round(Int, x), 0, 255)) + +function blend_pixel!(image, px, py, color, alpha) + 1 <= px <= WIDTH || return image + 1 <= py <= HEIGHT || return image + + base = image[py, px] + image[py, px] = ( + clamp_u8((1 - alpha) * base[1] + alpha * color[1]), + clamp_u8((1 - alpha) * base[2] + alpha * color[2]), + clamp_u8((1 - alpha) * base[3] + alpha * color[3]), + ) + return image +end + +function data_to_pixel(x, y) + px = round(Int, 1 + (x - X_RANGE[1]) / (X_RANGE[2] - X_RANGE[1]) * (WIDTH - 1)) + py = round(Int, 1 + (Y_RANGE[2] - y) / (Y_RANGE[2] - Y_RANGE[1]) * (HEIGHT - 1)) + return px, py +end + +function pixel_to_data(px, py) + x = X_RANGE[1] + (px - 1) / (WIDTH - 1) * (X_RANGE[2] - X_RANGE[1]) + y = Y_RANGE[2] - (py - 1) / (HEIGHT - 1) * (Y_RANGE[2] - Y_RANGE[1]) + return [x, y] +end + +function weighted_logo_color(resp) + r = g = b = 0.0 + for k in eachindex(LOGO_COLORS) + color = LOGO_COLORS[k] + r += resp[k] * color[1] + g += resp[k] * color[2] + b += resp[k] * color[3] + end + return (r, g, b) +end + +function make_background() + image = fill((UInt8(252), UInt8(252), UInt8(250)), HEIGHT, WIDTH) + logps = Matrix{Float64}(undef, HEIGHT, WIDTH) + maxlogp = -Inf + + for py in 1:HEIGHT, px in 1:WIDTH + _, lp = responsibilities_from_logs(log_components(pixel_to_data(px, py))) + logps[py, px] = lp + maxlogp = max(maxlogp, lp) + end + + for py in 1:HEIGHT, px in 1:WIDTH + x = pixel_to_data(px, py) + resp, _ = responsibilities_from_logs(log_components(x)) + intensity = exp(logps[py, px] - maxlogp)^0.38 + alpha = 0.08 + 0.72 * intensity + blend_pixel!(image, px, py, weighted_logo_color(resp), alpha) + end + + return image +end + +function draw_disk!(image, cx, cy, radius, color, alpha) + xmin = floor(Int, cx - radius - 1) + xmax = ceil(Int, cx + radius + 1) + ymin = floor(Int, cy - radius - 1) + ymax = ceil(Int, cy + radius + 1) + + for py in ymin:ymax, px in xmin:xmax + dist = hypot(px - cx, py - cy) + if dist <= radius + 0.5 + edge = clamp(radius + 0.5 - dist, 0.0, 1.0) + blend_pixel!(image, px, py, color, alpha * edge) + end + end + + return image +end + +function draw_line!(image, p1, p2, radius, color, alpha) + dx = p2[1] - p1[1] + dy = p2[2] - p1[2] + steps = max(1, ceil(Int, 1.7 * max(abs(dx), abs(dy)))) + + for i in 0:steps + t = i / steps + px = p1[1] + t * dx + py = p1[2] + t * dy + draw_disk!(image, px, py, radius, color, alpha) + end + + return image +end + +function draw_trajectory!(image, trajectory; ghost=false) + steps = size(trajectory, 2) + points = Vector{Tuple{Int,Int}}(undef, steps) + + for t in 1:steps + points[t] = data_to_pixel(trajectory[1, t], trajectory[2, t]) + end + + if ghost + for t in 2:steps + draw_line!(image, points[t - 1], points[t], 0.9, (25, 31, 40), 0.13) + end + for t in 1:3:steps + draw_disk!(image, points[t][1], points[t][2], 1.45, (25, 31, 40), 0.18) + end + else + for t in 2:steps + draw_line!(image, points[t - 1], points[t], 1.15, (28, 36, 50), 0.48) + end + for t in 1:steps + draw_disk!(image, points[t][1], points[t][2], 3.0, (28, 36, 50), 0.42) + draw_disk!(image, points[t][1], points[t][2], 2.05, (247, 166, 38), 0.88) + end + draw_disk!(image, points[1][1], points[1][2], 4.2, (255, 255, 255), 0.82) + draw_disk!(image, points[1][1], points[1][2], 3.0, (56, 152, 38), 0.95) + draw_disk!(image, points[end][1], points[end][2], 4.2, (255, 255, 255), 0.82) + draw_disk!(image, points[end][1], points[end][2], 3.0, (203, 60, 51), 0.95) + end + + return image +end + +function draw_rect!(image, x1, y1, x2, y2, color, alpha) + for py in max(1, y1):min(HEIGHT, y2), px in max(1, x1):min(WIDTH, x2) + blend_pixel!(image, px, py, color, alpha) + end + return image +end + +function draw_progress!(image, frame_index, frame_count) + margin = 44 + y = HEIGHT - 22 + width = WIDTH - 2margin + filled = round(Int, width * (frame_index - 1) / max(1, frame_count - 1)) + + draw_rect!(image, margin, y, margin + width, y + 6, (24, 30, 39), 0.16) + draw_rect!(image, margin, y, margin + filled, y + 6, (247, 166, 38), 0.90) + draw_disk!(image, margin + filled, y + 3, 4.0, (247, 166, 38), 0.95) + return image +end + +function write_ppm(path, image) + open(path, "w") do io + write(io, "P6\n$WIDTH $HEIGHT\n255\n") + for py in 1:HEIGHT, px in 1:WIDTH + color = image[py, px] + write(io, color[1], color[2], color[3]) + end + end + return path +end + +function render_frames(frame_dir, iterates, final_trajectory) + background = make_background() + frame_paths = String[] + + for (i, trajectory) in enumerate(iterates) + image = copy(background) + draw_trajectory!(image, final_trajectory; ghost=true) + draw_trajectory!(image, trajectory) + draw_progress!(image, i, length(iterates)) + + path = joinpath(frame_dir, @sprintf("frame_%03d.ppm", i)) + write_ppm(path, image) + push!(frame_paths, path) + end + + return frame_paths +end + +function main() + rng = MersenneTwister(20260428) + steps = 1000 + maxiter = 256 + epsilon = 0.095 + damping = 0.55 + x0 = [ + rand(rng) * (X_RANGE[2] - X_RANGE[1]) + X_RANGE[1], + rand(rng) * (Y_RANGE[2] - Y_RANGE[1]) + Y_RANGE[1], + ] + + + tape = make_tape(rng, 2, steps) + rec = make_recursion(tape, epsilon) + iterates, metrics = record_iterates(rec, x0; steps=steps, maxiter=maxiter, damping=damping) + + selected = unique(round.(Int, range(0, maxiter; length=256))) + selected_iterates = [iterates[i + 1] for i in selected] + + noise = [step.xi for step in tape] + uniforms = [step.u for step in tape] + sequential = MALA.run_mala_sequential_taped(logposterior, gradposterior, x0, epsilon, noise, uniforms) + final_trajectory = reduce(hcat, sequential[2:end]) + + output = isempty(ARGS) ? joinpath(@__DIR__, "julia_deer_posterior.gif") : ARGS[1] + mkpath(dirname(output)) + + convert = Sys.which("convert") + convert === nothing && error("ImageMagick `convert` is required to build the GIF") + + mktempdir() do frame_dir + frame_paths = render_frames(frame_dir, selected_iterates, final_trajectory) + animation_paths = vcat(frame_paths, fill(last(frame_paths), 8)) + tmp_output = joinpath(frame_dir, "julia_deer_posterior.gif") + run(`$convert -delay 7 -loop 0 $animation_paths -layers Optimize $tmp_output`) + cp(tmp_output, output; force=true) + end + + final_error = maximum(abs.(last(iterates) .- final_trajectory)) + println("wrote ", output) + println("last DEER metric: ", @sprintf("%.3g", last(metrics))) + println("max error vs sequential taped MALA: ", @sprintf("%.3g", final_error)) +end + +main() diff --git a/docs/src/index.md b/docs/src/index.md index b593bf9..f267b7d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -4,13 +4,29 @@ CurrentModule = ParallelMCMC # ParallelMCMC.jl +```@raw html +

+ ParallelMCMC logo +

+``` + **ParallelMCMC.jl** is a Julia package for **parallel-across-the-sequence** MCMC — algorithms that solve an entire trajectory of $T$ correlated steps *simultaneously* rather than one at a time. +```@raw html +

+ DEER trajectory estimates improving on a Julia-logo-shaped posterior +

+ +

+ DEER iterates on a synthetic Julia-logo-shaped posterior: orange trajectory estimates move toward the taped MALA path over repeated trajectory solves. +

+``` + The flagship algorithm is **DEER** (Deterministic Equivalent-Expectation Recursion), which reformulates a chain of $T$ MALA steps as a fixed-point problem and solves it via Newton iterations, each costing $O(\log T)$ parallel work via an associative prefix scan. The result is that wall-clock time per sample is *sublinear* in chain length on multi-core CPUs and GPUs. The algorithm is described in: -> Zoltowski, D., Wu, Y., Gonzalez, D., Kozachkov, L., & Linderman, S. (2025). +> Zoltowski, D. M., Wu, S., Gonzalez, X., Kozachkov, L., & Linderman, S. W. (2025). > **Parallelizing MCMC Across the Sequence Length.** > *NeurIPS 2025.* [arXiv:2508.18413](https://arxiv.org/abs/2508.18413) @@ -28,10 +44,10 @@ All samplers implement the [AbstractMCMC](https://github.com/TuringLang/Abstract ## Installation -ParallelMCMC.jl is a registered Julia package: +To install the package into your current environment: ```julia-repl -julia> ] add ParallelMCMC +pkg> add https://github.com/rsenne/ParallelMCMC.jl ``` ## Quick start @@ -54,7 +70,7 @@ model = DensityModel(logp, grad_logp, 2; ### DEER — parallel-across-sequence (primary algorithm) ```julia -sampler = ParallelMALASampler(0.1; T=64, jacobian=:diag) +sampler = ParallelMALASampler(0.1; T=64, jacobian=:stoch_diag) chain = sample(model, sampler, 500; chain_type=MCMCChains.Chains) diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index 96833d1..3ee8752 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -8,7 +8,7 @@ using LogDensityProblems: LogDensityProblems using LogDensityProblemsAD: LogDensityProblemsAD """ - DensityModel(turing_model::DynamicPPL.Model; ad_backend=ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse))) + DensityModel(turing_model::DynamicPPL.Model; ad_backend=ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Duplicated), hvp=nothing) Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a `DensityModel`, automatically extracting parameter names and wiring up gradient @@ -39,7 +39,11 @@ chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; """ function ParallelMCMC.DensityModel( turing_model::DynamicPPL.Model; - ad_backend=ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), + ad_backend=ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Duplicated, + ), + hvp=nothing, ) # Build the LogDensityProblems-compatible gradient object ld = DynamicPPL.LogDensityFunction(turing_model) @@ -56,13 +60,13 @@ function ParallelMCMC.DensityModel( # Try to extract parameter names; fall back to nothing on any error or mismatch. param_names = _try_extract_param_names(turing_model, dim) - logp(x) = LogDensityProblems.logdensity(ldg, x) + logp(x) = LogDensityProblems.logdensity(ld, x) function gradlogp(x) _, g = LogDensityProblems.logdensity_and_gradient(ldg, x) return g end - return ParallelMCMC.DensityModel(logp, gradlogp, dim; param_names=param_names) + return ParallelMCMC.DensityModel(logp, gradlogp, dim; hvp=hvp, param_names=param_names) end """ diff --git a/src/DEER/DEER.jl b/src/DEER/DEER.jl index 3e0c00e..1dcb15b 100644 --- a/src/DEER/DEER.jl +++ b/src/DEER/DEER.jl @@ -5,23 +5,35 @@ using DifferentiationInterface using ADTypes: ADTypes, AbstractADType import Enzyme: Enzyme using Random +using CUDA: CUDA include("DEERScan.jl") using .DEERScan const DI = DifferentiationInterface -const DEFAULT_BACKEND = ADTypes.AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Duplicated) +const DEFAULT_BACKEND = ADTypes.AutoEnzyme(; + mode=Enzyme.Forward, function_annotation=Enzyme.Duplicated +) const DEFAULT_HVP_BACKEND = DI.SecondOrder( DEFAULT_BACKEND, - ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse)), + # The log-density function itself is non-active in the reverse pass. + # Marking it duplicated makes DynamicPPL closure fields look active to Enzyme. + ADTypes.AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), ) export TapedRecursion, - DEERWorkspace, - deer_update!, deer_update, - solve, - _prepare_hvp, _hvp_prepared, _hvp_nopre, - DEFAULT_BACKEND, DEFAULT_HVP_BACKEND + DEERWorkspace, + deer_update!, + deer_update, + solve, + _prepare_hvp, + _hvp_prepared, + _hvp_nopre, + DEFAULT_BACKEND, + DEFAULT_HVP_BACKEND """ Deterministic recursion driven by a pre-generated tape. @@ -40,7 +52,9 @@ struct TapedRecursion{Ff,Fj,Ffj,Ffb,Tt} tape::Vector{Tt} end -function TapedRecursion(step_fwd, jvp, tape::Vector; fwd_and_jvp=nothing, fwd_and_jvp_batch=nothing) +function TapedRecursion( + step_fwd, jvp, tape::Vector; fwd_and_jvp=nothing, fwd_and_jvp_batch=nothing +) return TapedRecursion(step_fwd, jvp, fwd_and_jvp, fwd_and_jvp_batch, tape) end @@ -51,14 +65,17 @@ All buffers are created with the same array type / device placement as `S_templa (or `s0_template` for vector buffers), so the workspace is GPU-compatible when constructed from `CuArray` templates. """ -struct DEERWorkspace{M,V,SW} +struct DEERWorkspace{M,V,SW,HZ,H} A::M B::M Xbar::M Z::M + Zhost::HZ + S_work::M S_tmp::M diff_buf::M zbuf::V + zhost::H jt_buf::V xbar_buf::V scan::SW @@ -69,13 +86,18 @@ function DEERWorkspace(S_template::AbstractMatrix, s0_template::AbstractVector) B = similar(S_template) Xbar = similar(S_template) Z = similar(S_template) + Zhost = Z isa CUDA.CuArray ? Matrix{eltype(S_template)}(undef, size(Z)...) : nothing + S_work = similar(S_template) S_tmp = similar(S_template) diff_buf = similar(S_template) zbuf = similar(s0_template) + zhost = zbuf isa CUDA.CuArray ? Vector{eltype(s0_template)}(undef, length(s0_template)) : nothing jt_buf = similar(s0_template) xbar_buf = similar(s0_template) scan = DEERScan.AffineScanWorkspace(S_template) - return DEERWorkspace(A, B, Xbar, Z, S_tmp, diff_buf, zbuf, jt_buf, xbar_buf, scan) + return DEERWorkspace( + A, B, Xbar, Z, Zhost, S_work, S_tmp, diff_buf, zbuf, zhost, jt_buf, xbar_buf, scan + ) end function DEERWorkspace(s0_template::AbstractVector, T::Integer) @@ -104,7 +126,9 @@ function _tangent_like(x::AbstractArray, v::AbstractArray) return v_copy end -function _hvp_prepared(f, prep, backend::AbstractADType, x::AbstractVector, v::AbstractVector) +function _hvp_prepared( + f, prep, backend::AbstractADType, x::AbstractVector, v::AbstractVector +) x_exec = _materialize_ad_vector(x) v_exec = _tangent_like(x_exec, v) res = DI.pushforward(f, prep, backend, x_exec, (v_exec,)) @@ -135,11 +159,7 @@ function _prepare_logdensity_hvp(f, backend::AbstractADType, x_template::Abstrac end function _logdensity_hvp_prepared( - f, - prep, - backend::AbstractADType, - x::AbstractVector, - v::AbstractVector, + f, prep, backend::AbstractADType, x::AbstractVector, v::AbstractVector ) x_exec = _materialize_ad_vector(x) v_exec = _tangent_like(x_exec, v) @@ -148,10 +168,7 @@ function _logdensity_hvp_prepared( end function _logdensity_hvp_nopre( - f, - backend::AbstractADType, - x::AbstractVector, - v::AbstractVector, + f, backend::AbstractADType, x::AbstractVector, v::AbstractVector ) x_exec = _materialize_ad_vector(x) v_exec = _tangent_like(x_exec, v) @@ -161,9 +178,7 @@ function _logdensity_hvp_nopre( end function _prepare_batch_hvp_from_grad( - grad_batch, - backend::AbstractADType, - X_template::AbstractMatrix, + grad_batch, backend::AbstractADType, X_template::AbstractMatrix ) V_template = similar(X_template) fill!(V_template, zero(eltype(X_template))) @@ -173,11 +188,7 @@ function _prepare_batch_hvp_from_grad( end function _batch_hvp_from_grad_prepared( - grad_batch, - prep, - backend::AbstractADType, - X::AbstractMatrix, - V::AbstractMatrix, + grad_batch, prep, backend::AbstractADType, X::AbstractMatrix, V::AbstractMatrix ) X_exec = _materialize_ad_matrix(X) V_exec = _tangent_like(X_exec, V) @@ -185,26 +196,34 @@ function _batch_hvp_from_grad_prepared( return res isa Tuple ? first(res) : res end -@inline function _rademacher!(z::AbstractVector{T}, rng::AbstractRNG) where {T} - D = length(z) - bits = rand(rng, Bool, D) - vals = Vector{T}(undef, D) - @inbounds for i in 1:D - vals[i] = bits[i] ? one(T) : -one(T) +@inline function _rademacher!(z::AbstractArray{T}, rng::AbstractRNG) where {T} + @inbounds for i in eachindex(z) + z[i] = rand(rng, Bool) ? one(T) : -one(T) end - copyto!(z, vals) return z end -@inline function _rademacher_matrix!(Z::AbstractMatrix{T}, rng::AbstractRNG) where {T} - D, n = size(Z) - bits = rand(rng, Bool, D, n) - vals = Matrix{T}(undef, D, n) - @. vals = ifelse(bits, one(T), -one(T)) - copyto!(Z, vals) - return Z +@inline function _rademacher!(z::CUDA.CuArray{T}, rng::AbstractRNG) where {T} + host = Vector{T}(undef, length(z)) + _rademacher!(z, rng, host) + return z end +@inline function _rademacher!( + z::CUDA.CuArray{T}, rng::AbstractRNG, host::AbstractArray{T} +) where {T} + length(host) == length(z) || throw(DimensionMismatch("host buffer must match z")) + _rademacher!(host, rng) + copyto!(z, host) + return z +end + +@inline _rademacher!(z::AbstractArray, rng::AbstractRNG, ::Nothing) = _rademacher!(z, rng) + +@inline _rademacher_matrix!(Z::AbstractMatrix, rng::AbstractRNG) = _rademacher!(Z, rng) +@inline _rademacher_matrix!(Z::AbstractMatrix, rng::AbstractRNG, host) = + _rademacher!(Z, rng, host) + function jac_diag_via_jvps(rec::TapedRecursion, x::AbstractVector, t::Int) D = length(x) FT = eltype(x) @@ -234,7 +253,7 @@ function jac_diag_stoch( FT = float(eltype(x)) probes ≥ 1 || throw(ArgumentError("probes must be ≥ 1")) - z = zbuf === nothing ? Vector{FT}(undef, D) : zbuf + z = zbuf === nothing ? similar(x, FT, D) : zbuf length(z) == D || throw(ArgumentError("zbuf must have length D")) d = fill!(similar(z, D), zero(FT)) @@ -280,17 +299,17 @@ function deer_update!( Xbar = ws.Xbar Z = ws.Z - copyto!(view(Xbar, :, 1), s0) + copyto!(Xbar, 1, s0, 1, D) if T > 1 - @views Xbar[:, 2:end] .= S_in[:, 1:(T - 1)] + copyto!(Xbar, D + 1, S_in, 1, D * (T - 1)) end - _rademacher_matrix!(Z, rng) + _rademacher_matrix!(Z, rng, ws.Zhost) FT, Jt = rec.fwd_and_jvp_batch(Xbar, Z) @. A = Z * Jt for _ in 2:probes - _rademacher_matrix!(Z, rng) + _rademacher_matrix!(Z, rng, ws.Zhost) _, Jt = rec.fwd_and_jvp_batch(Xbar, Z) @. A += Z * Jt end @@ -307,12 +326,12 @@ function deer_update!( xbar = if t == 1 s0 else - copyto!(xbar_buf, view(S_in, :, t - 1)) + copyto!(xbar_buf, 1, S_in, (t - 2) * D + 1, D) xbar_buf end if jacobian === :stoch_diag - _rademacher!(zbuf, rng) + _rademacher!(zbuf, rng, ws.zhost) ft, jvp1 = if rec.fwd_and_jvp !== nothing rec.fwd_and_jvp(xbar, rec.tape[t], zbuf) else @@ -321,7 +340,7 @@ function deer_update!( @. jt_buf = zbuf * jvp1 for _ in 2:probes - _rademacher!(zbuf, rng) + _rademacher!(zbuf, rng, ws.zhost) jvp_k = rec.jvp(xbar, rec.tape[t], zbuf) @. jt_buf += zbuf * jvp_k end @@ -329,14 +348,18 @@ function deer_update!( jt_buf .*= one(eltype(jt_buf)) / probes end - view(A, :, t) .= jt_buf - @views @. B[:, t] = ft - jt_buf * xbar + colstart = (t - 1) * D + 1 + copyto!(A, colstart, jt_buf, 1, D) + @. xbar_buf = ft - jt_buf * xbar + copyto!(B, colstart, xbar_buf, 1, D) elseif jacobian === :diag ft = rec.step_fwd(xbar, rec.tape[t]) jt = jac_diag_via_jvps(rec, xbar, t) - view(A, :, t) .= jt - @views @. B[:, t] = ft - jt * xbar + colstart = (t - 1) * D + 1 + copyto!(A, colstart, jt, 1, D) + @. xbar_buf = ft - jt * xbar + copyto!(B, colstart, xbar_buf, 1, D) else throw(ArgumentError("jacobian must be :diag or :stoch_diag")) @@ -362,7 +385,9 @@ function deer_update( ) ws = DEERWorkspace(S, vec(s0_in)) S_out = similar(S) - deer_update!(ws, S_out, rec, s0_in, S; jacobian=jacobian, damping=damping, probes=probes, rng=rng) + deer_update!( + ws, S_out, rec, s0_in, S; jacobian=jacobian, damping=damping, probes=probes, rng=rng + ) return S_out end @@ -371,6 +396,12 @@ Run DEER iterations until convergence. When no workspace is supplied, one is created automatically. Supplying a pre-allocated `DEERWorkspace` is the intended path for repeated GPU solves. + +If `workspace` is supplied, `copy_result` defaults to `true` so the returned +trajectory is owned by the caller and will not be overwritten by a later solve +using the same workspace. Set `copy_result=false` for allocation-sensitive +internal loops; in that mode the returned trajectory may alias workspace-owned +buffers such as `workspace.S_tmp`. """ function solve( rec::TapedRecursion, @@ -385,6 +416,7 @@ function solve( rng::AbstractRNG=Random.default_rng(), return_info::Bool=false, workspace::Union{Nothing,DEERWorkspace}=nothing, + copy_result::Union{Nothing,Bool}=nothing, ) s0 = vec(s0_in) D = length(s0) @@ -395,16 +427,22 @@ function solve( tol_rel ≥ 0 || throw(ArgumentError("tol_rel must be ≥ 0")) damping > 0 && damping ≤ 1 || throw(ArgumentError("damping must be in (0,1]")) + copy_result === nothing && (copy_result = workspace !== nothing) + ws = workspace === nothing ? DEERWorkspace(s0, T) : workspace - size(ws.S_tmp) == (D, T) || throw(ArgumentError("workspace state buffer has wrong size")) + size(ws.S_work) == (D, T) || + throw(ArgumentError("workspace state buffer has wrong size")) + size(ws.S_tmp) == (D, T) || + throw(ArgumentError("workspace state buffer has wrong size")) S = if init === nothing - S0 = similar(s0, D, T) - S0 .= reshape(s0, D, 1) + S0 = ws.S_work + S0 .= s0 S0 else size(init) == (D, T) || throw(ArgumentError("init must be size (D,T)")) - copy(init) + copyto!(ws.S_work, init) + ws.S_work end S_new = ws.S_tmp @@ -414,7 +452,17 @@ function solve( for iter in 1:maxiter iters = iter - deer_update!(ws, S_new, rec, s0, S; jacobian=jacobian, damping=damping, probes=probes, rng=rng) + deer_update!( + ws, + S_new, + rec, + s0, + S; + jacobian=jacobian, + damping=damping, + probes=probes, + rng=rng, + ) @. ws.diff_buf = abs(S_new - S) Δ_max = maximum(ws.diff_buf) @@ -432,8 +480,10 @@ function solve( S, S_new = S_new, S end - return_info || return S - return S, + S_result = copy_result ? copy(S) : S + + return_info || return S_result + return S_result, ( converged=converged, iters=iters, diff --git a/src/DEER/DEERScan.jl b/src/DEER/DEERScan.jl index a13a67c..601af90 100644 --- a/src/DEER/DEERScan.jl +++ b/src/DEER/DEERScan.jl @@ -1,9 +1,12 @@ module DEERScan export AffineScanWorkspace, - solve_affine_seq!, solve_affine_seq, - solve_affine_scan_diag!, solve_affine_scan_diag, - check_affine_scan, affine_scan_residual + solve_affine_seq!, + solve_affine_seq, + solve_affine_scan_diag!, + solve_affine_scan_diag, + check_affine_scan, + affine_scan_residual """ Reusable workspace for the diagonal affine scan. @@ -45,7 +48,9 @@ Inputs: This is the ground-truth implementation to compare against the parallel scan path. It is intentionally simple and should be used in tests. """ -function solve_affine_seq!(S::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, s0::AbstractVector) +function solve_affine_seq!( + S::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, s0::AbstractVector +) D, T = size(A) size(B) == (D, T) || throw(DimensionMismatch("B must have the same size as A")) size(S) == (D, T) || throw(DimensionMismatch("S must have size (D, T)")) @@ -108,10 +113,14 @@ function solve_affine_scan_diag!( size(B) == (D, T) || throw(DimensionMismatch("B must have the same size as A")) size(S) == (D, T) || throw(DimensionMismatch("S must have size (D, T)")) length(s0) == D || throw(DimensionMismatch("length(s0) must equal size(A,1)")) - size(ws.alpha) == (D, T) || throw(DimensionMismatch("workspace has wrong matrix size")) - size(ws.beta) == (D, T) || throw(DimensionMismatch("workspace has wrong matrix size")) - size(ws.alpha_new) == (D, T) || throw(DimensionMismatch("workspace has wrong matrix size")) - size(ws.beta_new) == (D, T) || throw(DimensionMismatch("workspace has wrong matrix size")) + size(ws.alpha) == (D, T) || + throw(DimensionMismatch("workspace has wrong matrix size")) + size(ws.beta) == (D, T) || + throw(DimensionMismatch("workspace has wrong matrix size")) + size(ws.alpha_new) == (D, T) || + throw(DimensionMismatch("workspace has wrong matrix size")) + size(ws.beta_new) == (D, T) || + throw(DimensionMismatch("workspace has wrong matrix size")) Base.mightalias(S, A) && throw(ArgumentError("S must not alias A")) Base.mightalias(S, B) && throw(ArgumentError("S must not alias B")) end @@ -128,14 +137,27 @@ function solve_affine_scan_diag!( offset = 1 while offset < T + last_level = (offset << 1) >= T @views begin - alpha_new[:, 1:offset] .= alpha[:, 1:offset] - beta_new[:, 1:offset] .= beta[:, 1:offset] + if !last_level + # The destination already has the older prefix up to offset ÷ 2; + # only the newly exposed unchanged segment needs refreshing. + copy_start = offset == 1 ? 1 : (offset >> 1) + 1 + alpha_new[:, copy_start:offset] .= alpha[:, copy_start:offset] + beta_new[:, copy_start:offset] .= beta[:, copy_start:offset] + end - if offset < T - alpha_new[:, (offset + 1):T] .= alpha[:, (offset + 1):T] .* alpha[:, 1:(T - offset)] - beta_new[:, (offset + 1):T] .= - alpha[:, (offset + 1):T] .* beta[:, 1:(T - offset)] .+ beta[:, (offset + 1):T] + alpha_new[:, (offset + 1):T] .= + alpha[:, (offset + 1):T] .* alpha[:, 1:(T - offset)] + beta_new[:, (offset + 1):T] .= + alpha[:, (offset + 1):T] .* beta[:, 1:(T - offset)] .+ + beta[:, (offset + 1):T] + + if last_level + S[:, 1:offset] .= alpha[:, 1:offset] .* s0 .+ beta[:, 1:offset] + S[:, (offset + 1):T] .= + alpha_new[:, (offset + 1):T] .* s0 .+ beta_new[:, (offset + 1):T] + return S end end @@ -144,7 +166,7 @@ function solve_affine_scan_diag!( offset <<= 1 end - S .= alpha .* reshape(s0, D, 1) .+ beta + S .= alpha .* s0 .+ beta return S end @@ -176,7 +198,9 @@ Return the maximum absolute residual of the recurrence Useful for debugging scan correctness independent of DEER. """ -function affine_scan_residual(S::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, s0::AbstractVector) +function affine_scan_residual( + S::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, s0::AbstractVector +) D, T = size(A) size(B) == (D, T) || throw(DimensionMismatch("B must have the same size as A")) size(S) == (D, T) || throw(DimensionMismatch("S must have size (D, T)")) @@ -207,25 +231,27 @@ Returns a named tuple with: This is intended as a lightweight validation helper when plugging the scan into DEER. """ -function check_affine_scan(A::AbstractMatrix, B::AbstractMatrix, s0::AbstractVector; atol=1e-6, rtol=1e-6) +function check_affine_scan( + A::AbstractMatrix, B::AbstractMatrix, s0::AbstractVector; atol=1e-6, rtol=1e-6 +) S_seq = solve_affine_seq(A, B, s0) S_scan = solve_affine_scan_diag(A, B, s0) diff = abs.(S_seq .- S_scan) scale = atol .+ rtol .* abs.(S_seq) - max_abs_err = maximum(diff) - max_rel_err = maximum(diff ./ scale) + max_abs_err = isempty(diff) ? zero(eltype(diff)) : maximum(diff) + max_rel_err = isempty(diff) ? zero(eltype(diff)) : maximum(diff ./ scale) residual_seq = affine_scan_residual(S_seq, A, B, s0) residual_scan = affine_scan_residual(S_scan, A, B, s0) return ( - ok = all(diff .<= scale), - max_abs_err = max_abs_err, - max_rel_err = max_rel_err, - residual_seq = residual_seq, - residual_scan = residual_scan, - S_seq = S_seq, - S_scan = S_scan, + ok=all(diff .<= scale), + max_abs_err=max_abs_err, + max_rel_err=max_rel_err, + residual_seq=residual_seq, + residual_scan=residual_scan, + S_seq=S_seq, + S_scan=S_scan, ) end diff --git a/src/MALA/MALA.jl b/src/MALA/MALA.jl index 9c183bd..abf1016 100644 --- a/src/MALA/MALA.jl +++ b/src/MALA/MALA.jl @@ -3,31 +3,44 @@ module MALA using Random, LinearAlgebra export MALAWorkspace, - logq_mala, logq_mala!, - mala_step_taped, mala_step_taped!, - run_mala_sequential_taped, - mala_proposal, mala_proposal!, - mala_logα, mala_logα!, - mala_accept_indicator, - mala_step_full, - mala_step_with_logα, - mala_step_surrogate_sigmoid, - mala_step_surrogate_sigmoid_jvp, - mala_step_taped_and_jvp, - mala_step_taped_and_jvp!, - mala_step_batched_fwd_and_jvp + MALABatchedWorkspace, + logq_mala, + logq_mala!, + mala_step_taped, + mala_step_taped!, + run_mala_sequential_taped, + mala_proposal, + mala_proposal!, + mala_logα, + mala_logα!, + mala_accept_indicator, + mala_step_full, + mala_step_with_logα, + mala_step_with_logα!, + mala_step_surrogate_sigmoid, + mala_step_surrogate_sigmoid_jvp, + mala_step_taped_and_jvp, + mala_step_taped_and_jvp!, + mala_step_batched, + mala_step_batched!, + mala_step_batched_fwd_and_jvp, + mala_step_batched_fwd_and_jvp! # Preconditioner dispatch helpers. cholM is either nothing (identity) or a Cholesky factor. -_apply_M(g, ::Nothing) = g -_apply_M(g, cholM::Cholesky) = cholM.L * (cholM.L' * g) +_apply_M!(out, g, ::Nothing, tmp=out) = copyto!(out, g) +function _apply_M!(out, g, cholM::Cholesky, tmp) + mul!(tmp, adjoint(cholM.L), g) + mul!(out, cholM.L, tmp) + return out +end -_apply_L(ξ, ::Nothing) = ξ -_apply_L(ξ, cholM::Cholesky) = cholM.L * ξ +_apply_L!(out, ξ, ::Nothing) = (out .= ξ) +_apply_L!(out, ξ, cholM::Cholesky) = mul!(out, cholM.L, ξ) -_quad_Minv(r, ::Nothing) = dot(r, r) -function _quad_Minv(r, cholM::Cholesky) - w = cholM.L \ r - return dot(w, w) +_quad_Minv!(tmp, r, ::Nothing) = dot(r, r) +function _quad_Minv!(tmp, r, cholM::Cholesky) + ldiv!(tmp, cholM.L, r) + return dot(tmp, tmp) end _logdet_M(::Nothing) = false # Bool promotes to any numeric type without widening @@ -50,20 +63,105 @@ struct MALAWorkspace{V} w::V dr::V jvp_out::V + solve_buf::V end -function MALAWorkspace(x_template::AbstractVector) - g_x = similar(x_template) - g_y = similar(x_template) - y = similar(x_template) - μ = similar(x_template) - r = similar(x_template) - Hv_x = similar(x_template) - Hv_y = similar(x_template) - w = similar(x_template) - dr = similar(x_template) - jvp_out = similar(x_template) - return MALAWorkspace(g_x, g_y, y, μ, r, Hv_x, Hv_y, w, dr, jvp_out) +function MALAWorkspace(x::AbstractVector) + return MALAWorkspace( + similar(x), + similar(x), + similar(x), + similar(x), + similar(x), + similar(x), + similar(x), + similar(x), + similar(x), + similar(x), + similar(x), + ) +end + +""" +Reusable scratch buffers for batched MALA primitives. + +Matrix buffers have the same size, eltype, and device placement as `X_template`. +Vector buffers have length `size(X_template, 2)`. +""" +struct MALABatchedWorkspace{M,V,RM,BV,BM} + G_X::M + G_Y::M + Y::M + MG_X::M + MG_Y::M + LΞ::M + R::M + dR::M + Hv_X::M + Hv_Y::M + W::M + M_Hv_X::M + M_Hv_Y::M + Minv_R::M + solve_tmp::M + prod::M + lp_X::V + lp_Y::V + lq_YX::V + lq_XY::V + logα::V + logu::V + g::V + dlogα::V + dg::V + dot1::V + dot2::V + dot3::V + row::RM + accepted::BV + accepted_row::BM +end + +function MALABatchedWorkspace(X::AbstractMatrix) + N = size(X, 2) + mat() = similar(X) + vec() = similar(X, eltype(X), N) + row() = similar(X, eltype(X), 1, N) + accepted = similar(X, Bool, N) + accepted_row = similar(X, Bool, 1, N) + return MALABatchedWorkspace( + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + mat(), + vec(), + vec(), + vec(), + vec(), + vec(), + vec(), + vec(), + vec(), + vec(), + vec(), + vec(), + vec(), + row(), + accepted, + accepted_row, + ) end """ @@ -85,21 +183,17 @@ function logq_mala!( if cholM === nothing @. ws.μ = x + ϵ * gradlogp_x else - tmp = _apply_M(gradlogp_x, cholM) - @. ws.μ = x + ϵ * tmp + _apply_M!(ws.w, gradlogp_x, cholM, ws.solve_buf) + @. ws.μ = x + ϵ * ws.w end @. ws.r = y - ws.μ - return -T(0.5) * _quad_Minv(ws.r, cholM) / (2ϵ) - (T(d) / 2) * log(T(4π) * ϵ) - - T(0.5) * _logdet_M(cholM) + return -T(0.5) * _quad_Minv!(ws.solve_buf, ws.r, cholM) / (2ϵ) - + (T(d) / 2) * log(T(4π) * ϵ) - T(0.5) * _logdet_M(cholM) end function logq_mala( - y::AbstractVector, - x::AbstractVector, - gradlogp_x::AbstractVector, - ϵ::Real; - cholM=nothing, + y::AbstractVector, x::AbstractVector, gradlogp_x::AbstractVector, ϵ::Real; cholM=nothing ) ws = MALAWorkspace(x) return logq_mala!(ws, y, x, gradlogp_x, ϵ; cholM=cholM) @@ -127,20 +221,15 @@ function mala_proposal!( if cholM === nothing @. y = x + ϵ * ws.g_x + sqrt2ϵ * ξ else - Mgx = _apply_M(ws.g_x, cholM) - Lξ = _apply_L(ξ, cholM) - @. y = x + ϵ * Mgx + sqrt2ϵ * Lξ + _apply_M!(ws.μ, ws.g_x, cholM, ws.solve_buf) + _apply_L!(ws.r, ξ, cholM) + @. y = x + ϵ * ws.μ + sqrt2ϵ * ws.r end return y end function mala_proposal( - logp, - gradlogp, - x::AbstractVector, - ϵ::Real, - ξ::AbstractVector; - cholM=nothing, + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector; cholM=nothing ) ws = MALAWorkspace(x) y = similar(x) @@ -172,12 +261,7 @@ function mala_logα!( end function mala_logα( - logp, - gradlogp, - x::AbstractVector, - y::AbstractVector, - ϵ::Real; - cholM=nothing, + logp, gradlogp, x::AbstractVector, y::AbstractVector, ϵ::Real; cholM=nothing ) ws = MALAWorkspace(x) return mala_logα!(ws, logp, gradlogp, x, y, ϵ; cholM=cholM) @@ -221,13 +305,7 @@ function mala_step_taped!( end function mala_step_taped( - logp, - gradlogp, - x::AbstractVector, - ϵ::Real, - ξ::AbstractVector, - u::Real; - cholM=nothing, + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; cholM=nothing ) ws = MALAWorkspace(x) x_next = similar(x) @@ -272,13 +350,7 @@ Primal accept indicator for a taped MALA step. Returns a float in {0, 1} matching the precision of `u`. """ function mala_accept_indicator( - logp, - gradlogp, - x::AbstractVector, - ϵ::Real, - ξ::AbstractVector, - u::Real; - cholM=nothing, + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; cholM=nothing ) ws = MALAWorkspace(x) mala_proposal!(ws.y, ws, logp, gradlogp, x, ϵ, ξ; cholM=cholM) @@ -288,19 +360,23 @@ function mala_accept_indicator( end function mala_step_full( - logp, - gradlogp, - x::AbstractVector, - ϵ::Real, - ξ::AbstractVector, - u::Real; - cholM=nothing, + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; cholM=nothing ) x_next, accepted, _ = mala_step_with_logα(logp, gradlogp, x, ϵ, ξ, u; cholM=cholM) return x_next, accepted end function mala_step_with_logα( + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; cholM=nothing +) + ws = MALAWorkspace(x) + x_next = similar(x) + return mala_step_with_logα!(x_next, ws, logp, gradlogp, x, ϵ, ξ, u; cholM=cholM) +end + +function mala_step_with_logα!( + x_next::AbstractVector, + ws::MALAWorkspace, logp, gradlogp, x::AbstractVector, @@ -310,9 +386,12 @@ function mala_step_with_logα( cholM=nothing, ) length(x) == length(ξ) || throw(DimensionMismatch("x and ξ must have the same length")) + length(x_next) == length(x) || + throw(DimensionMismatch("x_next must have the same length as x")) + length(ws.y) == length(x) || + throw(DimensionMismatch("workspace must match length of x")) 0.0 < u < 1.0 || throw(ArgumentError("u must be in (0, 1)")) - ws = MALAWorkspace(x) mala_proposal!(ws.y, ws, logp, gradlogp, x, ϵ, ξ; cholM=cholM) logp_x = logp(x) logp_y = logp(ws.y) @@ -323,7 +402,11 @@ function mala_step_with_logα( logα = (logp_y + logq_xy) - (logp_x + logq_yx) accepted = log(u) < logα - x_next = accepted ? copy(ws.y) : copy(x) + if accepted + copyto!(x_next, ws.y) + else + copyto!(x_next, x) + end return x_next, accepted, logα end @@ -331,13 +414,7 @@ end Stop-gradient surrogate step used for Jacobians. """ function mala_step_surrogate_sigmoid( - logp, - gradlogp, - x::AbstractVector, - ϵ::Real, - ξ::AbstractVector, - u::Real; - cholM=nothing, + logp, gradlogp, x::AbstractVector, ϵ::Real, ξ::AbstractVector, u::Real; cholM=nothing ) ws = MALAWorkspace(x) mala_proposal!(ws.y, ws, logp, gradlogp, x, ϵ, ξ; cholM=cholM) @@ -350,38 +427,67 @@ function mala_step_surrogate_sigmoid( end # Apply mass matrix to a D×N matrix of gradient columns. -_apply_M_batched(G::AbstractMatrix, ::Nothing) = G -_apply_M_batched(G::AbstractMatrix, cholM::Cholesky) = cholM.L * (cholM.L' * G) - -_apply_L_batched(Ξ::AbstractMatrix, ::Nothing) = Ξ -_apply_L_batched(Ξ::AbstractMatrix, cholM::Cholesky) = cholM.L * Ξ - -function _quad_Minv_batched(R::AbstractMatrix, ::Nothing) - return vec(sum(abs2, R; dims=1)) +_apply_M_batched!(out, G, ::Nothing, tmp=out) = copyto!(out, G) +function _apply_M_batched!(out, G, cholM::Cholesky, tmp) + mul!(tmp, adjoint(cholM.L), G) + mul!(out, cholM.L, tmp) + return out end -function _quad_Minv_batched(R::AbstractMatrix, cholM::Cholesky) - W = cholM.L \ R - return vec(sum(abs2, W; dims=1)) +_apply_L_batched!(out, Ξ, ::Nothing) = copyto!(out, Ξ) +_apply_L_batched!(out, Ξ, cholM::Cholesky) = mul!(out, cholM.L, Ξ) + +function _sum_columns!(out::AbstractVector, row::AbstractMatrix, A::AbstractMatrix) + sum!(row, A) + copyto!(out, row) + return out end -function logq_mala_batched( +function logq_mala_batched!( + out::AbstractVector, + ws::MALABatchedWorkspace, Y::AbstractMatrix, X::AbstractMatrix, gradlogp_X::AbstractMatrix, ε::Real; cholM=nothing, ) - T = typeof(ε) - D = size(X, 1) - μ = X .+ ε .* _apply_M_batched(gradlogp_X, cholM) - R = Y .- μ - q = _quad_Minv_batched(R, cholM) - ldet = _logdet_M(cholM) - return @. -T(0.5) * q / (2ε) - (T(D) / 2) * log(T(4π) * ε) - T(0.5) * ldet + D, N = size(X) + size(Y) == (D, N) || throw(DimensionMismatch("X and Y must have the same size")) + size(gradlogp_X) == (D, N) || + throw(DimensionMismatch("gradlogp_X must have the same size as X")) + length(out) == N || throw(DimensionMismatch("out must have length size(X,2)")) + + εT = convert(eltype(X), ε) + if cholM === nothing + @. ws.R = Y - X - εT * gradlogp_X + @. ws.R = abs2(ws.R) + _sum_columns!(out, ws.row, ws.R) + else + _apply_M_batched!(ws.MG_X, gradlogp_X, cholM, ws.solve_tmp) + @. ws.R = Y - X - εT * ws.MG_X + ldiv!(ws.Minv_R, cholM.L, ws.R) + @. ws.Minv_R = abs2(ws.Minv_R) + _sum_columns!(out, ws.row, ws.Minv_R) + end + + FT = typeof(εT) + ldet = cholM === nothing ? zero(FT) : FT(_logdet_M(cholM)) + c = (FT(D) / 2) * log(FT(4π) * εT) + FT(0.5) * ldet + @. out = -FT(0.5) * out / (2εT) - c + return out end -function mala_step_batched( +function logq_mala_batched(Y, X, gradlogp_X, ε; cholM=nothing) + ws = MALABatchedWorkspace(X) + out = similar(X, eltype(X), size(X, 2)) + return logq_mala_batched!(out, ws, Y, X, gradlogp_X, ε; cholM=cholM) +end + +function mala_step_batched!( + X_next::AbstractMatrix, + accepted::AbstractVector, + ws::MALABatchedWorkspace, logp_batch, gradlogp_batch, X::AbstractMatrix, @@ -392,26 +498,53 @@ function mala_step_batched( ) D, N = size(X) size(Ξ) == (D, N) || throw(DimensionMismatch("X and Ξ must have the same size")) + size(X_next) == (D, N) || + throw(DimensionMismatch("X_next must have the same size as X")) length(u) == N || throw(DimensionMismatch("u must have length N = size(X,2)")) + length(accepted) == N || + throw(DimensionMismatch("accepted must have length N = size(X,2)")) - ε_T = eltype(X)(ε) + εT = convert(eltype(X), ε) + sqrt2εT = sqrt(2 * εT) - G_X = gradlogp_batch(X) - Y = X .+ ε_T .* _apply_M_batched(G_X, cholM) .+ sqrt(2 * ε_T) .* _apply_L_batched(Ξ, cholM) + copyto!(ws.G_X, gradlogp_batch(X)) + _apply_M_batched!(ws.MG_X, ws.G_X, cholM, ws.solve_tmp) + _apply_L_batched!(ws.LΞ, Ξ, cholM) + @. ws.Y = X + εT * ws.MG_X + sqrt2εT * ws.LΞ - lp_X = logp_batch(X) - lp_Y = logp_batch(Y) - G_Y = gradlogp_batch(Y) + copyto!(ws.lp_X, logp_batch(X)) + copyto!(ws.lp_Y, logp_batch(ws.Y)) + copyto!(ws.G_Y, gradlogp_batch(ws.Y)) - lq_YX = logq_mala_batched(Y, X, G_X, ε_T; cholM=cholM) - lq_XY = logq_mala_batched(X, Y, G_Y, ε_T; cholM=cholM) + logq_mala_batched!(ws.lq_YX, ws, ws.Y, X, ws.G_X, εT; cholM=cholM) + logq_mala_batched!(ws.lq_XY, ws, X, ws.Y, ws.G_Y, εT; cholM=cholM) - logα = @. (lp_Y + lq_XY) - (lp_X + lq_YX) - accepted = @. log(u) < logα + @. ws.logα = (ws.lp_Y + ws.lq_XY) - (ws.lp_X + ws.lq_YX) + @. ws.logu = log(u) + @. accepted = ws.logu < ws.logα + copyto!(ws.accepted_row, accepted) + @. X_next = ifelse(ws.accepted_row, ws.Y, X) + return X_next, accepted +end - mask = reshape(accepted, 1, N) - X_next = @. ifelse(mask, Y, X) - return X_next, vec(accepted) +function mala_step_batched( + logp_batch, + gradlogp_batch, + X::AbstractMatrix, + ε::Real, + Ξ::AbstractMatrix, + u::AbstractVector; + cholM=nothing, +) + D, N = size(X) + size(Ξ) == (D, N) || throw(DimensionMismatch("X and Ξ must have the same size")) + length(u) == N || throw(DimensionMismatch("u must have length N = size(X,2)")) + ws = MALABatchedWorkspace(X) + X_next = similar(X) + accepted = similar(X, Bool, N) + return mala_step_batched!( + X_next, accepted, ws, logp_batch, gradlogp_batch, X, ε, Ξ, u; cholM=cholM + ) end """ @@ -420,7 +553,10 @@ Batched fused surrogate forward step and JVP over columns. `X`, `Ξ`, and `Z` are all D×T, and `u` has length T. Returns `(FT, Jt)` where both outputs are D×T. """ -function mala_step_batched_fwd_and_jvp( +function mala_step_batched_fwd_and_jvp!( + FT::AbstractMatrix, + Jt::AbstractMatrix, + ws::MALABatchedWorkspace, logp_batch, gradlogp_batch, hvp_batch, @@ -434,53 +570,99 @@ function mala_step_batched_fwd_and_jvp( D, T = size(X) size(Ξ) == (D, T) || throw(DimensionMismatch("X and Ξ must have the same size")) size(Z) == (D, T) || throw(DimensionMismatch("X and Z must have the same size")) - length(u) == T || throw(DimensionMismatch("u must have length T")) + size(FT) == (D, T) || throw(DimensionMismatch("FT must have the same size as X")) + size(Jt) == (D, T) || throw(DimensionMismatch("Jt must have the same size as X")) + length(u) == T || throw(DimensionMismatch("u must have length size(X,2)")) εT = convert(eltype(X), ε) oneT = one(eltype(X)) - sqrt2εT = sqrt(2 * εT) - G_X = gradlogp_batch(X) - Y = X .+ εT .* _apply_M_batched(G_X, cholM) .+ sqrt2εT .* _apply_L_batched(Ξ, cholM) + copyto!(ws.G_X, gradlogp_batch(X)) + _apply_M_batched!(ws.MG_X, ws.G_X, cholM, ws.solve_tmp) + _apply_L_batched!(ws.LΞ, Ξ, cholM) + @. ws.Y = X + εT * ws.MG_X + sqrt(2 * εT) * ws.LΞ - lp_X = logp_batch(X) - lp_Y = logp_batch(Y) - G_Y = gradlogp_batch(Y) + copyto!(ws.lp_X, logp_batch(X)) + copyto!(ws.lp_Y, logp_batch(ws.Y)) + copyto!(ws.G_Y, gradlogp_batch(ws.Y)) - lq_YX = logq_mala_batched(Y, X, G_X, εT; cholM=cholM) - lq_XY = logq_mala_batched(X, Y, G_Y, εT; cholM=cholM) - logα = @. (lp_Y + lq_XY) - (lp_X + lq_YX) + logq_mala_batched!(ws.lq_YX, ws, ws.Y, X, ws.G_X, εT; cholM=cholM) + logq_mala_batched!(ws.lq_XY, ws, X, ws.Y, ws.G_Y, εT; cholM=cholM) - u_like = similar(logα, eltype(logα), T) - copyto!(u_like, u) - logu = log.(u_like) + @. ws.logα = (ws.lp_Y + ws.lq_XY) - (ws.lp_X + ws.lq_YX) + @. ws.logu = log(u) + @. ws.g = oneT / (oneT + exp(-(ws.logα - ws.logu))) + @. ws.accepted = ws.logu < ws.logα + copyto!(ws.accepted_row, ws.accepted) + @. FT = ifelse(ws.accepted_row, ws.Y, X) - g̃ = logα .- logu - g = @. oneT / (oneT + exp(-g̃)) - g_row = reshape(g, 1, T) + copyto!(ws.Hv_X, hvp_batch(X, Z)) + _apply_M_batched!(ws.M_Hv_X, ws.Hv_X, cholM, ws.solve_tmp) + @. ws.W = Z + εT * ws.M_Hv_X - accepted = @. logu < logα - accepted_row = reshape(accepted, 1, T) - FT = @. ifelse(accepted_row, Y, X) + copyto!(ws.Hv_Y, hvp_batch(ws.Y, ws.W)) + _apply_M_batched!(ws.MG_Y, ws.G_Y, cholM, ws.solve_tmp) + _apply_M_batched!(ws.M_Hv_Y, ws.Hv_Y, cholM, ws.solve_tmp) + @. ws.R = X - ws.Y - εT * ws.MG_Y + @. ws.dR = Z - ws.W - εT * ws.M_Hv_Y - Hv_X = hvp_batch(X, Z) - W = Z .+ εT .* _apply_M_batched(Hv_X, cholM) + if cholM === nothing + @. ws.prod = ws.R * ws.dR + else + ldiv!(ws.Minv_R, cholM, ws.R) + @. ws.prod = ws.Minv_R * ws.dR + end + _sum_columns!(ws.dot3, ws.row, ws.prod) + @. ws.prod = ws.G_Y * ws.W + _sum_columns!(ws.dot1, ws.row, ws.prod) + @. ws.prod = ws.G_X * Z + _sum_columns!(ws.dot2, ws.row, ws.prod) - Hv_Y = hvp_batch(Y, W) - R = X .- Y .- εT .* _apply_M_batched(G_Y, cholM) - dR = Z .- W .- εT .* _apply_M_batched(Hv_Y, cholM) + @. ws.dlogα = ws.dot1 - ws.dot2 - inv(2 * εT) * ws.dot3 + @. ws.dg = ws.g * (oneT - ws.g) * ws.dlogα - Minv_R = cholM === nothing ? R : (cholM \ R) + copyto!(ws.row, ws.dg) + @. Jt = ifelse(ws.accepted_row, ws.W, Z) + (ws.Y - X) * ws.row - dlogα = vec(sum(G_Y .* W; dims=1)) .- - vec(sum(G_X .* Z; dims=1)) .- - inv(2 * εT) .* vec(sum(Minv_R .* dR; dims=1)) + return FT, Jt +end - dg = g .* (oneT .- g) .* dlogα - dg_row = reshape(dg, 1, T) +function mala_step_batched_fwd_and_jvp!( + FT::AbstractMatrix, + Jt::AbstractMatrix, + logp_batch, + gradlogp_batch, + hvp_batch, + X::AbstractMatrix, + ε::Real, + Ξ::AbstractMatrix, + u::AbstractVector, + Z::AbstractMatrix; + cholM=nothing, +) + ws = MALABatchedWorkspace(X) + return mala_step_batched_fwd_and_jvp!( + FT, Jt, ws, logp_batch, gradlogp_batch, hvp_batch, X, ε, Ξ, u, Z; cholM=cholM + ) +end - Jt = @. ifelse(accepted_row, W, Z) + (Y - X) * dg_row - return FT, Jt +function mala_step_batched_fwd_and_jvp( + logp_batch, + gradlogp_batch, + hvp_batch, + X::AbstractMatrix, + ε::Real, + Ξ::AbstractMatrix, + u::AbstractVector, + Z::AbstractMatrix; + cholM=nothing, +) + FT = similar(X) + Jt = similar(X) + ws = MALABatchedWorkspace(X) + return mala_step_batched_fwd_and_jvp!( + FT, Jt, ws, logp_batch, gradlogp_batch, hvp_batch, X, ε, Ξ, u, Z; cholM=cholM + ) end """ @@ -515,8 +697,8 @@ function mala_step_surrogate_sigmoid_jvp( if cholM === nothing @. ws.w = v + ε * ws.Hv_x else - MHv_x = _apply_M(ws.Hv_x, cholM) - @. ws.w = v + ε * MHv_x + _apply_M!(ws.μ, ws.Hv_x, cholM, ws.solve_buf) + @. ws.w = v + ε * ws.μ end copyto!(ws.Hv_y, hvp_fn(ws.y, ws.w)) @@ -525,11 +707,12 @@ function mala_step_surrogate_sigmoid_jvp( @. ws.dr = v - ws.w - ε * ws.Hv_y Minv_r = ws.r else - Mg_y = _apply_M(ws.g_y, cholM) - MHv_y = _apply_M(ws.Hv_y, cholM) - @. ws.r = x - ws.y - ε * Mg_y - @. ws.dr = v - ws.w - ε * MHv_y - Minv_r = cholM \ ws.r + _apply_M!(ws.μ, ws.g_y, cholM, ws.solve_buf) + _apply_M!(ws.jvp_out, ws.Hv_y, cholM, ws.solve_buf) + @. ws.r = x - ws.y - ε * ws.μ + @. ws.dr = v - ws.w - ε * ws.jvp_out + ldiv!(ws.solve_buf, cholM, ws.r) + Minv_r = ws.solve_buf end dlogα = dot(ws.g_y, ws.w) - dot(ws.g_x, v) - inv(2 * ε) * dot(Minv_r, ws.dr) @@ -584,8 +767,8 @@ function mala_step_taped_and_jvp!( if cholM === nothing @. ws.w = v + ε * ws.Hv_x else - MHv_x = _apply_M(ws.Hv_x, cholM) - @. ws.w = v + ε * MHv_x + _apply_M!(ws.μ, ws.Hv_x, cholM, ws.solve_buf) + @. ws.w = v + ε * ws.μ end copyto!(ws.Hv_y, hvp_fn(ws.y, ws.w)) @@ -594,11 +777,12 @@ function mala_step_taped_and_jvp!( @. ws.dr = v - ws.w - ε * ws.Hv_y Minv_r = ws.r else - Mg_y = _apply_M(ws.g_y, cholM) - MHv_y = _apply_M(ws.Hv_y, cholM) - @. ws.r = x - ws.y - ε * Mg_y - @. ws.dr = v - ws.w - ε * MHv_y - Minv_r = cholM \ ws.r + _apply_M!(ws.μ, ws.g_y, cholM, ws.solve_buf) + _apply_M!(ws.jvp_out, ws.Hv_y, cholM, ws.solve_buf) + @. ws.r = x - ws.y - ε * ws.μ + @. ws.dr = v - ws.w - ε * ws.jvp_out + ldiv!(ws.solve_buf, cholM, ws.r) + Minv_r = ws.solve_buf end dlogα = dot(ws.g_y, ws.w) - dot(ws.g_x, v) - inv(2 * ε) * dot(Minv_r, ws.dr) @@ -623,18 +807,7 @@ function mala_step_taped_and_jvp( x_next = similar(x) jvp_out = similar(x) mala_step_taped_and_jvp!( - x_next, - jvp_out, - ws, - logp, - gradlogp, - x, - ε, - ξ, - u, - v, - hvp_fn; - cholM=cholM, + x_next, jvp_out, ws, logp, gradlogp, x, ε, ξ, u, v, hvp_fn; cholM=cholM ) return x_next, jvp_out end diff --git a/src/interface.jl b/src/interface.jl index a61ae1a..d68efc6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -56,9 +56,14 @@ function DensityModel( hvp_batch=nothing, ) return DensityModel( - logdensity, grad_logdensity, hvp, - logdensity_batch, grad_logdensity_batch, hvp_batch, - dim, param_names, + logdensity, + grad_logdensity, + hvp, + logdensity_batch, + grad_logdensity_batch, + hvp_batch, + dim, + param_names, ) end @@ -82,17 +87,46 @@ function MALASampler(epsilon::Real; cholM=nothing) return MALASampler{typeof(eps_f),typeof(cholM)}(eps_f, cholM) end -struct MALAState{V<:AbstractVector,L<:Real} +""" +State for a `MALASampler` chain. +""" +struct MALAState{V<:AbstractVector,L<:Real,W,NV<:AbstractVector,H} x::V logp::L + workspace::W + noise::NV + noise_host::H end +""" +One `MALASampler` sample: parameter vector `x`, its log-density `logp`, and an +accept/reject flag. +""" struct MALATransition{V<:AbstractVector,L<:Real} x::V logp::L accepted::Bool end +function _make_noise_buffer(x::AbstractVector, ::Type{FP}, D::Int) where {FP} + ξ = similar(x, FP, D) + host = ξ isa CUDA.CuArray ? Vector{FP}(undef, D) : nothing + return ξ, host +end + +function _randn_like!( + rng::Random.AbstractRNG, ξ::AbstractVector{FP}, host::Union{Nothing,AbstractVector{FP}} +) where {FP} + if ξ isa CUDA.CuArray + host === nothing && error("CuArray normal noise requires a reusable host buffer") + randn!(rng, host) + copyto!(ξ, host) + else + randn!(rng, ξ) + end + return ξ +end + function AbstractMCMC.step( rng::Random.AbstractRNG, model::DensityModel, @@ -106,8 +140,10 @@ function AbstractMCMC.step( randn(rng, FP, model.dim) end logp_val = model.logdensity(x) + ws = MALA.MALAWorkspace(x) + noise, noise_host = _make_noise_buffer(x, FP, model.dim) t = MALATransition(x, logp_val, true) - s = MALAState(x, logp_val) + s = MALAState(x, logp_val, ws, noise, noise_host) return t, s end @@ -122,16 +158,25 @@ function AbstractMCMC.step( ϵ = sampler.epsilon D = model.dim - ξ = randn(rng, eltype(x), D) + ξ = _randn_like!(rng, state.noise, state.noise_host) u = rand(rng) - x_next, accepted = MALA.mala_step_full( - model.logdensity, model.grad_logdensity, x, ϵ, ξ, u; cholM=sampler.cholM + x_next = similar(x) + x_next, accepted, _ = MALA.mala_step_with_logα!( + x_next, + state.workspace, + model.logdensity, + model.grad_logdensity, + x, + ϵ, + ξ, + u; + cholM=sampler.cholM, ) logp_val = accepted ? model.logdensity(x_next) : state.logp t = MALATransition(x_next, logp_val, accepted) - s = MALAState(x_next, logp_val) + s = MALAState(x_next, logp_val, state.workspace, state.noise, state.noise_host) return t, s end @@ -190,6 +235,9 @@ end ParallelMALASampler(epsilon; T, maxiter, tol_abs, tol_rel, jacobian, damping, probes, cholM, backend) DEER-parallelized MALA sampler. + +Supported Jacobian modes are `:stoch_diag` (the default Hutchinson diagonal +estimator) and `:diag` (exact diagonal via `D` JVPs). """ struct ParallelMALASampler{FP<:AbstractFloat,CM,AD} <: AbstractMCMC.AbstractSampler epsilon::FP @@ -217,6 +265,8 @@ function ParallelMALASampler( backend=DEER.DEFAULT_BACKEND, ) epsilon > 0 || throw(ArgumentError("epsilon must be > 0, got $epsilon")) + (jacobian === :stoch_diag || jacobian === :diag) || + throw(ArgumentError("jacobian must be :stoch_diag or :diag")) eps_f = float(epsilon) FP = typeof(eps_f) return ParallelMALASampler{FP,typeof(cholM),typeof(backend)}( @@ -254,6 +304,53 @@ struct ParallelMALATransition{V<:AbstractVector,L<:Real} logp::L end +struct ParallelMALABlockSamples{B<:AbstractVector,L<:AbstractVector} <: + AbstractVector{ParallelMALATransition} + blocks::B + logps::L + n::Int + T::Int +end + +Base.size(samples::ParallelMALABlockSamples) = (samples.n,) +Base.length(samples::ParallelMALABlockSamples) = samples.n +Base.IndexStyle(::Type{<:ParallelMALABlockSamples}) = IndexLinear() + +function Base.getindex(samples::ParallelMALABlockSamples, i::Int) + 1 <= i <= samples.n || throw(BoundsError(samples, i)) + block_idx = fld(i - 1, samples.T) + 1 + t = mod(i - 1, samples.T) + 1 + return ParallelMALATransition( + samples.blocks[block_idx][:, t], samples.logps[block_idx][t] + ) +end + +function _make_mala_tape_block( + rng::Random.AbstractRNG, x0::AbstractVector, ::Type{FP}, D::Int, T::Int +) where {FP} + Xi = similar(x0, FP, D, T) + U_host = Vector{FP}(undef, T) + + if Xi isa CUDA.CuArray + Xi_host = Matrix{FP}(undef, D, T) + for t in 1:T + randn!(rng, view(Xi_host, :, t)) + U_host[t] = FP(rand(rng)) + end + copyto!(Xi, Xi_host) + else + for t in 1:T + randn!(rng, view(Xi, :, t)) + U_host[t] = FP(rand(rng)) + end + end + + U = similar(x0, FP, T) + copyto!(U, U_host) + tape = [MALATapeElement(view(Xi, :, t), U_host[t]) for t in 1:T] + return tape, Xi, U +end + function _build_mala_deer_rec( model::DensityModel, ε::Real, @@ -261,6 +358,8 @@ function _build_mala_deer_rec( x0_like::AbstractVector; cholM=nothing, backend=DEER.DEFAULT_BACKEND, + tape_noise=nothing, + tape_uniforms=nothing, ) logp = model.logdensity gradlogp = model.grad_logdensity @@ -272,9 +371,7 @@ function _build_mala_deer_rec( else hvp_backend = DEER._hvp_second_order_backend(backend) prep_hvp = DEER._prepare_logdensity_hvp(logp, hvp_backend, x0_like) - (pt, dir) -> DEER._logdensity_hvp_prepared( - logp, prep_hvp, hvp_backend, pt, dir - ) + (pt, dir) -> DEER._logdensity_hvp_prepared(logp, prep_hvp, hvp_backend, pt, dir) end # Exact forward step. @@ -282,63 +379,85 @@ function _build_mala_deer_rec( (x, te) -> MALA.mala_step_taped(logp, gradlogp, x, ε, te.ξ, te.u; cholM=cholM) # Explicit analytical JVP of the surrogate step. - jvp = (x, te, v) -> MALA.mala_step_surrogate_sigmoid_jvp( - logp, gradlogp, x, ε, te.ξ, te.u, v, hvp_fn; cholM=cholM - ) + jvp = + (x, te, v) -> MALA.mala_step_surrogate_sigmoid_jvp( + logp, gradlogp, x, ε, te.ξ, te.u, v, hvp_fn; cholM=cholM + ) # Fused forward step + JVP. - fwd_and_jvp = (x, te, v) -> MALA.mala_step_taped_and_jvp( - logp, gradlogp, x, ε, te.ξ, te.u, v, hvp_fn; cholM=cholM - ) - - fwd_and_jvp_batch = if model.logdensity_batch !== nothing && - model.grad_logdensity_batch !== nothing - D = length(x0_like) - T = length(tape) + fwd_and_jvp = + (x, te, v) -> MALA.mala_step_taped_and_jvp( + logp, gradlogp, x, ε, te.ξ, te.u, v, hvp_fn; cholM=cholM + ) - Ξ = similar(x0_like, D, T) - U_host = Vector{typeof(tape[1].u)}(undef, T) - U = similar(x0_like, typeof(tape[1].u), T) + fwd_and_jvp_batch = + if model.logdensity_batch !== nothing && model.grad_logdensity_batch !== nothing + D = length(x0_like) + T = length(tape) - for t in 1:T - copyto!(view(Ξ, :, t), tape[t].ξ) - U_host[t] = tape[t].u - end - copyto!(U, U_host) - - X_template = similar(x0_like, D, T) - fill!(X_template, zero(eltype(X_template))) - hvp_batch = if model.hvp_batch !== nothing - model.hvp_batch - else - hvp_batch_backend = DEER._hvp_backend(backend) - prep_hvp_batch = DEER._prepare_batch_hvp_from_grad( - model.grad_logdensity_batch, hvp_batch_backend, X_template + (tape_noise === nothing) == (tape_uniforms === nothing) || throw( + ArgumentError("tape_noise and tape_uniforms must be provided together") ) - (X, V) -> DEER._batch_hvp_from_grad_prepared( - model.grad_logdensity_batch, prep_hvp_batch, hvp_batch_backend, X, V + Xi, U = if tape_noise === nothing + Xi_local = similar(x0_like, D, T) + U_host = Vector{typeof(tape[1].u)}(undef, T) + U_local = similar(x0_like, typeof(tape[1].u), T) + + for t in 1:T + copyto!(Xi_local, (t - 1) * D + 1, tape[t].ξ, 1, D) + U_host[t] = tape[t].u + end + copyto!(U_local, U_host) + Xi_local, U_local + else + size(tape_noise) == (D, T) || + throw(DimensionMismatch("tape_noise must have size (D, T)")) + length(tape_uniforms) == T || + throw(DimensionMismatch("tape_uniforms must have length T")) + tape_noise, tape_uniforms + end + + X_template = similar(x0_like, D, T) + fill!(X_template, zero(eltype(X_template))) + batch_ws = MALA.MALABatchedWorkspace(X_template) + FT = similar(X_template) + Jt = similar(X_template) + hvp_batch = if model.hvp_batch !== nothing + model.hvp_batch + else + hvp_batch_backend = DEER._hvp_backend(backend) + prep_hvp_batch = DEER._prepare_batch_hvp_from_grad( + model.grad_logdensity_batch, hvp_batch_backend, X_template + ) + (X, V) -> DEER._batch_hvp_from_grad_prepared( + model.grad_logdensity_batch, + prep_hvp_batch, + hvp_batch_backend, + X, + V, + ) + end + + (Xbar, Z) -> MALA.mala_step_batched_fwd_and_jvp!( + FT, + Jt, + batch_ws, + model.logdensity_batch, + model.grad_logdensity_batch, + hvp_batch, + Xbar, + ε, + Xi, + U, + Z; + cholM=cholM, ) + else + nothing end - (Xbar, Z) -> MALA.mala_step_batched_fwd_and_jvp( - model.logdensity_batch, - model.grad_logdensity_batch, - hvp_batch, - Xbar, - ε, - Ξ, - U, - Z; - cholM=cholM, - ) - else - nothing - end - return DEER.TapedRecursion( - step_fwd, jvp, tape; - fwd_and_jvp=fwd_and_jvp, - fwd_and_jvp_batch=fwd_and_jvp_batch, + step_fwd, jvp, tape; fwd_and_jvp=fwd_and_jvp, fwd_and_jvp_batch=fwd_and_jvp_batch ) end @@ -353,13 +472,17 @@ function _deer_solve_new_tape( T = sampler.T FP = typeof(sampler.epsilon) - tape = map(1:T) do _ - ξ = copyto!(similar(x0, D), randn(rng, FP, D)) - MALATapeElement(ξ, FP(rand(rng))) - end + tape, Xi, U = _make_mala_tape_block(rng, x0, FP, D, T) rec = _build_mala_deer_rec( - model, sampler.epsilon, tape, x0; cholM=sampler.cholM, backend=sampler.backend + model, + sampler.epsilon, + tape, + x0; + cholM=sampler.cholM, + backend=sampler.backend, + tape_noise=Xi, + tape_uniforms=U, ) ws = workspace === nothing ? DEER.DEERWorkspace(x0, T) : workspace @@ -375,6 +498,7 @@ function _deer_solve_new_tape( probes=sampler.probes, rng=rng, workspace=ws, + copy_result=false, ) return S, tape, ws end @@ -388,6 +512,251 @@ function _trajectory_logps(model::DensityModel, S::AbstractMatrix) return [model.logdensity(S[:, t]) for t in 1:T] end +function _parallel_mala_param_names(model::DensityModel, D::Int, param_names) + if param_names !== nothing + return param_names + elseif model.param_names !== nothing + return model.param_names + else + return [Symbol("x[$i]") for i in 1:D] + end +end + +function _parallel_mala_initial_x( + rng::Random.AbstractRNG, model::DensityModel, ::ParallelMALASampler{FP}, initial_params +) where {FP} + return initial_params !== nothing ? copy(initial_params) : randn(rng, FP, model.dim) +end + +function _parallel_mala_progress(progress, progressname) + progress === true && return AbstractMCMC.CreateNewProgressBar(progressname) + progress === false && return AbstractMCMC.NoLogging() + return progress +end + +function _parallel_mala_update_progress!( + progress, nsteps::Int, Ntotal::Int, next_update::Real, threshold::Real +) + if nsteps >= next_update && next_update <= Ntotal + AbstractMCMC.update_progress!(progress, min(nsteps / Ntotal, 1)) + next_update += threshold + end + return next_update +end + +function _copy_trajectory_rows!( + vals::AbstractMatrix{Float64}, first_row::Int, S::AbstractMatrix, ncols::Int +) + rows = first_row:(first_row + ncols - 1) + S_host = Array(view(S, :, 1:ncols)) + vals[rows, :] .= transpose(S_host) + return vals +end + +function _sample_parallel_mala_chain( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::ParallelMALASampler, + N::Int; + initial_params=nothing, + param_names=nothing, + progress=AbstractMCMC.PROGRESS[], + progressname="Sampling", +) + D = model.dim + names = _parallel_mala_param_names(model, D, param_names) + internal_names = [:logp] + + vals = Matrix{Float64}(undef, N, D) + internals = Matrix{Float64}(undef, N, 1) + + progress = _parallel_mala_progress(progress, progressname) + x0 = _parallel_mala_initial_x(rng, model, sampler, initial_params) + ws = nothing + nsteps = 0 + next_update = N / AbstractMCMC.get_n_updates(progress) + threshold = next_update + + AbstractMCMC.@maybewithricherlogger begin + AbstractMCMC.init_progress!(progress) + try + while nsteps < N + S, tape, ws = _deer_solve_new_tape(rng, model, sampler, x0; workspace=ws) + logps = _trajectory_logps(model, S) + nkeep = min(sampler.T, N - nsteps) + first_row = nsteps + 1 + rows = first_row:(first_row + nkeep - 1) + + _copy_trajectory_rows!(vals, first_row, S, nkeep) + internals[rows, 1] .= view(logps, 1:nkeep) + + x0 = copy(view(S, :, nkeep)) + nsteps += nkeep + next_update = _parallel_mala_update_progress!( + progress, nsteps, N, next_update, threshold + ) + end + finally + AbstractMCMC.finish_progress!(progress) + end + end + + return MCMCChains.Chains( + hcat(vals, internals), + vcat(names, internal_names), + Dict(:parameters => names, :internals => internal_names), + ) +end + +function _sample_parallel_mala_blocks( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::ParallelMALASampler, + N::Int; + initial_params=nothing, + progress=AbstractMCMC.PROGRESS[], + progressname="Sampling", +) + blocks = Vector{AbstractMatrix}(undef, 0) + logp_blocks = Vector{AbstractVector}(undef, 0) + sizehint!(blocks, cld(N, sampler.T)) + sizehint!(logp_blocks, cld(N, sampler.T)) + + progress = _parallel_mala_progress(progress, progressname) + x0 = _parallel_mala_initial_x(rng, model, sampler, initial_params) + ws = nothing + nsteps = 0 + next_update = N / AbstractMCMC.get_n_updates(progress) + threshold = next_update + final_state = nothing + + AbstractMCMC.@maybewithricherlogger begin + AbstractMCMC.init_progress!(progress) + try + while nsteps < N + S, tape, ws = _deer_solve_new_tape(rng, model, sampler, x0; workspace=ws) + logps = _trajectory_logps(model, S) + nkeep = min(sampler.T, N - nsteps) + S_keep = copy(view(S, :, 1:nkeep)) + logps_keep = collect(view(logps, 1:nkeep)) + x0 = copy(view(S_keep, :, nkeep)) + + push!(blocks, S_keep) + push!(logp_blocks, logps_keep) + final_state = ParallelMALAState( + x0, logps_keep[nkeep], S_keep, logps_keep, ws, tape, nkeep + ) + + nsteps += nkeep + next_update = _parallel_mala_update_progress!( + progress, nsteps, N, next_update, threshold + ) + end + finally + AbstractMCMC.finish_progress!(progress) + end + end + + return ParallelMALABlockSamples(blocks, logp_blocks, N, sampler.T), final_state +end + +function _default_parallel_mala_mcmcsample( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::ParallelMALASampler, + N::Integer; + kwargs..., +) + return invoke( + AbstractMCMC.mcmcsample, + Tuple{ + Random.AbstractRNG, + AbstractMCMC.AbstractModel, + AbstractMCMC.AbstractSampler, + Integer, + }, + rng, + model, + sampler, + N; + kwargs..., + ) +end + +function AbstractMCMC.mcmcsample( + rng::Random.AbstractRNG, + model::DensityModel, + sampler::ParallelMALASampler, + N::Integer; + progress=AbstractMCMC.PROGRESS[], + progressname="Sampling", + callback=nothing, + num_warmup::Int=0, + discard_initial::Int=num_warmup, + thinning=1, + chain_type::Type=Any, + initial_state=nothing, + initial_params=nothing, + param_names=nothing, + kwargs..., +) + if callback !== nothing || + num_warmup != 0 || + discard_initial != 0 || + thinning != 1 || + initial_state !== nothing + return _default_parallel_mala_mcmcsample( + rng, + model, + sampler, + N; + progress=progress, + progressname=progressname, + callback=callback, + num_warmup=num_warmup, + discard_initial=discard_initial, + thinning=thinning, + chain_type=chain_type, + initial_state=initial_state, + initial_params=initial_params, + param_names=param_names, + kwargs..., + ) + end + + N > 0 || error("the number of samples must be ≥ 1") + N_int = Int(N) + + if chain_type === MCMCChains.Chains + return _sample_parallel_mala_chain( + rng, + model, + sampler, + N_int; + initial_params=initial_params, + param_names=param_names, + progress=progress, + progressname=progressname, + ) + end + + samples, state = _sample_parallel_mala_blocks( + rng, + model, + sampler, + N_int; + initial_params=initial_params, + progress=progress, + progressname=progressname, + ) + chain_type === Any && return samples + + sample_vec = [samples[i] for i in 1:length(samples)] + return AbstractMCMC.bundle_samples( + sample_vec, model, sampler, state, chain_type; param_names=param_names, kwargs... + ) +end + function AbstractMCMC.step( rng::Random.AbstractRNG, model::DensityModel, @@ -521,15 +890,26 @@ function AdaptiveMALASampler( ) end -struct AdaptiveMALAState{V<:AbstractVector,FP<:AbstractFloat} +""" +State for an `AdaptiveMALASampler` chain, including dual-averaging adaptation +statistics. +""" +struct AdaptiveMALAState{V<:AbstractVector,FP<:AbstractFloat,W,NV<:AbstractVector,H} x::V logp::FP epsilon::FP epsilon_bar::FP H_bar::FP step::Int + workspace::W + noise::NV + noise_host::H end +""" +One `AdaptiveMALASampler` sample: parameter vector `x`, its log-density `logp`, +the step size used for that transition, and whether the sample came from warmup. +""" struct AdaptiveMALATransition{V<:AbstractVector,FP<:AbstractFloat} x::V logp::FP @@ -575,9 +955,19 @@ function AbstractMCMC.step( randn(rng, FP, model.dim) end logp_val = FP(model.logdensity(x)) + ws = MALA.MALAWorkspace(x) + noise, noise_host = _make_noise_buffer(x, FP, model.dim) trans = AdaptiveMALATransition(x, logp_val, true, sampler.epsilon_init, true) state = AdaptiveMALAState( - x, logp_val, sampler.epsilon_init, sampler.epsilon_init, zero(FP), 0 + x, + logp_val, + sampler.epsilon_init, + sampler.epsilon_init, + zero(FP), + 0, + ws, + noise, + noise_host, ) return trans, state end @@ -593,11 +983,20 @@ function AbstractMCMC.step( in_warmup = state.step < sampler.n_warmup ε = in_warmup ? state.epsilon : state.epsilon_bar - ξ = randn(rng, eltype(state.x), D) + ξ = _randn_like!(rng, state.noise, state.noise_host) u = rand(rng) - x_next, accepted, logα = MALA.mala_step_with_logα( - model.logdensity, model.grad_logdensity, state.x, ε, ξ, u; cholM=sampler.cholM + x_next = similar(state.x) + x_next, accepted, logα = MALA.mala_step_with_logα!( + x_next, + state.workspace, + model.logdensity, + model.grad_logdensity, + state.x, + ε, + ξ, + u; + cholM=sampler.cholM, ) logp_next = accepted ? FP(model.logdensity(x_next)) : state.logp @@ -617,7 +1016,17 @@ function AbstractMCMC.step( end trans = AdaptiveMALATransition(x_next, logp_next, accepted, ε, in_warmup) - new_state = AdaptiveMALAState(x_next, logp_next, ε_new, ε_bar_new, H_bar_new, m_new) + new_state = AdaptiveMALAState( + x_next, + logp_next, + ε_new, + ε_bar_new, + H_bar_new, + m_new, + state.workspace, + state.noise, + state.noise_host, + ) return trans, new_state end @@ -663,11 +1072,3 @@ function AbstractMCMC.bundle_samples( Dict(:parameters => names, :internals => internal_names), ) end - -""" - DensityModel(ld) - -Construct a `DensityModel` from any object `ld` that implements the -LogDensityProblems interface. -""" -function DensityModel end diff --git a/test/Project.toml b/test/Project.toml index 075ba10..09c274c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +ParallelMCMC = "1a970f40-4406-51c9-a967-cb3143c111e8" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/test/test-AbstractMCMC-Interface.jl b/test/test-AbstractMCMC-Interface.jl index f7ab675..b01f19f 100644 --- a/test/test-AbstractMCMC-Interface.jl +++ b/test/test-AbstractMCMC-Interface.jl @@ -5,6 +5,7 @@ using Statistics using MCMCChains using ParallelMCMC +const MALA_iface = ParallelMCMC.MALA logp_iface(x) = -0.5 * dot(x, x) gradlogp_iface(x) = -x @@ -38,6 +39,7 @@ gradlogp_iface(x) = -x @test transition.logp == logp_iface(transition.x) @test state.logp == transition.logp @test transition.accepted == true + @test state.workspace isa MALA_iface.MALAWorkspace end @testset "initial step respects initial_params" begin @@ -81,6 +83,7 @@ gradlogp_iface(x) = -x @test isfinite(t2.logp) @test s2.x == t2.x @test s2.logp == t2.logp + @test s2.workspace === s1.workspace end @testset "step determinism with fixed rng" begin diff --git a/test/test-Adaptive-MALA.jl b/test/test-Adaptive-MALA.jl index a0e1e8d..fa6a987 100644 --- a/test/test-Adaptive-MALA.jl +++ b/test/test-Adaptive-MALA.jl @@ -29,6 +29,29 @@ gradlogp_adapt(x) = -x end end +@testset "mala_step_with_logα! matches allocating wrapper" begin + rng = MersenneTwister(2) + D = 4 + x = randn(rng, D) + ξ = randn(rng, D) + u = rand(rng) + + x_ref, accepted_ref, logα_ref = MALA.mala_step_with_logα( + logp_adapt, gradlogp_adapt, x, 0.1, ξ, u + ) + + ws = MALA.MALAWorkspace(x) + x_next = similar(x) + x_out, accepted, logα = MALA.mala_step_with_logα!( + x_next, ws, logp_adapt, gradlogp_adapt, x, 0.1, ξ, u + ) + + @test x_out === x_next + @test x_next == x_ref + @test accepted == accepted_ref + @test logα ≈ logα_ref +end + @testset "AdaptiveMALASampler construction" begin s = AdaptiveMALASampler(0.1) @test s isa ParallelMCMC.AbstractMCMC.AbstractSampler @@ -61,6 +84,7 @@ end @test state.step == 0 @test state.epsilon == sampler.epsilon_init @test state.epsilon_bar == sampler.epsilon_init + @test state.workspace isa MALA.MALAWorkspace end @testset "AdaptiveMALASampler initial step respects initial_params" begin diff --git a/test/test-DEER-Interface.jl b/test/test-DEER-Interface.jl index 2675740..5cfee5d 100644 --- a/test/test-DEER-Interface.jl +++ b/test/test-DEER-Interface.jl @@ -11,6 +11,10 @@ gradlogp_deer(x) = -x logp_batch_deer(X) = vec(-0.5 .* sum(abs2, X; dims=1)) gradlogp_batch_deer(X) = -X +logp_quartic_deer(x) = -0.5 * dot(x, x) - 0.1 * sum(abs2.(x) .^ 2) +gradlogp_quartic_deer(x) = @. -x - 0.4 * x^3 +hvp_quartic_deer(x, v) = @. (-1 - 1.2 * x^2) * v + @testset "ParallelMALASampler construction" begin s = ParallelMALASampler(0.05) @test s isa ParallelMCMC.AbstractMCMC.AbstractSampler @@ -24,6 +28,7 @@ gradlogp_batch_deer(X) = -X @test s2.T == 32 @test s2.jacobian === :stoch_diag @test s2.damping == 0.8 + @test_throws ArgumentError ParallelMALASampler(0.1; jacobian=:full) end @testset "ParallelMALASampler initial step" begin @@ -61,15 +66,43 @@ end @test isfinite(trans.logp) end +@testset "scalar DEER path can AD the logdensity HVP with Enzyme when hvp is omitted" begin + rng = MersenneTwister(2027) + D, T = 3, 5 + ε = 0.03 + + tape = [ParallelMCMC.MALATapeElement(randn(rng, D), rand(rng)) for _ in 1:T] + model = DensityModel(logp_quartic_deer, gradlogp_quartic_deer, D) + + rec_ad = ParallelMCMC._build_mala_deer_rec( + model, ε, tape, zeros(D); backend=ParallelMCMC.DEER.DEFAULT_BACKEND + ) + x = randn(rng, D) + v = randn(rng, D) + te = tape[1] + + f_ad, jvp_ad = rec_ad.fwd_and_jvp(x, te, v) + f_ref, jvp_ref = ParallelMCMC.MALA.mala_step_taped_and_jvp( + logp_quartic_deer, + gradlogp_quartic_deer, + x, + ε, + te.ξ, + te.u, + v, + hvp_quartic_deer, + ) + + @test f_ad ≈ f_ref atol=1e-10 rtol=1e-10 + @test jvp_ad ≈ jvp_ref atol=1e-10 rtol=1e-10 +end + @testset "batched DEER path can AD the batched gradient when hvp_batch is omitted" begin rng = MersenneTwister(2026) D, T = 3, 6 ε = 0.05 - tape = [ - ParallelMCMC.MALATapeElement(randn(rng, D), rand(rng)) - for _ in 1:T - ] + tape = [ParallelMCMC.MALATapeElement(randn(rng, D), rand(rng)) for _ in 1:T] model = DensityModel( logp_deer, gradlogp_deer, @@ -89,14 +122,7 @@ end U = [te.u for te in tape] hvp_batch_ref = (X, V) -> -V FT_ref, Jt_ref = ParallelMCMC.MALA.mala_step_batched_fwd_and_jvp( - logp_batch_deer, - gradlogp_batch_deer, - hvp_batch_ref, - X, - ε, - Ξ, - U, - Z, + logp_batch_deer, gradlogp_batch_deer, hvp_batch_ref, X, ε, Ξ, U, Z ) @test FT_ad ≈ FT_ref atol=1e-10 rtol=1e-10 @@ -143,9 +169,7 @@ end ) sampler = ParallelMALASampler(0.05; T=8) - _, state = ParallelMCMC.AbstractMCMC.step( - rng, model, sampler; initial_params=zeros(2) - ) + _, state = ParallelMCMC.AbstractMCMC.step(rng, model, sampler; initial_params=zeros(2)) @test batch_calls[] ≥ 1 scalar_after_solve = scalar_calls[] diff --git a/test/test-DEER-Turing-Logistic.jl b/test/test-DEER-Turing-Logistic.jl index 20ea153..13e9fc9 100644 --- a/test/test-DEER-Turing-Logistic.jl +++ b/test/test-DEER-Turing-Logistic.jl @@ -54,6 +54,13 @@ function _gradlogp_lr(β, X, y) return X' * (y .- p) .- β end +function _hvp_lr(β, v, X, y) + logits = X * β + p = @. 1 / (1 + exp(-logits)) + w = @. p * (1 - p) + return -(X' * (w .* (X * v))) .- v +end + @model function _deer_logistic_regression(X, y) D = size(X, 2) β ~ MvNormal(zeros(D), I) @@ -63,10 +70,17 @@ end end end +function _deer_logistic_turing_density_model() + return DensityModel( + _deer_logistic_regression(_LR_X, _LR_y); + hvp=(β, v) -> _hvp_lr(β, v, _LR_X, _LR_y), + ) +end + # Turing integration tests (CPU) @testset "ParallelMALASampler Turing logistic: param names extracted correctly" begin - model = DensityModel(_deer_logistic_regression(_LR_X, _LR_y)) + model = _deer_logistic_turing_density_model() @test model.dim == _LR_D @test model.param_names == [Symbol("β[1]"), Symbol("β[2]")] @@ -75,14 +89,17 @@ end end @testset "ParallelMALASampler Turing logistic: chains output well-formed" begin - model = DensityModel(_deer_logistic_regression(_LR_X, _LR_y)) - sampler = ParallelMALASampler(0.1; T=16, maxiter=50, tol_abs=1e-4, tol_rel=1e-3, damping=0.5) + model = _deer_logistic_turing_density_model() + sampler = ParallelMALASampler( + 0.05; T=16, maxiter=80, tol_abs=1e-4, tol_rel=1e-3, jacobian=:diag, damping=0.5 + ) chain = sample( MersenneTwister(42), model, sampler, 400; + initial_params=zeros(_LR_D), chain_type=MCMCChains.Chains, progress=false, ) @@ -95,14 +112,17 @@ end end @testset "ParallelMALASampler Turing logistic: posterior sign correct" begin - model = DensityModel(_deer_logistic_regression(_LR_X, _LR_y)) - sampler = ParallelMALASampler(0.1; T=16, maxiter=50, tol_abs=1e-4, tol_rel=1e-3, damping=0.5) + model = _deer_logistic_turing_density_model() + sampler = ParallelMALASampler( + 0.05; T=16, maxiter=80, tol_abs=1e-4, tol_rel=1e-3, jacobian=:diag, damping=0.5 + ) chain = sample( MersenneTwister(99), model, sampler, 800; + initial_params=zeros(_LR_D), chain_type=MCMCChains.Chains, progress=false, ) @@ -120,7 +140,12 @@ end @testset "ParallelMALASampler logistic: posterior mean matches AdaptiveMALA reference" begin X64, y64 = Float64.(_LR_X), Float64.(_LR_y) - model = DensityModel(β -> _logp_lr(β, X64, y64), β -> _gradlogp_lr(β, X64, y64), _LR_D) + model = DensityModel( + β -> _logp_lr(β, X64, y64), + β -> _gradlogp_lr(β, X64, y64), + _LR_D; + hvp=(β, v) -> _hvp_lr(β, v, X64, y64), + ) mala_chain = sample( MersenneTwister(2025), @@ -194,21 +219,18 @@ else gradlogp_gpu = β -> _gradlogp_lr(β, X_gpu, y_gpu) tape = [(ξ=CUDA.CuArray(ξs_cpu[t]), u=us[t]) for t in 1:T] - _backend = ADTypes.AutoEnzyme() step_fwd = (x, te) -> ParallelMCMC.MALA.mala_step_taped(logp_gpu, gradlogp_gpu, x, ε, te.ξ, te.u) - hvp_fn_gpu = (pt, dir) -> ParallelMCMC.DEER._hvp_nopre(gradlogp_gpu, _backend, pt, dir) + hvp_fn_gpu = (pt, dir) -> _hvp_lr(pt, dir, X_gpu, y_gpu) jvp = - (x, te, v) -> - ParallelMCMC.MALA.mala_step_surrogate_sigmoid_jvp( - logp_gpu, gradlogp_gpu, x, ε, te.ξ, te.u, v, hvp_fn_gpu - ) + (x, te, v) -> ParallelMCMC.MALA.mala_step_surrogate_sigmoid_jvp( + logp_gpu, gradlogp_gpu, x, ε, te.ξ, te.u, v, hvp_fn_gpu + ) fwd_and_jvp = - (x, te, v) -> - ParallelMCMC.MALA.mala_step_taped_and_jvp( - logp_gpu, gradlogp_gpu, x, ε, te.ξ, te.u, v, hvp_fn_gpu - ) + (x, te, v) -> ParallelMCMC.MALA.mala_step_taped_and_jvp( + logp_gpu, gradlogp_gpu, x, ε, te.ξ, te.u, v, hvp_fn_gpu + ) rec = ParallelMCMC.DEER.TapedRecursion(step_fwd, jvp, tape; fwd_and_jvp=fwd_and_jvp) @@ -247,10 +269,16 @@ else ) model_cpu = DensityModel( - β -> _logp_lr(β, _X_f32, _y_f32), β -> _gradlogp_lr(β, _X_f32, _y_f32), _LR_D + β -> _logp_lr(β, _X_f32, _y_f32), + β -> _gradlogp_lr(β, _X_f32, _y_f32), + _LR_D; + hvp=(β, v) -> _hvp_lr(β, v, _X_f32, _y_f32), ) model_gpu = DensityModel( - β -> _logp_lr(β, X_gpu, y_gpu), β -> _gradlogp_lr(β, X_gpu, y_gpu), _LR_D + β -> _logp_lr(β, X_gpu, y_gpu), + β -> _gradlogp_lr(β, X_gpu, y_gpu), + _LR_D; + hvp=(β, v) -> _hvp_lr(β, v, X_gpu, y_gpu), ) n_samples, n_burn = 400, 100 diff --git a/test/test-DEERScan.jl b/test/test-DEERScan.jl index 9dc1030..9bac8e7 100644 --- a/test/test-DEERScan.jl +++ b/test/test-DEERScan.jl @@ -34,6 +34,79 @@ function make_affine_problem(rng::AbstractRNG, D::Int, T::Int; FT=Float32, stabl return A, B, s0 end +@testset "DEERScan CPU scan matches sequential reference" begin + rng = MersenneTwister(11) + + for (D, T) in ((1, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 17), (8, 31), (8, 64)) + A, B, s0 = make_affine_problem(rng, D, T; FT=Float64, stable=true) + S_ref = DEERScan.solve_affine_seq(A, B, s0) + + ws = DEERScan.AffineScanWorkspace(A) + S = similar(A) + out = DEERScan.solve_affine_scan_diag!(S, A, B, s0, ws) + + @test out === S + @test S ≈ S_ref atol=1e-12 rtol=1e-12 + + fill!(S, NaN) + DEERScan.solve_affine_scan_diag!(S, A, B, s0, ws) + @test S ≈ S_ref atol=1e-12 rtol=1e-12 + end +end + +@testset "DEERScan affine scan validation helpers" begin + rng = MersenneTwister(12) + D, T = 4, 9 + A, B, s0 = make_affine_problem(rng, D, T; FT=Float64, stable=true) + + S_seq = DEERScan.solve_affine_seq(A, B, s0) + S_scan = DEERScan.solve_affine_scan_diag(A, B, s0) + + @test DEERScan.affine_scan_residual(S_seq, A, B, s0) ≤ 1e-12 + @test DEERScan.affine_scan_residual(S_scan, A, B, s0) ≤ 1e-12 + + S_bad_first = copy(S_seq) + S_bad_first[2, 1] += 0.25 + @test DEERScan.affine_scan_residual(S_bad_first, A, B, s0) ≈ 0.25 + + S_bad_later = copy(S_seq) + S_bad_later[3, 6] -= 0.5 + @test DEERScan.affine_scan_residual(S_bad_later, A, B, s0) ≈ 0.5 + + chk = DEERScan.check_affine_scan(A, B, s0; atol=1e-12, rtol=1e-12) + @test chk.ok + @test chk.max_abs_err ≤ 1e-12 + @test chk.max_rel_err ≤ 1 + @test chk.residual_seq ≤ 1e-12 + @test chk.residual_scan ≤ 1e-12 + @test chk.S_seq ≈ S_seq atol=0 rtol=0 + @test chk.S_scan ≈ S_scan atol=0 rtol=0 + + A0 = zeros(Float64, D, 0) + B0 = similar(A0) + s00 = randn(rng, D) + S0 = similar(A0) + chk0 = DEERScan.check_affine_scan(A0, B0, s00) + @test DEERScan.affine_scan_residual(S0, A0, B0, s00) == 0.0 + @test chk0.ok + @test chk0.max_abs_err == 0.0 + @test chk0.max_rel_err == 0.0 + @test size(chk0.S_seq) == (D, 0) + @test size(chk0.S_scan) == (D, 0) +end + +@testset "DEERScan affine scan helper shape validation" begin + rng = MersenneTwister(13) + A, B, s0 = make_affine_problem(rng, 3, 5; FT=Float64, stable=true) + S = DEERScan.solve_affine_seq(A, B, s0) + + @test_throws DimensionMismatch DEERScan.affine_scan_residual(S, B[:, 1:4], B, s0) + @test_throws DimensionMismatch DEERScan.affine_scan_residual(S[:, 1:4], A, B, s0) + @test_throws DimensionMismatch DEERScan.affine_scan_residual(S, A, B, randn(rng, 4)) + @test_throws DimensionMismatch DEERScan.check_affine_scan(A, B[:, 1:4], s0) + @test_throws DimensionMismatch DEERScan.check_affine_scan(A, B, randn(rng, 4)) +end + @testset "DEERScan primitive (GPU-first)" begin if !CUDA_AVAILABLE @info "No CUDA GPU detected — skipping DEERScan GPU tests." @@ -42,7 +115,9 @@ end rng = MersenneTwister(1) for (D, T) in ((1, 1), (1, 17), (4, 16), (8, 31), (16, 64)) - A_cpu, B_cpu, s0_cpu = make_affine_problem(rng, D, T; FT=Float32, stable=true) + A_cpu, B_cpu, s0_cpu = make_affine_problem( + rng, D, T; FT=Float32, stable=true + ) # CPU oracle: strictly sequential reference S_ref = DEERScan.solve_affine_seq(A_cpu, B_cpu, s0_cpu) @@ -56,7 +131,7 @@ end @test S_gpu isa CUDA.CuMatrix @test size(S_gpu) == (D, T) - @test Array(S_gpu) ≈ S_ref atol=1f-5 rtol=1f-5 + @test Array(S_gpu) ≈ S_ref atol=1.0f-5 rtol=1.0f-5 end end @@ -76,7 +151,7 @@ end @test out === S_gpu @test S_gpu isa CUDA.CuMatrix - @test Array(S_gpu) ≈ S_ref atol=1f-5 rtol=1f-5 + @test Array(S_gpu) ≈ S_ref atol=1.0f-5 rtol=1.0f-5 end @testset "GPU scan satisfies the recurrence residual" begin @@ -93,7 +168,7 @@ end # Residual can be checked on CPU after materializing. resid = DEERScan.affine_scan_residual(Array(S_gpu), A_cpu, B_cpu, s0_cpu) - @test resid ≤ 1f-5 + @test resid ≤ 1.0f-5 end @testset "GPU scan agrees with CPU on edge cases" begin @@ -111,7 +186,7 @@ end CUDA.CuArray(A_cpu), CUDA.CuArray(B_cpu), CUDA.CuArray(s0_cpu) ) - @test Array(S_gpu) ≈ S_ref atol=1f-6 rtol=1f-6 + @test Array(S_gpu) ≈ S_ref atol=1.0f-6 rtol=1.0f-6 end # zero multiplicative term => S[:,t] = B[:,t] @@ -126,8 +201,8 @@ end CUDA.CuArray(A_cpu), CUDA.CuArray(B_cpu), CUDA.CuArray(s0_cpu) ) - @test Array(S_gpu) ≈ S_ref atol=1f-6 rtol=1f-6 - @test Array(S_gpu) ≈ B_cpu atol=1f-6 rtol=1f-6 + @test Array(S_gpu) ≈ S_ref atol=1.0f-6 rtol=1.0f-6 + @test Array(S_gpu) ≈ B_cpu atol=1.0f-6 rtol=1.0f-6 end # T = 1 @@ -143,7 +218,7 @@ end ) @test size(S_gpu) == (D, 1) - @test Array(S_gpu) ≈ S_ref atol=1f-6 rtol=1f-6 + @test Array(S_gpu) ≈ S_ref atol=1.0f-6 rtol=1.0f-6 end end @@ -189,12 +264,16 @@ end B_bad_gpu = CUDA.CuArray(randn(Float32, 4, 9)) s0_gpu = CUDA.CuArray(s0_cpu) - @test_throws DimensionMismatch DEERScan.solve_affine_scan_diag(A_gpu, B_bad_gpu, s0_gpu) + @test_throws DimensionMismatch DEERScan.solve_affine_scan_diag( + A_gpu, B_bad_gpu, s0_gpu + ) B_gpu = CUDA.CuArray(B_cpu) s0_bad_gpu = CUDA.CuArray(randn(Float32, 5)) - @test_throws DimensionMismatch DEERScan.solve_affine_scan_diag(A_gpu, B_gpu, s0_bad_gpu) + @test_throws DimensionMismatch DEERScan.solve_affine_scan_diag( + A_gpu, B_gpu, s0_bad_gpu + ) end @testset "GPU scan helper check_affine_scan still agrees with CPU" begin @@ -203,12 +282,12 @@ end D, T = 4, 24 A_cpu, B_cpu, s0_cpu = make_affine_problem(rng, D, T; FT=Float32, stable=true) - chk = DEERScan.check_affine_scan(A_cpu, B_cpu, s0_cpu; atol=1f-6, rtol=1f-6) + chk = DEERScan.check_affine_scan(A_cpu, B_cpu, s0_cpu; atol=1.0f-6, rtol=1.0f-6) @test chk.ok - @test chk.max_abs_err ≤ 1f-6 - @test chk.residual_seq ≤ 1f-6 - @test chk.residual_scan ≤ 1f-6 + @test chk.max_abs_err ≤ 1.0f-6 + @test chk.residual_seq ≤ 1.0f-6 + @test chk.residual_scan ≤ 1.0f-6 end end -end \ No newline at end of file +end diff --git a/test/test-Deer-vs-MALA.jl b/test/test-Deer-vs-MALA.jl index ffcc2e4..75460d0 100644 --- a/test/test-Deer-vs-MALA.jl +++ b/test/test-Deer-vs-MALA.jl @@ -30,9 +30,9 @@ end function _make_rec(tape, ε) step_fwd = - (x, tt) -> MALA.mala_step_taped(logp_stdnormal, gradlogp_stdnormal, x, ε, tt.ξ, tt.u) - hvp_fn = - (pt, dir) -> DEER._hvp_nopre(gradlogp_stdnormal, DEER.DEFAULT_BACKEND, pt, dir) + (x, tt) -> + MALA.mala_step_taped(logp_stdnormal, gradlogp_stdnormal, x, ε, tt.ξ, tt.u) + hvp_fn = (pt, dir) -> DEER._hvp_nopre(gradlogp_stdnormal, DEER.DEFAULT_BACKEND, pt, dir) jvp = (x, tt, v) -> MALA.mala_step_surrogate_sigmoid_jvp( logp_stdnormal, gradlogp_stdnormal, x, ε, tt.ξ, tt.u, v, hvp_fn @@ -133,6 +133,28 @@ end @test isapprox(S_scan, S_ref; rtol=1e-12, atol=1e-12) end +@testset "DEER.solve copies workspace result by default" begin + D, T = 2, 4 + tape = collect(1:T) + step_fwd = (x, t) -> x + jvp = (x, t, v) -> v + rec = DEER.TapedRecursion(step_fwd, jvp, tape) + + s0 = [1.0, -1.0] + ws = DEER.DEERWorkspace(s0, T) + + S = DEER.solve(rec, s0; jacobian=:diag, workspace=ws) + S_saved = copy(S) + + @test S !== ws.S_tmp + + DEER.solve(rec, [2.0, 3.0]; jacobian=:diag, workspace=ws) + @test S == S_saved + + S_alias = DEER.solve(rec, s0; jacobian=:diag, workspace=ws, copy_result=false) + @test S_alias === ws.S_tmp +end + @testset "DEER stochastic diagonal (Hutchinson) recovers MALA trajectory" begin rng = MersenneTwister(777) @@ -182,7 +204,7 @@ end xbar = xs_seq[tcheck] diag_exact = DEER.jac_diag_via_jvps(rec, xbar, tcheck) - diag_est = DEER.jac_diag_stoch( + diag_est = DEER.jac_diag_stoch( rec, xbar, tcheck; probes=200, rng=MersenneTwister(123), zbuf=zeros(D) ) diff --git a/test/test-GPU-DEER.jl b/test/test-GPU-DEER.jl index b7bdf3a..19a12eb 100644 --- a/test/test-GPU-DEER.jl +++ b/test/test-GPU-DEER.jl @@ -72,16 +72,19 @@ else function _make_gpu_rec(tape, ε, backend) step_fwd = (x, te) -> MALA.mala_step_taped(_logp, _gradlogp, x, ε, te.ξ, te.u) hvp_fn = (pt, dir) -> DEER._hvp_nopre(_gradlogp, backend, pt, dir) - jvp = (x, te, v) -> MALA.mala_step_surrogate_sigmoid_jvp( - _logp, _gradlogp, x, ε, te.ξ, te.u, v, hvp_fn - ) - fwd_and_jvp = (x, te, v) -> MALA.mala_step_taped_and_jvp( - _logp, _gradlogp, x, ε, te.ξ, te.u, v, hvp_fn - ) + jvp = + (x, te, v) -> MALA.mala_step_surrogate_sigmoid_jvp( + _logp, _gradlogp, x, ε, te.ξ, te.u, v, hvp_fn + ) + fwd_and_jvp = + (x, te, v) -> + MALA.mala_step_taped_and_jvp(_logp, _gradlogp, x, ε, te.ξ, te.u, v, hvp_fn) return DEER.TapedRecursion(step_fwd, jvp, tape; fwd_and_jvp=fwd_and_jvp) end - @testset "DEER.solve GPU backend=$bname (stoch_diag)" for (bname, backend) in _gpu_backends + @testset "DEER.solve GPU backend=$bname (stoch_diag)" for (bname, backend) in + _gpu_backends + rng = MersenneTwister(7) D, T = 4, 32 ε = 0.05f0 diff --git a/test/test-GPU-MALA.jl b/test/test-GPU-MALA.jl index b255ced..f6d7179 100644 --- a/test/test-GPU-MALA.jl +++ b/test/test-GPU-MALA.jl @@ -32,7 +32,7 @@ else scales = CUDA.CuArray(Float32.(1:D)) # D-vector on GPU return vec(-0.5f0 .* sum(X .^ 2 ./ scales; dims=1)) end - + function gradlogp_scaled(X) D = size(X, 1) scales = CUDA.CuArray(Float32.(1:D)) diff --git a/test/test-Jacobian-Estimator.jl b/test/test-Jacobian-Estimator.jl index b0c2517..a978b31 100644 --- a/test/test-Jacobian-Estimator.jl +++ b/test/test-Jacobian-Estimator.jl @@ -10,6 +10,27 @@ const MALA = ParallelMCMC.MALA logp_stdnormal_B(x) = -0.5 * dot(x, x) gradlogp_stdnormal_B(x) = -x +struct TaggedVector{T} <: AbstractVector{T} + data::Vector{T} +end + +Base.IndexStyle(::Type{<:TaggedVector}) = IndexLinear() +Base.size(x::TaggedVector) = size(x.data) +Base.axes(x::TaggedVector) = axes(x.data) +Base.getindex(x::TaggedVector, i::Int) = x.data[i] +Base.setindex!(x::TaggedVector, v, i::Int) = setindex!(x.data, v, i) +function Base.similar(x::TaggedVector, ::Type{T}, dims::Dims{1}) where {T} + TaggedVector(Vector{T}(undef, dims[1])) +end +function Base.similar(x::TaggedVector, ::Type{T}, n::Int) where {T} + TaggedVector(Vector{T}(undef, n)) +end +function Base.similar(x::TaggedVector, dims::Dims{1}) + TaggedVector(Vector{eltype(x)}(undef, dims[1])) +end +Base.similar(x::TaggedVector, n::Int) = TaggedVector(Vector{eltype(x)}(undef, n)) +Base.copy(x::TaggedVector) = TaggedVector(copy(x.data)) + # Reference log q(y|x) for MALA: # y ~ Normal( x + ϵ∇logp(x), 2ϵ I ) function logq_mala_ref_B( @@ -71,14 +92,75 @@ make_affine_tape(rng::AbstractRNG, D::Int, T::Int) = [randn(rng, D) for _ in 1:T return mean(mses) end - mse1 = mse_for_probes(1; nrep=25) - mse8 = mse_for_probes(8; nrep=25) + mse1 = mse_for_probes(1; nrep=25) + mse8 = mse_for_probes(8; nrep=25) mse64 = mse_for_probes(64; nrep=25) - @test mse8 < mse1 + @test mse8 < mse1 @test mse64 < mse8 end + @testset "jac_diag_stoch default probe uses x-like storage" begin + x = TaggedVector([1.0, 2.0, 3.0]) + tape = [1] + step_fwd = (x, t) -> x + jvp = function (x, t, v) + v isa TaggedVector || throw(ArgumentError("expected TaggedVector tangent")) + return copy(v) + end + rec = DEER.TapedRecursion(step_fwd, jvp, tape) + + d = DEER.jac_diag_stoch(rec, x, 1; rng=MersenneTwister(1)) + + @test d isa TaggedVector + @test collect(d) == ones(3) + end + + @testset "Rademacher and DEER argument validation paths" begin + rng = MersenneTwister(20260203) + z = zeros(Float64, 12) + out = DEER._rademacher!(z, rng) + @test out === z + @test all(abs.(z) .== 1) + + D, T = 3, 4 + tape = collect(1:T) + step_fwd = (x, t) -> 0.8 .* x .+ t + jvp = (x, t, v) -> 0.8 .* v + rec = DEER.TapedRecursion(step_fwd, jvp, tape) + x = randn(rng, D) + + @test_throws ArgumentError DEER.jac_diag_stoch( + rec, x, 1; probes=0, rng=MersenneTwister(1) + ) + @test_throws ArgumentError DEER.jac_diag_stoch( + rec, x, 1; zbuf=zeros(D + 1), rng=MersenneTwister(1) + ) + + S = randn(rng, D, T) + S_out = similar(S) + s0 = randn(rng, D) + ws = DEER.DEERWorkspace(S, s0) + + @test_throws ArgumentError DEER.deer_update!( + ws, S_out, rec, s0, S; jacobian=:full + ) + @test_throws ArgumentError DEER.deer_update!( + ws, S_out, rec, s0, S; damping=0.0 + ) + @test_throws ArgumentError DEER.deer_update!( + ws, S_out, rec, s0, S; probes=0 + ) + @test_throws ArgumentError DEER.deer_update!( + ws, randn(rng, D, T + 1), rec, s0, S + ) + + @test_throws ArgumentError DEER.solve(rec, s0; maxiter=0) + @test_throws ArgumentError DEER.solve(rec, s0; tol_abs=-1.0) + @test_throws ArgumentError DEER.solve(rec, s0; tol_rel=-1.0) + @test_throws ArgumentError DEER.solve(rec, s0; init=randn(rng, D, T + 1)) + end + @testset "batched stoch_diag uses z .* Jz in DEER update" begin rng = MersenneTwister(20251231) D, T = 7, 13 @@ -98,14 +180,7 @@ make_affine_tape(rng::AbstractRNG, D::Int, T::Int) = [randn(rng, D) for _ in 1:T ws = DEER.DEERWorkspace(S_in, s0) S_out = similar(S_in) DEER.deer_update!( - ws, - S_out, - rec, - s0, - S_in; - jacobian=:stoch_diag, - probes=1, - rng=MersenneTwister(1), + ws, S_out, rec, s0, S_in; jacobian=:stoch_diag, probes=1, rng=MersenneTwister(1) ) S_ref = DEER.solve_affine_seq(Atrue, Btrue, s0) diff --git a/test/test-MALA-Kernel.jl b/test/test-MALA-Kernel.jl index 95dedc8..4abc877 100644 --- a/test/test-MALA-Kernel.jl +++ b/test/test-MALA-Kernel.jl @@ -20,6 +20,20 @@ function logq_mala_ref( return -0.5 * dot(r, r) / (2ϵ) - (d / 2) * log(4π * ϵ) end +function logq_mala_mass_ref( + y::AbstractVector, + x::AbstractVector, + gradlogp_x::AbstractVector, + ϵ::Real, + cholM::Cholesky, +) + μ = x .+ ϵ .* (cholM.L * (adjoint(cholM.L) * gradlogp_x)) + r = y .- μ + d = length(x) + return -0.5 * dot(cholM \ r, r) / (2ϵ) - + (d / 2) * log(4π * ϵ) - 0.5 * logdet(cholM) +end + # Build a tape (ξs, us) deterministically function make_tape(rng::AbstractRNG, D::Int, T::Int) ξs = [randn(rng, D) for _ in 1:T] @@ -38,10 +52,10 @@ end ξs, us = make_tape(rng, D, T) xs1 = MALA.run_mala_sequential_taped( - logp_stdnormal, gradlogp_stdnormal, x0, ϵ, ξs, us + logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x0, ϵ, ξs, us ) xs2 = MALA.run_mala_sequential_taped( - logp_stdnormal, gradlogp_stdnormal, x0, ϵ, ξs, us + logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x0, ϵ, ξs, us ) @test length(xs1) == T + 1 @@ -54,10 +68,10 @@ end x = copy(x0) for t in 1:T x_next_1 = MALA.mala_step_taped( - logp_stdnormal, gradlogp_stdnormal, x, ϵ, ξs[t], us[t] + logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x, ϵ, ξs[t], us[t] ) x_next_2 = MALA.mala_step_taped( - logp_stdnormal, gradlogp_stdnormal, x, ϵ, ξs[t], us[t] + logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x, ϵ, ξs[t], us[t] ) @test x_next_1 == x_next_2 x = x_next_1 @@ -77,7 +91,7 @@ end ξs, us = make_tape(rng, D, T) xs = MALA.run_mala_sequential_taped( - logp_stdnormal, gradlogp_stdnormal, x0, ϵ, ξs, us + logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x0, ϵ, ξs, us ) # Collect post-burn samples @@ -159,22 +173,107 @@ end x = randn(rng, D) ξ = randn(rng, D) - y = MALA.mala_proposal(logp_stdnormal, gradlogp_stdnormal, x, ϵ, ξ) + y = MALA.mala_proposal(logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x, ϵ, ξ) - logα_impl = MALA.mala_logα(logp_stdnormal, gradlogp_stdnormal, x, y, ϵ) + logα_impl = MALA.mala_logα( + logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x, y, ϵ + ) # Independent recomputation: - gx = gradlogp_stdnormal(x) - gy = gradlogp_stdnormal(y) + gx = gradlogp_stdnormal_kernel(x) + gy = gradlogp_stdnormal_kernel(y) logq_y_given_x = logq_mala_ref(y, x, gx, ϵ) logq_x_given_y = logq_mala_ref(x, y, gy, ϵ) logα_ref = - (logp_stdnormal(y) + logq_x_given_y) - (logp_stdnormal(x) + logq_y_given_x) + (logp_stdnormal_kernel(y) + logq_x_given_y) - + (logp_stdnormal_kernel(x) + logq_y_given_x) @test isfinite(logα_impl) @test isfinite(logα_ref) @test logα_impl ≈ logα_ref atol=1e-10 rtol=1e-10 end + + @testset "mass-matrix scalar and batched paths match columnwise references" begin + rng = MersenneTwister(20260201) + D, N = 3, 5 + ϵ = 0.08 + M = Symmetric([1.7 0.2 -0.1; 0.2 1.3 0.15; -0.1 0.15 1.5]) + cholM = cholesky(M) + + x = randn(rng, D) + ξ = randn(rng, D) + y = MALA.mala_proposal( + logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x, ϵ, ξ; cholM=cholM + ) + y_ref = x .+ ϵ .* (cholM.L * (adjoint(cholM.L) * gradlogp_stdnormal_kernel(x))) .+ + sqrt(2ϵ) .* (cholM.L * ξ) + @test y ≈ y_ref atol=1e-12 rtol=1e-12 + + logq_impl = MALA.logq_mala(y, x, gradlogp_stdnormal_kernel(x), ϵ; cholM=cholM) + logq_ref = logq_mala_mass_ref(y, x, gradlogp_stdnormal_kernel(x), ϵ, cholM) + @test logq_impl ≈ logq_ref atol=1e-12 rtol=1e-12 + + X = randn(rng, D, N) + Ξ = randn(rng, D, N) + u = range(0.1, 0.9; length=N) |> collect + X_next, accepted = MALA.mala_step_batched( + X -> vec(-0.5 .* sum(abs2, X; dims=1)), + X -> -X, + X, + ϵ, + Ξ, + u; + cholM=cholM, + ) + + @test length(accepted) == N + for j in 1:N + xj = copy(view(X, :, j)) + ξj = copy(view(Ξ, :, j)) + x_ref, a_ref = MALA.mala_step_full( + logp_stdnormal_kernel, + gradlogp_stdnormal_kernel, + xj, + ϵ, + ξj, + u[j]; + cholM=cholM, + ) + @test accepted[j] == a_ref + @test view(X_next, :, j) ≈ x_ref atol=1e-12 rtol=1e-12 + end + + Y = randn(rng, D, N) + G = -X + logq_batch = MALA.logq_mala_batched(Y, X, G, ϵ; cholM=cholM) + logq_cols = [ + MALA.logq_mala( + copy(view(Y, :, j)), copy(view(X, :, j)), copy(view(G, :, j)), ϵ; cholM=cholM + ) for j in 1:N + ] + @test logq_batch ≈ logq_cols atol=1e-12 rtol=1e-12 + end + + @testset "standalone surrogate sigmoid matches explicit relaxed blend" begin + rng = MersenneTwister(20260202) + x = randn(rng, 4) + ξ = randn(rng, 4) + ϵ = 0.15 + u = 0.37 + + y = MALA.mala_proposal(logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x, ϵ, ξ) + logα = MALA.mala_logα(logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x, y, ϵ) + g = inv(1 + exp(-(logα - log(u)))) + expected = g .* y .+ (1 - g) .* x + + actual = MALA.mala_step_surrogate_sigmoid( + logp_stdnormal_kernel, gradlogp_stdnormal_kernel, x, ϵ, ξ, u + ) + + @test actual ≈ expected atol=1e-12 rtol=1e-12 + @test all(xi -> min(xi[1], xi[2]) - 1e-12 ≤ xi[3] ≤ max(xi[1], xi[2]) + 1e-12, + zip(x, y, actual)) + end end diff --git a/test/test-Turing-Integration.jl b/test/test-Turing-Integration.jl index 11e76ec..635ddcd 100644 --- a/test/test-Turing-Integration.jl +++ b/test/test-Turing-Integration.jl @@ -60,6 +60,34 @@ end @test isfinite(model.grad_logdensity([0.0])[1]) end +@testset "DynamicPPLExt: generic Turing model works with ParallelMALA and default Enzyme HVP" begin + model = DensityModel(normal_model(TRUE_OBS)) + + @test model.hvp === nothing + + for jacobian in (:diag, :stoch_diag) + sampler = ParallelMALASampler( + 0.03; + T=4, + maxiter=80, + tol_abs=1e-5, + tol_rel=1e-4, + jacobian=jacobian, + damping=0.5, + ) + + trans, state = ParallelMCMC.AbstractMCMC.step( + MersenneTwister(11), model, sampler; initial_params=[0.0] + ) + + @test trans isa ParallelMALATransition + @test state isa ParallelMALAState + @test length(trans.x) == 1 + @test isfinite(trans.logp) + @test all(isfinite, state.trajectory) + end +end + @testset "DynamicPPLExt: named columns in Chains output" begin model = DensityModel(normal_model(TRUE_OBS))