From 1bffd79fc5a3c887e3b262666c90d857d8f85335 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 21 Apr 2026 12:52:09 +0100 Subject: [PATCH 1/5] Use structural AbstractPPL AD prep --- Project.toml | 7 ++- docs/Project.toml | 5 ++ docs/make.jl | 1 + ext/DynamicPPLMooncakeExt.jl | 19 ++++--- src/logdensityfunction.jl | 84 +++++----------------------- src/test_utils/ad.jl | 1 - test/Project.toml | 3 + test/integration/enzyme/Project.toml | 1 + test/logdensityfunction.jl | 13 +++++ test/runtests.jl | 1 + 10 files changed, 53 insertions(+), 82 deletions(-) diff --git a/Project.toml b/Project.toml index b8923c715..4d1dbef55 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -28,6 +27,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} + [weakdeps] EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -42,7 +44,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLMooncakeExt = ["Mooncake", "DifferentiationInterface"] +DynamicPPLMooncakeExt = ["Mooncake"] DynamicPPLReverseDiffExt = ["ReverseDiff"] [compat] @@ -55,7 +57,6 @@ Bijectors = "0.15.17" Chairmarks = "1.3.1" Compat = "4" ConstructionBase = "1.5.4" -DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" diff --git a/docs/Project.toml b/docs/Project.toml index 288ae162a..0c716af15 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,6 +7,7 @@ BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -22,6 +23,9 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} + [compat] ADTypes = "1" AbstractMCMC = "5" @@ -31,6 +35,7 @@ BangBang = "0.4" Bijectors = "0.15.17" Chairmarks = "1" ChangesOfVariables = "0.1" +DifferentiationInterface = "0.6.41, 0.7" DimensionalData = "0.30" Distributions = "0.25" Documenter = "1" diff --git a/docs/make.jl b/docs/make.jl index 71fbfb2c0..fb705448c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,5 +1,6 @@ using Documenter using DocumenterInterLinks +using DifferentiationInterface using DynamicPPL using AbstractPPL # NOTE: This is necessary to ensure that if we print something from diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index b876df575..7a82e87fc 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,6 +1,7 @@ module DynamicPPLMooncakeExt using DynamicPPL: DynamicPPL, is_transformed +using AbstractPPL: AbstractPPL using Mooncake: Mooncake # These are purely optimisations (although quite significant ones sometimes, especially for @@ -15,17 +16,21 @@ Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ using DynamicPPL: @model, LinkAll, getlogjoint_internal, LogDensityFunction using ADTypes: AutoMooncake -import DifferentiationInterface using Distributions: Normal, InverseGamma, Beta using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin @compile_workload begin - for dist in (Normal(), InverseGamma(2, 3), Beta(2, 2)) - @model f() = x ~ dist - ldf = LogDensityFunction( - f(), getlogjoint_internal, LinkAll(); adtype=AutoMooncake() - ) - DynamicPPL.LogDensityProblems.logdensity_and_gradient(ldf, [0.5]) + # Julia does not guarantee transitive extensions are loaded while this + # extension precompiles, so skip the workload unless Mooncake's + # AbstractPPL methods are already available. + if !isnothing(Base.get_extension(AbstractPPL, :AbstractPPLMooncakeExt)) + for dist in (Normal(), InverseGamma(2, 3), Beta(2, 2)) + @model f() = x ~ dist + ldf = LogDensityFunction( + f(), getlogjoint_internal, LinkAll(); adtype=AutoMooncake() + ) + DynamicPPL.LogDensityProblems.logdensity_and_gradient(ldf, [0.5]) + end end end end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 4a0dbced9..4ee10f4a0 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -23,7 +23,6 @@ using ADTypes: ADTypes using BangBang: BangBang using AbstractPPL: AbstractPPL, VarName using LogDensityProblems: LogDensityProblems -import DifferentiationInterface as DI using Random: Random """ @@ -186,7 +185,7 @@ struct LogDensityFunction{ L<:AbstractTransformStrategy, F, VNT<:VarNamedTuple, - ADP<:Union{Nothing,DI.GradientPrep}, + ADP, # type of the vector passed to logdensity functions X<:AbstractVector, AC<:AccumulatorTuple, @@ -261,12 +260,10 @@ struct LogDensityFunction{ else # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, x) - args = (model, getlogdensity, ranges_and_transforms, transform_strategy, accs) - if _use_closure(adtype) - DI.prepare_gradient(LogDensityAt(args...), adtype, x) - else - DI.prepare_gradient(logdensity_at, adtype, x, map(DI.Constant, args)...) - end + problem = LogDensityAt( + model, getlogdensity, ranges_and_transforms, transform_strategy, accs + ) + AbstractPPL.prepare(adtype, problem, x) end return new{ typeof(model), @@ -450,6 +447,13 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) ) end +# Return the structural problem object itself so AD backends see a stable +# one-argument evaluator. This preserves compiled ReverseDiff tape reuse +# without relying on anonymous closures. +function AbstractPPL.prepare(problem::LogDensityAt, x::AbstractVector{<:AbstractFloat}) + return problem +end + function LogDensityProblems.logdensity( ldf::LogDensityFunction, params::AbstractVector{<:Real} ) @@ -469,32 +473,7 @@ function LogDensityProblems.logdensity_and_gradient( # `params` has to be converted to the same vector type that was used for AD preparation, # otherwise the preparation will not be valid. params = convert(get_input_vector_type(ldf), params) - return if _use_closure(ldf.adtype) - DI.value_and_gradient( - LogDensityAt( - ldf.model, - ldf._getlogdensity, - ldf._varname_ranges, - ldf.transform_strategy, - ldf._accs, - ), - ldf._adprep, - ldf.adtype, - params, - ) - else - DI.value_and_gradient( - logdensity_at, - ldf._adprep, - ldf.adtype, - params, - DI.Constant(ldf.model), - DI.Constant(ldf._getlogdensity), - DI.Constant(ldf._varname_ranges), - DI.Constant(ldf.transform_strategy), - DI.Constant(ldf._accs), - ) - end + return AbstractPPL.value_and_gradient(ldf._adprep, params) end function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} @@ -525,43 +504,6 @@ By default, this just returns the input unchanged. """ tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVector) = adtype -""" - _use_closure(adtype::ADTypes.AbstractADType) - -In LogDensityProblems, we want to calculate the derivative of `logdensity(f, x)` with -respect to x, where f is the model (in our case LogDensityFunction or its arguments ) and is -a constant. However, DifferentiationInterface generally expects a single-argument function -g(x) to differentiate. - -There are two ways of dealing with this: - -1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f) - -2. Use a constant DI.Context. This lets us pass a two-argument function to DI, as long as we - also give it the 'inactive argument' (i.e. the model) wrapped in `DI.Constant`. - -The relative performance of the two approaches, however, depends on the AD backend used. -Some benchmarks are provided here: https://github.com/TuringLang/DynamicPPL.jl/pull/1172 - -This function is used to determine whether a given AD backend should use a closure or a -constant. If `use_closure(adtype)` returns `true`, then the closure approach will be used. -By default, this function returns `false`, i.e. the constant approach will be used. -""" -# For these AD backends both closure and no closure work, but it is just faster to not use a -# closure (see link in the docstring). -_use_closure(::ADTypes.AutoForwardDiff) = false -_use_closure(::ADTypes.AutoMooncake) = false -_use_closure(::ADTypes.AutoMooncakeForward) = false -# For ReverseDiff, with the compiled tape, you _must_ use a closure because otherwise with -# DI.Constant arguments the tape will always be recompiled upon each call to -# value_and_gradient. For non-compiled ReverseDiff, it is faster to not use a closure. -_use_closure(::ADTypes.AutoReverseDiff{compile}) where {compile} = compile -# For AutoEnzyme it allows us to avoid setting function_annotation -_use_closure(::ADTypes.AutoEnzyme) = false -# Since for most backends it's faster to not use a closure, we set that as the default -# for unknown AD backends -_use_closure(::ADTypes.AbstractADType) = false - ###################################################### # Helper functions to extract ranges and link status # ###################################################### diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 434540cd5..4c145943d 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -2,7 +2,6 @@ module AD using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be -import DifferentiationInterface as DI using DocStringExtensions using DynamicPPL: DynamicPPL, diff --git a/test/Project.toml b/test/Project.toml index 73cff23ed..f8cc0e688 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -30,6 +30,9 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} + [compat] ADTypes = "1" AbstractMCMC = "5.10" diff --git a/test/integration/enzyme/Project.toml b/test/integration/enzyme/Project.toml index c26655fae..a0d578bfa 100644 --- a/test/integration/enzyme/Project.toml +++ b/test/integration/enzyme/Project.toml @@ -7,3 +7,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [sources] DynamicPPL = {path = "../../../"} +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 0b956dfc1..feadba2c1 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -485,6 +485,19 @@ end end end + @testset "ReverseDiff compiled prep is retained" begin + @model f() = x ~ Normal() + ldf = LogDensityFunction( + f(), getlogjoint_internal, LinkAll(); adtype=AutoReverseDiff(; compile=true) + ) + x = rand(ldf) + + @test hasfield(typeof(ldf._adprep.prep), :tape) + @test getfield(ldf._adprep.prep, :tape) !== nothing + @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Tuple + @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Tuple + end + # Test that various different ways of specifying array types as arguments work with all # ADTypes. @testset "Array argument types" begin diff --git a/test/runtests.jl b/test/runtests.jl index 1ba744c3f..edde98144 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using Documenter: Documenter +using DifferentiationInterface using DynamicPPL: DynamicPPL using Random: Random using Test: @testset, @test_throws From 663550fb5bd5dc2372a5673847bb37c70a9a81c6 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 21 Apr 2026 13:40:17 +0100 Subject: [PATCH 2/5] Fix Enzyme test project source pin --- test/integration/enzyme/Project.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/integration/enzyme/Project.toml b/test/integration/enzyme/Project.toml index a0d578bfa..c673319b1 100644 --- a/test/integration/enzyme/Project.toml +++ b/test/integration/enzyme/Project.toml @@ -1,10 +1,11 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [sources] -DynamicPPL = {path = "../../../"} -AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} +AbstractPPL = {rev = "evaluator-interface", url = "https://github.com/TuringLang/AbstractPPL.jl"} +DynamicPPL = {path = "../../.."} From dea9bace1358e8146add461a52854a37d7a19377 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 21 Apr 2026 20:43:02 +0100 Subject: [PATCH 3/5] Fix docstring typo, stale variable name, and redundant comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - logdensityfunction.jl: `fix_transform` → `fix_transforms` in docstring signature; tighten two internal comments - MarginalLogDensitiesExt: `oavi` → `accs` in example code; fix article typo "the an" → "an" - MCMCChainsExt: remove redundant comment after `|| continue` guard - vector_values.jl: collapse two-line comment to one (drop WHAT line) Co-Authored-By: Claude Sonnet 4.6 --- ext/DynamicPPLMCMCChainsExt.jl | 1 - ext/DynamicPPLMarginalLogDensitiesExt.jl | 4 ++-- src/accumulators/vector_values.jl | 4 +--- src/logdensityfunction.jl | 16 +++++++--------- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d7987ffa5..6f3aee672 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -141,7 +141,6 @@ function AbstractMCMC.to_samples( # which is why it still pops up in get_varnames). We can check for that and skip # if it's no longer there. has_varname(chain, vn) || continue - # If it's there then we can add it into the VNT. top_sym = AbstractPPL.getsym(vn) val = getindex_varname(chain, sample_idx, vn, chain_idx) # This call to get() is type unstable, but I tried writing a generated function diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 905ee168b..ae5f28444 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -144,7 +144,7 @@ accs = DynamicPPL.OnlyAccsVarInfo(( DynamicPPL.RawValueAccumulator(false), # ... whatever else you need )) -_, accs = DynamicPPL.init!!(rng, model, oavi, init_strategy, DynamicPPL.UnlinkAll()) +_, accs = DynamicPPL.init!!(rng, model, accs, init_strategy, DynamicPPL.UnlinkAll()) ``` You can then extract all the updated data from `accs` using DynamicPPL's existing API (see @@ -178,7 +178,7 @@ retcode: Success u: 1-element Vector{Float64}: 4.88281250001733e-5 -julia> # Get the an initialisation strategy representing the mode of `y`. +julia> # Get an initialisation strategy representing the mode of `y`. init_strategy = InitFromVector(mld, opt_solution.u); julia> # Evaluate the model with this initialisation strategy. diff --git a/src/accumulators/vector_values.jl b/src/accumulators/vector_values.jl index 1973b03c4..3f5b1c04f 100644 --- a/src/accumulators/vector_values.jl +++ b/src/accumulators/vector_values.jl @@ -8,9 +8,7 @@ Generate a `TransformedValue` that always has a vector as its stored value. function _get_vector_tval( val, tval::TransformedValue{V,T}, logjac, vn, dist ) where {V<:AbstractVector,T} - # If it's already an AbstractVector transformed value, then we are done. - # `tval.transform` could be a DynamicLink(), Unlink(), or some fixed transform that - # vectorises; it doesn't matter. + # `tval.transform` could be DynamicLink(), Unlink(), or a fixed vectorising transform. return tval end function _get_vector_tval(val, tval::TransformedValue{V,T}, logjac, vn, dist) where {V,T} diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 4ee10f4a0..bc1bd42ec 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -32,7 +32,7 @@ using Random: Random vi_vnt_or_tfm_strategy=_default_vnt(model, UnlinkAll()), accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=DynamicPPL.ldf_accs(getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, - fix_transform::Bool=false, + fix_transforms::Bool=false, ) A struct which contains a model, along with all the information necessary to: @@ -216,10 +216,8 @@ struct LogDensityFunction{ # `dynamic_transform_strategy` might be LinkAll() or UnlinkAll(), for example. We # might need to convert this to a set of fixed transforms. transform_strategy = if fix_transforms - # Reevaluate model again to determine the fixed transforms. This is kind of - # wasteful: for example, we could tie this model evaluation to one of the - # previous ones, but it's fine, since it's only done once in the LDF - # constructor. + # One extra model evaluation to cache the fixed transforms. Only happens once + # at construction time. transforms_vnt = get_fixed_transforms(model, dynamic_transform_strategy) fixed_transform_strategy = WithTransforms( transforms_vnt, dynamic_transform_strategy @@ -233,8 +231,7 @@ struct LogDensityFunction{ end ranges_and_transforms = get_rangeandtransforms(vnt) - # Determine whether all transforms are fixed. This enables fast parameter - # extraction in ParamsWithStats without model re-evaluation. + # All-fixed enables the model-free fast path in ParamsWithStats. all_fixed = all( rat -> rat.transform isa FixedTransform, values(ranges_and_transforms) ) @@ -252,7 +249,6 @@ struct LogDensityFunction{ trial_x = internal_values_as_vector(vnt) dim, et = length(trial_x), eltype(trial_x) x = to_vector_params_inner(vnt, ranges_and_transforms, et, dim) - # convert to AccumulatorTuple if needed accs = AccumulatorTuple(accs) # Do AD prep if needed prep = if adtype === nothing @@ -263,7 +259,9 @@ struct LogDensityFunction{ problem = LogDensityAt( model, getlogdensity, ranges_and_transforms, transform_strategy, accs ) - AbstractPPL.prepare(adtype, problem, x) + # `x` was just constructed from the same range metadata stored in `problem`, + # so the AD wrapper can skip its own hot-path dimension validation. + AbstractPPL.prepare(adtype, problem, x; check_dims=false) end return new{ typeof(model), From c58fe2408e98db5d885d1e916d21cea3af4ad89b Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 21 Apr 2026 22:12:57 +0100 Subject: [PATCH 4/5] Normalize scalar Enzyme gradients in AD tests --- src/test_utils/ad.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 4c145943d..01c99ef42 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -339,8 +339,7 @@ function run_ad( # Calculate log-density and gradient with the backend of interest value, grad = logdensity_and_gradient(ldf, params) - # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 - grad = collect(grad) + grad = vec(collect(grad)) verbose && println(" actual : $((value, grad))") # Test correctness @@ -357,8 +356,7 @@ function run_ad( model, getlogdensity, transform_strategy; adtype=test.adtype ) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) - # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 - grad_true = collect(grad_true) + grad_true = vec(collect(grad_true)) end # Perform testing verbose && println(" expected : $((value_true, grad_true))") From 378f974319d1e92114cb1b8d4796854957c3b84a Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 22 Apr 2026 15:51:17 +0100 Subject: [PATCH 5/5] Refine LogDensityFunction AD prep and tests --- benchmarks/Project.toml | 2 +- docs/Project.toml | 2 -- docs/make.jl | 1 - ext/DynamicPPLMCMCChainsExt.jl | 1 + src/accumulators/vector_values.jl | 4 +++- src/logdensityfunction.jl | 19 ++++++++----------- src/test_utils/ad.jl | 2 -- test/Project.toml | 4 +++- test/floattypes/Project.toml | 2 +- test/logdensityfunction.jl | 31 ++++++++++++++++++++++++------- 10 files changed, 41 insertions(+), 27 deletions(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 6adb27efa..9a2b2afcc 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -18,7 +18,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [sources] -DynamicPPL = {path = "../"} +DynamicPPL = {path = ".."} [compat] ADTypes = "1.14.0" diff --git a/docs/Project.toml b/docs/Project.toml index 0c716af15..38261c0a4 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,7 +7,6 @@ BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -35,7 +34,6 @@ BangBang = "0.4" Bijectors = "0.15.17" Chairmarks = "1" ChangesOfVariables = "0.1" -DifferentiationInterface = "0.6.41, 0.7" DimensionalData = "0.30" Distributions = "0.25" Documenter = "1" diff --git a/docs/make.jl b/docs/make.jl index fb705448c..71fbfb2c0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,5 @@ using Documenter using DocumenterInterLinks -using DifferentiationInterface using DynamicPPL using AbstractPPL # NOTE: This is necessary to ensure that if we print something from diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 6f3aee672..d7987ffa5 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -141,6 +141,7 @@ function AbstractMCMC.to_samples( # which is why it still pops up in get_varnames). We can check for that and skip # if it's no longer there. has_varname(chain, vn) || continue + # If it's there then we can add it into the VNT. top_sym = AbstractPPL.getsym(vn) val = getindex_varname(chain, sample_idx, vn, chain_idx) # This call to get() is type unstable, but I tried writing a generated function diff --git a/src/accumulators/vector_values.jl b/src/accumulators/vector_values.jl index 3f5b1c04f..1973b03c4 100644 --- a/src/accumulators/vector_values.jl +++ b/src/accumulators/vector_values.jl @@ -8,7 +8,9 @@ Generate a `TransformedValue` that always has a vector as its stored value. function _get_vector_tval( val, tval::TransformedValue{V,T}, logjac, vn, dist ) where {V<:AbstractVector,T} - # `tval.transform` could be DynamicLink(), Unlink(), or a fixed vectorising transform. + # If it's already an AbstractVector transformed value, then we are done. + # `tval.transform` could be a DynamicLink(), Unlink(), or some fixed transform that + # vectorises; it doesn't matter. return tval end function _get_vector_tval(val, tval::TransformedValue{V,T}, logjac, vn, dist) where {V,T} diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index bc1bd42ec..524f1e676 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -216,8 +216,10 @@ struct LogDensityFunction{ # `dynamic_transform_strategy` might be LinkAll() or UnlinkAll(), for example. We # might need to convert this to a set of fixed transforms. transform_strategy = if fix_transforms - # One extra model evaluation to cache the fixed transforms. Only happens once - # at construction time. + # Reevaluate model again to determine the fixed transforms. This is kind of + # wasteful: for example, we could tie this model evaluation to one of the + # previous ones, but it's fine, since it's only done once in the LDF + # constructor. transforms_vnt = get_fixed_transforms(model, dynamic_transform_strategy) fixed_transform_strategy = WithTransforms( transforms_vnt, dynamic_transform_strategy @@ -231,7 +233,8 @@ struct LogDensityFunction{ end ranges_and_transforms = get_rangeandtransforms(vnt) - # All-fixed enables the model-free fast path in ParamsWithStats. + # Determine whether all transforms are fixed. This enables fast parameter + # extraction in ParamsWithStats without model re-evaluation. all_fixed = all( rat -> rat.transform isa FixedTransform, values(ranges_and_transforms) ) @@ -249,6 +252,7 @@ struct LogDensityFunction{ trial_x = internal_values_as_vector(vnt) dim, et = length(trial_x), eltype(trial_x) x = to_vector_params_inner(vnt, ranges_and_transforms, et, dim) + # convert to AccumulatorTuple if needed accs = AccumulatorTuple(accs) # Do AD prep if needed prep = if adtype === nothing @@ -260,7 +264,7 @@ struct LogDensityFunction{ model, getlogdensity, ranges_and_transforms, transform_strategy, accs ) # `x` was just constructed from the same range metadata stored in `problem`, - # so the AD wrapper can skip its own hot-path dimension validation. + # so the AD wrapper can skip its hot-path dimension validation. AbstractPPL.prepare(adtype, problem, x; check_dims=false) end return new{ @@ -445,13 +449,6 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) ) end -# Return the structural problem object itself so AD backends see a stable -# one-argument evaluator. This preserves compiled ReverseDiff tape reuse -# without relying on anonymous closures. -function AbstractPPL.prepare(problem::LogDensityAt, x::AbstractVector{<:AbstractFloat}) - return problem -end - function LogDensityProblems.logdensity( ldf::LogDensityFunction, params::AbstractVector{<:Real} ) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 01c99ef42..42ac9203d 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -339,7 +339,6 @@ function run_ad( # Calculate log-density and gradient with the backend of interest value, grad = logdensity_and_gradient(ldf, params) - grad = vec(collect(grad)) verbose && println(" actual : $((value, grad))") # Test correctness @@ -356,7 +355,6 @@ function run_ad( model, getlogdensity, transform_strategy; adtype=test.adtype ) value_true, grad_true = logdensity_and_gradient(ldf_reference, params) - grad_true = vec(collect(grad_true)) end # Perform testing verbose && println(" expected : $((value_true, grad_true))") diff --git a/test/Project.toml b/test/Project.toml index f8cc0e688..2c635ed19 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -31,7 +32,8 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [sources] -AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} +AbstractPPL = {rev = "evaluator-interface", url = "https://github.com/TuringLang/AbstractPPL.jl"} +DynamicPPL = {path = ".."} [compat] ADTypes = "1" diff --git a/test/floattypes/Project.toml b/test/floattypes/Project.toml index 02a770fe7..e47e1ebf4 100644 --- a/test/floattypes/Project.toml +++ b/test/floattypes/Project.toml @@ -7,4 +7,4 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [sources] -DynamicPPL = {path = "../../"} +DynamicPPL = {path = "../.."} diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index feadba2c1..5ccdb48a1 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -485,17 +485,34 @@ end end end - @testset "ReverseDiff compiled prep is retained" begin + # Compiled ReverseDiff prep should be observable as lower repeated-call allocations. + @testset "ReverseDiff compiled prep reduces repeated-call allocations" begin @model f() = x ~ Normal() - ldf = LogDensityFunction( + ldf_compiled = LogDensityFunction( f(), getlogjoint_internal, LinkAll(); adtype=AutoReverseDiff(; compile=true) ) - x = rand(ldf) + ldf_uncompiled = LogDensityFunction( + f(), getlogjoint_internal, LinkAll(); adtype=AutoReverseDiff(; compile=false) + ) + params = rand(ldf_compiled) + + LogDensityProblems.logdensity_and_gradient(ldf_compiled, params) + LogDensityProblems.logdensity_and_gradient(ldf_uncompiled, params) + + function repeated_call_allocs(ldf, params) + GC.gc() + before = Base.gc_num() + for _ in 1:100 + LogDensityProblems.logdensity_and_gradient(ldf, params) + end + after = Base.gc_num() + return Base.GC_Diff(after, before).allocd + end + + allocs_compiled = repeated_call_allocs(ldf_compiled, params) + allocs_uncompiled = repeated_call_allocs(ldf_uncompiled, params) - @test hasfield(typeof(ldf._adprep.prep), :tape) - @test getfield(ldf._adprep.prep, :tape) !== nothing - @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Tuple - @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Tuple + @test allocs_compiled < allocs_uncompiled end # Test that various different ways of specifying array types as arguments work with all