Skip to content
43 changes: 43 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

## [0.0.2]

### Added

- DynamicPPL-backed `DensityModel`s now map samples back to the original
(possibly constrained) parameter space, with names taken directly from the
Turing model. This works correctly for distributions whose dimension changes
under linking (e.g. `Dirichlet`, `LKJ`, `product_distribution` with
`NamedTuple` keys).
- Chains from Turing models now expose `:logjoint`, `:logprior`, and
`:loglikelihood` as separate columns.
- `hvp` keyword argument is now forwarded through the `LogDensityProblems`-based
`DensityModel` constructor.

### Changed

- `DynamicPPL` compat bumped to `0.40.6, 0.41`.
- Parameter names for Turing models are now derived by reevaluating the model
via `DynamicPPL.ParamsWithStats` rather than extracted heuristically at
construction time. The `param_names` field on `DensityModel` is no longer
populated by the DynamicPPL convenience constructor.
- `model.logdensity` and `model.grad_logdensity` constructed from a
`LogDensityProblems` object are now `LogDensityProblemPrimal` /
`LogDensityProblemGradient` callable structs rather than anonymous closures.
Calling behaviour is unchanged; only the concrete type differs.

### Removed

- Heuristic prior-based parameter-name extraction (`_try_extract_param_names`)
and its warning fallback for dimension-changing bijectors.

## [0.0.1]

- Initial release.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
authors = ["Ryan Senne <rsenne@bu.edu>"]
name = "ParallelMCMC"
uuid = "1a970f40-4406-51c9-a967-cb3143c111e8"
version = "0.0.1"
version = "0.0.2"

[compat]
ADTypes = "1.21.0"
AbstractMCMC = "5.10.0"
CUDA = "5.11.0"
DifferentiationInterface = "0.7.13"
DynamicPPL = "0.40"
DynamicPPL = "0.40.6, 0.41"
Enzyme = "0.13.131"
LinearAlgebra = "1"
LogDensityProblems = "2"
Expand Down
11 changes: 9 additions & 2 deletions docs/src/10-getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ chain = sample(model, sampler, 500;

## Turing.jl integration

ParallelMCMC.jl integrates with Turing.jl models through the `LogDensityProblems` extension.
ParallelMCMC.jl integrates with Turing.jl models through the `DynamicPPL` and `LogDensityProblems` extensions.

### One-step convenience constructor

Load `DynamicPPL` (part of Turing.jl) and a single-argument `DensityModel` constructor becomes available. It extracts parameter names automatically:
Load `DynamicPPL` (part of Turing.jl) and a single-argument `DensityModel` constructor becomes available:

```julia
using Turing, ParallelMCMC, MCMCChains
Expand All @@ -108,6 +108,12 @@ chain = sample(model, ParallelMALASampler(0.1; T=64), 500;
chain_type=MCMCChains.Chains)
```

Much like Turing's own samplers, the resulting chain will always have parameters in the original (possibly constrained) space, even though the MCMC sampling itself is performed in unconstrained space.
Furthermore, parameter names are automatically extracted from the Turing model (and will always be the same as those when using Turing's own samplers).

Note that when sampling with a Turing model the returned chain will have `:logjoint`, `:logprior`, and `:loglikelihood` columns, since Turing models provide enough information to separate these contributions to the log-density.
This is in contrast to sampling with a manually constructed `DensityModel`, which returns only a single `:logp` column.

### Manual `LogDensityProblems` path

For explicit control over the AD backend, construct the `LogDensityFunction`
Expand All @@ -128,6 +134,7 @@ model = DensityModel(ld; param_names=[:μ])
```

This also accepts any other `LogDensityProblems`-compatible object.
As above, the returned chain will always contain parameters in the original space: the use of `LinkAll()` above only stipulates that MCMC sampling itself is to be performed in unconstrained space, and has no result on the form of the returned chain.

---

Expand Down
3 changes: 2 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ chain = sample(model, sampler, 2_000;

### Turing.jl integration

When `DynamicPPL` (part of Turing.jl) is loaded, a one-argument `DensityModel` constructor is available that wraps a `@model` directly, extracting parameter names automatically:
When `DynamicPPL` (part of Turing.jl) is loaded, a one-argument `DensityModel` constructor is available that wraps a `@model` directly.
Parameter names are automatically extracted, and values transformed back to the original model space:

```julia
using Turing, ParallelMCMC, MCMCChains
Expand Down
137 changes: 88 additions & 49 deletions ext/DynamicPPLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ module DynamicPPLExt
using ParallelMCMC
using ADTypes: ADTypes
using DynamicPPL: DynamicPPL
using AbstractMCMC: AbstractMCMC
using Enzyme: Enzyme
using MCMCChains: MCMCChains
using LogDensityProblems: LogDensityProblems

"""
Expand All @@ -28,13 +30,6 @@ model = DensityModel(mymodel(1.5))
chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000;
chain_type=MCMCChains.Chains, discard_warmup=true, progress=true)
```

# Notes
- Parameter names are extracted from the model's prior. For most common
distributions (Normal, MvNormal, Exponential, etc.) the names match the
unconstrained parameter space used by LogDensityProblems. If the extracted
names do not match the dimensionality (e.g. due to simplex constraints), the
constructor falls back to generic `x[1], x[2], ...` names with a warning.
"""
function ParallelMCMC.DensityModel(
turing_model::DynamicPPL.Model;
Expand All @@ -51,58 +46,102 @@ function ParallelMCMC.DensityModel(
DynamicPPL.LinkAll();
adtype=ad_backend,
)
# Requires LogDensityProblemsExt to be loaded
return ParallelMCMC.DensityModel(ld; hvp=hvp)
end

caps = LogDensityProblems.capabilities(ld)
caps isa LogDensityProblems.LogDensityOrder{0} && error(
"AD gradient setup failed. The wrapped model must support gradients. " *
"Ensure your ad_backend is compatible.",
)

dim = LogDensityProblems.dimension(ld)
######################
# Chain construction #
######################
# In this section, we define overloads for DynamicPPL-based models so that resulting chains
# are converted back into the original parameter space and contain the correct parameter
# names. This is done by converting the raw samples (vectors of parameters) back into
# `DynamicPPL.ParamsWithStats` objects.

# Try to extract parameter names; fall back to nothing on any error or mismatch.
param_names = _try_extract_param_names(turing_model, dim)
const ParallelMCMCTransitionTypes = Union{
ParallelMCMC.MALATransition,
ParallelMCMC.AdaptiveMALATransition,
ParallelMCMC.ParallelMALATransition,
}
# Types that represent LogDensityProblems objects that wrap DynamicPPL models.
const LDFPrimal = ParallelMCMC.LogDensityProblemPrimal{<:DynamicPPL.LogDensityFunction}
const LDFGradient = ParallelMCMC.LogDensityProblemGradient{<:DynamicPPL.LogDensityFunction}
const DensityModelLDF = ParallelMCMC.DensityModel{<:LDFPrimal,<:LDFGradient}

logp(x) = LogDensityProblems.logdensity(ld, x)
function gradlogp(x)
_, g = LogDensityProblems.logdensity_and_gradient(ld, x)
return g
end
"""
getstats(sample::ParallelMCMCTransitionTypes)

return ParallelMCMC.DensityModel(logp, gradlogp, dim; hvp=hvp, param_names=param_names)
Get a `NamedTuple` of stats from an MCMC transition.
"""
getstats(sample::ParallelMCMC.MALATransition) = (accepted=sample.accepted,)
function getstats(sample::ParallelMCMC.AdaptiveMALATransition)
return (
accepted=sample.accepted, step_size=sample.step_size, is_warmup=sample.is_warmup
)
end
getstats(::ParallelMCMCTransitionTypes) = (;)

"""
Extract flat parameter names from a DynamicPPL model by sampling from the prior.
Returns a `Vector{Symbol}` if the count matches `expected_dim`, otherwise `nothing`.
is_warmup(sample::ParallelMCMCTransitionTypes)

Check if a sample is from the warmup phase of MCMC sampling.
"""
function _try_extract_param_names(model::DynamicPPL.Model, expected_dim::Int)
try
vi = DynamicPPL.VarInfo(model)
names = Symbol[]
for vn in keys(vi)
val = vi[vn]
sym = DynamicPPL.getsym(vn)
if val isa Number
push!(names, Symbol(sym))
else
for i in 1:length(val)
push!(names, Symbol("$(sym)[$i]"))
end
end
end
if length(names) == expected_dim
return names
else
@warn "ParallelMCMC: parameter name extraction produced $(length(names)) names " *
"but model has $expected_dim unconstrained dimensions " *
"(likely due to bijector dimension changes, e.g. Dirichlet/LKJ constraints). " *
"Falling back to generic x[1], x[2], ... names."
return nothing
is_warmup(::ParallelMCMCTransitionTypes) = false
is_warmup(sample::ParallelMCMC.AdaptiveMALATransition) = sample.is_warmup

for (Ttrans, Tspl, Tstate) in (
(ParallelMCMC.MALATransition, ParallelMCMC.MALASampler, ParallelMCMC.MALAState),
(
ParallelMCMC.ParallelMALATransition,
ParallelMCMC.ParallelMALASampler,
ParallelMCMC.ParallelMALAState,
),
(
ParallelMCMC.AdaptiveMALATransition,
ParallelMCMC.AdaptiveMALASampler,
ParallelMCMC.AdaptiveMALAState,
),
)
@eval begin
function AbstractMCMC.bundle_samples(
ts::Vector{<:$Ttrans},
model::DensityModelLDF,
spl::$Tspl,
state::$Tstate,
chain_type::Type{MCMCChains.Chains};
discard_warmup::Bool=false,
kwargs...,
)
ts = discard_warmup ? filter(t -> !is_warmup(t), ts) : ts
return make_processed_dynamicppl_chain(MCMCChains.Chains, ts, model)
end
catch
return nothing
end
end

function make_processed_dynamicppl_chain(
::Type{Tchain}, ts::Vector{<:ParallelMCMCTransitionTypes}, model::DensityModelLDF
) where {Tchain}
pwss = map(ts) do t
# Note: This assumes that there is always a field called t.x. This is currently true
# of all samplers in ParallelMCMC
DynamicPPL.ParamsWithStats(t.x, model.logdensity.ld, getstats(t))
end
return AbstractMCMC.from_samples(Tchain, hcat(pwss))
end

function ParallelMCMC._construct_chain(
::Type{MCMCChains.Chains},
vals::AbstractMatrix{<:Real},
internals::AbstractMatrix{<:Real},
::Vector{Symbol},
internal_names::Vector{Symbol},
model::DensityModelLDF,
)
pwss = map(zip(eachrow(vals), eachrow(internals))) do (val, internal)
stats = NamedTuple{Tuple(internal_names)}(internal)
DynamicPPL.ParamsWithStats(val, model.logdensity.ld, stats)
Comment thread
rsenne marked this conversation as resolved.
end
return AbstractMCMC.from_samples(MCMCChains.Chains, hcat(pwss))
end
Comment thread
rsenne marked this conversation as resolved.

end # module
19 changes: 11 additions & 8 deletions ext/LogDensityProblemsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using ParallelMCMC
using LogDensityProblems: LogDensityProblems

"""
DensityModel(ld; param_names=nothing)
DensityModel(ld; param_names=nothing, hvp=nothing)

Construct a `DensityModel` from any object implementing the
[LogDensityProblems](https://github.com/tpapp/LogDensityProblems.jl) interface.
Expand All @@ -20,6 +20,8 @@ that will be used for the columns of the returned `MCMCChains.Chains` object.
If omitted, names default to `x[1], x[2], ...` unless you also pass `param_names`
to `sample(...)`.

The `hvp` keyword argument is forwarded to the main `DensityModel` constructor.

# Turing.jl / DynamicPPL example
```julia
using Turing, LogDensityProblems, ADTypes, Enzyme, ParallelMCMC, MCMCChains
Expand All @@ -45,7 +47,7 @@ chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000;
If DynamicPPL is loaded, the simpler one-step constructor `DensityModel(mymodel(obs))`
is also available and extracts parameter names automatically.
"""
function ParallelMCMC.DensityModel(ld; param_names=nothing)
function ParallelMCMC.DensityModel(ld; param_names=nothing, hvp=nothing)
caps = LogDensityProblems.capabilities(ld)
caps isa LogDensityProblems.LogDensityOrder{0} && error(
"LogDensityProblems model must support gradients (LogDensityOrder{1} or higher). " *
Expand All @@ -54,14 +56,15 @@ function ParallelMCMC.DensityModel(ld; param_names=nothing)

dim = LogDensityProblems.dimension(ld)

logp(x) = LogDensityProblems.logdensity(ld, x)
logp = ParallelMCMC.LogDensityProblemPrimal(ld)
gradlogp = ParallelMCMC.LogDensityProblemGradient(ld)

function gradlogp(x)
_, g = LogDensityProblems.logdensity_and_gradient(ld, x)
return g
end
return ParallelMCMC.DensityModel(logp, gradlogp, dim; param_names=param_names, hvp=hvp)
Comment thread
penelopeysm marked this conversation as resolved.
end

return ParallelMCMC.DensityModel(logp, gradlogp, dim; param_names=param_names)
(l::ParallelMCMC.LogDensityProblemPrimal)(x) = LogDensityProblems.logdensity(l.ld, x)
function (l::ParallelMCMC.LogDensityProblemGradient)(x)
return last(LogDensityProblems.logdensity_and_gradient(l.ld, x))
end

end # module
30 changes: 27 additions & 3 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ function DensityModel(
)
end

# Callable structs that allow us to dispatch on the type of the LogDensityProblems object in
# the postprocessing stage. Ideally these would be defined in the LogDensityProblemsExt.
# However, structs defined in extensions are hard to get hold of so we define them here.
# The callable behaviour itself is implemented in LogDensityProblemsExt.
struct LogDensityProblemPrimal{L}
ld::L
end
struct LogDensityProblemGradient{L}
ld::L
end

"""
MALASampler(epsilon; cholM=nothing)

Expand Down Expand Up @@ -557,12 +568,13 @@ function _sample_parallel_mala_chain(
rng::Random.AbstractRNG,
model::DensityModel,
sampler::ParallelMALASampler,
N::Int;
N::Int,
::Type{Tchn};
initial_params=nothing,
param_names=nothing,
progress=AbstractMCMC.PROGRESS[],
progressname="Sampling",
)
) where {Tchn}
D = model.dim
names = _parallel_mala_param_names(model, D, param_names)
internal_names = [:logp]
Expand Down Expand Up @@ -601,6 +613,17 @@ function _sample_parallel_mala_chain(
end
end

return _construct_chain(Tchn, vals, internals, names, internal_names, model)
end

function _construct_chain(
::Type{MCMCChains.Chains},
vals::AbstractMatrix{<:Real},
internals::AbstractMatrix{<:Real},
names::Vector{Symbol},
internal_names::Vector{Symbol},
model::DensityModel,
)
return MCMCChains.Chains(
hcat(vals, internals),
vcat(names, internal_names),
Expand Down Expand Up @@ -732,7 +755,8 @@ function AbstractMCMC.mcmcsample(
rng,
model,
sampler,
N_int;
N_int,
chain_type;
initial_params=initial_params,
param_names=param_names,
progress=progress,
Expand Down
Loading
Loading