From 0c44689fb640549a24b52ace7ef3818f00dd8ec4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 3 May 2026 23:55:01 +0100 Subject: [PATCH 01/12] Convert DynamicPPL samples back to original space --- Project.toml | 2 +- ext/DynamicPPLExt.jl | 108 ++++++++++++++++-------------- ext/LogDensityProblemsExt.jl | 15 +++-- src/interface.jl | 55 ++++++++++++--- test/test-DEER-Turing-Logistic.jl | 3 +- test/test-Turing-Integration.jl | 48 ++++++++----- 6 files changed, 143 insertions(+), 88 deletions(-) diff --git a/Project.toml b/Project.toml index 7c255d3..756be13 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,7 @@ ADTypes = "1.21.0" AbstractMCMC = "5.10.0" CUDA = "5.11.0" DifferentiationInterface = "0.7.13" -DynamicPPL = "0.40" +DynamicPPL = "0.40.6, 0.41" Enzyme = "0.13.131" LinearAlgebra = "1" LogDensityProblems = "2" diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index 294bfba..8021ac9 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -3,7 +3,9 @@ module DynamicPPLExt using ParallelMCMC using ADTypes: ADTypes using DynamicPPL: DynamicPPL +using AbstractMCMC: AbstractMCMC using Enzyme: Enzyme +using MCMCChains: MCMCChains using LogDensityProblems: LogDensityProblems """ @@ -28,13 +30,6 @@ model = DensityModel(mymodel(1.5)) chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; chain_type=MCMCChains.Chains, discard_warmup=true, progress=true) ``` - -# Notes -- Parameter names are extracted from the model's prior. For most common - distributions (Normal, MvNormal, Exponential, etc.) the names match the - unconstrained parameter space used by LogDensityProblems. If the extracted - names do not match the dimensionality (e.g. due to simplex constraints), the - constructor falls back to generic `x[1], x[2], ...` names with a warning. """ function ParallelMCMC.DensityModel( turing_model::DynamicPPL.Model; @@ -51,58 +46,69 @@ function ParallelMCMC.DensityModel( DynamicPPL.LinkAll(); adtype=ad_backend, ) + # Requires LogDensityProblemsExt to be loaded + return ParallelMCMC.DensityModel(ld; hvp=hvp) +end - caps = LogDensityProblems.capabilities(ld) - caps isa LogDensityProblems.LogDensityOrder{0} && error( - "AD gradient setup failed. The wrapped model must support gradients. " * - "Ensure your ad_backend is compatible.", - ) +# Types that represent LogDensityProblems objects that wrap DynamicPPL models. +const LDFPrimal = ParallelMCMC.LogDensityProblemPrimal{<:DynamicPPL.LogDensityFunction} +const LDFGradient = ParallelMCMC.LogDensityProblemGradient{<:DynamicPPL.LogDensityFunction} +const DensityModelLDF = ParallelMCMC.DensityModel{<:LDFPrimal,<:LDFGradient} - dim = LogDensityProblems.dimension(ld) +""" + postprocess_sample(model::DensityModel, sample) - # Try to extract parameter names; fall back to nothing on any error or mismatch. - param_names = _try_extract_param_names(turing_model, dim) +Converts a raw transition (e.g. `MALATransition`) into a `DynamicPPL.ParamsWithStats` object +by reevaluating the DynamicPPL model with the vectorised parameters. This requires that the `DensityModel` object was constructed with a `LogDensityFunctionPrimal` and `LogDensityFunctionGradient` that wrap a DynamicPPL model. +""" +function ParallelMCMC.postprocess_sample( + model::DensityModelLDF, sample::ParallelMCMC.MALATransition +) + stats = (accepted=sample.accepted,) + return DynamicPPL.ParamsWithStats(sample.x, model.logdensity.ld, stats) +end +function ParallelMCMC.postprocess_sample( + model::DensityModelLDF, sample::ParallelMCMC.ParallelMALATransition +) + return DynamicPPL.ParamsWithStats(sample.x, model.logdensity.ld) +end +function ParallelMCMC.postprocess_sample( + model::DensityModelLDF, sample::ParallelMCMC.AdaptiveMALATransition +) + stats = ( + accepted=sample.accepted, step_size=sample.step_size, is_warmup=sample.is_warmup + ) + return DynamicPPL.ParamsWithStats(sample.x, model.logdensity.ld, stats) +end - logp(x) = LogDensityProblems.logdensity(ld, x) - function gradlogp(x) - _, g = LogDensityProblems.logdensity_and_gradient(ld, x) - return g +function AbstractMCMC.bundle_samples( + ts::Vector{<:DynamicPPL.ParamsWithStats}, + model::DensityModel, + spl::AbstractMCMC.AbstractSampler, + state, + chain_type::Type{MCMCChains.Chains}; + discard_warmup=false, + kwargs..., +) + if discard_warmup + ts = filter(t -> hasproperty(t.stats, :is_warmup) && !t.stats.is_warmup, ts) end - - return ParallelMCMC.DensityModel(logp, gradlogp, dim; hvp=hvp, param_names=param_names) + return AbstractMCMC.from_samples(MCMCChains.Chains, hcat(ts)) end -""" -Extract flat parameter names from a DynamicPPL model by sampling from the prior. -Returns a `Vector{Symbol}` if the count matches `expected_dim`, otherwise `nothing`. -""" -function _try_extract_param_names(model::DynamicPPL.Model, expected_dim::Int) - try - vi = DynamicPPL.VarInfo(model) - names = Symbol[] - for vn in keys(vi) - val = vi[vn] - sym = DynamicPPL.getsym(vn) - if val isa Number - push!(names, Symbol(sym)) - else - for i in 1:length(val) - push!(names, Symbol("$(sym)[$i]")) - end - end - end - if length(names) == expected_dim - return names - else - @warn "ParallelMCMC: parameter name extraction produced $(length(names)) names " * - "but model has $expected_dim unconstrained dimensions " * - "(likely due to bijector dimension changes, e.g. Dirichlet/LKJ constraints). " * - "Falling back to generic x[1], x[2], ... names." - return nothing - end - catch - return nothing +function ParallelMCMC._construct_chain( + ::Type{MCMCChains.Chains}, + vals::AbstractMatrix{Float64}, + internals::AbstractMatrix{Float64}, + names::Vector{Symbol}, + internal_names::Vector{Symbol}, + model::DensityModelLDF, +) + pwss = map(zip(eachrow(vals), eachrow(internals))) do (val, internal) + stats = NamedTuple{Tuple(internal_names)}(internal) + DynamicPPL.ParamsWithStats(val, model.logdensity.ld, stats) end + return AbstractMCMC.from_samples(MCMCChains.Chains, hcat(pwss)) end end # module diff --git a/ext/LogDensityProblemsExt.jl b/ext/LogDensityProblemsExt.jl index 908b7d4..2660ee4 100644 --- a/ext/LogDensityProblemsExt.jl +++ b/ext/LogDensityProblemsExt.jl @@ -45,7 +45,7 @@ chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; If DynamicPPL is loaded, the simpler one-step constructor `DensityModel(mymodel(obs))` is also available and extracts parameter names automatically. """ -function ParallelMCMC.DensityModel(ld; param_names=nothing) +function ParallelMCMC.DensityModel(ld; param_names=nothing, hvp=nothing) caps = LogDensityProblems.capabilities(ld) caps isa LogDensityProblems.LogDensityOrder{0} && error( "LogDensityProblems model must support gradients (LogDensityOrder{1} or higher). " * @@ -54,14 +54,15 @@ function ParallelMCMC.DensityModel(ld; param_names=nothing) dim = LogDensityProblems.dimension(ld) - logp(x) = LogDensityProblems.logdensity(ld, x) + logp = ParallelMCMC.LogDensityProblemPrimal(ld) + gradlogp = ParallelMCMC.LogDensityProblemGradient(ld) - function gradlogp(x) - _, g = LogDensityProblems.logdensity_and_gradient(ld, x) - return g - end + return ParallelMCMC.DensityModel(logp, gradlogp, dim; param_names=param_names, hvp=hvp) +end - return ParallelMCMC.DensityModel(logp, gradlogp, dim; param_names=param_names) +(l::ParallelMCMC.LogDensityProblemPrimal)(x) = LogDensityProblems.logdensity(l.ld, x) +function (l::ParallelMCMC.LogDensityProblemGradient)(x) + return last(LogDensityProblems.logdensity_and_gradient(l.ld, x)) end end # module diff --git a/src/interface.jl b/src/interface.jl index d68efc6..8fbda03 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -67,6 +67,28 @@ function DensityModel( ) end +# Callable structs that allow us to dispatch on the type of the LogDensityProblems object in +# the postprocessing stage. Ideally these would be defined in the LogDensityProblemsExt. +# However, structs defined in extensions are hard to get hold of so we define them here. +# The callable behaviour itself is implemented in LogDensityProblemsExt. +struct LogDensityProblemPrimal{L} + ld::L +end +struct LogDensityProblemGradient{L} + ld::L +end + +""" + postprocess_sample(model::DensityModel, transition) + +Optional step to postprocess raw transitions from the sampler. Overloading +this allows us to, for example, transform samples from unconstrained space +back to the original parameter space when wrapping a DynamicPPL model. + +By default, this function returns the transition unchanged. +""" +postprocess_sample(::DensityModel, transition) = transition + """ MALASampler(epsilon; cholM=nothing) @@ -144,7 +166,7 @@ function AbstractMCMC.step( noise, noise_host = _make_noise_buffer(x, FP, model.dim) t = MALATransition(x, logp_val, true) s = MALAState(x, logp_val, ws, noise, noise_host) - return t, s + return postprocess_sample(model, t), s end function AbstractMCMC.step( @@ -177,7 +199,7 @@ function AbstractMCMC.step( logp_val = accepted ? model.logdensity(x_next) : state.logp t = MALATransition(x_next, logp_val, accepted) s = MALAState(x_next, logp_val, state.workspace, state.noise, state.noise_host) - return t, s + return postprocess_sample(model, t), s end function AbstractMCMC.bundle_samples( @@ -557,12 +579,13 @@ function _sample_parallel_mala_chain( rng::Random.AbstractRNG, model::DensityModel, sampler::ParallelMALASampler, - N::Int; + N::Int, + ::Type{Tchn}; initial_params=nothing, param_names=nothing, progress=AbstractMCMC.PROGRESS[], progressname="Sampling", -) +) where {Tchn} D = model.dim names = _parallel_mala_param_names(model, D, param_names) internal_names = [:logp] @@ -601,6 +624,17 @@ function _sample_parallel_mala_chain( end end + return _construct_chain(Tchn, vals, internals, names, internal_names, model) +end + +function _construct_chain( + ::Type{MCMCChains.Chains}, + vals::AbstractMatrix{Float64}, + internals::AbstractMatrix{Float64}, + names::Vector{Symbol}, + internal_names::Vector{Symbol}, + model::DensityModel, +) return MCMCChains.Chains( hcat(vals, internals), vcat(names, internal_names), @@ -732,7 +766,8 @@ function AbstractMCMC.mcmcsample( rng, model, sampler, - N_int; + N_int, + chain_type; initial_params=initial_params, param_names=param_names, progress=progress, @@ -776,7 +811,7 @@ function AbstractMCMC.step( logp1 = logps[1] trans = ParallelMALATransition(x1, logp1) state = ParallelMALAState(x1, logp1, S, logps, ws, tape, 1) - return trans, state + return postprocess_sample(model, trans), state end function AbstractMCMC.step( @@ -802,7 +837,7 @@ function AbstractMCMC.step( state.tape, t_next, ) - return trans, new_state + return postprocess_sample(model, trans), new_state else x0 = state.trajectory[:, T] S_new, tape, ws = _deer_solve_new_tape( @@ -813,7 +848,7 @@ function AbstractMCMC.step( logp_new = logps[1] trans = ParallelMALATransition(x_new, logp_new) new_state = ParallelMALAState(x_new, logp_new, S_new, logps, ws, tape, 1) - return trans, new_state + return postprocess_sample(model, trans), new_state end end @@ -969,7 +1004,7 @@ function AbstractMCMC.step( noise, noise_host, ) - return trans, state + return postprocess_sample(model, trans), state end function AbstractMCMC.step( @@ -1027,7 +1062,7 @@ function AbstractMCMC.step( state.noise, state.noise_host, ) - return trans, new_state + return postprocess_sample(model, trans), new_state end function AbstractMCMC.bundle_samples( diff --git a/test/test-DEER-Turing-Logistic.jl b/test/test-DEER-Turing-Logistic.jl index 8a5bca6..e6e5ff3 100644 --- a/test/test-DEER-Turing-Logistic.jl +++ b/test/test-DEER-Turing-Logistic.jl @@ -72,7 +72,7 @@ end function _deer_logistic_turing_density_model() return DensityModel( _deer_logistic_regression(_LR_X, _LR_y); - hvp=(β, v) -> _hvp_lr(β, v, _LR_X, _LR_y), + hvp=(β, v) -> _hvp_lr(β, v, _LR_X, _LR_y) ) end @@ -82,7 +82,6 @@ end model = _deer_logistic_turing_density_model() @test model.dim == _LR_D - @test model.param_names == [Symbol("β[1]"), Symbol("β[2]")] @test isfinite(model.logdensity(zeros(_LR_D))) @test all(isfinite, model.grad_logdensity(zeros(_LR_D))) end diff --git a/test/test-Turing-Integration.jl b/test/test-Turing-Integration.jl index d379396..c5560a7 100644 --- a/test/test-Turing-Integration.jl +++ b/test/test-Turing-Integration.jl @@ -9,7 +9,7 @@ using ParallelMCMC using DynamicPPL using LogDensityProblems using ADTypes -using Distributions: Beta, Normal, MvNormal +using Distributions: Beta, Normal, MvNormal, product_distribution, Dirichlet # A simple 1-D normal likelihood: μ ~ N(0,1), y | μ ~ N(μ, 0.5) # Posterior: μ | y=1.5 is N(μ_post, σ_post²) @@ -25,15 +25,16 @@ const TRUE_VAR_POST = 0.2 end @model function mv_model(y) - μ ~ MvNormal(zeros(2), I) - y ~ MvNormal(μ, 0.5 * I) + c ~ Dirichlet(ones(3)) # to test constraints + μ ~ product_distribution((a=Normal(), b=Normal())) + y ~ MvNormal([μ.a, μ.b], 0.5 * I) end @model function beta_model() x ~ Beta(2, 2) end -@testset "LogDensityProblemsExt: param_names kwarg" begin +@testset "directly passing LogDensityFunction" begin ld = DynamicPPL.LogDensityFunction( normal_model(TRUE_OBS), DynamicPPL.getlogjoint_internal, @@ -41,10 +42,9 @@ end adtype=ADTypes.AutoEnzyme(), ) - model = DensityModel(ld; param_names=[:μ]) + model = DensityModel(ld) @test model.dim == 1 - @test model.param_names == [:μ] chain = sample( MersenneTwister(1), @@ -54,24 +54,31 @@ end chain_type=MCMCChains.Chains, progress=false, ) - @test :μ in names(chain, :parameters) - @test !(Symbol("x[1]") in names(chain, :parameters)) + @test only(names(chain, :parameters)) == :μ end @testset "DynamicPPLExt: convenience constructor" begin model = DensityModel(normal_model(TRUE_OBS)) @test model.dim == 1 - @test model.param_names == [:μ] @test isfinite(model.logdensity([0.0])) @test isfinite(model.grad_logdensity([0.0])[1]) + + chain = sample( + MersenneTwister(1), + model, + AdaptiveMALASampler(0.3; n_warmup=200), + 600; + chain_type=MCMCChains.Chains, + progress=false, + ) + @test only(names(chain, :parameters)) == :μ end @testset "DynamicPPLExt: convenience constructor uses linked space for constrained models" begin model = DensityModel(beta_model()) @test model.dim == 1 - @test model.param_names == [:x] @test isfinite(model.logdensity([-0.4])) @test isfinite(model.grad_logdensity([-0.4])[1]) end @@ -96,10 +103,10 @@ end MersenneTwister(11), model, sampler; initial_params=[0.0] ) - @test trans isa ParallelMALATransition + @test trans isa DynamicPPL.ParamsWithStats @test state isa ParallelMALAState - @test length(trans.x) == 1 - @test isfinite(trans.logp) + @test only(keys(trans.params)) == @varname(μ) + @test isfinite(trans.stats.logjoint) @test all(isfinite, state.trajectory) end end @@ -176,8 +183,8 @@ end obs = [1.0, -1.0] model = DensityModel(mv_model(obs)) - @test model.dim == 2 - @test model.param_names == [Symbol("μ[1]"), Symbol("μ[2]")] + # 2 from linked Dirichlet + 2 from product_distribution + @test model.dim == 4 chain = sample( MersenneTwister(7), @@ -187,6 +194,13 @@ end chain_type=MCMCChains.Chains, progress=false, ) - @test Symbol("μ[1]") in names(chain, :parameters) - @test Symbol("μ[2]") in names(chain, :parameters) + + # Check that the chain contains parameters in original space. + # The Dirichlet parameter should have length 3. + @test Set(names(chain, :parameters)) == + Set(Symbol.(["c[1]", "c[2]", "c[3]", "μ.a", "μ.b"])) + for i in 1:3 + # Dirichlet samples should be non-negative + @test all(chain[Symbol("c[$i]")] .>= 0.0) + end end From 036003b024f9984f9e5342fd342b14fa642f2724 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 4 May 2026 00:59:17 +0100 Subject: [PATCH 02/12] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 756be13..04cc852 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ authors = ["Ryan Senne "] name = "ParallelMCMC" uuid = "1a970f40-4406-51c9-a967-cb3143c111e8" -version = "0.0.1" +version = "0.0.2" [compat] ADTypes = "1.21.0" From 61f7e6258e20d0bc8368965174cf2175b178d5f1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 4 May 2026 01:46:07 +0100 Subject: [PATCH 03/12] Fix is_warmup check Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- ext/DynamicPPLExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index 8021ac9..a9534fc 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -91,7 +91,7 @@ function AbstractMCMC.bundle_samples( kwargs..., ) if discard_warmup - ts = filter(t -> hasproperty(t.stats, :is_warmup) && !t.stats.is_warmup, ts) + ts = filter(t -> !hasproperty(t.stats, :is_warmup) || !t.stats.is_warmup, ts) end return AbstractMCMC.from_samples(MCMCChains.Chains, hcat(ts)) end From 819094191e316ffe6d71017aa091de88b7686637 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 4 May 2026 01:47:23 +0100 Subject: [PATCH 04/12] Fix docstring typo --- ext/DynamicPPLExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index a9534fc..633d9f8 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -59,7 +59,9 @@ const DensityModelLDF = ParallelMCMC.DensityModel{<:LDFPrimal,<:LDFGradient} postprocess_sample(model::DensityModel, sample) Converts a raw transition (e.g. `MALATransition`) into a `DynamicPPL.ParamsWithStats` object -by reevaluating the DynamicPPL model with the vectorised parameters. This requires that the `DensityModel` object was constructed with a `LogDensityFunctionPrimal` and `LogDensityFunctionGradient` that wrap a DynamicPPL model. +by reevaluating the DynamicPPL model with the vectorised parameters. This requires that the +`DensityModel` object was constructed with a `LogDensityProblemPrimal` and +`LogDensityProblemGradient` that wrap a DynamicPPL model. """ function ParallelMCMC.postprocess_sample( model::DensityModelLDF, sample::ParallelMCMC.MALATransition From c0a3836c55816c861420a01c23330c5522c0b800 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 4 May 2026 01:50:05 +0100 Subject: [PATCH 05/12] Fix docstring --- ext/LogDensityProblemsExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/LogDensityProblemsExt.jl b/ext/LogDensityProblemsExt.jl index 2660ee4..366f737 100644 --- a/ext/LogDensityProblemsExt.jl +++ b/ext/LogDensityProblemsExt.jl @@ -4,7 +4,7 @@ using ParallelMCMC using LogDensityProblems: LogDensityProblems """ - DensityModel(ld; param_names=nothing) + DensityModel(ld; param_names=nothing, hvp=nothing) Construct a `DensityModel` from any object implementing the [LogDensityProblems](https://github.com/tpapp/LogDensityProblems.jl) interface. @@ -20,6 +20,8 @@ that will be used for the columns of the returned `MCMCChains.Chains` object. If omitted, names default to `x[1], x[2], ...` unless you also pass `param_names` to `sample(...)`. +The `hvp` keyword argument is forwarded to the main `DensityModel` constructor. + # Turing.jl / DynamicPPL example ```julia using Turing, LogDensityProblems, ADTypes, Enzyme, ParallelMCMC, MCMCChains From 0a9bef5171e5b9096bdc81c2d7c44fb358f7f5de Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 6 May 2026 11:43:33 +0100 Subject: [PATCH 06/12] Broaden type parameter --- src/interface.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 8fbda03..e780e5e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -629,8 +629,8 @@ end function _construct_chain( ::Type{MCMCChains.Chains}, - vals::AbstractMatrix{Float64}, - internals::AbstractMatrix{Float64}, + vals::AbstractMatrix{Real}, + internals::AbstractMatrix{Real}, names::Vector{Symbol}, internal_names::Vector{Symbol}, model::DensityModel, From f1a3d69c79a749f55702b3397a8176874d50e4f8 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 9 May 2026 02:25:24 +0100 Subject: [PATCH 07/12] Overload `bundle_samples` directly --- ext/DynamicPPLExt.jl | 95 ++++++++++++++++++++++++++++---------------- src/interface.jl | 29 +++++--------- 2 files changed, 70 insertions(+), 54 deletions(-) diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index 633d9f8..46dbbbe 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -50,59 +50,86 @@ function ParallelMCMC.DensityModel( return ParallelMCMC.DensityModel(ld; hvp=hvp) end +###################### +# Chain construction # +###################### +# In this section, we define overloads for DynamicPPL-based models so that resulting chains +# are converted back into the original parameter space and contain the correct parameter +# names. This is done by converting the raw samples (vectors of parameters) back into +# `DynamicPPL.ParamsWithStats` objects. + +const ParallelMCMCTransitionTypes = Union{ + ParallelMCMC.MALATransition, + ParallelMCMC.AdaptiveMALATransition, + ParallelMCMC.ParallelMALATransition, +} # Types that represent LogDensityProblems objects that wrap DynamicPPL models. const LDFPrimal = ParallelMCMC.LogDensityProblemPrimal{<:DynamicPPL.LogDensityFunction} const LDFGradient = ParallelMCMC.LogDensityProblemGradient{<:DynamicPPL.LogDensityFunction} const DensityModelLDF = ParallelMCMC.DensityModel{<:LDFPrimal,<:LDFGradient} -""" - postprocess_sample(model::DensityModel, sample) - -Converts a raw transition (e.g. `MALATransition`) into a `DynamicPPL.ParamsWithStats` object -by reevaluating the DynamicPPL model with the vectorised parameters. This requires that the -`DensityModel` object was constructed with a `LogDensityProblemPrimal` and -`LogDensityProblemGradient` that wrap a DynamicPPL model. -""" -function ParallelMCMC.postprocess_sample( - model::DensityModelLDF, sample::ParallelMCMC.MALATransition +function AbstractMCMC.bundle_samples( + ts::Vector{<:ParallelMCMC.MALATransition}, + model::DensityModelLDF, + spl::ParallelMCMC.MALASampler, + state::ParallelMCMC.MALAState, + chain_type::Type{MCMCChains.Chains}; + kwargs..., ) - stats = (accepted=sample.accepted,) - return DynamicPPL.ParamsWithStats(sample.x, model.logdensity.ld, stats) + return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) end -function ParallelMCMC.postprocess_sample( - model::DensityModelLDF, sample::ParallelMCMC.ParallelMALATransition +function AbstractMCMC.bundle_samples( + ts::Vector{<:ParallelMCMC.ParallelMALATransition}, + model::DensityModelLDF, + spl::ParallelMCMC.ParallelMALASampler, + state::ParallelMCMC.ParallelMALAState, + chain_type::Type{MCMCChains.Chains}; + kwargs..., ) - return DynamicPPL.ParamsWithStats(sample.x, model.logdensity.ld) + return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) end -function ParallelMCMC.postprocess_sample( - model::DensityModelLDF, sample::ParallelMCMC.AdaptiveMALATransition +function AbstractMCMC.bundle_samples( + ts::Vector{<:ParallelMCMC.AdaptiveMALATransition}, + model::DensityModelLDF, + spl::ParallelMCMC.AdaptiveMALASampler, + state::ParallelMCMC.AdaptiveMALAState, + chain_type::Type{MCMCChains.Chains}; + discard_warmup::Bool=false, + kwargs..., ) - stats = ( + ts = discard_warmup ? filter(t -> !t.is_warmup, ts) : ts + return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) +end + +""" + getstats(sample::ParallelMCMCTransitionTypes) + +Get a `NamedTuple` of stats from an MCMC transition. +""" +getstats(sample::ParallelMCMC.MALATransition) = (accepted=sample.accepted,) +function getstats(sample::ParallelMCMC.AdaptiveMALATransition) + return ( accepted=sample.accepted, step_size=sample.step_size, is_warmup=sample.is_warmup ) - return DynamicPPL.ParamsWithStats(sample.x, model.logdensity.ld, stats) end +getstats(::ParallelMCMCTransitionTypes) = (;) -function AbstractMCMC.bundle_samples( - ts::Vector{<:DynamicPPL.ParamsWithStats}, - model::DensityModel, - spl::AbstractMCMC.AbstractSampler, - state, - chain_type::Type{MCMCChains.Chains}; - discard_warmup=false, - kwargs..., -) - if discard_warmup - ts = filter(t -> !hasproperty(t.stats, :is_warmup) || !t.stats.is_warmup, ts) +function make_processed_dynamicppl_chain( + ::Type{Tchain}, ts::Vector{<:ParallelMCMCTransitionTypes}, model::DensityModelLDF +) where {Tchain} + pwss = map(ts) do t + # Note: This assumes that there is always a field called t.x. This is currently true + # of all samplers in ParallelMCMC + DynamicPPL.ParamsWithStats(t.x, model.logdensity.ld, getstats(t)) end - return AbstractMCMC.from_samples(MCMCChains.Chains, hcat(ts)) + return AbstractMCMC.from_samples(Tchain, hcat(pwss)) end function ParallelMCMC._construct_chain( ::Type{MCMCChains.Chains}, - vals::AbstractMatrix{Float64}, - internals::AbstractMatrix{Float64}, - names::Vector{Symbol}, + vals::AbstractMatrix{<:Real}, + internals::AbstractMatrix{<:Real}, + ::Vector{Symbol}, internal_names::Vector{Symbol}, model::DensityModelLDF, ) diff --git a/src/interface.jl b/src/interface.jl index e780e5e..4ad6fd4 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -78,17 +78,6 @@ struct LogDensityProblemGradient{L} ld::L end -""" - postprocess_sample(model::DensityModel, transition) - -Optional step to postprocess raw transitions from the sampler. Overloading -this allows us to, for example, transform samples from unconstrained space -back to the original parameter space when wrapping a DynamicPPL model. - -By default, this function returns the transition unchanged. -""" -postprocess_sample(::DensityModel, transition) = transition - """ MALASampler(epsilon; cholM=nothing) @@ -166,7 +155,7 @@ function AbstractMCMC.step( noise, noise_host = _make_noise_buffer(x, FP, model.dim) t = MALATransition(x, logp_val, true) s = MALAState(x, logp_val, ws, noise, noise_host) - return postprocess_sample(model, t), s + return t, s end function AbstractMCMC.step( @@ -199,7 +188,7 @@ function AbstractMCMC.step( logp_val = accepted ? model.logdensity(x_next) : state.logp t = MALATransition(x_next, logp_val, accepted) s = MALAState(x_next, logp_val, state.workspace, state.noise, state.noise_host) - return postprocess_sample(model, t), s + return t, s end function AbstractMCMC.bundle_samples( @@ -629,8 +618,8 @@ end function _construct_chain( ::Type{MCMCChains.Chains}, - vals::AbstractMatrix{Real}, - internals::AbstractMatrix{Real}, + vals::AbstractMatrix{<:Real}, + internals::AbstractMatrix{<:Real}, names::Vector{Symbol}, internal_names::Vector{Symbol}, model::DensityModel, @@ -811,7 +800,7 @@ function AbstractMCMC.step( logp1 = logps[1] trans = ParallelMALATransition(x1, logp1) state = ParallelMALAState(x1, logp1, S, logps, ws, tape, 1) - return postprocess_sample(model, trans), state + return trans, state end function AbstractMCMC.step( @@ -837,7 +826,7 @@ function AbstractMCMC.step( state.tape, t_next, ) - return postprocess_sample(model, trans), new_state + return trans, new_state else x0 = state.trajectory[:, T] S_new, tape, ws = _deer_solve_new_tape( @@ -848,7 +837,7 @@ function AbstractMCMC.step( logp_new = logps[1] trans = ParallelMALATransition(x_new, logp_new) new_state = ParallelMALAState(x_new, logp_new, S_new, logps, ws, tape, 1) - return postprocess_sample(model, trans), new_state + return trans, new_state end end @@ -1004,7 +993,7 @@ function AbstractMCMC.step( noise, noise_host, ) - return postprocess_sample(model, trans), state + return trans, state end function AbstractMCMC.step( @@ -1062,7 +1051,7 @@ function AbstractMCMC.step( state.noise, state.noise_host, ) - return postprocess_sample(model, trans), new_state + return trans, new_state end function AbstractMCMC.bundle_samples( From 3bfcb5c23da6ea2fabd8a180b16f225b22b7798b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 9 May 2026 02:29:33 +0100 Subject: [PATCH 08/12] Update tests --- test/test-Turing-Integration.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test-Turing-Integration.jl b/test/test-Turing-Integration.jl index c5560a7..d812dfd 100644 --- a/test/test-Turing-Integration.jl +++ b/test/test-Turing-Integration.jl @@ -103,10 +103,9 @@ end MersenneTwister(11), model, sampler; initial_params=[0.0] ) - @test trans isa DynamicPPL.ParamsWithStats + @test trans isa ParallelMALATransition @test state isa ParallelMALAState - @test only(keys(trans.params)) == @varname(μ) - @test isfinite(trans.stats.logjoint) + @test isfinite(trans.logp) @test all(isfinite, state.trajectory) end end From dc089e75cf47c79c3302e9cd8f23348985065df0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 9 May 2026 02:32:25 +0100 Subject: [PATCH 09/12] reduce code duplication via eval --- ext/DynamicPPLExt.jl | 70 +++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index 46dbbbe..07f8603 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -68,39 +68,6 @@ const LDFPrimal = ParallelMCMC.LogDensityProblemPrimal{<:DynamicPPL.LogDensityFu const LDFGradient = ParallelMCMC.LogDensityProblemGradient{<:DynamicPPL.LogDensityFunction} const DensityModelLDF = ParallelMCMC.DensityModel{<:LDFPrimal,<:LDFGradient} -function AbstractMCMC.bundle_samples( - ts::Vector{<:ParallelMCMC.MALATransition}, - model::DensityModelLDF, - spl::ParallelMCMC.MALASampler, - state::ParallelMCMC.MALAState, - chain_type::Type{MCMCChains.Chains}; - kwargs..., -) - return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) -end -function AbstractMCMC.bundle_samples( - ts::Vector{<:ParallelMCMC.ParallelMALATransition}, - model::DensityModelLDF, - spl::ParallelMCMC.ParallelMALASampler, - state::ParallelMCMC.ParallelMALAState, - chain_type::Type{MCMCChains.Chains}; - kwargs..., -) - return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) -end -function AbstractMCMC.bundle_samples( - ts::Vector{<:ParallelMCMC.AdaptiveMALATransition}, - model::DensityModelLDF, - spl::ParallelMCMC.AdaptiveMALASampler, - state::ParallelMCMC.AdaptiveMALAState, - chain_type::Type{MCMCChains.Chains}; - discard_warmup::Bool=false, - kwargs..., -) - ts = discard_warmup ? filter(t -> !t.is_warmup, ts) : ts - return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) -end - """ getstats(sample::ParallelMCMCTransitionTypes) @@ -114,6 +81,43 @@ function getstats(sample::ParallelMCMC.AdaptiveMALATransition) end getstats(::ParallelMCMCTransitionTypes) = (;) +""" + is_warmup(sample::ParallelMCMCTransitionTypes) + +Check if a sample is from the warmup phase of MCMC sampling. +""" +is_warmup(::ParallelMCMCTransitionTypes) = false +is_warmup(sample::ParallelMCMC.AdaptiveMALATransition) = sample.is_warmup + +for (Ttrans, Tspl, Tstate) in ( + (ParallelMCMC.MALATransition, ParallelMCMC.MALASampler, ParallelMCMC.MALAState), + ( + ParallelMCMC.ParallelMALATransition, + ParallelMCMC.ParallelMALASampler, + ParallelMCMC.ParallelMALAState, + ), + ( + ParallelMCMC.AdaptiveMALATransition, + ParallelMCMC.AdaptiveMALASampler, + ParallelMCMC.AdaptiveMALAState, + ), +) + @eval begin + function AbstractMCMC.bundle_samples( + ts::Vector{<:$Ttrans}, + model::DensityModelLDF, + spl::$Tspl, + state::$Tstate, + chain_type::Type{MCMCChains.Chains}; + discard_warmup::Bool=false, + kwargs..., + ) + ts = discard_warmup ? filter(is_warmup, ts) : ts + return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) + end + end +end + function make_processed_dynamicppl_chain( ::Type{Tchain}, ts::Vector{<:ParallelMCMCTransitionTypes}, model::DensityModelLDF ) where {Tchain} From fae2babad8f57ebd10ddeb1265032afc58e56606 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 9 May 2026 02:37:24 +0100 Subject: [PATCH 10/12] Fix incorrect negation --- ext/DynamicPPLExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index 07f8603..ae0c6cb 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -112,7 +112,7 @@ for (Ttrans, Tspl, Tstate) in ( discard_warmup::Bool=false, kwargs..., ) - ts = discard_warmup ? filter(is_warmup, ts) : ts + ts = discard_warmup ? filter(t -> !is_warmup(t), ts) : ts return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model) end end From d308dcc10c79a32638a9e8846a8151df546285a4 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 9 May 2026 02:44:18 +0100 Subject: [PATCH 11/12] Add explanations in docs --- docs/src/10-getting-started.md | 11 +++++++++-- docs/src/index.md | 3 ++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/src/10-getting-started.md b/docs/src/10-getting-started.md index b0a7d22..159afa0 100644 --- a/docs/src/10-getting-started.md +++ b/docs/src/10-getting-started.md @@ -88,11 +88,11 @@ chain = sample(model, sampler, 500; ## Turing.jl integration -ParallelMCMC.jl integrates with Turing.jl models through the `LogDensityProblems` extension. +ParallelMCMC.jl integrates with Turing.jl models through the `DynamicPPL` and `LogDensityProblems` extensions. ### One-step convenience constructor -Load `DynamicPPL` (part of Turing.jl) and a single-argument `DensityModel` constructor becomes available. It extracts parameter names automatically: +Load `DynamicPPL` (part of Turing.jl) and a single-argument `DensityModel` constructor becomes available: ```julia using Turing, ParallelMCMC, MCMCChains @@ -108,6 +108,12 @@ chain = sample(model, ParallelMALASampler(0.1; T=64), 500; chain_type=MCMCChains.Chains) ``` +Much like Turing's own samplers, the resulting chain will always have parameters in the original (possibly constrained) space, even though the MCMC sampling itself is performed in unconstrained space. +Furthermore, parameter names are automatically extracted from the Turing model (and will always be the same as those when using Turing's own samplers). + +Note that when sampling with a Turing model the returned chain will have `:logjoint`, `:logprior`, and `:loglikelihood` columns, since Turing models provide enough information to separate these contributions to the log-density. +This is in contrast to sampling with a manually constructed `DensityModel`, which returns only a single `:logp` column. + ### Manual `LogDensityProblems` path For explicit control over the AD backend, construct the `LogDensityFunction` @@ -128,6 +134,7 @@ model = DensityModel(ld; param_names=[:μ]) ``` This also accepts any other `LogDensityProblems`-compatible object. +As above, the returned chain will always contain parameters in the original space: the use of `LinkAll()` above only stipulates that MCMC sampling itself is to be performed in unconstrained space, and has no result on the form of the returned chain. --- diff --git a/docs/src/index.md b/docs/src/index.md index f267b7d..be54b16 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -91,7 +91,8 @@ chain = sample(model, sampler, 2_000; ### Turing.jl integration -When `DynamicPPL` (part of Turing.jl) is loaded, a one-argument `DensityModel` constructor is available that wraps a `@model` directly, extracting parameter names automatically: +When `DynamicPPL` (part of Turing.jl) is loaded, a one-argument `DensityModel` constructor is available that wraps a `@model` directly. +Parameter names are automatically extracted, and values transformed back to the original model space: ```julia using Turing, ParallelMCMC, MCMCChains From f58b89fe5a299d76b2c7cce5df07e181c645cbf9 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Sat, 9 May 2026 11:07:50 -0400 Subject: [PATCH 12/12] add changelog --- CHANGELOG.md | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..7ba0801 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,43 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.0.2] + +### Added + +- DynamicPPL-backed `DensityModel`s now map samples back to the original + (possibly constrained) parameter space, with names taken directly from the + Turing model. This works correctly for distributions whose dimension changes + under linking (e.g. `Dirichlet`, `LKJ`, `product_distribution` with + `NamedTuple` keys). +- Chains from Turing models now expose `:logjoint`, `:logprior`, and + `:loglikelihood` as separate columns. +- `hvp` keyword argument is now forwarded through the `LogDensityProblems`-based + `DensityModel` constructor. + +### Changed + +- `DynamicPPL` compat bumped to `0.40.6, 0.41`. +- Parameter names for Turing models are now derived by reevaluating the model + via `DynamicPPL.ParamsWithStats` rather than extracted heuristically at + construction time. The `param_names` field on `DensityModel` is no longer + populated by the DynamicPPL convenience constructor. +- `model.logdensity` and `model.grad_logdensity` constructed from a + `LogDensityProblems` object are now `LogDensityProblemPrimal` / + `LogDensityProblemGradient` callable structs rather than anonymous closures. + Calling behaviour is unchanged; only the concrete type differs. + +### Removed + +- Heuristic prior-based parameter-name extraction (`_try_extract_param_names`) + and its warning fallback for dimension-changing bijectors. + +## [0.0.1] + +- Initial release.