diff --git a/Project.toml b/Project.toml index 28d4a3035..8cffa4dcf 100644 --- a/Project.toml +++ b/Project.toml @@ -4,10 +4,10 @@ version = "0.7.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -29,17 +29,17 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"] AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"] AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"] -AdvancedVIDynamicPPLExt = ["DynamicPPL", "Accessors", "Distributions", "DifferentiationInterface", "LogDensityProblems"] +AdvancedVIDynamicPPLExt = ["DynamicPPL", "Accessors", "Distributions", "LogDensityProblems"] [compat] ADTypes = "1" +AbstractPPL = "0.15" Accessors = "0.1" ChainRulesCore = "1" DiffResults = "1" -DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111" DocStringExtensions = "0.8, 0.9" -DynamicPPL = "0.40, 0.41" +DynamicPPL = "0.42" Enzyme = "0.13" FillArrays = "1.3" Functors = "0.4, 0.5" @@ -63,3 +63,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Pkg", "Test"] + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} diff --git a/docs/Project.toml b/docs/Project.toml index c05bf9e27..9374d4c52 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -26,7 +26,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" ADTypes = "1" Accessors = "0.1" AdvancedVI = "0.7, 0.6" -Bijectors = "0.13.6, 0.14, 0.15" +Bijectors = "0.15, 0.16" DataFrames = "1" DifferentiationInterface = "0.7" Distributions = "0.25" diff --git a/ext/AdvancedVIDynamicPPLExt.jl b/ext/AdvancedVIDynamicPPLExt.jl index eaff70ae7..699327fc6 100644 --- a/ext/AdvancedVIDynamicPPLExt.jl +++ b/ext/AdvancedVIDynamicPPLExt.jl @@ -1,10 +1,8 @@ module AdvancedVIDynamicPPLExt using ADTypes: ADTypes -using Accessors using AdvancedVI: AdvancedVI -using DifferentiationInterface: DifferentiationInterface -using Distributions: Distributions +using AbstractPPL: AbstractPPL using DynamicPPL: DynamicPPL using LogDensityProblems: LogDensityProblems using Random @@ -15,36 +13,6 @@ function adtype_capabilities(::Type{<:ADTypes.AbstractADType}) return LogDensityProblems.LogDensityOrder{1}() end -function adtype_capabilities( - ::Type{ - <:Union{ - <:ADTypes.AutoForwardDiff, - <:ADTypes.AutoReverseDiff, - <:ADTypes.AutoMooncake, - <:ADTypes.AutoEnzyme, - <:DifferentiationInterface.SecondOrder, - }, - }, -) - return LogDensityProblems.LogDensityOrder{2}() -end - -struct DynamicPPLModelLogDensityFunction{ - Model<:DynamicPPL.Model, - LogLikeAdj<:Real, - VarInfo<:DynamicPPL.AbstractVarInfo, - ADType<:ADTypes.AbstractADType, - PrepGrad<:Union{Nothing,DifferentiationInterface.GradientPrep}, - PrepHess<:Union{Nothing,DifferentiationInterface.HessianPrep}, -} - model::Model - loglikeadj::LogLikeAdj - varinfo::VarInfo - adtype::ADType - prep_grad::PrepGrad - prep_hess::PrepHess -end - function logdensity_impl( params, model::DynamicPPL.Model, loglikeadj::Real, varinfo::DynamicPPL.AbstractVarInfo ) @@ -63,6 +31,31 @@ function subsample_dynamicpplmodel( return DynamicPPL.Model{Threaded}(model.f, model.args, new_kwargs, model.context) end +# `model_ref`/`loglikeadj_ref` are mutated in place by `subsample`; the closure +# inside `prep_grad`/`prep_hess` reads through them so the prep stays valid +# across subsampling steps (AbstractPPL bakes the closure into the prep, unlike +# DI's `Constant` which can be rebound at call time). +# +# `model_ref` is typed `Ref{Any}` because `subsample_dynamicpplmodel` returns a +# `DynamicPPL.Model` whose `defaults` NamedTuple type varies with the batch — a +# typed `Ref{<:DynamicPPL.Model}` would throw on reassignment. The tradeoff is a +# dynamic dispatch on each `prob.model_ref[]` read; do not "tighten" it. +struct DynamicPPLModelLogDensityFunction{ + Model<:DynamicPPL.Model, + LogLikeAdj<:Real, + VarInfo<:DynamicPPL.AbstractVarInfo, + ADType<:Union{Nothing,ADTypes.AbstractADType}, + PrepGrad, + PrepHess, +} + model_ref::Ref{Any} + loglikeadj_ref::Ref{LogLikeAdj} + varinfo::VarInfo + adtype::ADType + prep_grad::PrepGrad + prep_hess::PrepHess +end + function DynamicPPLModelLogDensityFunction( model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; @@ -80,32 +73,24 @@ function DynamicPPLModelLogDensityFunction( subsample_dynamicpplmodel(model, batch) end - params = [val for val in varinfo[:]] + params = collect(varinfo[:]) + model_ref = Ref{Any}(model_sub) + adj0 = float(loglikeadj) + loglikeadj_ref = Ref(adj0) + f = params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo) cap = adtype_capabilities(typeof(adtype)) + prep_grad = if cap >= LogDensityProblems.LogDensityOrder{1}() - DifferentiationInterface.prepare_gradient( - logdensity_impl, - DifferentiationInterface.inner(adtype), - params, - DifferentiationInterface.Constant(model_sub), - DifferentiationInterface.Constant(loglikeadj), - DifferentiationInterface.Constant(varinfo), - ) + AbstractPPL.prepare(adtype, f, params) else nothing end - prep_hess = if cap >= LogDensityProblems.LogDensityOrder{2}() && use_hessian + prep_hess = if cap >= LogDensityProblems.LogDensityOrder{1}() && use_hessian try - DifferentiationInterface.prepare_hessian( - logdensity_impl, - adtype, - params, - DifferentiationInterface.Constant(model_sub), - DifferentiationInterface.Constant(loglikeadj), - DifferentiationInterface.Constant(varinfo), - ) - catch - @warn "The selected AD backend has second-order capabilities but `DifferentiationInterface.prepare_hessian` failed. AdvancedVI will treat the model to only have first-order capability." + AbstractPPL.prepare(adtype, f, params; order=2) + catch err + err isa MethodError || rethrow() + @warn "The selected AD backend does not support `AbstractPPL.prepare(...; order=2)`. AdvancedVI will treat the model as first-order only." nothing end else @@ -113,49 +98,32 @@ function DynamicPPLModelLogDensityFunction( end return DynamicPPLModelLogDensityFunction{ typeof(model), - typeof(loglikeadj), + typeof(adj0), typeof(varinfo), typeof(adtype), typeof(prep_grad), typeof(prep_hess), }( - model, loglikeadj, varinfo, adtype, prep_grad, prep_hess + model_ref, loglikeadj_ref, varinfo, adtype, prep_grad, prep_hess ) end function LogDensityProblems.logdensity(prob::DynamicPPLModelLogDensityFunction, params) - (; model, loglikeadj, varinfo) = prob - return logdensity_impl(params, model, loglikeadj, varinfo) + return logdensity_impl(params, prob.model_ref[], prob.loglikeadj_ref[], prob.varinfo) end function LogDensityProblems.logdensity_and_gradient( prob::DynamicPPLModelLogDensityFunction, params ) - (; model, adtype, loglikeadj, varinfo, prep_grad) = prob - return DifferentiationInterface.value_and_gradient( - logdensity_impl, - prep_grad, - DifferentiationInterface.inner(adtype), - params, - DifferentiationInterface.Constant(model), - DifferentiationInterface.Constant(loglikeadj), - DifferentiationInterface.Constant(varinfo), - ) + val, grad = AbstractPPL.value_and_gradient!!(prob.prep_grad, params) + return val, copy(grad) end function LogDensityProblems.logdensity_gradient_and_hessian( prob::DynamicPPLModelLogDensityFunction, params ) - (; model, adtype, loglikeadj, varinfo, prep_hess) = prob - return DifferentiationInterface.value_gradient_and_hessian( - logdensity_impl, - prep_hess, - adtype, - params, - DifferentiationInterface.Constant(model), - DifferentiationInterface.Constant(loglikeadj), - DifferentiationInterface.Constant(varinfo), - ) + val, grad, H = AbstractPPL.value_gradient_and_hessian!!(prob.prep_hess, params) + return val, copy(grad), copy(H) end function LogDensityProblems.capabilities( @@ -175,12 +143,12 @@ function LogDensityProblems.dimension(prob::DynamicPPLModelLogDensityFunction) end function AdvancedVI.subsample(prob::DynamicPPLModelLogDensityFunction, batch) - model = prob.model + model = prob.model_ref[] if !haskey(model.defaults, :datapoints) throw( ArgumentError( - "Subsampling is turned on, but the model does not have have a `datapoints` keyword argument.", + "Subsampling is turned on, but the model does not have a `datapoints` keyword argument.", ), ) end @@ -188,11 +156,13 @@ function AdvancedVI.subsample(prob::DynamicPPLModelLogDensityFunction, batch) n_datapoints = length(model.defaults.datapoints) batchsize = length(batch) model_sub = subsample_dynamicpplmodel(model, batch) - loglikeadj = n_datapoints / batchsize + T = eltype(prob.loglikeadj_ref) + loglikeadj = T(n_datapoints) / T(batchsize) + + prob.model_ref[] = model_sub + prob.loglikeadj_ref[] = loglikeadj - prob′ = @set prob.model = model_sub - prob′′ = @set prob′.loglikeadj = loglikeadj - return prob′′ + return prob end end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 4044f6732..e2fa40ab2 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -17,7 +17,7 @@ using LogDensityProblems using ADTypes using DiffResults -using DifferentiationInterface +using AbstractPPL: AbstractPPL using ChainRulesCore: ChainRulesCore using FillArrays @@ -45,31 +45,31 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif - `f`: Function subject to differentiation. - `x`: The point to evaluate the gradient. - `aux`: Auxiliary input passed to `f`. -- `prep`: Output of `DifferentiationInterface.prepare_gradient`. +- `prep`: Output of `_prepare_gradient`. - `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ function _value_and_gradient!( f, out::DiffResults.MutableDiffResult, ad::ADTypes.AbstractADType, x, aux ) - grad_buf = DiffResults.gradient(out) - y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux)) - DiffResults.value!(out, y) + prepared = AbstractPPL.prepare(ad, Base.Fix2(f, aux), x) + val, grad = AbstractPPL.value_and_gradient!!(prepared, x) + DiffResults.value!(out, val) + copyto!(DiffResults.gradient(out), grad) return out end function _value_and_gradient!( f, out::DiffResults.MutableDiffResult, prep, ad::ADTypes.AbstractADType, x, aux ) - grad_buf = DiffResults.gradient(out) - y, _ = DifferentiationInterface.value_and_gradient!( - f, grad_buf, prep, ad, x, Constant(aux) - ) - DiffResults.value!(out, y) + prep.evaluator.context[1][] = aux + val, grad = AbstractPPL.value_and_gradient!!(prep, x) + DiffResults.value!(out, val) + copyto!(DiffResults.gradient(out), grad) return out end """ - _prepare_gradient!(f, ad, x, aux) + _prepare_gradient(f, ad, x, aux) Prepare AD backend for taking gradients of a function `f` at `x` using the automatic differentiation backend `ad`. @@ -80,7 +80,7 @@ Prepare AD backend for taking gradients of a function `f` at `x` using the autom - `aux`: Auxiliary input passed to `f`. """ function _prepare_gradient(f, ad::ADTypes.AbstractADType, x, aux) - return DifferentiationInterface.prepare_gradient(f, ad, x, Constant(aux)) + return AbstractPPL.prepare(ad, (x, aref) -> f(x, aref[]), x; context=(Ref(aux),)) end """ diff --git a/src/algorithms/subsampledobjective.jl b/src/algorithms/subsampledobjective.jl index 21734f7ca..ad1f858ef 100644 --- a/src/algorithms/subsampledobjective.jl +++ b/src/algorithms/subsampledobjective.jl @@ -32,7 +32,7 @@ function init( sub_st = init(rng, subsampling) # This is necessary to ensure that `init` sees the type "conditioned" on a minibatch - # when calling `DifferentiationInterface.prepare_*` inside it. + # so that any prepared AD evaluator inside it sees the correct batch-subsampled type. batch, _, _ = step(rng, subsampling, sub_st, true) prob_sub = subsample(prob, batch) q_init_sub = subsample(q_init, batch) diff --git a/src/reshuffling.jl b/src/reshuffling.jl index e0e50cfbd..dd7f95aa8 100644 --- a/src/reshuffling.jl +++ b/src/reshuffling.jl @@ -49,7 +49,7 @@ function step( # Ignore the trailing batch if its size is smaller than `batchsize`. # This should only be used when estimating gradients during optimization. # This is necessary to ensure that all batches have the same size. - # Otherwise, `DifferentiationInterface.prepare_*` behaves incorrectly. + # Otherwise, prepared AD evaluators may see inconsistent batch sizes. (sub_step, batch), iterator = Iterators.peel(iterator) end epoch = epoch + 1 diff --git a/test/Project.toml b/test/Project.toml index 356cbf4e1..6114a8fdf 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,8 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -22,12 +25,16 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"} + [compat] ADTypes = "0.2.1, 1" +Bijectors = "0.16" DiffResults = "1" DifferentiationInterface = "0.6, 0.7" Distributions = "0.25.111" -DynamicPPL = "0.40, 0.41" +DynamicPPL = "0.42" Enzyme = "0.13, 0.14, 0.15" FillArrays = "1.6.1" ForwardDiff = "0.10.36, 1" diff --git a/test/runtests.jl b/test/runtests.jl index b50e69bfc..519542e99 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,7 @@ using Random, StableRNGs using Statistics using StatsBase +using DifferentiationInterface using AdvancedVI const PROGRESS = haskey(ENV, "PROGRESS")