Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ DynamicPPL = "0.40"
Enzyme = "0.13.131"
LinearAlgebra = "1"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
MCMCChains = "7.7.0"
Random = "1"
Statistics = "1"
Expand All @@ -30,7 +29,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
DynamicPPLExt = ["DynamicPPL", "LogDensityProblems", "LogDensityProblemsAD"]
DynamicPPLExt = ["DynamicPPL", "LogDensityProblems"]
LogDensityProblemsExt = "LogDensityProblems"

[extras]
Expand All @@ -39,4 +38,3 @@ CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
[weakdeps]
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
15 changes: 10 additions & 5 deletions docs/src/10-getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,21 @@ chain = sample(model, ParallelMALASampler(0.1; T=64), 500;

### Manual `LogDensityProblems` path

For explicit control over the AD backend, wrap the model yourself:
For explicit control over the AD backend, construct the `LogDensityFunction`
directly with DynamicPPL's `adtype` interface:

```julia
using Turing, LogDensityProblems, LogDensityProblemsAD, ADTypes
using Turing, LogDensityProblems, ADTypes
using ParallelMCMC, MCMCChains

ld = DynamicPPL.LogDensityFunction(normal_model(1.5))
ldg = LogDensityProblemsAD.ADgradient(ADTypes.AutoEnzyme(), ld)
ld = DynamicPPL.LogDensityFunction(
normal_model(1.5),
Comment on lines +117 to +121
DynamicPPL.getlogjoint_internal,
DynamicPPL.LinkAll();
adtype=ADTypes.AutoEnzyme(),
)

model = DensityModel(ldg; param_names=[:μ])
model = DensityModel(ld; param_names=[:μ])
```

This also accepts any other `LogDensityProblems`-compatible object.
Expand Down
21 changes: 12 additions & 9 deletions ext/DynamicPPLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ using ADTypes: ADTypes
using DynamicPPL: DynamicPPL
using Enzyme: Enzyme
using LogDensityProblems: LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

"""
DensityModel(turing_model::DynamicPPL.Model; ad_backend=ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Duplicated), hvp=nothing)

Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a
`DensityModel`, automatically extracting parameter names and wiring up gradient
computation via `LogDensityProblemsAD`.
computation via DynamicPPL's `adtype` interface.

Requires `DynamicPPL` and `LogDensityProblemsAD` to be loaded.
Requires `DynamicPPL` to be loaded.

# Example
```julia
Expand Down Expand Up @@ -45,24 +44,28 @@ function ParallelMCMC.DensityModel(
),
hvp=nothing,
)
# Build the LogDensityProblems-compatible gradient object
ld = DynamicPPL.LogDensityFunction(turing_model)
ldg = LogDensityProblemsAD.ADgradient(ad_backend, ld)
# Sample in linked/unconstrained space and let DynamicPPL provide the gradient.
ld = DynamicPPL.LogDensityFunction(
turing_model,
DynamicPPL.getlogjoint_internal,
DynamicPPL.LinkAll();
adtype=ad_backend,
)

caps = LogDensityProblems.capabilities(ldg)
caps = LogDensityProblems.capabilities(ld)
caps isa LogDensityProblems.LogDensityOrder{0} && error(
"AD gradient setup failed. The wrapped model must support gradients. " *
"Ensure your ad_backend is compatible.",
)

dim = LogDensityProblems.dimension(ldg)
dim = LogDensityProblems.dimension(ld)

# Try to extract parameter names; fall back to nothing on any error or mismatch.
param_names = _try_extract_param_names(turing_model, dim)

logp(x) = LogDensityProblems.logdensity(ld, x)
function gradlogp(x)
_, g = LogDensityProblems.logdensity_and_gradient(ldg, x)
_, g = LogDensityProblems.logdensity_and_gradient(ld, x)
return g
end

Expand Down
23 changes: 13 additions & 10 deletions ext/LogDensityProblemsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ Construct a `DensityModel` from any object implementing the
`ld` must support:
- `LogDensityProblems.capabilities(ld)` returning at least
`LogDensityProblems.LogDensityOrder{1}` (i.e. gradient available).
- `LogDensityProblems.dimension(ld)` `Int`
- `LogDensityProblems.logdensity_and_gradient(ld, x)` `(logp, grad)`
- `LogDensityProblems.dimension(ld)` -> `Int`
- `LogDensityProblems.logdensity_and_gradient(ld, x)` -> `(logp, grad)`

The optional `param_names` keyword accepts a `Vector{Symbol}` of parameter names
that will be used for the columns of the returned `MCMCChains.Chains` object.
Expand All @@ -22,31 +22,34 @@ to `sample(...)`.

# Turing.jl / DynamicPPL example
```julia
using Turing, LogDensityProblems, LogDensityProblemsAD, Enzyme, ParallelMCMC, MCMCChains
using Turing, LogDensityProblems, ADTypes, Enzyme, ParallelMCMC, MCMCChains

@model function mymodel(y)
μ ~ Normal(0, 1)
y ~ Normal(μ, 0.5)
end

obs = 1.5
ld = DynamicPPL.LogDensityFunction(mymodel(obs))
ldg = LogDensityProblemsAD.ADgradient(ADTypes.AutoEnzyme(), ld)
ld = DynamicPPL.LogDensityFunction(
mymodel(obs),
DynamicPPL.getlogjoint_internal,
DynamicPPL.LinkAll();
adtype=ADTypes.AutoEnzyme(),
Comment on lines +33 to +37
)

model = DensityModel(ldg; param_names=[:μ])
model = DensityModel(ld; param_names=[:μ])
chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000;
chain_type=MCMCChains.Chains, discard_warmup=true, progress=true)
```

If both DynamicPPL and LogDensityProblemsAD are loaded, the simpler one-step
constructor `DensityModel(mymodel(obs))` is also available and extracts parameter
names automatically.
If DynamicPPL is loaded, the simpler one-step constructor `DensityModel(mymodel(obs))`
is also available and extracts parameter names automatically.
"""
function ParallelMCMC.DensityModel(ld; param_names=nothing)
caps = LogDensityProblems.capabilities(ld)
caps isa LogDensityProblems.LogDensityOrder{0} && error(
"LogDensityProblems model must support gradients (LogDensityOrder{1} or higher). " *
"Wrap it with LogDensityProblemsAD.ADgradient first.",
"Construct it with gradient support enabled.",
)

dim = LogDensityProblems.dimension(ld)
Expand Down
1 change: 0 additions & 1 deletion test/test-DEER-Turing-Logistic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using ParallelMCMC

using DynamicPPL
using LogDensityProblems
using LogDensityProblemsAD
using ADTypes
using Distributions: MvNormal, Bernoulli

Expand Down
26 changes: 21 additions & 5 deletions test/test-Turing-Integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ using ParallelMCMC

using DynamicPPL
using LogDensityProblems
using LogDensityProblemsAD
using ADTypes
using Distributions: Normal, MvNormal
using Distributions: Beta, Normal, MvNormal

# A simple 1-D normal likelihood: μ ~ N(0,1), y | μ ~ N(μ, 0.5)
# Posterior: μ | y=1.5 is N(μ_post, σ_post²)
Expand All @@ -30,11 +29,19 @@ end
y ~ MvNormal(μ, 0.5 * I)
end

@model function beta_model()
x ~ Beta(2, 2)
end

@testset "LogDensityProblemsExt: param_names kwarg" begin
ld = DynamicPPL.LogDensityFunction(normal_model(TRUE_OBS))
ldg = LogDensityProblemsAD.ADgradient(ADTypes.AutoEnzyme(), ld)
ld = DynamicPPL.LogDensityFunction(
normal_model(TRUE_OBS),
DynamicPPL.getlogjoint_internal,
DynamicPPL.LinkAll();
adtype=ADTypes.AutoEnzyme(),
)

model = DensityModel(ldg; param_names=[:μ])
model = DensityModel(ld; param_names=[:μ])

@test model.dim == 1
@test model.param_names == [:μ]
Expand All @@ -60,6 +67,15 @@ end
@test isfinite(model.grad_logdensity([0.0])[1])
end

@testset "DynamicPPLExt: convenience constructor uses linked space for constrained models" begin
model = DensityModel(beta_model())

@test model.dim == 1
@test model.param_names == [:x]
@test isfinite(model.logdensity([-0.4]))
@test isfinite(model.grad_logdensity([-0.4])[1])
end

@testset "DynamicPPLExt: generic Turing model works with ParallelMALA and default Enzyme HVP" begin
model = DensityModel(normal_model(TRUE_OBS))

Expand Down
Loading