Skip to content
11 changes: 7 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ version = "0.7.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -29,17 +29,17 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
AdvancedVIEnzymeExt = ["Enzyme", "ChainRulesCore"]
AdvancedVIMooncakeExt = ["Mooncake", "ChainRulesCore"]
AdvancedVIReverseDiffExt = ["ReverseDiff", "ChainRulesCore"]
AdvancedVIDynamicPPLExt = ["DynamicPPL", "Accessors", "Distributions", "DifferentiationInterface", "LogDensityProblems"]
AdvancedVIDynamicPPLExt = ["DynamicPPL", "Accessors", "Distributions", "LogDensityProblems"]

[compat]
ADTypes = "1"
AbstractPPL = "0.15"
Accessors = "0.1"
ChainRulesCore = "1"
DiffResults = "1"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25.111"
DocStringExtensions = "0.8, 0.9"
DynamicPPL = "0.40, 0.41"
DynamicPPL = "0.42"
Enzyme = "0.13"
FillArrays = "1.3"
Functors = "0.4, 0.5"
Expand All @@ -63,3 +63,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Pkg", "Test"]

[sources]
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"}
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
ADTypes = "1"
Accessors = "0.1"
AdvancedVI = "0.7, 0.6"
Bijectors = "0.13.6, 0.14, 0.15"
Bijectors = "0.15, 0.16"
DataFrames = "1"
DifferentiationInterface = "0.7"
Distributions = "0.25"
Expand Down
136 changes: 53 additions & 83 deletions ext/AdvancedVIDynamicPPLExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
module AdvancedVIDynamicPPLExt

using ADTypes: ADTypes
using Accessors
using AdvancedVI: AdvancedVI
using DifferentiationInterface: DifferentiationInterface
using Distributions: Distributions
using AbstractPPL: AbstractPPL
using DynamicPPL: DynamicPPL
using LogDensityProblems: LogDensityProblems
using Random
Expand All @@ -15,36 +13,6 @@ function adtype_capabilities(::Type{<:ADTypes.AbstractADType})
return LogDensityProblems.LogDensityOrder{1}()
end

function adtype_capabilities(
::Type{
<:Union{
<:ADTypes.AutoForwardDiff,
<:ADTypes.AutoReverseDiff,
<:ADTypes.AutoMooncake,
<:ADTypes.AutoEnzyme,
<:DifferentiationInterface.SecondOrder,
},
},
)
return LogDensityProblems.LogDensityOrder{2}()
end

struct DynamicPPLModelLogDensityFunction{
Model<:DynamicPPL.Model,
LogLikeAdj<:Real,
VarInfo<:DynamicPPL.AbstractVarInfo,
ADType<:ADTypes.AbstractADType,
PrepGrad<:Union{Nothing,DifferentiationInterface.GradientPrep},
PrepHess<:Union{Nothing,DifferentiationInterface.HessianPrep},
}
model::Model
loglikeadj::LogLikeAdj
varinfo::VarInfo
adtype::ADType
prep_grad::PrepGrad
prep_hess::PrepHess
end

function logdensity_impl(
params, model::DynamicPPL.Model, loglikeadj::Real, varinfo::DynamicPPL.AbstractVarInfo
)
Expand All @@ -63,6 +31,31 @@ function subsample_dynamicpplmodel(
return DynamicPPL.Model{Threaded}(model.f, model.args, new_kwargs, model.context)
end

# `model_ref`/`loglikeadj_ref` are mutated in place by `subsample`; the closure
# inside `prep_grad`/`prep_hess` reads through them so the prep stays valid
# across subsampling steps (AbstractPPL bakes the closure into the prep, unlike
# DI's `Constant` which can be rebound at call time).
#
# `model_ref` is typed `Ref{Any}` because `subsample_dynamicpplmodel` returns a
# `DynamicPPL.Model` whose `defaults` NamedTuple type varies with the batch — a
# typed `Ref{<:DynamicPPL.Model}` would throw on reassignment. The tradeoff is a
# dynamic dispatch on each `prob.model_ref[]` read; do not "tighten" it.
struct DynamicPPLModelLogDensityFunction{
Model<:DynamicPPL.Model,
LogLikeAdj<:Real,
VarInfo<:DynamicPPL.AbstractVarInfo,
ADType<:Union{Nothing,ADTypes.AbstractADType},
PrepGrad,
PrepHess,
}
model_ref::Ref{Any}
loglikeadj_ref::Ref{LogLikeAdj}
varinfo::VarInfo
adtype::ADType
prep_grad::PrepGrad
prep_hess::PrepHess
end

function DynamicPPLModelLogDensityFunction(
model::DynamicPPL.Model,
varinfo::DynamicPPL.AbstractVarInfo;
Expand All @@ -80,82 +73,57 @@ function DynamicPPLModelLogDensityFunction(
subsample_dynamicpplmodel(model, batch)
end

params = [val for val in varinfo[:]]
params = collect(varinfo[:])
model_ref = Ref{Any}(model_sub)
adj0 = float(loglikeadj)
loglikeadj_ref = Ref(adj0)
f = params -> logdensity_impl(params, model_ref[], loglikeadj_ref[], varinfo)
cap = adtype_capabilities(typeof(adtype))

prep_grad = if cap >= LogDensityProblems.LogDensityOrder{1}()
DifferentiationInterface.prepare_gradient(
logdensity_impl,
DifferentiationInterface.inner(adtype),
params,
DifferentiationInterface.Constant(model_sub),
DifferentiationInterface.Constant(loglikeadj),
DifferentiationInterface.Constant(varinfo),
)
AbstractPPL.prepare(adtype, f, params)
else
nothing
end
prep_hess = if cap >= LogDensityProblems.LogDensityOrder{2}() && use_hessian
prep_hess = if cap >= LogDensityProblems.LogDensityOrder{1}() && use_hessian
try
DifferentiationInterface.prepare_hessian(
logdensity_impl,
adtype,
params,
DifferentiationInterface.Constant(model_sub),
DifferentiationInterface.Constant(loglikeadj),
DifferentiationInterface.Constant(varinfo),
)
catch
@warn "The selected AD backend has second-order capabilities but `DifferentiationInterface.prepare_hessian` failed. AdvancedVI will treat the model to only have first-order capability."
AbstractPPL.prepare(adtype, f, params; order=2)
catch err
err isa MethodError || rethrow()
@warn "The selected AD backend does not support `AbstractPPL.prepare(...; order=2)`. AdvancedVI will treat the model as first-order only."
nothing
end
else
nothing
end
return DynamicPPLModelLogDensityFunction{
typeof(model),
typeof(loglikeadj),
typeof(adj0),
typeof(varinfo),
typeof(adtype),
typeof(prep_grad),
typeof(prep_hess),
}(
model, loglikeadj, varinfo, adtype, prep_grad, prep_hess
model_ref, loglikeadj_ref, varinfo, adtype, prep_grad, prep_hess
)
end

function LogDensityProblems.logdensity(prob::DynamicPPLModelLogDensityFunction, params)
(; model, loglikeadj, varinfo) = prob
return logdensity_impl(params, model, loglikeadj, varinfo)
return logdensity_impl(params, prob.model_ref[], prob.loglikeadj_ref[], prob.varinfo)
end

function LogDensityProblems.logdensity_and_gradient(
prob::DynamicPPLModelLogDensityFunction, params
)
(; model, adtype, loglikeadj, varinfo, prep_grad) = prob
return DifferentiationInterface.value_and_gradient(
logdensity_impl,
prep_grad,
DifferentiationInterface.inner(adtype),
params,
DifferentiationInterface.Constant(model),
DifferentiationInterface.Constant(loglikeadj),
DifferentiationInterface.Constant(varinfo),
)
val, grad = AbstractPPL.value_and_gradient!!(prob.prep_grad, params)
return val, copy(grad)
end

function LogDensityProblems.logdensity_gradient_and_hessian(
prob::DynamicPPLModelLogDensityFunction, params
)
(; model, adtype, loglikeadj, varinfo, prep_hess) = prob
return DifferentiationInterface.value_gradient_and_hessian(
logdensity_impl,
prep_hess,
adtype,
params,
DifferentiationInterface.Constant(model),
DifferentiationInterface.Constant(loglikeadj),
DifferentiationInterface.Constant(varinfo),
)
val, grad, H = AbstractPPL.value_gradient_and_hessian!!(prob.prep_hess, params)
return val, copy(grad), copy(H)
end

function LogDensityProblems.capabilities(
Expand All @@ -175,24 +143,26 @@ function LogDensityProblems.dimension(prob::DynamicPPLModelLogDensityFunction)
end

function AdvancedVI.subsample(prob::DynamicPPLModelLogDensityFunction, batch)
model = prob.model
model = prob.model_ref[]

if !haskey(model.defaults, :datapoints)
throw(
ArgumentError(
"Subsampling is turned on, but the model does not have have a `datapoints` keyword argument.",
"Subsampling is turned on, but the model does not have a `datapoints` keyword argument.",
),
)
end

n_datapoints = length(model.defaults.datapoints)
batchsize = length(batch)
model_sub = subsample_dynamicpplmodel(model, batch)
loglikeadj = n_datapoints / batchsize
T = eltype(prob.loglikeadj_ref)
loglikeadj = T(n_datapoints) / T(batchsize)

prob.model_ref[] = model_sub
prob.loglikeadj_ref[] = loglikeadj

prob′ = @set prob.model = model_sub
prob′′ = @set prob′.loglikeadj = loglikeadj
return prob′′
return prob
end

end
24 changes: 12 additions & 12 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using LogDensityProblems

using ADTypes
using DiffResults
using DifferentiationInterface
using AbstractPPL: AbstractPPL
using ChainRulesCore: ChainRulesCore

using FillArrays
Expand Down Expand Up @@ -45,31 +45,31 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif
- `f`: Function subject to differentiation.
- `x`: The point to evaluate the gradient.
- `aux`: Auxiliary input passed to `f`.
- `prep`: Output of `DifferentiationInterface.prepare_gradient`.
- `prep`: Output of `_prepare_gradient`.
- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value.
"""
function _value_and_gradient!(
f, out::DiffResults.MutableDiffResult, ad::ADTypes.AbstractADType, x, aux
)
grad_buf = DiffResults.gradient(out)
y, _ = DifferentiationInterface.value_and_gradient!(f, grad_buf, ad, x, Constant(aux))
DiffResults.value!(out, y)
prepared = AbstractPPL.prepare(ad, Base.Fix2(f, aux), x)
val, grad = AbstractPPL.value_and_gradient!!(prepared, x)
DiffResults.value!(out, val)
copyto!(DiffResults.gradient(out), grad)
return out
end

function _value_and_gradient!(
f, out::DiffResults.MutableDiffResult, prep, ad::ADTypes.AbstractADType, x, aux
)
grad_buf = DiffResults.gradient(out)
y, _ = DifferentiationInterface.value_and_gradient!(
f, grad_buf, prep, ad, x, Constant(aux)
)
DiffResults.value!(out, y)
prep.evaluator.context[1][] = aux
val, grad = AbstractPPL.value_and_gradient!!(prep, x)
DiffResults.value!(out, val)
copyto!(DiffResults.gradient(out), grad)
return out
end

"""
_prepare_gradient!(f, ad, x, aux)
_prepare_gradient(f, ad, x, aux)

Prepare AD backend for taking gradients of a function `f` at `x` using the automatic differentiation backend `ad`.

Expand All @@ -80,7 +80,7 @@ Prepare AD backend for taking gradients of a function `f` at `x` using the autom
- `aux`: Auxiliary input passed to `f`.
"""
function _prepare_gradient(f, ad::ADTypes.AbstractADType, x, aux)
return DifferentiationInterface.prepare_gradient(f, ad, x, Constant(aux))
return AbstractPPL.prepare(ad, (x, aref) -> f(x, aref[]), x; context=(Ref(aux),))
end

"""
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/subsampledobjective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function init(
sub_st = init(rng, subsampling)

# This is necessary to ensure that `init` sees the type "conditioned" on a minibatch
# when calling `DifferentiationInterface.prepare_*` inside it.
# so that any prepared AD evaluator inside it sees the correct batch-subsampled type.
batch, _, _ = step(rng, subsampling, sub_st, true)
prob_sub = subsample(prob, batch)
q_init_sub = subsample(q_init, batch)
Expand Down
2 changes: 1 addition & 1 deletion src/reshuffling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function step(
# Ignore the trailing batch if its size is smaller than `batchsize`.
# This should only be used when estimating gradients during optimization.
# This is necessary to ensure that all batches have the same size.
# Otherwise, `DifferentiationInterface.prepare_*` behaves incorrectly.
# Otherwise, prepared AD evaluators may see inconsistent batch sizes.
(sub_step, batch), iterator = Iterators.peel(iterator)
end
epoch = epoch + 1
Expand Down
9 changes: 8 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -22,12 +25,16 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[sources]
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "adproblems-interface"}

[compat]
ADTypes = "0.2.1, 1"
Bijectors = "0.16"
DiffResults = "1"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25.111"
DynamicPPL = "0.40, 0.41"
DynamicPPL = "0.42"
Enzyme = "0.13, 0.14, 0.15"
FillArrays = "1.6.1"
ForwardDiff = "0.10.36, 1"
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Random, StableRNGs
using Statistics
using StatsBase

using DifferentiationInterface
using AdvancedVI

const PROGRESS = haskey(ENV, "PROGRESS")
Expand Down
Loading