From f689bd8324f1aaa7c089ea8343cd20bfd8751d17 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Fri, 24 Apr 2026 16:36:28 +0100 Subject: [PATCH 1/8] Migrate from DifferentiationInterface to AbstractPPL evaluator interface Replace `DifferentiationInterface` with `AbstractPPL.prepare` / `AbstractPPL.value_and_gradient` throughout. Key changes: - `_prepare_gradient` / `_value_and_gradient!` now wrap an `AbstractPPL` prepared evaluator in a `_VIGradPrep` struct that holds an `aux_ref` so auxiliary inputs can be swapped without re-preparing. - `DynamicPPLModelLogDensityFunction` stores `model_ref` and `loglikeadj_ref` as `Ref`s; `subsample` mutates them in-place instead of creating a new struct via `@set`, keeping the prepared evaluator valid across subsampling steps. - Drop second-order (Hessian) support; `use_hessian=true` now warns and is ignored. - Pin dev branches via `[sources]`: `AbstractPPL@evaluator-interface` and `DynamicPPL@adproblems-interface`. Co-Authored-By: Claude Sonnet 4.6 --- Project.toml | 10 +- ext/AdvancedVIDynamicPPLExt.jl | 149 ++++++------------ src/AdvancedVI.jl | 50 +++--- src/algorithms/abstractobjective.jl | 2 +- src/algorithms/common.jl | 2 +- src/algorithms/fisherminbatchmatch.jl | 18 +-- src/algorithms/klminnaturalgraddescent.jl | 10 +- src/algorithms/klminsqrtnaturalgraddescent.jl | 6 +- src/algorithms/klminwassfwdbwd.jl | 6 +- src/algorithms/subsampledobjective.jl | 2 +- src/optimization/rules.jl | 8 +- src/reshuffling.jl | 2 +- test/Project.toml | 6 + test/integration/dynamicppl.jl | 14 +- test/runtests.jl | 1 + 15 files changed, 129 insertions(+), 157 deletions(-) diff --git a/Project.toml b/Project.toml index 28d4a3035..886401ebe 100644 --- a/Project.toml +++ b/Project.toml @@ -4,10 +4,10 @@ version = "0.7.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -29,14 +29,14 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"] AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"] AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"] -AdvancedVIDynamicPPLExt = ["DynamicPPL", "Accessors", "Distributions", "DifferentiationInterface", "LogDensityProblems"] +AdvancedVIDynamicPPLExt = ["DynamicPPL", "Accessors", "Distributions", "LogDensityProblems"] [compat] ADTypes = "1" +AbstractPPL = "0.14" Accessors = "0.1" ChainRulesCore = "1" DiffResults = "1" -DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" DynamicPPL = "0.40, 0.41" @@ -63,3 +63,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Pkg", "Test"] + +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} diff --git a/ext/AdvancedVIDynamicPPLExt.jl b/ext/AdvancedVIDynamicPPLExt.jl index eaff70ae7..08c4bcf4d 100644 --- a/ext/AdvancedVIDynamicPPLExt.jl +++ b/ext/AdvancedVIDynamicPPLExt.jl @@ -3,48 +3,20 @@ module AdvancedVIDynamicPPLExt using ADTypes: ADTypes using Accessors using AdvancedVI: AdvancedVI -using DifferentiationInterface: DifferentiationInterface +using AbstractPPL: AbstractPPL using Distributions: Distributions using DynamicPPL: DynamicPPL using LogDensityProblems: LogDensityProblems using Random -adtype_capabilities(::Type{Nothing}) = LogDensityProblems.LogDensityOrder{0}() +function adtype_capabilities(::Type{Nothing}) + return LogDensityProblems.LogDensityOrder{0}() +end function adtype_capabilities(::Type{<:ADTypes.AbstractADType}) return LogDensityProblems.LogDensityOrder{1}() end -function adtype_capabilities( - ::Type{ - <:Union{ - <:ADTypes.AutoForwardDiff, - <:ADTypes.AutoReverseDiff, - <:ADTypes.AutoMooncake, - <:ADTypes.AutoEnzyme, - <:DifferentiationInterface.SecondOrder, - }, - }, -) - return LogDensityProblems.LogDensityOrder{2}() -end - -struct DynamicPPLModelLogDensityFunction{ - Model<:DynamicPPL.Model, - LogLikeAdj<:Real, - VarInfo<:DynamicPPL.AbstractVarInfo, - ADType<:ADTypes.AbstractADType, - PrepGrad<:Union{Nothing,DifferentiationInterface.GradientPrep}, - PrepHess<:Union{Nothing,DifferentiationInterface.HessianPrep}, -} - model::Model - loglikeadj::LogLikeAdj - varinfo::VarInfo - adtype::ADType - prep_grad::PrepGrad - prep_hess::PrepHess -end - function logdensity_impl( params, model::DynamicPPL.Model, loglikeadj::Real, varinfo::DynamicPPL.AbstractVarInfo ) @@ -63,14 +35,31 @@ function subsample_dynamicpplmodel( return DynamicPPL.Model{Threaded}(model.f, model.args, new_kwargs, model.context) end +struct DynamicPPLModelLogDensityFunction{ + Model<:DynamicPPL.Model, + VarInfo<:DynamicPPL.AbstractVarInfo, + ADType<:Union{Nothing,ADTypes.AbstractADType}, + PrepGrad, +} + model::Model + varinfo::VarInfo + adtype::ADType + # Refs are updated in-place by subsample; the prepared AD evaluator reads + # through them on every call, so the prep remains valid across subsampling. + model_ref::Ref{Any} + loglikeadj_ref::Ref{Float64} + prep_grad::PrepGrad +end + function DynamicPPLModelLogDensityFunction( model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; - use_hessian::Bool=true, + use_hessian::Bool=false, adtype::Union{Nothing,ADTypes.AbstractADType}=nothing, loglikeadj::Real=1.0, subsampling::Union{Nothing,AdvancedVI.AbstractSubsampling}=nothing, ) + use_hessian && @warn "`use_hessian` is no longer supported and will be ignored." model_sub = if isnothing(subsampling) model else @@ -82,92 +71,45 @@ function DynamicPPLModelLogDensityFunction( params = [val for val in varinfo[:]] cap = adtype_capabilities(typeof(adtype)) + + model_ref = Ref{Any}(model_sub) + loglikeadj_ref = Ref{Float64}(float(loglikeadj)) + prep_grad = if cap >= LogDensityProblems.LogDensityOrder{1}() - DifferentiationInterface.prepare_gradient( - logdensity_impl, - DifferentiationInterface.inner(adtype), + AbstractPPL.prepare( + adtype, + params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo), params, - DifferentiationInterface.Constant(model_sub), - DifferentiationInterface.Constant(loglikeadj), - DifferentiationInterface.Constant(varinfo), ) else nothing end - prep_hess = if cap >= LogDensityProblems.LogDensityOrder{2}() && use_hessian - try - DifferentiationInterface.prepare_hessian( - logdensity_impl, - adtype, - params, - DifferentiationInterface.Constant(model_sub), - DifferentiationInterface.Constant(loglikeadj), - DifferentiationInterface.Constant(varinfo), - ) - catch - @warn "The selected AD backend has second-order capabilities but `DifferentiationInterface.prepare_hessian` failed. AdvancedVI will treat the model to only have first-order capability." - nothing - end - else - nothing - end - return DynamicPPLModelLogDensityFunction{ - typeof(model), - typeof(loglikeadj), - typeof(varinfo), - typeof(adtype), - typeof(prep_grad), - typeof(prep_hess), - }( - model, loglikeadj, varinfo, adtype, prep_grad, prep_hess + + return DynamicPPLModelLogDensityFunction( + model, varinfo, adtype, model_ref, loglikeadj_ref, prep_grad ) end function LogDensityProblems.logdensity(prob::DynamicPPLModelLogDensityFunction, params) - (; model, loglikeadj, varinfo) = prob - return logdensity_impl(params, model, loglikeadj, varinfo) + return logdensity_impl(params, prob.model_ref[], prob.loglikeadj_ref[], prob.varinfo) end function LogDensityProblems.logdensity_and_gradient( prob::DynamicPPLModelLogDensityFunction, params ) - (; model, adtype, loglikeadj, varinfo, prep_grad) = prob - return DifferentiationInterface.value_and_gradient( - logdensity_impl, - prep_grad, - DifferentiationInterface.inner(adtype), - params, - DifferentiationInterface.Constant(model), - DifferentiationInterface.Constant(loglikeadj), - DifferentiationInterface.Constant(varinfo), - ) + return AbstractPPL.value_and_gradient(prob.prep_grad, params) end -function LogDensityProblems.logdensity_gradient_and_hessian( - prob::DynamicPPLModelLogDensityFunction, params -) - (; model, adtype, loglikeadj, varinfo, prep_hess) = prob - return DifferentiationInterface.value_gradient_and_hessian( - logdensity_impl, - prep_hess, - adtype, - params, - DifferentiationInterface.Constant(model), - DifferentiationInterface.Constant(loglikeadj), - DifferentiationInterface.Constant(varinfo), - ) +function LogDensityProblems.capabilities( + ::Type{<:DynamicPPLModelLogDensityFunction{M,V,Nothing,G}} +) where {M,V,G} + return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:DynamicPPLModelLogDensityFunction{M,L,V,ADType,PG,PH}} -) where {M,L,V,ADType<:ADTypes.AbstractADType,PG,PH} - return if PH != Nothing - LogDensityProblems.LogDensityOrder{2}() - elseif PG != Nothing - LogDensityProblems.LogDensityOrder{1}() - else - LogDensityProblems.LogDensityOrder{0}() - end + ::Type{<:DynamicPPLModelLogDensityFunction{M,V,<:ADTypes.AbstractADType,G}} +) where {M,V,G} + return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.dimension(prob::DynamicPPLModelLogDensityFunction) @@ -180,7 +122,7 @@ function AdvancedVI.subsample(prob::DynamicPPLModelLogDensityFunction, batch) if !haskey(model.defaults, :datapoints) throw( ArgumentError( - "Subsampling is turned on, but the model does not have have a `datapoints` keyword argument.", + "Subsampling is turned on, but the model does not have a `datapoints` keyword argument.", ), ) end @@ -190,9 +132,10 @@ function AdvancedVI.subsample(prob::DynamicPPLModelLogDensityFunction, batch) model_sub = subsample_dynamicpplmodel(model, batch) loglikeadj = n_datapoints / batchsize - prob′ = @set prob.model = model_sub - prob′′ = @set prob′.loglikeadj = loglikeadj - return prob′′ + prob.model_ref[] = model_sub + prob.loglikeadj_ref[] = loglikeadj + + return prob end end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 1d07fd975..a0ce225de 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -17,13 +17,20 @@ using LogDensityProblems using ADTypes using DiffResults -using DifferentiationInterface +using AbstractPPL: AbstractPPL using ChainRulesCore: ChainRulesCore using FillArrays using StatsBase +# Holds the AbstractPPL prepared evaluator together with the aux Ref so that +# _value_and_gradient! can update aux before every evaluation. +struct _VIGradPrep{P,R} + prepared::P + aux_ref::R +end + # Derivatives """ _value_and_gradient!(f, out, ad, x, aux) @@ -33,9 +40,9 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif `f` may receive auxiliary input as `f(x,aux)`. # Arguments -- `ad::ADTypes.AbstractADType`: +- `ad::ADTypes.AbstractADType`: automatic differentiation backend. Currently supports - `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`, + `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`, `ADTypes.AutoMooncake()` and `ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), @@ -45,31 +52,36 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif - `f`: Function subject to differentiation. - `x`: The point to evaluate the gradient. - `aux`: Auxiliary input passed to `f`. -- `prep`: Output of `DifferentiationInterface.prepare_gradient`. +- `prep`: Output of `_prepare_gradient`. - `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ function _value_and_gradient!( f, out::DiffResults.MutableDiffResult, ad::ADTypes.AbstractADType, x, aux ) - grad_buf = DiffResults.gradient(out) - y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux)) - DiffResults.value!(out, y) + prepared = AbstractPPL.prepare(ad, Base.Fix2(f, aux), x) + val, grad = AbstractPPL.value_and_gradient(prepared, x) + DiffResults.value!(out, val) + copyto!(DiffResults.gradient(out), grad) return out end function _value_and_gradient!( - f, out::DiffResults.MutableDiffResult, prep, ad::ADTypes.AbstractADType, x, aux + f, + out::DiffResults.MutableDiffResult, + prep::_VIGradPrep, + ad::ADTypes.AbstractADType, + x, + aux, ) - grad_buf = DiffResults.gradient(out) - y, _ = DifferentiationInterface.value_and_gradient!( - f, grad_buf, prep, ad, x, Constant(aux) - ) - DiffResults.value!(out, y) + prep.aux_ref[] = aux + val, grad = AbstractPPL.value_and_gradient(prep.prepared, x) + DiffResults.value!(out, val) + copyto!(DiffResults.gradient(out), grad) return out end """ - _prepare_gradient!(f, ad, x, aux) + _prepare_gradient(f, ad, x, aux) Prepare AD backend for taking gradients of a function `f` at `x` using the automatic differentiation backend `ad`. @@ -80,7 +92,9 @@ Prepare AD backend for taking gradients of a function `f` at `x` using the autom - `aux`: Auxiliary input passed to `f`. """ function _prepare_gradient(f, ad::ADTypes.AbstractADType, x, aux) - return DifferentiationInterface.prepare_gradient(f, ad, x, Constant(aux)) + aux_ref = Ref(aux) + prepared = AbstractPPL.prepare(ad, x -> f(x, aux_ref[]), x) + return _VIGradPrep(prepared, aux_ref) end """ @@ -238,7 +252,7 @@ function step( objargs...; kwargs..., ) - nothing + return nothing end """ @@ -276,11 +290,11 @@ Please refer to the respective documentation of each algorithm for more info. function estimate_objective( ::Random.AbstractRNG, ::AbstractVariationalAlgorithm, q, prob; kwargs... ) - nothing + return nothing end function estimate_objective(alg::AbstractVariationalAlgorithm, q, prob; kwargs...) - estimate_objective(Random.default_rng(), alg, q, prob; kwargs...) + return estimate_objective(Random.default_rng(), alg, q, prob; kwargs...) end export estimate_objective diff --git a/src/algorithms/abstractobjective.jl b/src/algorithms/abstractobjective.jl index 65316c7e7..ff027bb54 100644 --- a/src/algorithms/abstractobjective.jl +++ b/src/algorithms/abstractobjective.jl @@ -31,7 +31,7 @@ function init( ::Any, ::Any, ) - nothing + return nothing end """ diff --git a/src/algorithms/common.jl b/src/algorithms/common.jl index 0b99ff0d0..96187fe6a 100644 --- a/src/algorithms/common.jl +++ b/src/algorithms/common.jl @@ -116,5 +116,5 @@ function step( ) info = !isnothing(info′) ? merge(info′, info) : info end - state, false, info + return state, false, info end diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index b794a12af..f3ea6fd8d 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -89,14 +89,14 @@ function rand_batch_match_samples_with_objective!( μ = q.location C = q.scale u = Random.randn!(rng, u_buf) - z = C*u .+ μ + z = C * u .+ μ logπ_sum = zero(eltype(μ)) for b in 1:n_samples logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b)) grad_buf[:, b] = gb logπ_sum += logπb end - logπ_avg = logπ_sum/n_samples + logπ_avg = logπ_sum / n_samples # Estimate objective values # @@ -105,7 +105,7 @@ function rand_batch_match_samples_with_objective!( # = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ) # = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2] # = E[| -u - C'∇logπ(z)) |^2] - fisher = sum(abs2, -u_buf - (C'*grad_buf))/n_samples + fisher = sum(abs2, -u_buf - (C' * grad_buf)) / n_samples return u_buf, z, grad_buf, fisher, logπ_avg end @@ -145,13 +145,13 @@ function step( gbar, Γ = mean_and_cov(grad_buf, 2) μmz = μ - zbar - λ = convert(eltype(μ), d*n_samples / iteration) + λ = convert(eltype(μ), d * n_samples / iteration) - U = Symmetric(λ*Γ + (λ/(1 + λ)*gbar)*gbar') - V = Symmetric(Σ + λ*C + (λ/(1 + λ)*μmz)*μmz') + U = Symmetric(λ * Γ + (λ / (1 + λ) * gbar) * gbar') + V = Symmetric(Σ + λ * C + (λ / (1 + λ) * μmz) * μmz') - Σ′ = Hermitian(2*V/(I + real(sqrt(I + 4*U*V)))) - μ′ = 1/(1 + λ)*μ + λ/(1 + λ)*(Σ′*gbar + zbar) + Σ′ = Hermitian(2 * V / (I + real(sqrt(I + 4 * U * V)))) + μ′ = 1 / (1 + λ) * μ + λ / (1 + λ) * (Σ′ * gbar + zbar) q′ = MvLocationScale(μ′[:, 1], cholesky(Σ′).L, q.dist) elbo = logπ_avg + entropy(q) @@ -163,7 +163,7 @@ function step( info′ = callback(; rng, iteration, q, state) info = !isnothing(info′) ? merge(info′, info) : info end - state, false, info + return state, false, info end """ diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index c101537a4..596cd372c 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -81,10 +81,10 @@ function init( grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) scale = q_init.scale - qcov = Hermitian(scale*scale') + qcov = Hermitian(scale * scale') scale_inv = inv(scale) prec_chol = scale_inv' - prec = Hermitian(prec_chol*prec_chol') + prec = Hermitian(prec_chol * prec_chol') return KLMinNaturalGradDescentState( q_init, prob, prec, qcov, 0, sub_st, grad_buf, hess_buf ) @@ -127,7 +127,7 @@ function step( # Handling the positive-definite constraint in the Bayesian learning rule. # In ICML 2020. G_hat = S - (-hess_buf) - Hermitian(S - η*G_hat + η^2/2*G_hat*qcov*G_hat) + Hermitian(S - η * G_hat + η^2 / 2 * G_hat * qcov * G_hat) else Hermitian(((1 - η) * S + η * (-hess_buf))) end @@ -136,7 +136,7 @@ function step( prec_chol = cholesky(S′).L prec_chol_inv = inv(prec_chol) scale = prec_chol_inv' - qcov = Hermitian(scale*scale') + qcov = Hermitian(scale * scale') q′ = MvLocationScale(m′, scale, q.dist) state = KLMinNaturalGradDescentState( @@ -149,7 +149,7 @@ function step( info′ = callback(; rng, iteration, q=q′, info) info = !isnothing(info′) ? merge(info′, info) : info end - state, false, info + return state, false, info end """ diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index 622c111c3..3b29c3b2b 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -105,8 +105,8 @@ function step( rng, q, n_samples, grad_buf, hess_buf, prob_sub ) - CtHCmI = C'*(-hess_buf)*C - I - CtHCmI_tril = LowerTriangular(tril(CtHCmI) - Diagonal(diag(CtHCmI))/2) + CtHCmI = C' * (-hess_buf) * C - I + CtHCmI_tril = LowerTriangular(tril(CtHCmI) - Diagonal(diag(CtHCmI)) / 2) m′ = m - η * C * (C' * -grad_buf) C′ = C - η * C * CtHCmI_tril @@ -123,7 +123,7 @@ function step( info′ = callback(; rng, iteration, q=q′, info) info = !isnothing(info′) ? merge(info′, info) : info end - state, false, info + return state, false, info end """ diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 602f4be41..40577bfc4 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -104,10 +104,10 @@ function step( m′ = m - η * (-grad_buf) M = I - η * (-hess_buf') - Σ_half = Hermitian(M*Σ*M') + Σ_half = Hermitian(M * Σ * M') # Compute the JKO proximal operator - Σ′ = (Σ_half + 2*η*I + sqrt(Hermitian(Σ_half*(Σ_half + 4*η*I))))/2 + Σ′ = (Σ_half + 2 * η * I + sqrt(Hermitian(Σ_half * (Σ_half + 4 * η * I)))) / 2 q′ = MvLocationScale(m′, cholesky(Σ′).L, q.dist) state = KLMinWassFwdBwdState(q′, prob, Σ′, iteration, sub_st′, grad_buf, hess_buf) @@ -118,7 +118,7 @@ function step( info′ = callback(; rng, iteration, q=q′, info) info = !isnothing(info′) ? merge(info′, info) : info end - state, false, info + return state, false, info end """ diff --git a/src/algorithms/subsampledobjective.jl b/src/algorithms/subsampledobjective.jl index 21734f7ca..ad1f858ef 100644 --- a/src/algorithms/subsampledobjective.jl +++ b/src/algorithms/subsampledobjective.jl @@ -32,7 +32,7 @@ function init( sub_st = init(rng, subsampling) # This is necessary to ensure that `init` sees the type "conditioned" on a minibatch - # when calling `DifferentiationInterface.prepare_*` inside it. + # so that any prepared AD evaluator inside it sees the correct batch-subsampled type. batch, _, _ = step(rng, subsampling, sub_st, true) prob_sub = subsample(prob, batch) q_init_sub = subsample(q_init, batch) diff --git a/src/optimization/rules.jl b/src/optimization/rules.jl index 2025632e3..5015f4ebe 100644 --- a/src/optimization/rules.jl +++ b/src/optimization/rules.jl @@ -18,7 +18,9 @@ Optimisers.@def struct DoWG <: Optimisers.AbstractRule alpha = 1e-6 end -Optimisers.init(o::DoWG, x::AbstractArray{T}) where {T} = (copy(x), zero(T), T(o.alpha)*(1 + norm(x))) +function Optimisers.init(o::DoWG, x::AbstractArray{T}) where {T} + return (copy(x), zero(T), T(o.alpha) * (1 + norm(x))) +end function Optimisers.apply!(::DoWG, state, x::AbstractArray{T}, dx) where {T} x0, v, r = state @@ -47,7 +49,9 @@ Optimisers.@def struct DoG <: Optimisers.AbstractRule alpha = 1e-6 end -Optimisers.init(o::DoG, x::AbstractArray{T}) where {T} = (copy(x), zero(T), T(o.alpha)*(1 + norm(x))) +function Optimisers.init(o::DoG, x::AbstractArray{T}) where {T} + return (copy(x), zero(T), T(o.alpha) * (1 + norm(x))) +end function Optimisers.apply!(::DoG, state, x::AbstractArray{T}, dx) where {T} x0, v, r = state diff --git a/src/reshuffling.jl b/src/reshuffling.jl index e0e50cfbd..dd7f95aa8 100644 --- a/src/reshuffling.jl +++ b/src/reshuffling.jl @@ -49,7 +49,7 @@ function step( # Ignore the trailing batch if its size is smaller than `batchsize`. # This should only be used when estimating gradients during optimization. # This is necessary to ensure that all batches have the same size. - # Otherwise, `DifferentiationInterface.prepare_*` behaves incorrectly. + # Otherwise, prepared AD evaluators may see inconsistent batch sizes. (sub_step, batch), iterator = Iterators.peel(iterator) end epoch = epoch + 1 diff --git a/test/Project.toml b/test/Project.toml index ce8ea455b..6aad881ab 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,7 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -23,6 +25,10 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} + [compat] ADTypes = "0.2.1, 1" DiffResults = "1" diff --git a/test/integration/dynamicppl.jl b/test/integration/dynamicppl.jl index 43a38d3b8..a3ed01d20 100644 --- a/test/integration/dynamicppl.jl +++ b/test/integration/dynamicppl.jl @@ -1,7 +1,7 @@ @testset "DynamicPPL" begin DynamicPPL.@model function normal(μ) - x ~ MvNormal(μ, I) + return x ~ MvNormal(μ, I) end DynamicPPL.@model function normal_subsampled(μs; datapoints=1:size(μs, 2)) @@ -22,18 +22,18 @@ alg = KLMinRepGradProxDescent(AD) d = LogDensityProblems.dimension(prob) - q0 = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d))) + q0 = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6 * I, d, d))) q, _, _ = AdvancedVI.optimize(alg, 1000, prob, q0; show_progress=false) Δλ0 = sum(abs2, q0.location - μ_true) Δλ = sum(abs2, q.location - μ_true) - @test Δλ ≤ Δλ0/2 + @test Δλ ≤ Δλ0 / 2 end @testset "subsampling" begin n_data = 32 - μs = 3*randn(2, n_data) - μ_true = mean(μs, dims=2)[:, 1] + μs = 3 * randn(2, n_data) + μ_true = mean(μs; dims=2)[:, 1] model = normal_subsampled(μs) vi = DynamicPPL.VarInfo(model) @@ -48,11 +48,11 @@ alg = KLMinRepGradProxDescent(AD; subsampling) d = LogDensityProblems.dimension(prob) - q0 = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d))) + q0 = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6 * I, d, d))) q, _, _ = AdvancedVI.optimize(alg, 1000, prob, q0; show_progress=false) Δλ0 = sum(abs2, q0.location - μ_true) Δλ = sum(abs2, q.location - μ_true) - @test Δλ ≤ Δλ0/2 + @test Δλ ≤ Δλ0 / 2 end end diff --git a/test/runtests.jl b/test/runtests.jl index e1e230eb9..4fc054147 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,7 @@ using Random, StableRNGs using Statistics using StatsBase +using DifferentiationInterface using AdvancedVI const PROGRESS = haskey(ENV, "PROGRESS") From 79cc048af6d30055a72aa202896d2e5ce7263999 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 19 May 2026 18:02:56 +0100 Subject: [PATCH 2/8] Bump AbstractPPL@0.15, add LDP+Hessian on prepared evaluators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Switch `value_and_gradient` → `value_and_gradient!!` per AbstractPPL 0.15. - In AdvancedVI core, give every `AbstractPPL.Prepared{<:AbstractADType,<:VectorEvaluator}` a `LogDensityProblems` interface (order-1 fallback) so any AD-backed prep acts as a `LogDensityProblem` without backend-specific wiring. - In `AdvancedVIMooncakeExt`, promote Mooncake-prepared evaluators to order-2 and add `logdensity_gradient_and_hessian` via forward-over-reverse Mooncake. - Simplify `DynamicPPLModelLogDensityFunction` to delegate LDP calls to its inner prep; `capabilities` reads off the prep's own capability so the Hessian branch is exposed exactly when the backend supports it. - Bump compat: AbstractPPL 0.15, DynamicPPL 0.42 (with branch pin until release); add Bijectors branch pin in `test/` for the same reason. Co-Authored-By: Claude Opus 4.7 (1M context) --- Project.toml | 5 ++- ext/AdvancedVIDynamicPPLExt.jl | 56 +++++++++++++++------------------- ext/AdvancedVIMooncakeExt.jl | 22 +++++++++++++ src/AdvancedVI.jl | 36 +++++++++++++++++++--- test/Project.toml | 5 +-- 5 files changed, 83 insertions(+), 41 deletions(-) diff --git a/Project.toml b/Project.toml index 886401ebe..8cffa4dcf 100644 --- a/Project.toml +++ b/Project.toml @@ -33,13 +33,13 @@ AdvancedVIDynamicPPLExt = ["DynamicPPL", "Accessors", "Distributions", "LogDensi [compat] ADTypes = "1" -AbstractPPL = "0.14" +AbstractPPL = "0.15" Accessors = "0.1" ChainRulesCore = "1" DiffResults = "1" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" -DynamicPPL = "0.40, 0.41" +DynamicPPL = "0.42" Enzyme = "0.13" FillArrays = "1.3" Functors = "0.4, 0.5" @@ -65,5 +65,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" test = ["Pkg", "Test"] [sources] -AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} diff --git a/ext/AdvancedVIDynamicPPLExt.jl b/ext/AdvancedVIDynamicPPLExt.jl index 08c4bcf4d..6e5f86eef 100644 --- a/ext/AdvancedVIDynamicPPLExt.jl +++ b/ext/AdvancedVIDynamicPPLExt.jl @@ -1,22 +1,12 @@ module AdvancedVIDynamicPPLExt using ADTypes: ADTypes -using Accessors using AdvancedVI: AdvancedVI using AbstractPPL: AbstractPPL -using Distributions: Distributions using DynamicPPL: DynamicPPL using LogDensityProblems: LogDensityProblems using Random -function adtype_capabilities(::Type{Nothing}) - return LogDensityProblems.LogDensityOrder{0}() -end - -function adtype_capabilities(::Type{<:ADTypes.AbstractADType}) - return LogDensityProblems.LogDensityOrder{1}() -end - function logdensity_impl( params, model::DynamicPPL.Model, loglikeadj::Real, varinfo::DynamicPPL.AbstractVarInfo ) @@ -35,31 +25,29 @@ function subsample_dynamicpplmodel( return DynamicPPL.Model{Threaded}(model.f, model.args, new_kwargs, model.context) end +# `LogDensityProblems.capabilities` and the gradient/Hessian methods dispatch +# off `Prep`, so the AD backend's `Prepared` type drives the LDP capability. struct DynamicPPLModelLogDensityFunction{ Model<:DynamicPPL.Model, VarInfo<:DynamicPPL.AbstractVarInfo, ADType<:Union{Nothing,ADTypes.AbstractADType}, - PrepGrad, + Prep, } model::Model varinfo::VarInfo adtype::ADType - # Refs are updated in-place by subsample; the prepared AD evaluator reads - # through them on every call, so the prep remains valid across subsampling. model_ref::Ref{Any} loglikeadj_ref::Ref{Float64} - prep_grad::PrepGrad + prep::Prep end function DynamicPPLModelLogDensityFunction( model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; - use_hessian::Bool=false, adtype::Union{Nothing,ADTypes.AbstractADType}=nothing, loglikeadj::Real=1.0, subsampling::Union{Nothing,AdvancedVI.AbstractSubsampling}=nothing, ) - use_hessian && @warn "`use_hessian` is no longer supported and will be ignored." model_sub = if isnothing(subsampling) model else @@ -69,24 +57,20 @@ function DynamicPPLModelLogDensityFunction( subsample_dynamicpplmodel(model, batch) end - params = [val for val in varinfo[:]] - cap = adtype_capabilities(typeof(adtype)) + params = collect(varinfo[:]) model_ref = Ref{Any}(model_sub) loglikeadj_ref = Ref{Float64}(float(loglikeadj)) - prep_grad = if cap >= LogDensityProblems.LogDensityOrder{1}() - AbstractPPL.prepare( - adtype, - params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo), - params, - ) - else + prep = if isnothing(adtype) nothing + else + f = params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo) + AbstractPPL.prepare(adtype, f, params) end return DynamicPPLModelLogDensityFunction( - model, varinfo, adtype, model_ref, loglikeadj_ref, prep_grad + model, varinfo, adtype, model_ref, loglikeadj_ref, prep ) end @@ -97,19 +81,25 @@ end function LogDensityProblems.logdensity_and_gradient( prob::DynamicPPLModelLogDensityFunction, params ) - return AbstractPPL.value_and_gradient(prob.prep_grad, params) + return LogDensityProblems.logdensity_and_gradient(prob.prep, params) +end + +function LogDensityProblems.logdensity_gradient_and_hessian( + prob::DynamicPPLModelLogDensityFunction, params +) + return LogDensityProblems.logdensity_gradient_and_hessian(prob.prep, params) end function LogDensityProblems.capabilities( - ::Type{<:DynamicPPLModelLogDensityFunction{M,V,Nothing,G}} -) where {M,V,G} + ::Type{<:DynamicPPLModelLogDensityFunction{M,V,Nothing,P}} +) where {M,V,P} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:DynamicPPLModelLogDensityFunction{M,V,<:ADTypes.AbstractADType,G}} -) where {M,V,G} - return LogDensityProblems.LogDensityOrder{1}() + ::Type{<:DynamicPPLModelLogDensityFunction{M,V,A,P}} +) where {M,V,A<:ADTypes.AbstractADType,P} + return LogDensityProblems.capabilities(P) end function LogDensityProblems.dimension(prob::DynamicPPLModelLogDensityFunction) @@ -132,6 +122,8 @@ function AdvancedVI.subsample(prob::DynamicPPLModelLogDensityFunction, batch) model_sub = subsample_dynamicpplmodel(model, batch) loglikeadj = n_datapoints / batchsize + # Mutates the refs so the previously prepared AD evaluator keeps reading + # the latest batch without needing a re-prepare. prob.model_ref[] = model_sub prob.loglikeadj_ref[] = loglikeadj diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl index 605f77bfa..9b562aad8 100644 --- a/ext/AdvancedVIMooncakeExt.jl +++ b/ext/AdvancedVIMooncakeExt.jl @@ -1,5 +1,8 @@ module AdvancedVIMooncakeExt +using ADTypes: AutoMooncake, AutoMooncakeForward +using AbstractPPL: AbstractPPL +using AbstractPPL.Evaluators: Prepared, VectorEvaluator using AdvancedVI using LogDensityProblems using Mooncake @@ -31,4 +34,23 @@ function Mooncake.rrule!!( return Mooncake.zero_fcodual(ℓπ), logdensity_pb end +const _MooncakePrepared = Prepared{<:AutoMooncake,<:VectorEvaluator} + +# Order-1 LDP methods are inherited from the AbstractADType fallback in +# AdvancedVI core. +function LogDensityProblems.capabilities(::Type{<:_MooncakePrepared}) + LogDensityProblems.LogDensityOrder{2}() +end + +# Mooncake forward-over-reverse Hessian: a fresh forward-mode Jacobian cache +# is built per call, so this is fine for occasional use but costly inside a +# tight per-sample loop. +function LogDensityProblems.logdensity_gradient_and_hessian(p::_MooncakePrepared, x) + val, grad = LogDensityProblems.logdensity_and_gradient(p, x) + grad_fn = y -> LogDensityProblems.logdensity_and_gradient(p, y)[2] + fwd_jac = AbstractPPL.prepare(AutoMooncakeForward(), grad_fn, x) + _, H = AbstractPPL.value_and_jacobian!!(fwd_jac, x) + return val, grad, copy(H) +end + end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index a0ce225de..82c865a5e 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -18,14 +18,15 @@ using LogDensityProblems using ADTypes using DiffResults using AbstractPPL: AbstractPPL +using AbstractPPL.Evaluators: Prepared, VectorEvaluator using ChainRulesCore: ChainRulesCore using FillArrays using StatsBase -# Holds the AbstractPPL prepared evaluator together with the aux Ref so that -# _value_and_gradient! can update aux before every evaluation. +# `aux` is captured by Ref so the same prepared evaluator can be reused after +# aux changes — re-preparing per call would defeat the cache. struct _VIGradPrep{P,R} prepared::P aux_ref::R @@ -59,7 +60,7 @@ function _value_and_gradient!( f, out::DiffResults.MutableDiffResult, ad::ADTypes.AbstractADType, x, aux ) prepared = AbstractPPL.prepare(ad, Base.Fix2(f, aux), x) - val, grad = AbstractPPL.value_and_gradient(prepared, x) + val, grad = AbstractPPL.value_and_gradient!!(prepared, x) DiffResults.value!(out, val) copyto!(DiffResults.gradient(out), grad) return out @@ -74,7 +75,7 @@ function _value_and_gradient!( aux, ) prep.aux_ref[] = aux - val, grad = AbstractPPL.value_and_gradient(prep.prepared, x) + val, grad = AbstractPPL.value_and_gradient!!(prep.prepared, x) DiffResults.value!(out, val) copyto!(DiffResults.gradient(out), grad) return out @@ -110,6 +111,33 @@ This is an indirection for handling the type stability of `restructure`, as some """ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) +# Gradient-only LDP fallback for any AD-prepared evaluator; backend extensions +# override `capabilities` and add `logdensity_gradient_and_hessian` if they can. +function LogDensityProblems.capabilities( + ::Type{<:Prepared{<:ADTypes.AbstractADType,<:VectorEvaluator}} +) + LogDensityProblems.LogDensityOrder{1}() +end + +function LogDensityProblems.dimension( + p::Prepared{<:ADTypes.AbstractADType,<:VectorEvaluator} +) + p.evaluator.dim +end + +function LogDensityProblems.logdensity( + p::Prepared{<:ADTypes.AbstractADType,<:VectorEvaluator}, x +) + p(x) +end + +function LogDensityProblems.logdensity_and_gradient( + p::Prepared{<:ADTypes.AbstractADType,<:VectorEvaluator}, x +) + val, grad = AbstractPPL.value_and_gradient!!(p, x) + return val, copy(grad) +end + include("mixedad_logdensity.jl") # Variational Families diff --git a/test/Project.toml b/test/Project.toml index 6aad881ab..d85b47d4d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -26,15 +27,15 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [sources] -AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "evaluator-interface"} DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} +Bijectors = {url = "https://github.com/TuringLang/Bijectors.jl", rev = "replace-di-with-abstractppl"} [compat] ADTypes = "0.2.1, 1" DiffResults = "1" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111" -DynamicPPL = "0.40, 0.41" +DynamicPPL = "0.42" Enzyme = "0.13, 0.14, 0.15" FillArrays = "1.6.1" ForwardDiff = "0.10.36, 1" From 92124d6aa0a117aee86a1988d1fb281cbb849556 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 19 May 2026 18:14:55 +0100 Subject: [PATCH 3/8] Revert formatting-only changes pulled in by the migration `git diff main --stat` was full of `return` keyword and `*` spacing adjustments unrelated to the AbstractPPL 0.15 switch. Restoring those files to `main`'s state shrinks the review surface to just the load-bearing changes (AbstractPPL/DynamicPPL bump, LDP+Hessian wiring, stale-DI comment fixes). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/AdvancedVI.jl | 10 +++++----- src/algorithms/abstractobjective.jl | 2 +- src/algorithms/common.jl | 2 +- src/algorithms/fisherminbatchmatch.jl | 18 +++++++++--------- src/algorithms/klminnaturalgraddescent.jl | 10 +++++----- src/algorithms/klminsqrtnaturalgraddescent.jl | 6 +++--- src/algorithms/klminwassfwdbwd.jl | 6 +++--- src/optimization/rules.jl | 8 ++------ test/integration/dynamicppl.jl | 14 +++++++------- 9 files changed, 36 insertions(+), 40 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 82c865a5e..5c1105a3b 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -41,9 +41,9 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif `f` may receive auxiliary input as `f(x,aux)`. # Arguments -- `ad::ADTypes.AbstractADType`: +- `ad::ADTypes.AbstractADType`: automatic differentiation backend. Currently supports - `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`, + `ADTypes.AutoZygote()`, `ADTypes.ForwardDiff()`, `ADTypes.ReverseDiff()`, `ADTypes.AutoMooncake()` and `ADTypes.AutoEnzyme(; mode=Enzyme.set_runtime_activity(Enzyme.Reverse), @@ -280,7 +280,7 @@ function step( objargs...; kwargs..., ) - return nothing + nothing end """ @@ -318,11 +318,11 @@ Please refer to the respective documentation of each algorithm for more info. function estimate_objective( ::Random.AbstractRNG, ::AbstractVariationalAlgorithm, q, prob; kwargs... ) - return nothing + nothing end function estimate_objective(alg::AbstractVariationalAlgorithm, q, prob; kwargs...) - return estimate_objective(Random.default_rng(), alg, q, prob; kwargs...) + estimate_objective(Random.default_rng(), alg, q, prob; kwargs...) end export estimate_objective diff --git a/src/algorithms/abstractobjective.jl b/src/algorithms/abstractobjective.jl index ff027bb54..65316c7e7 100644 --- a/src/algorithms/abstractobjective.jl +++ b/src/algorithms/abstractobjective.jl @@ -31,7 +31,7 @@ function init( ::Any, ::Any, ) - return nothing + nothing end """ diff --git a/src/algorithms/common.jl b/src/algorithms/common.jl index 96187fe6a..0b99ff0d0 100644 --- a/src/algorithms/common.jl +++ b/src/algorithms/common.jl @@ -116,5 +116,5 @@ function step( ) info = !isnothing(info′) ? merge(info′, info) : info end - return state, false, info + state, false, info end diff --git a/src/algorithms/fisherminbatchmatch.jl b/src/algorithms/fisherminbatchmatch.jl index f3ea6fd8d..b794a12af 100644 --- a/src/algorithms/fisherminbatchmatch.jl +++ b/src/algorithms/fisherminbatchmatch.jl @@ -89,14 +89,14 @@ function rand_batch_match_samples_with_objective!( μ = q.location C = q.scale u = Random.randn!(rng, u_buf) - z = C * u .+ μ + z = C*u .+ μ logπ_sum = zero(eltype(μ)) for b in 1:n_samples logπb, gb = LogDensityProblems.logdensity_and_gradient(prob, view(z, :, b)) grad_buf[:, b] = gb logπ_sum += logπb end - logπ_avg = logπ_sum / n_samples + logπ_avg = logπ_sum/n_samples # Estimate objective values # @@ -105,7 +105,7 @@ function rand_batch_match_samples_with_objective!( # = E[| C' ( -(CC')\((Cu + μ) - μ) - ∇logπ(z)) |^2] (z = Cu + μ) # = E[| C' ( -(CC')\(Cu) - ∇logπ(z)) |^2] # = E[| -u - C'∇logπ(z)) |^2] - fisher = sum(abs2, -u_buf - (C' * grad_buf)) / n_samples + fisher = sum(abs2, -u_buf - (C'*grad_buf))/n_samples return u_buf, z, grad_buf, fisher, logπ_avg end @@ -145,13 +145,13 @@ function step( gbar, Γ = mean_and_cov(grad_buf, 2) μmz = μ - zbar - λ = convert(eltype(μ), d * n_samples / iteration) + λ = convert(eltype(μ), d*n_samples / iteration) - U = Symmetric(λ * Γ + (λ / (1 + λ) * gbar) * gbar') - V = Symmetric(Σ + λ * C + (λ / (1 + λ) * μmz) * μmz') + U = Symmetric(λ*Γ + (λ/(1 + λ)*gbar)*gbar') + V = Symmetric(Σ + λ*C + (λ/(1 + λ)*μmz)*μmz') - Σ′ = Hermitian(2 * V / (I + real(sqrt(I + 4 * U * V)))) - μ′ = 1 / (1 + λ) * μ + λ / (1 + λ) * (Σ′ * gbar + zbar) + Σ′ = Hermitian(2*V/(I + real(sqrt(I + 4*U*V)))) + μ′ = 1/(1 + λ)*μ + λ/(1 + λ)*(Σ′*gbar + zbar) q′ = MvLocationScale(μ′[:, 1], cholesky(Σ′).L, q.dist) elbo = logπ_avg + entropy(q) @@ -163,7 +163,7 @@ function step( info′ = callback(; rng, iteration, q, state) info = !isnothing(info′) ? merge(info′, info) : info end - return state, false, info + state, false, info end """ diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index 596cd372c..c101537a4 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -81,10 +81,10 @@ function init( grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) scale = q_init.scale - qcov = Hermitian(scale * scale') + qcov = Hermitian(scale*scale') scale_inv = inv(scale) prec_chol = scale_inv' - prec = Hermitian(prec_chol * prec_chol') + prec = Hermitian(prec_chol*prec_chol') return KLMinNaturalGradDescentState( q_init, prob, prec, qcov, 0, sub_st, grad_buf, hess_buf ) @@ -127,7 +127,7 @@ function step( # Handling the positive-definite constraint in the Bayesian learning rule. # In ICML 2020. G_hat = S - (-hess_buf) - Hermitian(S - η * G_hat + η^2 / 2 * G_hat * qcov * G_hat) + Hermitian(S - η*G_hat + η^2/2*G_hat*qcov*G_hat) else Hermitian(((1 - η) * S + η * (-hess_buf))) end @@ -136,7 +136,7 @@ function step( prec_chol = cholesky(S′).L prec_chol_inv = inv(prec_chol) scale = prec_chol_inv' - qcov = Hermitian(scale * scale') + qcov = Hermitian(scale*scale') q′ = MvLocationScale(m′, scale, q.dist) state = KLMinNaturalGradDescentState( @@ -149,7 +149,7 @@ function step( info′ = callback(; rng, iteration, q=q′, info) info = !isnothing(info′) ? merge(info′, info) : info end - return state, false, info + state, false, info end """ diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index 3b29c3b2b..622c111c3 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -105,8 +105,8 @@ function step( rng, q, n_samples, grad_buf, hess_buf, prob_sub ) - CtHCmI = C' * (-hess_buf) * C - I - CtHCmI_tril = LowerTriangular(tril(CtHCmI) - Diagonal(diag(CtHCmI)) / 2) + CtHCmI = C'*(-hess_buf)*C - I + CtHCmI_tril = LowerTriangular(tril(CtHCmI) - Diagonal(diag(CtHCmI))/2) m′ = m - η * C * (C' * -grad_buf) C′ = C - η * C * CtHCmI_tril @@ -123,7 +123,7 @@ function step( info′ = callback(; rng, iteration, q=q′, info) info = !isnothing(info′) ? merge(info′, info) : info end - return state, false, info + state, false, info end """ diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 40577bfc4..602f4be41 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -104,10 +104,10 @@ function step( m′ = m - η * (-grad_buf) M = I - η * (-hess_buf') - Σ_half = Hermitian(M * Σ * M') + Σ_half = Hermitian(M*Σ*M') # Compute the JKO proximal operator - Σ′ = (Σ_half + 2 * η * I + sqrt(Hermitian(Σ_half * (Σ_half + 4 * η * I)))) / 2 + Σ′ = (Σ_half + 2*η*I + sqrt(Hermitian(Σ_half*(Σ_half + 4*η*I))))/2 q′ = MvLocationScale(m′, cholesky(Σ′).L, q.dist) state = KLMinWassFwdBwdState(q′, prob, Σ′, iteration, sub_st′, grad_buf, hess_buf) @@ -118,7 +118,7 @@ function step( info′ = callback(; rng, iteration, q=q′, info) info = !isnothing(info′) ? merge(info′, info) : info end - return state, false, info + state, false, info end """ diff --git a/src/optimization/rules.jl b/src/optimization/rules.jl index 5015f4ebe..2025632e3 100644 --- a/src/optimization/rules.jl +++ b/src/optimization/rules.jl @@ -18,9 +18,7 @@ Optimisers.@def struct DoWG <: Optimisers.AbstractRule alpha = 1e-6 end -function Optimisers.init(o::DoWG, x::AbstractArray{T}) where {T} - return (copy(x), zero(T), T(o.alpha) * (1 + norm(x))) -end +Optimisers.init(o::DoWG, x::AbstractArray{T}) where {T} = (copy(x), zero(T), T(o.alpha)*(1 + norm(x))) function Optimisers.apply!(::DoWG, state, x::AbstractArray{T}, dx) where {T} x0, v, r = state @@ -49,9 +47,7 @@ Optimisers.@def struct DoG <: Optimisers.AbstractRule alpha = 1e-6 end -function Optimisers.init(o::DoG, x::AbstractArray{T}) where {T} - return (copy(x), zero(T), T(o.alpha) * (1 + norm(x))) -end +Optimisers.init(o::DoG, x::AbstractArray{T}) where {T} = (copy(x), zero(T), T(o.alpha)*(1 + norm(x))) function Optimisers.apply!(::DoG, state, x::AbstractArray{T}, dx) where {T} x0, v, r = state diff --git a/test/integration/dynamicppl.jl b/test/integration/dynamicppl.jl index a3ed01d20..43a38d3b8 100644 --- a/test/integration/dynamicppl.jl +++ b/test/integration/dynamicppl.jl @@ -1,7 +1,7 @@ @testset "DynamicPPL" begin DynamicPPL.@model function normal(μ) - return x ~ MvNormal(μ, I) + x ~ MvNormal(μ, I) end DynamicPPL.@model function normal_subsampled(μs; datapoints=1:size(μs, 2)) @@ -22,18 +22,18 @@ alg = KLMinRepGradProxDescent(AD) d = LogDensityProblems.dimension(prob) - q0 = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6 * I, d, d))) + q0 = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d))) q, _, _ = AdvancedVI.optimize(alg, 1000, prob, q0; show_progress=false) Δλ0 = sum(abs2, q0.location - μ_true) Δλ = sum(abs2, q.location - μ_true) - @test Δλ ≤ Δλ0 / 2 + @test Δλ ≤ Δλ0/2 end @testset "subsampling" begin n_data = 32 - μs = 3 * randn(2, n_data) - μ_true = mean(μs; dims=2)[:, 1] + μs = 3*randn(2, n_data) + μ_true = mean(μs, dims=2)[:, 1] model = normal_subsampled(μs) vi = DynamicPPL.VarInfo(model) @@ -48,11 +48,11 @@ alg = KLMinRepGradProxDescent(AD; subsampling) d = LogDensityProblems.dimension(prob) - q0 = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6 * I, d, d))) + q0 = FullRankGaussian(zeros(d), LowerTriangular(Matrix{Float64}(0.6*I, d, d))) q, _, _ = AdvancedVI.optimize(alg, 1000, prob, q0; show_progress=false) Δλ0 = sum(abs2, q0.location - μ_true) Δλ = sum(abs2, q.location - μ_true) - @test Δλ ≤ Δλ0 / 2 + @test Δλ ≤ Δλ0/2 end end From 6da323210bf150e640ad52c050264a380307c2cd Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 19 May 2026 20:54:34 +0100 Subject: [PATCH 4/8] Move LDP+Hessian-on-Prepared work to a follow-up branch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Keeps this branch focused on the AbstractPPL evaluator-interface migration: the LDP fallback methods on `Prepared`, the Mooncake forward-over-reverse Hessian, and the DynamicPPLModelLogDensityFunction LDP-delegation refactor have been moved to the `ldp-on-prepared-hessian` branch (branched from `main` and carrying both the migration and the feature work). What stays here: - Switch DI → AbstractPPL.prepare / value_and_gradient!! in core. - DynamicPPLModelLogDensityFunction uses AbstractPPL.prepare instead of DI.prepare_gradient (Hessian path is dropped as in the original migration commit; `use_hessian=true` warns and is ignored). - Compat bumps: AbstractPPL@0.15, DynamicPPL@0.42, plus Bijectors source pin in `test/`. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AdvancedVIDynamicPPLExt.jl | 56 +++++++++++++++++++--------------- ext/AdvancedVIMooncakeExt.jl | 22 ------------- src/AdvancedVI.jl | 28 ----------------- 3 files changed, 32 insertions(+), 74 deletions(-) diff --git a/ext/AdvancedVIDynamicPPLExt.jl b/ext/AdvancedVIDynamicPPLExt.jl index 6e5f86eef..af3839bc6 100644 --- a/ext/AdvancedVIDynamicPPLExt.jl +++ b/ext/AdvancedVIDynamicPPLExt.jl @@ -1,12 +1,22 @@ module AdvancedVIDynamicPPLExt using ADTypes: ADTypes +using Accessors using AdvancedVI: AdvancedVI using AbstractPPL: AbstractPPL +using Distributions: Distributions using DynamicPPL: DynamicPPL using LogDensityProblems: LogDensityProblems using Random +function adtype_capabilities(::Type{Nothing}) + return LogDensityProblems.LogDensityOrder{0}() +end + +function adtype_capabilities(::Type{<:ADTypes.AbstractADType}) + return LogDensityProblems.LogDensityOrder{1}() +end + function logdensity_impl( params, model::DynamicPPL.Model, loglikeadj::Real, varinfo::DynamicPPL.AbstractVarInfo ) @@ -25,29 +35,31 @@ function subsample_dynamicpplmodel( return DynamicPPL.Model{Threaded}(model.f, model.args, new_kwargs, model.context) end -# `LogDensityProblems.capabilities` and the gradient/Hessian methods dispatch -# off `Prep`, so the AD backend's `Prepared` type drives the LDP capability. struct DynamicPPLModelLogDensityFunction{ Model<:DynamicPPL.Model, VarInfo<:DynamicPPL.AbstractVarInfo, ADType<:Union{Nothing,ADTypes.AbstractADType}, - Prep, + PrepGrad, } model::Model varinfo::VarInfo adtype::ADType + # Refs are updated in-place by subsample; the prepared AD evaluator reads + # through them on every call, so the prep remains valid across subsampling. model_ref::Ref{Any} loglikeadj_ref::Ref{Float64} - prep::Prep + prep_grad::PrepGrad end function DynamicPPLModelLogDensityFunction( model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; + use_hessian::Bool=false, adtype::Union{Nothing,ADTypes.AbstractADType}=nothing, loglikeadj::Real=1.0, subsampling::Union{Nothing,AdvancedVI.AbstractSubsampling}=nothing, ) + use_hessian && @warn "`use_hessian` is no longer supported and will be ignored." model_sub = if isnothing(subsampling) model else @@ -57,20 +69,24 @@ function DynamicPPLModelLogDensityFunction( subsample_dynamicpplmodel(model, batch) end - params = collect(varinfo[:]) + params = [val for val in varinfo[:]] + cap = adtype_capabilities(typeof(adtype)) model_ref = Ref{Any}(model_sub) loglikeadj_ref = Ref{Float64}(float(loglikeadj)) - prep = if isnothing(adtype) - nothing + prep_grad = if cap >= LogDensityProblems.LogDensityOrder{1}() + AbstractPPL.prepare( + adtype, + params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo), + params, + ) else - f = params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo) - AbstractPPL.prepare(adtype, f, params) + nothing end return DynamicPPLModelLogDensityFunction( - model, varinfo, adtype, model_ref, loglikeadj_ref, prep + model, varinfo, adtype, model_ref, loglikeadj_ref, prep_grad ) end @@ -81,25 +97,19 @@ end function LogDensityProblems.logdensity_and_gradient( prob::DynamicPPLModelLogDensityFunction, params ) - return LogDensityProblems.logdensity_and_gradient(prob.prep, params) -end - -function LogDensityProblems.logdensity_gradient_and_hessian( - prob::DynamicPPLModelLogDensityFunction, params -) - return LogDensityProblems.logdensity_gradient_and_hessian(prob.prep, params) + return AbstractPPL.value_and_gradient!!(prob.prep_grad, params) end function LogDensityProblems.capabilities( - ::Type{<:DynamicPPLModelLogDensityFunction{M,V,Nothing,P}} -) where {M,V,P} + ::Type{<:DynamicPPLModelLogDensityFunction{M,V,Nothing,G}} +) where {M,V,G} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:DynamicPPLModelLogDensityFunction{M,V,A,P}} -) where {M,V,A<:ADTypes.AbstractADType,P} - return LogDensityProblems.capabilities(P) + ::Type{<:DynamicPPLModelLogDensityFunction{M,V,<:ADTypes.AbstractADType,G}} +) where {M,V,G} + return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.dimension(prob::DynamicPPLModelLogDensityFunction) @@ -122,8 +132,6 @@ function AdvancedVI.subsample(prob::DynamicPPLModelLogDensityFunction, batch) model_sub = subsample_dynamicpplmodel(model, batch) loglikeadj = n_datapoints / batchsize - # Mutates the refs so the previously prepared AD evaluator keeps reading - # the latest batch without needing a re-prepare. prob.model_ref[] = model_sub prob.loglikeadj_ref[] = loglikeadj diff --git a/ext/AdvancedVIMooncakeExt.jl b/ext/AdvancedVIMooncakeExt.jl index 9b562aad8..605f77bfa 100644 --- a/ext/AdvancedVIMooncakeExt.jl +++ b/ext/AdvancedVIMooncakeExt.jl @@ -1,8 +1,5 @@ module AdvancedVIMooncakeExt -using ADTypes: AutoMooncake, AutoMooncakeForward -using AbstractPPL: AbstractPPL -using AbstractPPL.Evaluators: Prepared, VectorEvaluator using AdvancedVI using LogDensityProblems using Mooncake @@ -34,23 +31,4 @@ function Mooncake.rrule!!( return Mooncake.zero_fcodual(ℓπ), logdensity_pb end -const _MooncakePrepared = Prepared{<:AutoMooncake,<:VectorEvaluator} - -# Order-1 LDP methods are inherited from the AbstractADType fallback in -# AdvancedVI core. -function LogDensityProblems.capabilities(::Type{<:_MooncakePrepared}) - LogDensityProblems.LogDensityOrder{2}() -end - -# Mooncake forward-over-reverse Hessian: a fresh forward-mode Jacobian cache -# is built per call, so this is fine for occasional use but costly inside a -# tight per-sample loop. -function LogDensityProblems.logdensity_gradient_and_hessian(p::_MooncakePrepared, x) - val, grad = LogDensityProblems.logdensity_and_gradient(p, x) - grad_fn = y -> LogDensityProblems.logdensity_and_gradient(p, y)[2] - fwd_jac = AbstractPPL.prepare(AutoMooncakeForward(), grad_fn, x) - _, H = AbstractPPL.value_and_jacobian!!(fwd_jac, x) - return val, grad, copy(H) -end - end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 5c1105a3b..765ef2977 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -18,7 +18,6 @@ using LogDensityProblems using ADTypes using DiffResults using AbstractPPL: AbstractPPL -using AbstractPPL.Evaluators: Prepared, VectorEvaluator using ChainRulesCore: ChainRulesCore using FillArrays @@ -111,33 +110,6 @@ This is an indirection for handling the type stability of `restructure`, as some """ restructure_ad_forward(::ADTypes.AbstractADType, restructure, params) = restructure(params) -# Gradient-only LDP fallback for any AD-prepared evaluator; backend extensions -# override `capabilities` and add `logdensity_gradient_and_hessian` if they can. -function LogDensityProblems.capabilities( - ::Type{<:Prepared{<:ADTypes.AbstractADType,<:VectorEvaluator}} -) - LogDensityProblems.LogDensityOrder{1}() -end - -function LogDensityProblems.dimension( - p::Prepared{<:ADTypes.AbstractADType,<:VectorEvaluator} -) - p.evaluator.dim -end - -function LogDensityProblems.logdensity( - p::Prepared{<:ADTypes.AbstractADType,<:VectorEvaluator}, x -) - p(x) -end - -function LogDensityProblems.logdensity_and_gradient( - p::Prepared{<:ADTypes.AbstractADType,<:VectorEvaluator}, x -) - val, grad = AbstractPPL.value_and_gradient!!(p, x) - return val, copy(grad) -end - include("mixedad_logdensity.jl") # Variational Families From d106e7b373c7a6748ff4ab4e19b7a0a8e4d9acdf Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 19 May 2026 22:14:21 +0100 Subject: [PATCH 5/8] Restore Hessian support via AbstractPPL hg/hessian-order MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pins AbstractPPL to `hg/hessian-order`, which adds `prepare(adtype, f, x; order=2)` and `value_gradient_and_hessian!!`. `DynamicPPLModelLogDensityFunction` goes back to main's `prep_grad` + `prep_hess` shape (default `use_hessian=true`), so the diff reads as a clean DI → AbstractPPL swap rather than a redesign. `use_hessian` falls back to gradient-only when the AD backend refuses `order=2` (MethodError only; other errors propagate). Co-Authored-By: Claude Opus 4.7 (1M context) --- Project.toml | 1 + ext/AdvancedVIDynamicPPLExt.jl | 82 +++++++++++++++++++++------------- src/AdvancedVI.jl | 4 +- test/Project.toml | 1 + 4 files changed, 54 insertions(+), 34 deletions(-) diff --git a/Project.toml b/Project.toml index 8cffa4dcf..264f66372 100644 --- a/Project.toml +++ b/Project.toml @@ -65,4 +65,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" test = ["Pkg", "Test"] [sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "hg/hessian-order"} DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} diff --git a/ext/AdvancedVIDynamicPPLExt.jl b/ext/AdvancedVIDynamicPPLExt.jl index af3839bc6..95a28ba3d 100644 --- a/ext/AdvancedVIDynamicPPLExt.jl +++ b/ext/AdvancedVIDynamicPPLExt.jl @@ -1,17 +1,13 @@ module AdvancedVIDynamicPPLExt using ADTypes: ADTypes -using Accessors using AdvancedVI: AdvancedVI using AbstractPPL: AbstractPPL -using Distributions: Distributions using DynamicPPL: DynamicPPL using LogDensityProblems: LogDensityProblems using Random -function adtype_capabilities(::Type{Nothing}) - return LogDensityProblems.LogDensityOrder{0}() -end +adtype_capabilities(::Type{Nothing}) = LogDensityProblems.LogDensityOrder{0}() function adtype_capabilities(::Type{<:ADTypes.AbstractADType}) return LogDensityProblems.LogDensityOrder{1}() @@ -35,31 +31,33 @@ function subsample_dynamicpplmodel( return DynamicPPL.Model{Threaded}(model.f, model.args, new_kwargs, model.context) end +# `model_ref`/`loglikeadj_ref` are mutated in place by `subsample`; the closure +# inside `prep_grad`/`prep_hess` reads through them so the prep stays valid +# across subsampling steps (AbstractPPL bakes the closure into the prep, unlike +# DI's `Constant` which can be rebound at call time). struct DynamicPPLModelLogDensityFunction{ Model<:DynamicPPL.Model, VarInfo<:DynamicPPL.AbstractVarInfo, ADType<:Union{Nothing,ADTypes.AbstractADType}, PrepGrad, + PrepHess, } - model::Model - varinfo::VarInfo - adtype::ADType - # Refs are updated in-place by subsample; the prepared AD evaluator reads - # through them on every call, so the prep remains valid across subsampling. model_ref::Ref{Any} loglikeadj_ref::Ref{Float64} + varinfo::VarInfo + adtype::ADType prep_grad::PrepGrad + prep_hess::PrepHess end function DynamicPPLModelLogDensityFunction( model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; - use_hessian::Bool=false, + use_hessian::Bool=true, adtype::Union{Nothing,ADTypes.AbstractADType}=nothing, loglikeadj::Real=1.0, subsampling::Union{Nothing,AdvancedVI.AbstractSubsampling}=nothing, ) - use_hessian && @warn "`use_hessian` is no longer supported and will be ignored." model_sub = if isnothing(subsampling) model else @@ -69,24 +67,36 @@ function DynamicPPLModelLogDensityFunction( subsample_dynamicpplmodel(model, batch) end - params = [val for val in varinfo[:]] - cap = adtype_capabilities(typeof(adtype)) - + params = collect(varinfo[:]) model_ref = Ref{Any}(model_sub) loglikeadj_ref = Ref{Float64}(float(loglikeadj)) + f = params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo) + cap = adtype_capabilities(typeof(adtype)) prep_grad = if cap >= LogDensityProblems.LogDensityOrder{1}() - AbstractPPL.prepare( - adtype, - params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo), - params, - ) + AbstractPPL.prepare(adtype, f, params) else nothing end - - return DynamicPPLModelLogDensityFunction( - model, varinfo, adtype, model_ref, loglikeadj_ref, prep_grad + prep_hess = if cap >= LogDensityProblems.LogDensityOrder{1}() && use_hessian + try + AbstractPPL.prepare(adtype, f, params; order=2) + catch err + err isa MethodError || rethrow() + @warn "The selected AD backend does not support `AbstractPPL.prepare(...; order=2)`. AdvancedVI will treat the model as first-order only." + nothing + end + else + nothing + end + return DynamicPPLModelLogDensityFunction{ + typeof(model), + typeof(varinfo), + typeof(adtype), + typeof(prep_grad), + typeof(prep_hess), + }( + model_ref, loglikeadj_ref, varinfo, adtype, prep_grad, prep_hess ) end @@ -97,19 +107,27 @@ end function LogDensityProblems.logdensity_and_gradient( prob::DynamicPPLModelLogDensityFunction, params ) - return AbstractPPL.value_and_gradient!!(prob.prep_grad, params) + val, grad = AbstractPPL.value_and_gradient!!(prob.prep_grad, params) + return val, copy(grad) end -function LogDensityProblems.capabilities( - ::Type{<:DynamicPPLModelLogDensityFunction{M,V,Nothing,G}} -) where {M,V,G} - return LogDensityProblems.LogDensityOrder{0}() +function LogDensityProblems.logdensity_gradient_and_hessian( + prob::DynamicPPLModelLogDensityFunction, params +) + val, grad, H = AbstractPPL.value_gradient_and_hessian!!(prob.prep_hess, params) + return val, copy(grad), copy(H) end function LogDensityProblems.capabilities( - ::Type{<:DynamicPPLModelLogDensityFunction{M,V,<:ADTypes.AbstractADType,G}} -) where {M,V,G} - return LogDensityProblems.LogDensityOrder{1}() + ::Type{<:DynamicPPLModelLogDensityFunction{M,V,ADType,PG,PH}} +) where {M,V,ADType<:ADTypes.AbstractADType,PG,PH} + return if PH != Nothing + LogDensityProblems.LogDensityOrder{2}() + elseif PG != Nothing + LogDensityProblems.LogDensityOrder{1}() + else + LogDensityProblems.LogDensityOrder{0}() + end end function LogDensityProblems.dimension(prob::DynamicPPLModelLogDensityFunction) @@ -117,7 +135,7 @@ function LogDensityProblems.dimension(prob::DynamicPPLModelLogDensityFunction) end function AdvancedVI.subsample(prob::DynamicPPLModelLogDensityFunction, batch) - model = prob.model + model = prob.model_ref[] if !haskey(model.defaults, :datapoints) throw( diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 765ef2977..9fd6b64cd 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -24,8 +24,8 @@ using FillArrays using StatsBase -# `aux` is captured by Ref so the same prepared evaluator can be reused after -# aux changes — re-preparing per call would defeat the cache. +# `AbstractPPL.prepare` bakes the closure at prep time, so `aux` is captured +# via a `Ref` that callers mutate before each evaluation. struct _VIGradPrep{P,R} prepared::P aux_ref::R diff --git a/test/Project.toml b/test/Project.toml index d85b47d4d..ae013611c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -27,6 +27,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [sources] +AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "hg/hessian-order"} DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} Bijectors = {url = "https://github.com/TuringLang/Bijectors.jl", rev = "replace-di-with-abstractppl"} From 79db2fb87b90e7e76d3c7ffa737e715cfec4d0b4 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 19 May 2026 22:22:03 +0100 Subject: [PATCH 6/8] Drop _VIGradPrep; thread aux through AbstractPPL context Wrapping the prepared evaluator and the aux `Ref` in a struct was just working around the lack of mutable state on `AbstractPPL.Prepared`. AbstractPPL 0.15's `prepare(...; context=tuple)` keeps the tuple on the evaluator and threads it through every call, so passing `Ref(aux)` in the context stores the ref inside the prep itself; callers mutate `prep.evaluator.context[1][]` before each evaluation. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AdvancedVIDynamicPPLExt.jl | 6 +----- src/AdvancedVI.jl | 22 ++++------------------ 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/ext/AdvancedVIDynamicPPLExt.jl b/ext/AdvancedVIDynamicPPLExt.jl index 95a28ba3d..53ed27c35 100644 --- a/ext/AdvancedVIDynamicPPLExt.jl +++ b/ext/AdvancedVIDynamicPPLExt.jl @@ -90,11 +90,7 @@ function DynamicPPLModelLogDensityFunction( nothing end return DynamicPPLModelLogDensityFunction{ - typeof(model), - typeof(varinfo), - typeof(adtype), - typeof(prep_grad), - typeof(prep_hess), + typeof(model),typeof(varinfo),typeof(adtype),typeof(prep_grad),typeof(prep_hess) }( model_ref, loglikeadj_ref, varinfo, adtype, prep_grad, prep_hess ) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 9fd6b64cd..3e32f7359 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -24,13 +24,6 @@ using FillArrays using StatsBase -# `AbstractPPL.prepare` bakes the closure at prep time, so `aux` is captured -# via a `Ref` that callers mutate before each evaluation. -struct _VIGradPrep{P,R} - prepared::P - aux_ref::R -end - # Derivatives """ _value_and_gradient!(f, out, ad, x, aux) @@ -66,15 +59,10 @@ function _value_and_gradient!( end function _value_and_gradient!( - f, - out::DiffResults.MutableDiffResult, - prep::_VIGradPrep, - ad::ADTypes.AbstractADType, - x, - aux, + f, out::DiffResults.MutableDiffResult, prep, ad::ADTypes.AbstractADType, x, aux ) - prep.aux_ref[] = aux - val, grad = AbstractPPL.value_and_gradient!!(prep.prepared, x) + prep.evaluator.context[1][] = aux + val, grad = AbstractPPL.value_and_gradient!!(prep, x) DiffResults.value!(out, val) copyto!(DiffResults.gradient(out), grad) return out @@ -92,9 +80,7 @@ Prepare AD backend for taking gradients of a function `f` at `x` using the autom - `aux`: Auxiliary input passed to `f`. """ function _prepare_gradient(f, ad::ADTypes.AbstractADType, x, aux) - aux_ref = Ref(aux) - prepared = AbstractPPL.prepare(ad, x -> f(x, aux_ref[]), x) - return _VIGradPrep(prepared, aux_ref) + return AbstractPPL.prepare(ad, (x, aref) -> f(x, aref[]), x; context=(Ref(aux),)) end """ From 451e7b706fd2b7749b751213670f787306869293 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 20 May 2026 11:13:12 +0100 Subject: [PATCH 7/8] Parameterise loglikeadj eltype; expand model_ref WHY comment `loglikeadj_ref` now tracks the input numeric type via a `LogLikeAdj<:Real` struct parameter, restoring `Float32`/`BigFloat`/dual support that the hard-pinned `Ref{Float64}` quietly dropped. `subsample` computes the adjustment in the field's eltype to preserve precision across minibatches. The block comment now also flags why `model_ref::Ref{Any}` cannot be tightened (subsample-widened `defaults` NamedTuple), so a future reader doesn't try. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AdvancedVIDynamicPPLExt.jl | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/ext/AdvancedVIDynamicPPLExt.jl b/ext/AdvancedVIDynamicPPLExt.jl index 53ed27c35..699327fc6 100644 --- a/ext/AdvancedVIDynamicPPLExt.jl +++ b/ext/AdvancedVIDynamicPPLExt.jl @@ -35,15 +35,21 @@ end # inside `prep_grad`/`prep_hess` reads through them so the prep stays valid # across subsampling steps (AbstractPPL bakes the closure into the prep, unlike # DI's `Constant` which can be rebound at call time). +# +# `model_ref` is typed `Ref{Any}` because `subsample_dynamicpplmodel` returns a +# `DynamicPPL.Model` whose `defaults` NamedTuple type varies with the batch — a +# typed `Ref{<:DynamicPPL.Model}` would throw on reassignment. The tradeoff is a +# dynamic dispatch on each `prob.model_ref[]` read; do not "tighten" it. struct DynamicPPLModelLogDensityFunction{ Model<:DynamicPPL.Model, + LogLikeAdj<:Real, VarInfo<:DynamicPPL.AbstractVarInfo, ADType<:Union{Nothing,ADTypes.AbstractADType}, PrepGrad, PrepHess, } model_ref::Ref{Any} - loglikeadj_ref::Ref{Float64} + loglikeadj_ref::Ref{LogLikeAdj} varinfo::VarInfo adtype::ADType prep_grad::PrepGrad @@ -69,7 +75,8 @@ function DynamicPPLModelLogDensityFunction( params = collect(varinfo[:]) model_ref = Ref{Any}(model_sub) - loglikeadj_ref = Ref{Float64}(float(loglikeadj)) + adj0 = float(loglikeadj) + loglikeadj_ref = Ref(adj0) f = params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo) cap = adtype_capabilities(typeof(adtype)) @@ -90,7 +97,12 @@ function DynamicPPLModelLogDensityFunction( nothing end return DynamicPPLModelLogDensityFunction{ - typeof(model),typeof(varinfo),typeof(adtype),typeof(prep_grad),typeof(prep_hess) + typeof(model), + typeof(adj0), + typeof(varinfo), + typeof(adtype), + typeof(prep_grad), + typeof(prep_hess), }( model_ref, loglikeadj_ref, varinfo, adtype, prep_grad, prep_hess ) @@ -115,8 +127,8 @@ function LogDensityProblems.logdensity_gradient_and_hessian( end function LogDensityProblems.capabilities( - ::Type{<:DynamicPPLModelLogDensityFunction{M,V,ADType,PG,PH}} -) where {M,V,ADType<:ADTypes.AbstractADType,PG,PH} + ::Type{<:DynamicPPLModelLogDensityFunction{M,L,V,ADType,PG,PH}} +) where {M,L,V,ADType<:ADTypes.AbstractADType,PG,PH} return if PH != Nothing LogDensityProblems.LogDensityOrder{2}() elseif PG != Nothing @@ -144,7 +156,8 @@ function AdvancedVI.subsample(prob::DynamicPPLModelLogDensityFunction, batch) n_datapoints = length(model.defaults.datapoints) batchsize = length(batch) model_sub = subsample_dynamicpplmodel(model, batch) - loglikeadj = n_datapoints / batchsize + T = eltype(prob.loglikeadj_ref) + loglikeadj = T(n_datapoints) / T(batchsize) prob.model_ref[] = model_sub prob.loglikeadj_ref[] = loglikeadj From 5790cd422612cdc891b14dae828e3092a8d822d2 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 21 May 2026 20:48:50 +0100 Subject: [PATCH 8/8] Drop AbstractPPL and Bijectors source pins; restrict Bijectors compat AbstractPPL 0.15.1 and Bijectors 0.16.0 are now released. The DynamicPPL `adproblems-interface` branch has been updated to require Bijectors 0.16, so keeping the DynamicPPL source pin no longer pulls in an older Bijectors. Restrict the package and test envs to Bijectors 0.16; relax docs to "0.15, 0.16" since NormalizingFlows 0.2.2 still requires 0.15.x. Co-Authored-By: Claude Opus 4.7 (1M context) --- Project.toml | 1 - docs/Project.toml | 2 +- test/Project.toml | 3 +-- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 264f66372..8cffa4dcf 100644 --- a/Project.toml +++ b/Project.toml @@ -65,5 +65,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" test = ["Pkg", "Test"] [sources] -AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "hg/hessian-order"} DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} diff --git a/docs/Project.toml b/docs/Project.toml index c05bf9e27..9374d4c52 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -26,7 +26,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" ADTypes = "1" Accessors = "0.1" AdvancedVI = "0.7, 0.6" -Bijectors = "0.13.6, 0.14, 0.15" +Bijectors = "0.15, 0.16" DataFrames = "1" DifferentiationInterface = "0.7" Distributions = "0.25" diff --git a/test/Project.toml b/test/Project.toml index d7583ac7b..6114a8fdf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -26,12 +26,11 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [sources] -AbstractPPL = {url = "https://github.com/TuringLang/AbstractPPL.jl", rev = "hg/hessian-order"} DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} -Bijectors = {url = "https://github.com/TuringLang/Bijectors.jl", rev = "replace-di-with-abstractppl"} [compat] ADTypes = "0.2.1, 1" +Bijectors = "0.16" DiffResults = "1" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111"