diff --git a/Project.toml b/Project.toml index 171dcd9..7c255d3 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ DynamicPPL = "0.40" Enzyme = "0.13.131" LinearAlgebra = "1" LogDensityProblems = "2" -LogDensityProblemsAD = "1" MCMCChains = "7.7.0" Random = "1" Statistics = "1" @@ -30,7 +29,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [extensions] -DynamicPPLExt = ["DynamicPPL", "LogDensityProblems", "LogDensityProblemsAD"] +DynamicPPLExt = ["DynamicPPL", "LogDensityProblems"] LogDensityProblemsExt = "LogDensityProblems" [extras] @@ -39,4 +38,3 @@ CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" [weakdeps] DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" diff --git a/docs/src/10-getting-started.md b/docs/src/10-getting-started.md index 5a5826f..b0a7d22 100644 --- a/docs/src/10-getting-started.md +++ b/docs/src/10-getting-started.md @@ -110,16 +110,21 @@ chain = sample(model, ParallelMALASampler(0.1; T=64), 500; ### Manual `LogDensityProblems` path -For explicit control over the AD backend, wrap the model yourself: +For explicit control over the AD backend, construct the `LogDensityFunction` +directly with DynamicPPL's `adtype` interface: ```julia -using Turing, LogDensityProblems, LogDensityProblemsAD, ADTypes +using Turing, LogDensityProblems, ADTypes using ParallelMCMC, MCMCChains -ld = DynamicPPL.LogDensityFunction(normal_model(1.5)) -ldg = LogDensityProblemsAD.ADgradient(ADTypes.AutoEnzyme(), ld) +ld = DynamicPPL.LogDensityFunction( + normal_model(1.5), + DynamicPPL.getlogjoint_internal, + DynamicPPL.LinkAll(); + adtype=ADTypes.AutoEnzyme(), +) -model = DensityModel(ldg; param_names=[:μ]) +model = DensityModel(ld; param_names=[:μ]) ``` This also accepts any other `LogDensityProblems`-compatible object. diff --git a/ext/DynamicPPLExt.jl b/ext/DynamicPPLExt.jl index 3ee8752..294bfba 100644 --- a/ext/DynamicPPLExt.jl +++ b/ext/DynamicPPLExt.jl @@ -5,16 +5,15 @@ using ADTypes: ADTypes using DynamicPPL: DynamicPPL using Enzyme: Enzyme using LogDensityProblems: LogDensityProblems -using LogDensityProblemsAD: LogDensityProblemsAD """ DensityModel(turing_model::DynamicPPL.Model; ad_backend=ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation=Enzyme.Duplicated), hvp=nothing) Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a `DensityModel`, automatically extracting parameter names and wiring up gradient -computation via `LogDensityProblemsAD`. +computation via DynamicPPL's `adtype` interface. -Requires `DynamicPPL` and `LogDensityProblemsAD` to be loaded. +Requires `DynamicPPL` to be loaded. # Example ```julia @@ -45,24 +44,28 @@ function ParallelMCMC.DensityModel( ), hvp=nothing, ) - # Build the LogDensityProblems-compatible gradient object - ld = DynamicPPL.LogDensityFunction(turing_model) - ldg = LogDensityProblemsAD.ADgradient(ad_backend, ld) + # Sample in linked/unconstrained space and let DynamicPPL provide the gradient. + ld = DynamicPPL.LogDensityFunction( + turing_model, + DynamicPPL.getlogjoint_internal, + DynamicPPL.LinkAll(); + adtype=ad_backend, + ) - caps = LogDensityProblems.capabilities(ldg) + caps = LogDensityProblems.capabilities(ld) caps isa LogDensityProblems.LogDensityOrder{0} && error( "AD gradient setup failed. The wrapped model must support gradients. " * "Ensure your ad_backend is compatible.", ) - dim = LogDensityProblems.dimension(ldg) + dim = LogDensityProblems.dimension(ld) # Try to extract parameter names; fall back to nothing on any error or mismatch. param_names = _try_extract_param_names(turing_model, dim) logp(x) = LogDensityProblems.logdensity(ld, x) function gradlogp(x) - _, g = LogDensityProblems.logdensity_and_gradient(ldg, x) + _, g = LogDensityProblems.logdensity_and_gradient(ld, x) return g end diff --git a/ext/LogDensityProblemsExt.jl b/ext/LogDensityProblemsExt.jl index 4430066..908b7d4 100644 --- a/ext/LogDensityProblemsExt.jl +++ b/ext/LogDensityProblemsExt.jl @@ -12,8 +12,8 @@ Construct a `DensityModel` from any object implementing the `ld` must support: - `LogDensityProblems.capabilities(ld)` returning at least `LogDensityProblems.LogDensityOrder{1}` (i.e. gradient available). -- `LogDensityProblems.dimension(ld)` → `Int` -- `LogDensityProblems.logdensity_and_gradient(ld, x)` → `(logp, grad)` +- `LogDensityProblems.dimension(ld)` -> `Int` +- `LogDensityProblems.logdensity_and_gradient(ld, x)` -> `(logp, grad)` The optional `param_names` keyword accepts a `Vector{Symbol}` of parameter names that will be used for the columns of the returned `MCMCChains.Chains` object. @@ -22,7 +22,7 @@ to `sample(...)`. # Turing.jl / DynamicPPL example ```julia -using Turing, LogDensityProblems, LogDensityProblemsAD, Enzyme, ParallelMCMC, MCMCChains +using Turing, LogDensityProblems, ADTypes, Enzyme, ParallelMCMC, MCMCChains @model function mymodel(y) μ ~ Normal(0, 1) @@ -30,23 +30,26 @@ using Turing, LogDensityProblems, LogDensityProblemsAD, Enzyme, ParallelMCMC, MC end obs = 1.5 -ld = DynamicPPL.LogDensityFunction(mymodel(obs)) -ldg = LogDensityProblemsAD.ADgradient(ADTypes.AutoEnzyme(), ld) +ld = DynamicPPL.LogDensityFunction( + mymodel(obs), + DynamicPPL.getlogjoint_internal, + DynamicPPL.LinkAll(); + adtype=ADTypes.AutoEnzyme(), +) -model = DensityModel(ldg; param_names=[:μ]) +model = DensityModel(ld; param_names=[:μ]) chain = sample(model, AdaptiveMALASampler(0.3; n_warmup=500), 2_000; chain_type=MCMCChains.Chains, discard_warmup=true, progress=true) ``` -If both DynamicPPL and LogDensityProblemsAD are loaded, the simpler one-step -constructor `DensityModel(mymodel(obs))` is also available and extracts parameter -names automatically. +If DynamicPPL is loaded, the simpler one-step constructor `DensityModel(mymodel(obs))` +is also available and extracts parameter names automatically. """ function ParallelMCMC.DensityModel(ld; param_names=nothing) caps = LogDensityProblems.capabilities(ld) caps isa LogDensityProblems.LogDensityOrder{0} && error( "LogDensityProblems model must support gradients (LogDensityOrder{1} or higher). " * - "Wrap it with LogDensityProblemsAD.ADgradient first.", + "Construct it with gradient support enabled.", ) dim = LogDensityProblems.dimension(ld) diff --git a/test/test-DEER-Turing-Logistic.jl b/test/test-DEER-Turing-Logistic.jl index 13e9fc9..8a5bca6 100644 --- a/test/test-DEER-Turing-Logistic.jl +++ b/test/test-DEER-Turing-Logistic.jl @@ -8,7 +8,6 @@ using ParallelMCMC using DynamicPPL using LogDensityProblems -using LogDensityProblemsAD using ADTypes using Distributions: MvNormal, Bernoulli diff --git a/test/test-Turing-Integration.jl b/test/test-Turing-Integration.jl index 635ddcd..d379396 100644 --- a/test/test-Turing-Integration.jl +++ b/test/test-Turing-Integration.jl @@ -8,9 +8,8 @@ using ParallelMCMC using DynamicPPL using LogDensityProblems -using LogDensityProblemsAD using ADTypes -using Distributions: Normal, MvNormal +using Distributions: Beta, Normal, MvNormal # A simple 1-D normal likelihood: μ ~ N(0,1), y | μ ~ N(μ, 0.5) # Posterior: μ | y=1.5 is N(μ_post, σ_post²) @@ -30,11 +29,19 @@ end y ~ MvNormal(μ, 0.5 * I) end +@model function beta_model() + x ~ Beta(2, 2) +end + @testset "LogDensityProblemsExt: param_names kwarg" begin - ld = DynamicPPL.LogDensityFunction(normal_model(TRUE_OBS)) - ldg = LogDensityProblemsAD.ADgradient(ADTypes.AutoEnzyme(), ld) + ld = DynamicPPL.LogDensityFunction( + normal_model(TRUE_OBS), + DynamicPPL.getlogjoint_internal, + DynamicPPL.LinkAll(); + adtype=ADTypes.AutoEnzyme(), + ) - model = DensityModel(ldg; param_names=[:μ]) + model = DensityModel(ld; param_names=[:μ]) @test model.dim == 1 @test model.param_names == [:μ] @@ -60,6 +67,15 @@ end @test isfinite(model.grad_logdensity([0.0])[1]) end +@testset "DynamicPPLExt: convenience constructor uses linked space for constrained models" begin + model = DensityModel(beta_model()) + + @test model.dim == 1 + @test model.param_names == [:x] + @test isfinite(model.logdensity([-0.4])) + @test isfinite(model.grad_logdensity([-0.4])[1]) +end + @testset "DynamicPPLExt: generic Turing model works with ParallelMALA and default Enzyme HVP" begin model = DensityModel(normal_model(TRUE_OBS))