Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probabilistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.15"
version = "0.15.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
52 changes: 52 additions & 0 deletions docs/src/evaluators.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,57 @@ library invokes the inner callable many times with same-length dual arrays
derived from a single user-supplied `x`; re-validating on each invocation
would be redundant work in the hot path.

## Hessian (`order=2`)

Pass `order=2` to `prepare` to build a Hessian-capable evaluator. The
returned object answers `value_gradient_and_hessian!!`, which returns
`(value, gradient, hessian)` in a single call. `order=2` requires
`problem` to be scalar-valued; a vector-valued probe throws at preparation
time.

```julia
using AbstractPPL: prepare, value_gradient_and_hessian!!
using ADTypes: AutoForwardDiff
using ForwardDiff, DifferentiationInterface

quadratic(x) = sum(abs2, x)
prepared = prepare(AutoForwardDiff(), quadratic, zeros(3); order=2)
val, grad, hess = value_gradient_and_hessian!!(prepared, [1.0, 2.0, 3.0])
# val == 14.0
# grad == [2.0, 4.0, 6.0]
# hess == [2 0 0; 0 2 0; 0 0 2]
```

Both `context=` and `check_dims=` apply to `order=2` preps with the same
semantics as for `order=1`. The `!!` aliasing contract also extends: the
returned gradient and Hessian may alias internal cache buffers of
`prepared`, so copy before retaining them past the next call. NamedTuple
inputs are not supported at `order=2`.

For DifferentiationInterface, `adtype` can be either a single backend
(letting DI pick its own Hessian strategy) or a
[`DifferentiationInterface.SecondOrder(outer, inner)`](https://juliadiff.org/DifferentiationInterface.jl/stable/api/#DifferentiationInterface.SecondOrder)
composition that selects the outer differentiator and the inner gradient
backend independently — typically forward-over-reverse:

```julia
using DifferentiationInterface: SecondOrder
using ADTypes: AutoForwardDiff, AutoReverseDiff

adtype = SecondOrder(AutoForwardDiff(), AutoReverseDiff())
prepared = prepare(adtype, quadratic, zeros(3); order=2)
```

`SecondOrder <: AbstractADType`, so the same `prepare(adtype, problem, x; order=2)`
entry handles it.

Calling `value_gradient_and_hessian!!` on an `order=1` prep throws an
`ArgumentError` — re-prepare with `order=2` instead. The reverse is allowed:
`value_and_gradient!!` on an `order=2` prep returns `(value, gradient)`
without paying the Hessian cost, since `prepare` builds a dedicated
gradient prep alongside the Hessian one. `value_and_jacobian!!` is rejected
because `order=2` requires a scalar-valued problem.

## Constant context arguments

When the underlying callable naturally takes the form `f(x, context...)` —
Expand Down Expand Up @@ -177,4 +228,5 @@ p([1.0, 2.0, 3.0])
AbstractPPL.prepare
AbstractPPL.value_and_gradient!!
AbstractPPL.value_and_jacobian!!
AbstractPPL.value_gradient_and_hessian!!
```
253 changes: 193 additions & 60 deletions ext/AbstractPPLDifferentiationInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,55 @@ using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_
using ADTypes: AbstractADType, AutoReverseDiff
using DifferentiationInterface: DifferentiationInterface as DI

# AD target used by both `DICache` modes. `Vararg{Any,N}` with a free `N`
# forces specialization on the trailing arity (a bare `Vararg{Any}` would
# skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the
# constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on
# the closure path — empty `ctx` then makes the splat a no-op.
# AD target used by every DI cache. `Vararg{Any,N}` with a free `N` forces
# specialization on the trailing arity (a bare `Vararg{Any}` would skip it).
# DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the constants path,
# and as `_call_evaluator(x, evaluator)` (via `Fix2`) on the closure path —
# empty `ctx` then makes the splat a no-op.
@inline _call_evaluator(x, f::F, ctx::Vararg{Any,N}) where {F,N} = f(x, ctx...)

# `Mode` tags the cache shape:
# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure,
# the AD call passes **0** `DI.Constant`s.
# * `N::Int` — constants path: `N == length(evaluator.context)`, the
# AD call passes **N + 1** `DI.Constant`s (`f` plus the
# `N` context values).
# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*`
# at compile time without a runtime branch.
struct DICache{Mode,F,GP,JP}
# `Mode` tags the call shape:
# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure; the
# AD call passes **0** `DI.Constant`s.
# * `N::Int` — constants path: `N == length(evaluator.context)`; the AD
# call passes **N + 1** `DI.Constant`s (`f` plus the `N`
# context values).
# Encoding `Mode` in each cache type resolves the closure-vs-constants dispatch
# in `_di_value_and_*` at compile time without a runtime branch.

# `Nothing` in the prep slot flags the empty-input cache (DI prep paths fail
# on length-0 input, e.g. ForwardDiff `BoundsError`). Hot paths dispatch on the
# `Nothing` parameter to short-circuit before any DI call. Same convention for
# `DIJacobianCache` and `DIHessianCache` below.
struct DIGradientCache{Mode,F,GP}
target::F
gradient_prep::GP
function DIGradientCache(target::F, gp::GP, ::Val{Mode}) where {Mode,F,GP}
return new{Mode,F,GP}(target, gp)
end
end

struct DIJacobianCache{Mode,F,JP}
target::F
jacobian_prep::JP
function DICache{Mode}(target::F, gp::GP, jp::JP) where {Mode,F,GP,JP}
return new{Mode,F,GP,JP}(target, gp, jp)
function DIJacobianCache(target::F, jp::JP, ::Val{Mode}) where {Mode,F,JP}
return new{Mode,F,JP}(target, jp)
end
end

# Order=2 (scalar-output). `grad_buf` / `hess_buf` are caller-owned output
# buffers handed to `DI.value_gradient_and_hessian!`; the returned arrays alias
# them (`!!` contract).
struct DIHessianCache{Mode,F,GP,HP,G,H}
target::F
gradient_prep::GP
hessian_prep::HP
grad_buf::G
hess_buf::H
function DIHessianCache(
target::F, gp::GP, hp::HP, g::G, h::H, ::Val{Mode}
) where {Mode,F,GP,HP,G,H}
return new{Mode,F,GP,HP,G,H}(target, gp, hp, g, h)
end
end

Expand All @@ -34,69 +62,107 @@ end
# target and call DI without constants. Context (if any) is captured inside
# the evaluator closure rather than lowered out — the lowered path would also
# require a closure here, so the wrapper cost is unavoidable for compiled tapes.
function _prepare_di(prep::F, adtype::AutoReverseDiff{true}, x, evaluator) where {F}
target = Base.Fix2(_call_evaluator, evaluator)
return target, prep(target, adtype, x), Val(:closure)
#
# `_di_call_shape` returns `(target, mode, constants)`. For the closure path
# `constants == ()` and the splat at every prep/call site collapses to nothing,
# letting prep and call sites share one shape regardless of mode.
function _di_call_shape(::AutoReverseDiff{true}, evaluator)
return Base.Fix2(_call_evaluator, evaluator), Val(:closure), ()
end

function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F}
constants = (DI.Constant(evaluator.f), map(DI.Constant, evaluator.context)...)
return (
_call_evaluator,
prep(_call_evaluator, adtype, x, constants...),
Val(length(evaluator.context)),
)
function _di_call_shape(::AbstractADType, evaluator)
return _call_evaluator,
Val(length(evaluator.context)),
(DI.Constant(evaluator.f), map(DI.Constant, evaluator.context)...)
end

@inline _wrap_cache(target, gp, jp, ::Val{Mode}) where {Mode} =
DICache{Mode}(target, gp, jp)
# `SecondOrder` doesn't define gradient prep; per DI's contract the inner
# adtype is the one used for the first derivative.
@inline _gradient_adtype(adtype::AbstractADType) = adtype
@inline _gradient_adtype(adtype::DI.SecondOrder) = DI.inner(adtype)

function _prepare_di(prep::F, adtype, x, evaluator) where {F}
target, mode, constants = _di_call_shape(adtype, evaluator)
return target, prep(target, adtype, x, constants...), mode
end

function AbstractPPL.prepare(
adtype::AbstractADType,
problem,
x::AbstractVector{<:Real};
check_dims::Bool=true,
context::Tuple=(),
order::Int=1,
)
Evaluators._validate_ad_order(order)
evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator
arity = _ad_output_arity(evaluator(x))
mode_empty = Val(length(context))
if order == 2
arity === :scalar || Evaluators._throw_hessian_needs_scalar()
if length(x) == 0
cache = DIHessianCache(
_call_evaluator, nothing, nothing, nothing, nothing, mode_empty
)
return Prepared(adtype, evaluator, cache, Val(2))
end
# Build both gradient and Hessian preps against the same target so
# `value_and_gradient!!` on the order=2 prep skips the O(n²) Hessian
# cost. Sharing the target matters for compiled-tape ReverseDiff —
# two `Fix2` instances may not be interchangeable in DI.
target, mode, constants = _di_call_shape(adtype, evaluator)
gradient_prep = DI.prepare_gradient(
target, _gradient_adtype(adtype), x, constants...
)
hessian_prep = DI.prepare_hessian(target, adtype, x, constants...)
# Buffers pre-allocated from `x`: hot path is zero-allocation on the
# gradient/Hessian outputs, returned arrays alias these slots.
cache = DIHessianCache(
target,
gradient_prep,
hessian_prep,
similar(x),
similar(x, length(x), length(x)),
mode,
)
return Prepared(adtype, evaluator, cache, Val(2))
end
if length(x) == 0
# DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`).
# `Val(0)` is an arity sentinel for the `gradient_prep === nothing`
# check below; the AD entry short-circuits before any DI call.
gp, jp = arity === :scalar ? (Val(0), nothing) : (nothing, Val(0))
cache = _wrap_cache(_call_evaluator, gp, jp, Val(length(context)))
cache = if arity === :scalar
DIGradientCache(_call_evaluator, nothing, mode_empty)
else
DIJacobianCache(_call_evaluator, nothing, mode_empty)
end
return Prepared(adtype, evaluator, cache)
end
if arity === :scalar
target, gradient_prep, mode = _prepare_di(DI.prepare_gradient, adtype, x, evaluator)
return Prepared(
adtype, evaluator, _wrap_cache(target, gradient_prep, nothing, mode)
)
return Prepared(adtype, evaluator, DIGradientCache(target, gradient_prep, mode))
end
target, jacobian_prep, mode = _prepare_di(DI.prepare_jacobian, adtype, x, evaluator)
return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, mode))
return Prepared(adtype, evaluator, DIJacobianCache(target, jacobian_prep, mode))
end

# Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile
# time. The unconstrained method matches every non-`:closure` `Mode` (i.e.
# any `Int N`); `:closure` is strictly more specific and wins for compiled
# tapes. On the constants path we always pass `DI.Constant(eval.f)` plus the
# `N` context constants — `N == 0` collapses the `map` splat to nothing.
@inline _di_value_and_gradient(c::DICache{:closure}, ad, x, _) =
DI.value_and_gradient(c.target, c.gradient_prep, ad, x)
@inline _di_value_and_gradient(c::DICache, ad, x, eval) = DI.value_and_gradient(
# Hot-path dispatch is by cache type + `Mode` (closure vs constants), both
# resolved at compile time. On the constants path we always pass
# `DI.Constant(eval.f)` plus the `N` context constants — `N == 0` collapses
# the `map` splat to nothing.
const _GradientCapable = Union{DIGradientCache,DIHessianCache}

@inline _di_value_and_gradient(
c::Union{DIGradientCache{:closure},DIHessianCache{:closure}}, ad, x, _
) = DI.value_and_gradient(c.target, c.gradient_prep, _gradient_adtype(ad), x)
@inline _di_value_and_gradient(c::_GradientCapable, ad, x, eval) = DI.value_and_gradient(
c.target,
c.gradient_prep,
ad,
_gradient_adtype(ad),
x,
DI.Constant(eval.f),
map(DI.Constant, eval.context)...,
)

@inline _di_value_and_jacobian(c::DICache{:closure}, ad, x, _) =
@inline _di_value_and_jacobian(c::DIJacobianCache{:closure}, ad, x, _) =
DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x)
@inline _di_value_and_jacobian(c::DICache, ad, x, eval) = DI.value_and_jacobian(
@inline _di_value_and_jacobian(c::DIJacobianCache, ad, x, eval) = DI.value_and_jacobian(
c.target,
c.jacobian_prep,
ad,
Expand All @@ -105,27 +171,94 @@ end
map(DI.Constant, eval.context)...,
)

@inline _di_value_gradient_and_hessian(c::DIHessianCache{:closure}, ad, x, _) =
DI.value_gradient_and_hessian!(c.target, c.grad_buf, c.hess_buf, c.hessian_prep, ad, x)
@inline _di_value_gradient_and_hessian(c::DIHessianCache, ad, x, eval) =
DI.value_gradient_and_hessian!(
c.target,
c.grad_buf,
c.hess_buf,
c.hessian_prep,
ad,
x,
DI.Constant(eval.f),
map(DI.Constant, eval.context)...,
)

# `value_and_gradient!!`: works on both `DIGradientCache` (order=1 scalar) and
# `DIHessianCache` (order=2). Empty-input caches carry `gradient_prep::Nothing`
# and dispatch to the short-circuit method below; vector-output caches reject.
@inline function AbstractPPL.value_and_gradient!!(
p::Prepared{
<:AbstractADType,
<:VectorEvaluator,
<:Union{DIGradientCache{<:Any,<:Any,Nothing},DIHessianCache{<:Any,<:Any,Nothing}},
},
x::AbstractVector{T},
) where {T<:Real}
Evaluators._check_ad_input(p.evaluator, x)
return (p.evaluator(x), T[])
end

@inline function AbstractPPL.value_and_gradient!!(
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:_GradientCapable}, x::AbstractVector{T}
) where {T<:Real}
p.cache.gradient_prep === nothing && Evaluators._throw_gradient_needs_scalar()
Evaluators._check_ad_input(p.evaluator, x)
# Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff
# `BoundsError`); typed `T[]` matches the caller's element type.
length(x) == 0 && return (p.evaluator(x), T[])
return _di_value_and_gradient(p.cache, p.adtype, x, p.evaluator)
end

@inline function AbstractPPL.value_and_gradient!!(
::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIJacobianCache},
::AbstractVector{<:Real},
)
return Evaluators._throw_gradient_needs_scalar()
end

@inline function AbstractPPL.value_and_jacobian!!(
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIJacobianCache{<:Any,<:Any,Nothing}},
x::AbstractVector{T},
) where {T<:Real}
Evaluators._check_ad_input(p.evaluator, x)
val = p.evaluator(x)
return (val, similar(x, length(val), 0))
end

@inline function AbstractPPL.value_and_jacobian!!(
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIJacobianCache}, x::AbstractVector{T}
) where {T<:Real}
p.cache.jacobian_prep === nothing && Evaluators._throw_jacobian_needs_vector()
Evaluators._check_ad_input(p.evaluator, x)
if length(x) == 0
val = p.evaluator(x)
return (val, similar(x, length(val), 0))
end
return _di_value_and_jacobian(p.cache, p.adtype, x, p.evaluator)
end

@inline function AbstractPPL.value_and_jacobian!!(
::Prepared{<:AbstractADType,<:VectorEvaluator,<:_GradientCapable},
::AbstractVector{<:Real},
)
return Evaluators._throw_jacobian_needs_vector()
end

@inline function AbstractPPL.value_gradient_and_hessian!!(
p::Prepared{
<:AbstractADType,<:VectorEvaluator,<:DIHessianCache{<:Any,<:Any,<:Any,Nothing}
},
x::AbstractVector{T},
) where {T<:Real}
Evaluators._check_ad_input(p.evaluator, x)
return (p.evaluator(x), T[], similar(x, 0, 0))
end

@inline function AbstractPPL.value_gradient_and_hessian!!(
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIHessianCache}, x::AbstractVector{T}
) where {T<:Real}
Evaluators._check_ad_input(p.evaluator, x)
return _di_value_gradient_and_hessian(p.cache, p.adtype, x, p.evaluator)
end

@inline function AbstractPPL.value_gradient_and_hessian!!(
::Prepared{<:AbstractADType,<:VectorEvaluator,<:Union{DIGradientCache,DIJacobianCache}},
::AbstractVector{<:Real},
)
return Evaluators._throw_hessian_needs_order_2_prep()
end

end # module
Loading
Loading