diff --git a/Project.toml b/Project.toml index 0e7ce54d..e0b3fa4c 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probabilistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.15" +version = "0.15.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index 9dd6f8af..7e1b1282 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -138,6 +138,57 @@ library invokes the inner callable many times with same-length dual arrays derived from a single user-supplied `x`; re-validating on each invocation would be redundant work in the hot path. +## Hessian (`order=2`) + +Pass `order=2` to `prepare` to build a Hessian-capable evaluator. The +returned object answers `value_gradient_and_hessian!!`, which returns +`(value, gradient, hessian)` in a single call. `order=2` requires +`problem` to be scalar-valued; a vector-valued probe throws at preparation +time. + +```julia +using AbstractPPL: prepare, value_gradient_and_hessian!! +using ADTypes: AutoForwardDiff +using ForwardDiff, DifferentiationInterface + +quadratic(x) = sum(abs2, x) +prepared = prepare(AutoForwardDiff(), quadratic, zeros(3); order=2) +val, grad, hess = value_gradient_and_hessian!!(prepared, [1.0, 2.0, 3.0]) +# val == 14.0 +# grad == [2.0, 4.0, 6.0] +# hess == [2 0 0; 0 2 0; 0 0 2] +``` + +Both `context=` and `check_dims=` apply to `order=2` preps with the same +semantics as for `order=1`. The `!!` aliasing contract also extends: the +returned gradient and Hessian may alias internal cache buffers of +`prepared`, so copy before retaining them past the next call. NamedTuple +inputs are not supported at `order=2`. + +For DifferentiationInterface, `adtype` can be either a single backend +(letting DI pick its own Hessian strategy) or a +[`DifferentiationInterface.SecondOrder(outer, inner)`](https://juliadiff.org/DifferentiationInterface.jl/stable/api/#DifferentiationInterface.SecondOrder) +composition that selects the outer differentiator and the inner gradient +backend independently — typically forward-over-reverse: + +```julia +using DifferentiationInterface: SecondOrder +using ADTypes: AutoForwardDiff, AutoReverseDiff + +adtype = SecondOrder(AutoForwardDiff(), AutoReverseDiff()) +prepared = prepare(adtype, quadratic, zeros(3); order=2) +``` + +`SecondOrder <: AbstractADType`, so the same `prepare(adtype, problem, x; order=2)` +entry handles it. + +Calling `value_gradient_and_hessian!!` on an `order=1` prep throws an +`ArgumentError` — re-prepare with `order=2` instead. The reverse is allowed: +`value_and_gradient!!` on an `order=2` prep returns `(value, gradient)` +without paying the Hessian cost, since `prepare` builds a dedicated +gradient prep alongside the Hessian one. `value_and_jacobian!!` is rejected +because `order=2` requires a scalar-valued problem. + ## Constant context arguments When the underlying callable naturally takes the form `f(x, context...)` — @@ -177,4 +228,5 @@ p([1.0, 2.0, 3.0]) AbstractPPL.prepare AbstractPPL.value_and_gradient!! AbstractPPL.value_and_jacobian!! +AbstractPPL.value_gradient_and_hessian!! ``` diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 1f4cfe8b..4ba7c344 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -5,27 +5,55 @@ using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_ using ADTypes: AbstractADType, AutoReverseDiff using DifferentiationInterface: DifferentiationInterface as DI -# AD target used by both `DICache` modes. `Vararg{Any,N}` with a free `N` -# forces specialization on the trailing arity (a bare `Vararg{Any}` would -# skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the -# constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on -# the closure path — empty `ctx` then makes the splat a no-op. +# AD target used by every DI cache. `Vararg{Any,N}` with a free `N` forces +# specialization on the trailing arity (a bare `Vararg{Any}` would skip it). +# DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the constants path, +# and as `_call_evaluator(x, evaluator)` (via `Fix2`) on the closure path — +# empty `ctx` then makes the splat a no-op. @inline _call_evaluator(x, f::F, ctx::Vararg{Any,N}) where {F,N} = f(x, ctx...) -# `Mode` tags the cache shape: -# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure, -# the AD call passes **0** `DI.Constant`s. -# * `N::Int` — constants path: `N == length(evaluator.context)`, the -# AD call passes **N + 1** `DI.Constant`s (`f` plus the -# `N` context values). -# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*` -# at compile time without a runtime branch. -struct DICache{Mode,F,GP,JP} +# `Mode` tags the call shape: +# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure; the +# AD call passes **0** `DI.Constant`s. +# * `N::Int` — constants path: `N == length(evaluator.context)`; the AD +# call passes **N + 1** `DI.Constant`s (`f` plus the `N` +# context values). +# Encoding `Mode` in each cache type resolves the closure-vs-constants dispatch +# in `_di_value_and_*` at compile time without a runtime branch. + +# `Nothing` in the prep slot flags the empty-input cache (DI prep paths fail +# on length-0 input, e.g. ForwardDiff `BoundsError`). Hot paths dispatch on the +# `Nothing` parameter to short-circuit before any DI call. Same convention for +# `DIJacobianCache` and `DIHessianCache` below. +struct DIGradientCache{Mode,F,GP} target::F gradient_prep::GP + function DIGradientCache(target::F, gp::GP, ::Val{Mode}) where {Mode,F,GP} + return new{Mode,F,GP}(target, gp) + end +end + +struct DIJacobianCache{Mode,F,JP} + target::F jacobian_prep::JP - function DICache{Mode}(target::F, gp::GP, jp::JP) where {Mode,F,GP,JP} - return new{Mode,F,GP,JP}(target, gp, jp) + function DIJacobianCache(target::F, jp::JP, ::Val{Mode}) where {Mode,F,JP} + return new{Mode,F,JP}(target, jp) + end +end + +# Order=2 (scalar-output). `grad_buf` / `hess_buf` are caller-owned output +# buffers handed to `DI.value_gradient_and_hessian!`; the returned arrays alias +# them (`!!` contract). +struct DIHessianCache{Mode,F,GP,HP,G,H} + target::F + gradient_prep::GP + hessian_prep::HP + grad_buf::G + hess_buf::H + function DIHessianCache( + target::F, gp::GP, hp::HP, g::G, h::H, ::Val{Mode} + ) where {Mode,F,GP,HP,G,H} + return new{Mode,F,GP,HP,G,H}(target, gp, hp, g, h) end end @@ -34,22 +62,28 @@ end # target and call DI without constants. Context (if any) is captured inside # the evaluator closure rather than lowered out — the lowered path would also # require a closure here, so the wrapper cost is unavoidable for compiled tapes. -function _prepare_di(prep::F, adtype::AutoReverseDiff{true}, x, evaluator) where {F} - target = Base.Fix2(_call_evaluator, evaluator) - return target, prep(target, adtype, x), Val(:closure) +# +# `_di_call_shape` returns `(target, mode, constants)`. For the closure path +# `constants == ()` and the splat at every prep/call site collapses to nothing, +# letting prep and call sites share one shape regardless of mode. +function _di_call_shape(::AutoReverseDiff{true}, evaluator) + return Base.Fix2(_call_evaluator, evaluator), Val(:closure), () end - -function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F} - constants = (DI.Constant(evaluator.f), map(DI.Constant, evaluator.context)...) - return ( - _call_evaluator, - prep(_call_evaluator, adtype, x, constants...), - Val(length(evaluator.context)), - ) +function _di_call_shape(::AbstractADType, evaluator) + return _call_evaluator, + Val(length(evaluator.context)), + (DI.Constant(evaluator.f), map(DI.Constant, evaluator.context)...) end -@inline _wrap_cache(target, gp, jp, ::Val{Mode}) where {Mode} = - DICache{Mode}(target, gp, jp) +# `SecondOrder` doesn't define gradient prep; per DI's contract the inner +# adtype is the one used for the first derivative. +@inline _gradient_adtype(adtype::AbstractADType) = adtype +@inline _gradient_adtype(adtype::DI.SecondOrder) = DI.inner(adtype) + +function _prepare_di(prep::F, adtype, x, evaluator) where {F} + target, mode, constants = _di_call_shape(adtype, evaluator) + return target, prep(target, adtype, x, constants...), mode +end function AbstractPPL.prepare( adtype::AbstractADType, @@ -57,46 +91,78 @@ function AbstractPPL.prepare( x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=(), + order::Int=1, ) + Evaluators._validate_ad_order(order) evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) + mode_empty = Val(length(context)) + if order == 2 + arity === :scalar || Evaluators._throw_hessian_needs_scalar() + if length(x) == 0 + cache = DIHessianCache( + _call_evaluator, nothing, nothing, nothing, nothing, mode_empty + ) + return Prepared(adtype, evaluator, cache, Val(2)) + end + # Build both gradient and Hessian preps against the same target so + # `value_and_gradient!!` on the order=2 prep skips the O(n²) Hessian + # cost. Sharing the target matters for compiled-tape ReverseDiff — + # two `Fix2` instances may not be interchangeable in DI. + target, mode, constants = _di_call_shape(adtype, evaluator) + gradient_prep = DI.prepare_gradient( + target, _gradient_adtype(adtype), x, constants... + ) + hessian_prep = DI.prepare_hessian(target, adtype, x, constants...) + # Buffers pre-allocated from `x`: hot path is zero-allocation on the + # gradient/Hessian outputs, returned arrays alias these slots. + cache = DIHessianCache( + target, + gradient_prep, + hessian_prep, + similar(x), + similar(x, length(x), length(x)), + mode, + ) + return Prepared(adtype, evaluator, cache, Val(2)) + end if length(x) == 0 - # DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`). - # `Val(0)` is an arity sentinel for the `gradient_prep === nothing` - # check below; the AD entry short-circuits before any DI call. - gp, jp = arity === :scalar ? (Val(0), nothing) : (nothing, Val(0)) - cache = _wrap_cache(_call_evaluator, gp, jp, Val(length(context))) + cache = if arity === :scalar + DIGradientCache(_call_evaluator, nothing, mode_empty) + else + DIJacobianCache(_call_evaluator, nothing, mode_empty) + end return Prepared(adtype, evaluator, cache) end if arity === :scalar target, gradient_prep, mode = _prepare_di(DI.prepare_gradient, adtype, x, evaluator) - return Prepared( - adtype, evaluator, _wrap_cache(target, gradient_prep, nothing, mode) - ) + return Prepared(adtype, evaluator, DIGradientCache(target, gradient_prep, mode)) end target, jacobian_prep, mode = _prepare_di(DI.prepare_jacobian, adtype, x, evaluator) - return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, mode)) + return Prepared(adtype, evaluator, DIJacobianCache(target, jacobian_prep, mode)) end -# Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile -# time. The unconstrained method matches every non-`:closure` `Mode` (i.e. -# any `Int N`); `:closure` is strictly more specific and wins for compiled -# tapes. On the constants path we always pass `DI.Constant(eval.f)` plus the -# `N` context constants — `N == 0` collapses the `map` splat to nothing. -@inline _di_value_and_gradient(c::DICache{:closure}, ad, x, _) = - DI.value_and_gradient(c.target, c.gradient_prep, ad, x) -@inline _di_value_and_gradient(c::DICache, ad, x, eval) = DI.value_and_gradient( +# Hot-path dispatch is by cache type + `Mode` (closure vs constants), both +# resolved at compile time. On the constants path we always pass +# `DI.Constant(eval.f)` plus the `N` context constants — `N == 0` collapses +# the `map` splat to nothing. +const _GradientCapable = Union{DIGradientCache,DIHessianCache} + +@inline _di_value_and_gradient( + c::Union{DIGradientCache{:closure},DIHessianCache{:closure}}, ad, x, _ +) = DI.value_and_gradient(c.target, c.gradient_prep, _gradient_adtype(ad), x) +@inline _di_value_and_gradient(c::_GradientCapable, ad, x, eval) = DI.value_and_gradient( c.target, c.gradient_prep, - ad, + _gradient_adtype(ad), x, DI.Constant(eval.f), map(DI.Constant, eval.context)..., ) -@inline _di_value_and_jacobian(c::DICache{:closure}, ad, x, _) = +@inline _di_value_and_jacobian(c::DIJacobianCache{:closure}, ad, x, _) = DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x) -@inline _di_value_and_jacobian(c::DICache, ad, x, eval) = DI.value_and_jacobian( +@inline _di_value_and_jacobian(c::DIJacobianCache, ad, x, eval) = DI.value_and_jacobian( c.target, c.jacobian_prep, ad, @@ -105,27 +171,94 @@ end map(DI.Constant, eval.context)..., ) +@inline _di_value_gradient_and_hessian(c::DIHessianCache{:closure}, ad, x, _) = + DI.value_gradient_and_hessian!(c.target, c.grad_buf, c.hess_buf, c.hessian_prep, ad, x) +@inline _di_value_gradient_and_hessian(c::DIHessianCache, ad, x, eval) = + DI.value_gradient_and_hessian!( + c.target, + c.grad_buf, + c.hess_buf, + c.hessian_prep, + ad, + x, + DI.Constant(eval.f), + map(DI.Constant, eval.context)..., + ) + +# `value_and_gradient!!`: works on both `DIGradientCache` (order=1 scalar) and +# `DIHessianCache` (order=2). Empty-input caches carry `gradient_prep::Nothing` +# and dispatch to the short-circuit method below; vector-output caches reject. +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{ + <:AbstractADType, + <:VectorEvaluator, + <:Union{DIGradientCache{<:Any,<:Any,Nothing},DIHessianCache{<:Any,<:Any,Nothing}}, + }, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + return (p.evaluator(x), T[]) +end + @inline function AbstractPPL.value_and_gradient!!( - p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:_GradientCapable}, x::AbstractVector{T} ) where {T<:Real} - p.cache.gradient_prep === nothing && Evaluators._throw_gradient_needs_scalar() Evaluators._check_ad_input(p.evaluator, x) - # Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff - # `BoundsError`); typed `T[]` matches the caller's element type. - length(x) == 0 && return (p.evaluator(x), T[]) return _di_value_and_gradient(p.cache, p.adtype, x, p.evaluator) end +@inline function AbstractPPL.value_and_gradient!!( + ::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIJacobianCache}, + ::AbstractVector{<:Real}, +) + return Evaluators._throw_gradient_needs_scalar() +end + @inline function AbstractPPL.value_and_jacobian!!( - p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIJacobianCache{<:Any,<:Any,Nothing}}, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + val = p.evaluator(x) + return (val, similar(x, length(val), 0)) +end + +@inline function AbstractPPL.value_and_jacobian!!( + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIJacobianCache}, x::AbstractVector{T} ) where {T<:Real} - p.cache.jacobian_prep === nothing && Evaluators._throw_jacobian_needs_vector() Evaluators._check_ad_input(p.evaluator, x) - if length(x) == 0 - val = p.evaluator(x) - return (val, similar(x, length(val), 0)) - end return _di_value_and_jacobian(p.cache, p.adtype, x, p.evaluator) end +@inline function AbstractPPL.value_and_jacobian!!( + ::Prepared{<:AbstractADType,<:VectorEvaluator,<:_GradientCapable}, + ::AbstractVector{<:Real}, +) + return Evaluators._throw_jacobian_needs_vector() +end + +@inline function AbstractPPL.value_gradient_and_hessian!!( + p::Prepared{ + <:AbstractADType,<:VectorEvaluator,<:DIHessianCache{<:Any,<:Any,<:Any,Nothing} + }, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + return (p.evaluator(x), T[], similar(x, 0, 0)) +end + +@inline function AbstractPPL.value_gradient_and_hessian!!( + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIHessianCache}, x::AbstractVector{T} +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + return _di_value_gradient_and_hessian(p.cache, p.adtype, x, p.evaluator) +end + +@inline function AbstractPPL.value_gradient_and_hessian!!( + ::Prepared{<:AbstractADType,<:VectorEvaluator,<:Union{DIGradientCache,DIJacobianCache}}, + ::AbstractVector{<:Real}, +) + return Evaluators._throw_hessian_needs_order_2_prep() +end + end # module diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index b07af8ae..3ff32313 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -8,33 +8,53 @@ using Mooncake: Mooncake const _MooncakeAD = Union{AutoMooncake,AutoMooncakeForward} -# `NamedTupleEvaluator` is the callable on the NamedTuple path; `NoTangent` -# stops Mooncake from deriving a `Tangent{NamedTuple{...}}` for its fields -# on every backward pass. The `VectorEvaluator` override is a defensive -# guard — vector preps no longer pass the evaluator wrapper to Mooncake. +# `NoTangent` stops Mooncake from deriving a `Tangent{...}` for the evaluator +# wrapper's fields on each backward pass. Load-bearing for both: +# * `NamedTupleEvaluator` — passed directly to Mooncake on the NamedTuple +# gradient path. +# * `VectorEvaluator` — wrapped by the order=2 path (Mooncake's Hessian +# API accepts only `AbstractVector` arguments, so +# context is closed over via `VectorEvaluator{false}`). Mooncake.tangent_type(::Type{<:VectorEvaluator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{<:NamedTupleEvaluator}) = Mooncake.NoTangent # Type parameters: # -# * `A::Symbol` — output arity, `:scalar` or `:vector`. Drives the -# gradient/jacobian dispatch and the arity-mismatch errors. +# * `A::Symbol` — `:scalar` / `:vector` for order=1 (output arity), `:hessian` +# for order=2. Drives every dispatch decision below. +# * `Target` — `Nothing` for order=1 (Mooncake's gradient/Jacobian API +# takes `f` and `context` as separate args). For `:hessian`, +# a `VectorEvaluator{false}` that closes over `f` and +# `context`, since Mooncake's Hessian API accepts only +# `AbstractVector` arguments. The evaluator's `NoTangent` +# tangent type prevents differentiation of its fields. # * `C` — the underlying Mooncake cache, or `Nothing` for the # empty-input shortcut. -struct MooncakeCache{A,C} +# * `G` — gradient cache populated only at order=2 so the order=1 +# `value_and_gradient!!` entry on a Hessian prep skips the +# Hessian work. `Nothing` for every order=1 path. +struct MooncakeCache{A,Target,C,G} + target::Target cache::C + gradient_cache::G + function MooncakeCache{A}( + target::Target, cache::C, gradient_cache::G=nothing + ) where {A,Target,C,G} + return new{A,Target,C,G}(target, cache, gradient_cache) + end end -MooncakeCache{A}(cache::C) where {A,C} = MooncakeCache{A,C}(cache) +MooncakeCache{A}(cache) where {A} = MooncakeCache{A}(nothing, cache) _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config -# NamedTuple-path helper: Mooncake exposes separate `prepare_*_cache` -# entries per AD mode but the call shape (target + values) is the same. -function _mooncake_gradient_cache(::AutoMooncake, f, x; config) - return Mooncake.prepare_gradient_cache(f, x; config) +# Mooncake exposes separate `prepare_*_cache` entries per AD mode; the call +# shape (callable + active arg + extra args) is the same. Used by the +# NamedTuple path, the order=1 scalar branch, and the order=2 gradient prep. +function _mooncake_gradient_cache(::AutoMooncake, f, x, args...; config) + return Mooncake.prepare_gradient_cache(f, x, args...; config) end -function _mooncake_gradient_cache(::AutoMooncakeForward, f, x; config) - return Mooncake.prepare_derivative_cache(f, x; config) +function _mooncake_gradient_cache(::AutoMooncakeForward, f, x, args...; config) + return Mooncake.prepare_derivative_cache(f, x, args...; config) end function AbstractPPL.prepare( @@ -47,10 +67,13 @@ function AbstractPPL.prepare( end """ - prepare(adtype::AutoMooncake, problem, x; check_dims=true, context::Tuple=()) - prepare(adtype::AutoMooncakeForward, problem, x; check_dims=true, context::Tuple=()) + prepare(adtype::AutoMooncake, problem, x; check_dims=true, context::Tuple=(), order=1) + prepare(adtype::AutoMooncakeForward, problem, x; check_dims=true, context::Tuple=(), order=1) -Prepare a Mooncake gradient/Jacobian evaluator for a dense vector input. +Prepare a Mooncake gradient, Jacobian, or Hessian evaluator for a dense vector +input. `order=1` (default) picks gradient/Jacobian by output arity; +`order=2` builds Hessian machinery (`value_gradient_and_hessian!!`) and +requires a scalar-valued problem. Non-`DenseVector` inputs (views, strided slices) are rejected: Mooncake assumes a contiguous primal and otherwise returns shape-incorrect tangents @@ -58,12 +81,12 @@ on reverse mode and crashes on forward/Jacobian paths. `context` follows the base `prepare` contract — the prepared evaluator computes `problem(x, context...)` with AD differentiating only `x`. One -Mooncake-specific restriction: vector-valued problems require `context=()`. +Mooncake-specific restriction for `order=1`: vector-valued problems require +`context=()`. `order=2` accepts any `context`. Empty input (`length(x) == 0`) is supported with any `context`; Mooncake builds no tape for zero-length `x`, so the prepared evaluator's AD entry -short-circuits to `(problem(x, context...), eltype(x)[])` without invoking -Mooncake. +short-circuits without invoking Mooncake. """ function AbstractPPL.prepare( adtype::_MooncakeAD, @@ -71,7 +94,9 @@ function AbstractPPL.prepare( x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=(), + order::Int=1, ) + Evaluators._validate_ad_order(order) x isa DenseVector || throw( ArgumentError( "AutoMooncake / AutoMooncakeForward require a dense vector input " * @@ -82,6 +107,29 @@ function AbstractPPL.prepare( evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) config = _mooncake_config(adtype) + if order == 2 + arity === :scalar || Evaluators._throw_hessian_needs_scalar() + # `{false}` skips the per-call shape check — `_check_ad_input` on the + # AD entry already validates `x`. `dim` is unused for `{false}`. + target = VectorEvaluator{false}(evaluator.f, 0, evaluator.context) + length(x) == 0 && return Prepared( + adtype, evaluator, MooncakeCache{:hessian}(target, nothing), Val(2) + ) + hess_cache = Mooncake.prepare_hessian_cache(target, x; config) + # Order=1 gradient cache so `value_and_gradient!!` on the same prep + # skips the Hessian work. Mooncake's `value_and_gradient!!` runs on + # `evaluator.f` with context-as-extra-args, distinct from the wrapped + # `target` used by the Hessian API. + grad_cache = _mooncake_gradient_cache( + adtype, evaluator.f, x, evaluator.context...; config + ) + return Prepared( + adtype, + evaluator, + MooncakeCache{:hessian}(target, hess_cache, grad_cache), + Val(2), + ) + end if !isempty(evaluator.context) && arity !== :scalar throw( ArgumentError( @@ -98,16 +146,12 @@ function AbstractPPL.prepare( # `problem` / `context` kwargs): a downstream override of structural # `prepare` may return a `VectorEvaluator` whose `.f`/`.context` differ # from the caller-supplied values, and the hot path reads them off the - # evaluator. Forward mode uses `prepare_derivative_cache` for both - # arities; the splat is a no-op for vector arity (empty `context`). - cache = if adtype isa AutoMooncake - if arity === :scalar - Mooncake.prepare_gradient_cache(evaluator.f, x, evaluator.context...; config) - else - Mooncake.prepare_pullback_cache(evaluator.f, x; config) - end + # evaluator. The reverse-mode vector branch is the only one that can't + # share `_mooncake_gradient_cache` — it needs `prepare_pullback_cache`. + cache = if arity !== :scalar && adtype isa AutoMooncake + Mooncake.prepare_pullback_cache(evaluator.f, x; config) else - Mooncake.prepare_derivative_cache(evaluator.f, x, evaluator.context...; config) + _mooncake_gradient_cache(adtype, evaluator.f, x, evaluator.context...; config) end return Prepared(adtype, evaluator, MooncakeCache{arity}(cache)) end @@ -125,12 +169,12 @@ end return (val, grad) end -# Empty-input shortcut. `MooncakeCache{:scalar,Nothing}` is strictly more -# specific than `MooncakeCache{:scalar}` on `C`, so dispatch unambiguously -# selects this method over the general scalar-gradient hot path below for -# zero-length `x`. +# Empty-input shortcut. `MooncakeCache{:scalar,Nothing,Nothing}` is strictly +# more specific than `MooncakeCache{:scalar}`, so dispatch unambiguously selects +# this method over the general scalar-gradient hot path below for zero-length +# `x`. @inline function AbstractPPL.value_and_gradient!!( - p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar,Nothing}}, + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar,Nothing,Nothing}}, x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) @@ -141,34 +185,37 @@ end # `args_to_zero` to mark `x` as the lone active input (`false` on `f`, # `true` on `x`, `false` on each context value); forward mode # (`ForwardCache`) derives activity from its seeded argument and rejects -# the kwarg. The `p.adtype isa AutoMooncake` branch is compile-folded -# since `adtype`'s concrete type lives in `Prepared`'s type parameters. -# Empty `context` collapses the splat and reduces `args_to_zero` to +# the kwarg. The `adtype isa AutoMooncake` branch is compile-folded +# since the concrete type lives in `Prepared`'s type parameters. Empty +# `context` collapses the splat and reduces `args_to_zero` to # `(false, true)`. `tangents[2]` is the `x`-gradient; trailing entries # (one per context value) are inactive and discarded. -@inline function AbstractPPL.value_and_gradient!!( - p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, - x::AbstractVector{T}, -) where {T<:Real} - Evaluators._check_ad_input(p.evaluator, x) - e = p.evaluator - val, tangents = if p.adtype isa AutoMooncake +@inline function _mooncake_value_and_gradient(adtype, gcache, e::VectorEvaluator, x) + val, tangents = if adtype isa AutoMooncake Mooncake.value_and_gradient!!( - p.cache.cache, + gcache, e.f, x, e.context...; args_to_zero=(false, true, map(_ -> false, e.context)...), ) else - Mooncake.value_and_gradient!!(p.cache.cache, e.f, x, e.context...) + Mooncake.value_and_gradient!!(gcache, e.f, x, e.context...) end return (val, tangents[2]) end +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, + x::AbstractVector{<:Real}, +) + Evaluators._check_ad_input(p.evaluator, x) + return _mooncake_value_and_gradient(p.adtype, p.cache.cache, p.evaluator, x) +end + # Arity-mismatch errors as dedicated methods so dispatch on -# `MooncakeCache{:scalar}` vs `{:vector}` resolves at compile time instead of -# a runtime check on the cache contents. +# `MooncakeCache{:scalar}` vs `{:vector}` vs `{:hessian}` resolves at compile +# time instead of a runtime check on the cache contents. @inline function AbstractPPL.value_and_gradient!!( ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector}}, ::AbstractVector{<:Real}, @@ -176,8 +223,34 @@ end return Evaluators._throw_gradient_needs_scalar() end +# Empty-input shortcut for order=2 preps: same `Nothing` specificity trick +# as the scalar case. `gradient_cache` is `Nothing` only on the empty-x prep. +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{ + <:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian,<:Any,Nothing,Nothing} + }, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + return (p.evaluator(x), T[]) +end + +# Order=2 prep also satisfies the order=1 gradient contract via the dedicated +# gradient cache built at prep time — skips the O(n²) Hessian work. +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian}}, + x::AbstractVector{<:Real}, +) + Evaluators._check_ad_input(p.evaluator, x) + return _mooncake_value_and_gradient(p.adtype, p.cache.gradient_cache, p.evaluator, x) +end + @inline function AbstractPPL.value_and_jacobian!!( - ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, + ::Prepared{ + <:_MooncakeAD, + <:VectorEvaluator, + <:Union{MooncakeCache{:scalar},MooncakeCache{:hessian}}, + }, ::AbstractVector{<:Real}, ) return Evaluators._throw_jacobian_needs_vector() @@ -186,7 +259,7 @@ end # Empty-input jacobian shortcut. Same `Nothing` specificity trick as the # scalar case above; skips Mooncake entirely. @inline function AbstractPPL.value_and_jacobian!!( - p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector,Nothing}}, + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector,Nothing,Nothing}}, x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) @@ -206,4 +279,37 @@ end return Mooncake.value_and_jacobian!!(p.cache.cache, p.evaluator.f, x) end +# Order=1 prep rejected for Hessian. `MooncakeCache{:hessian}` has dedicated +# methods below that are strictly more specific, so this catch-all only fires +# for `:scalar` / `:vector`. +@inline function AbstractPPL.value_gradient_and_hessian!!( + ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache}, ::AbstractVector{<:Real} +) + return Evaluators._throw_hessian_needs_order_2_prep() +end + +# Empty-input shortcut — Mooncake builds no tape for length-zero `x`. +@inline function AbstractPPL.value_gradient_and_hessian!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian,<:Any,Nothing}}, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + return (p.evaluator(x), T[], similar(x, 0, 0)) +end + +@inline function AbstractPPL.value_gradient_and_hessian!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian}}, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + # Mooncake's `value_gradient_and_hessian!!` currently allocates fresh + # gradient and Hessian arrays per call despite the `!!` name (unlike its + # `value_and_gradient!!`, which does alias `HVPCache` storage). The + # AbstractPPL `!!` contract permits aliasing rather than requiring it, so + # this is conformant; once Mooncake's `HVPCache` is updated to reuse + # output buffers, the returned arrays here will alias automatically with + # no extension change needed. + return Mooncake.value_gradient_and_hessian!!(p.cache.cache, p.cache.target, x) +end + end # module diff --git a/ext/AbstractPPLTestExt.jl b/ext/AbstractPPLTestExt.jl index 3ac01060..c5c74727 100644 --- a/ext/AbstractPPLTestExt.jl +++ b/ext/AbstractPPLTestExt.jl @@ -19,6 +19,16 @@ struct ValueCase jacobian::Any end +struct HessianCase + name::String + f::Any + x_proto::Any + x::Any + value::Any + gradient::Any + hessian::Any +end + struct ErrorCase name::String f::Any @@ -69,6 +79,45 @@ function AbstractPPL.generate_testcases(::Val{:vector}) ) end +function AbstractPPL.generate_testcases(::Val{:hessian}) + return ( + HessianCase( + "quadratic (scalar output)", + QuadraticProblem(), + zeros(3), + [3.0, 1.0, 2.0], + 14.0, + [6.0, 2.0, 4.0], + [2.0 0.0 0.0; 0.0 2.0 0.0; 0.0 0.0 2.0], + ), + HessianCase( + "empty input, scalar output", + x -> 7.5, + Float64[], + Float64[], + 7.5, + Float64[], + zeros(0, 0), + ), + ) +end + +function AbstractPPL.generate_testcases(::Val{:hessian_edge}) + return ( + # `value_gradient_and_hessian!!` rejects order=1 preps regardless of + # the underlying problem arity — both paths share the same dispatch + # so one case suffices. + ErrorCase( + "value_gradient_and_hessian!! on order=1 prep", + QuadraticProblem(), + zeros(3), + [3.0, 1.0, 2.0], + (prepared, x) -> AbstractPPL.value_gradient_and_hessian!!(prepared, x), + r"order=2", + ), + ) +end + function AbstractPPL.generate_testcases(::Val{:edge}) return ( ErrorCase( @@ -174,6 +223,7 @@ function AbstractPPL.run_testcases( for case in generate_testcases(Val(:vector)) @testset "$(case.name)" begin prepared = prepare_fn(adtype, case.f, case.x_proto) + @test AbstractPPL.order(prepared) == 1 @test prepared(case.x) ≈ case.value atol = atol rtol = rtol if case.gradient !== nothing val, grad = AbstractPPL.value_and_gradient!!(prepared, case.x) @@ -190,6 +240,33 @@ function AbstractPPL.run_testcases( return nothing end +function AbstractPPL.run_testcases( + ::Val{:hessian}, prepare_fn=AbstractPPL.prepare; adtype, atol=0, rtol=1e-10 +) + for case in generate_testcases(Val(:hessian)) + @testset "$(case.name)" begin + prepared = prepare_fn(adtype, case.f, case.x_proto; order=2) + @test AbstractPPL.order(prepared) == 2 + @test prepared(case.x) ≈ case.value atol = atol rtol = rtol + val, grad, hess = AbstractPPL.value_gradient_and_hessian!!(prepared, case.x) + @test val ≈ case.value atol = atol rtol = rtol + @test grad ≈ case.gradient atol = atol rtol = rtol + @test hess ≈ case.hessian atol = atol rtol = rtol + # Order=2 prep also satisfies the order=1 gradient contract. + val1, grad1 = AbstractPPL.value_and_gradient!!(prepared, case.x) + @test val1 ≈ case.value atol = atol rtol = rtol + @test grad1 ≈ case.gradient atol = atol rtol = rtol + end + end + for case in generate_testcases(Val(:hessian_edge)) + @testset "$(case.name)" begin + prepared = prepare_fn(adtype, case.f, case.x_proto) + @test_throws case.exception case.op(prepared, case.x) + end + end + return nothing +end + function AbstractPPL.run_testcases(::Val{:edge}, prepare_fn=AbstractPPL.prepare; adtype) for case in generate_testcases(Val(:edge)) @testset "$(case.name)" begin diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index c70b349d..2a32494d 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -11,7 +11,8 @@ include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") include("evaluators/Evaluators.jl") -using .Evaluators: prepare, value_and_gradient!!, value_and_jacobian!! +using .Evaluators: + prepare, value_and_gradient!!, value_and_jacobian!!, value_gradient_and_hessian!!, order """ generate_testcases(::Val{group}) @@ -19,9 +20,11 @@ using .Evaluators: prepare, value_and_gradient!!, value_and_jacobian!! Return a tuple of test cases for the conformance `group`. Implemented by the `Test` extension (`AbstractPPLTestExt`). Reserved group keys (extensions must not redefine these): `:vector` for value/gradient/jacobian round-trips on -vector-input evaluators; `:namedtuple` for `NamedTuple`-input evaluators; -`:edge` for error-path cases; `:cache_reuse` for repeated calls against a -single prepared evaluator. Downstream packages may add other keys. +vector-input evaluators; `:hessian` for `order=2` value/gradient/Hessian +round-trips on vector-input scalar-output evaluators; `:namedtuple` for +`NamedTuple`-input evaluators; `:edge` for error-path cases; `:cache_reuse` +for repeated calls against a single prepared evaluator. Downstream packages +may add other keys. """ function generate_testcases end @@ -38,7 +41,7 @@ function run_testcases end @static if VERSION >= v"1.11.0" eval( Meta.parse( - "public prepare, value_and_gradient!!, value_and_jacobian!!, generate_testcases, run_testcases", + "public prepare, value_and_gradient!!, value_and_jacobian!!, value_gradient_and_hessian!!, order, generate_testcases, run_testcases", ), ) end diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index dfae89f3..124c7134 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -6,10 +6,14 @@ import ..evaluate!! include("utils.jl") """ - Prepared{AD<:AbstractADType,E,C}(adtype, evaluator, cache) - Prepared(adtype, evaluator) # cache defaults to `nothing` + Prepared{AD<:AbstractADType,E,C,Order}(adtype, evaluator, cache) + Prepared(adtype, evaluator, cache, Val(Order)) + Prepared(adtype, evaluator, cache) # defaults `Order` to 1 + Prepared(adtype, evaluator) # cache defaults to `nothing` -AD-prepared evaluator parameterised by backend type `AD`. +AD-prepared evaluator parameterised by backend type `AD` and derivative order +`Order` (`1` for gradient/jacobian, `2` for Hessian). Retrieve `Order` via +[`order`](@ref). - `adtype` — the backend, used for dispatch. - `evaluator` — the user-facing callable (typically a `VectorEvaluator` or @@ -29,27 +33,49 @@ function AbstractPPL.value_and_gradient!!( end ``` """ -struct Prepared{AD<:AbstractADType,E,C} +struct Prepared{AD<:AbstractADType,E,C,Order} adtype::AD evaluator::E cache::C + function Prepared{AD,E,C,Order}( + adtype, evaluator, cache + ) where {AD<:AbstractADType,E,C,Order} + return new{AD,E,C,Order}(adtype, evaluator, cache) + end end -Prepared(adtype::AbstractADType, evaluator) = Prepared(adtype, evaluator, nothing) +function Prepared( + adtype::AD, evaluator::E, cache::C, ::Val{Order} +) where {AD<:AbstractADType,E,C,Order} + return Prepared{AD,E,C,Order}(adtype, evaluator, cache) +end +function Prepared(adtype::AbstractADType, evaluator, cache) + return Prepared(adtype, evaluator, cache, Val(1)) +end +Prepared(adtype::AbstractADType, evaluator) = Prepared(adtype, evaluator, nothing, Val(1)) (p::Prepared)(x) = p.evaluator(x) +""" + order(p::Prepared) + +Return the derivative order `p` was prepared for (`1` for gradient/jacobian, +`2` for Hessian). Type-stable — folds to the `Order` type parameter at compile +time. +""" +order(::Prepared{<:Any,<:Any,<:Any,O}) where {O} = O + """ prepare(problem, values::NamedTuple; check_dims::Bool=true) prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=()) - prepare(adtype, problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=()) + prepare(adtype, problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=(), order::Int=1) Prepare a callable evaluator for `problem`. Use the two-argument form with a `NamedTuple` when the evaluator works with named inputs, or with a vector when it works with vector inputs. The three-argument form, contributed by AD-backend extensions, additionally -prepares gradient or jacobian machinery for vector inputs. +prepares gradient, jacobian, or Hessian machinery for vector inputs. `check_dims` (default `true`) controls whether the returned evaluator validates the input shape on each call. Pass `check_dims=false` to skip the per-call @@ -61,9 +87,15 @@ through to `problem`: the prepared evaluator computes `problem(x, context...)`, and AD backends differentiate only with respect to `x`. `context=()` (the default) preserves the unary `problem(x)` contract. +`order` selects the derivative order to prepare for on the AD-aware form. The +default `order=1` prepares gradient (scalar output) or jacobian (vector output) +machinery. `order=2` prepares Hessian machinery via `value_gradient_and_hessian!!` +and requires `problem` to be scalar-valued — vector-valued problems will throw +during preparation. + The three-argument AD-aware form may invoke `problem` once during preparation -to detect output arity (scalar vs vector) and select gradient or jacobian -machinery accordingly. Avoid `prepare` calls when `problem` has side effects +to detect output arity (scalar vs vector) and select the appropriate +derivative machinery. Avoid `prepare` calls when `problem` has side effects that should fire only on user-driven evaluations. """ function prepare end @@ -99,6 +131,17 @@ The Jacobian has shape `(length(value), length(x))`. """ function value_and_jacobian!! end +""" + value_gradient_and_hessian!!(prepared, x::AbstractVector{<:Real}) + +Return `(value, gradient::AbstractVector, hessian::AbstractMatrix)` for a +scalar-valued evaluator prepared with `order=2`, potentially reusing internal +cache buffers. The returned gradient and Hessian may alias `prepared`'s +internal storage; copy if you need to retain them past the next call. +The Hessian has shape `(length(x), length(x))`. +""" +function value_gradient_and_hessian!! end + """ VectorEvaluator{CheckInput}(f, dim, context::Tuple=()) VectorEvaluator(f, dim, context::Tuple=()) # equivalent to `VectorEvaluator{true}(f, dim, context)` @@ -264,15 +307,32 @@ function _ad_output_arity(y) ) end -# Arity-mismatch errors shared by the DI and Mooncake extensions; kept here so -# the `:edge` testcase regexes (`r"scalar-valued"`, `r"vector-valued"`) pin a -# single error string instead of one per backend. +# Error helpers shared by the DI and Mooncake extensions; kept here so the +# `:edge` testcase regexes (`r"scalar-valued"`, `r"vector-valued"`, `r"order=2"`) +# pin a single error string instead of one per backend. function _throw_gradient_needs_scalar() throw(ArgumentError("`value_and_gradient!!` requires a scalar-valued function.")) end function _throw_jacobian_needs_vector() throw(ArgumentError("`value_and_jacobian!!` requires a vector-valued function.")) end +function _throw_hessian_needs_scalar() + throw( + ArgumentError("`value_gradient_and_hessian!!` requires a scalar-valued function.") + ) +end +function _throw_hessian_needs_order_2_prep() + throw( + ArgumentError( + "`value_gradient_and_hessian!!` requires an evaluator prepared with `order=2`." + ), + ) +end + +# Validate the `order=` kwarg of `prepare(adtype, problem, x; order)`. Shared by +# the DI and Mooncake extensions so the error string is identical. +@inline _validate_ad_order(order::Int) = + order in (1, 2) || throw(ArgumentError("`order` must be 1 or 2, got $order.")) # Complements the `typeof` check above: same-typed arrays can differ in `size`. # Arrays with non-`Real`/`Complex` eltype are walked element-wise to catch @@ -324,11 +384,14 @@ function __init__() end # Same fire-only-when-no-backend-loaded logic as the `prepare` hint above. Base.Experimental.register_error_hint(MethodError) do io, exc, args, kwargs - exc.f === value_and_gradient!! || exc.f === value_and_jacobian!! || return nothing + exc.f === value_and_gradient!! || + exc.f === value_and_jacobian!! || + exc.f === value_gradient_and_hessian!! || + return nothing isempty(methods(exc.f)) || return nothing print( io, - "\nNo AD backend extension is loaded. Load `DifferentiationInterface` (with a backend like `ForwardDiff`) or `Mooncake` to enable gradient/jacobian computation.", + "\nNo AD backend extension is loaded. Load `DifferentiationInterface` (with a backend like `ForwardDiff`) or `Mooncake` to enable gradient/jacobian/Hessian computation.", ) end end diff --git a/test/ext/differentiationinterface/main.jl b/test/ext/differentiationinterface/main.jl index 636f237e..dff6473e 100644 --- a/test/ext/differentiationinterface/main.jl +++ b/test/ext/differentiationinterface/main.jl @@ -3,9 +3,15 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) Pkg.instantiate() -using AbstractPPL: AbstractPPL, prepare, run_testcases, value_and_gradient!! +using AbstractPPL: + AbstractPPL, + prepare, + run_testcases, + value_and_gradient!!, + value_gradient_and_hessian!!, + order using ADTypes: AutoForwardDiff, AutoReverseDiff -using DifferentiationInterface: DifferentiationInterface as DI +using DifferentiationInterface: DifferentiationInterface as DI, SecondOrder using ForwardDiff using ReverseDiff using Test @@ -17,11 +23,12 @@ quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) @testset "AbstractPPLDifferentiationInterfaceExt" begin @testset "ForwardDiff" begin run_testcases(Val(:vector); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) + run_testcases(Val(:hessian); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) run_testcases(Val(:cache_reuse); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) run_testcases(Val(:edge); adtype=AutoForwardDiff()) end - # Compiled-tape ReverseDiff goes through the `_prepare_di(::AutoReverseDiff{true}, …)` + # Compiled-tape ReverseDiff goes through the `_di_call_shape(::AutoReverseDiff{true}, …)` # specialisation that closes the evaluator into a `Base.Fix2` target — the # `:cache_reuse` group exercises that path across multiple inputs. @testset "ReverseDiff (compiled tape)" begin @@ -31,20 +38,20 @@ quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) run_testcases(Val(:edge); adtype=adtype) end - # `DICache`'s `Mode` parameter is either `:closure` (compiled-tape + # The DI cache types' `Mode` parameter is either `:closure` (compiled-tape # ReverseDiff) or the integer context length on the constants path. The # constants-path integer also documents how many `DI.Constant`s the AD # call passes. - @testset "DICache encodes the call mode as a type parameter" begin + @testset "DI cache encodes the call mode as a type parameter" begin x = [1.0, 2.0, 3.0] prep_noctx = prepare(AutoForwardDiff(), quadratic, x) prep_closure = prepare(AutoReverseDiff(; compile=true), quadratic, x) affine(y, a, b) = a * sum(abs2, y) + b prep_ctx = prepare(AutoForwardDiff(), affine, x; context=(2.0, 1.0)) - @test prep_noctx.cache isa DIExt.DICache{0} - @test prep_closure.cache isa DIExt.DICache{:closure} - @test prep_ctx.cache isa DIExt.DICache{2} + @test prep_noctx.cache isa DIExt.DIGradientCache{0} + @test prep_closure.cache isa DIExt.DIGradientCache{:closure} + @test prep_ctx.cache isa DIExt.DIGradientCache{2} # Non-empty-context primal matches the underlying `f(x, context...)`. @test prep_ctx(x) == affine(x, 2.0, 1.0) @@ -57,4 +64,35 @@ quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) @inferred value_and_gradient!!(prep_closure, x) @inferred value_and_gradient!!(prep_ctx, x) end + + # `SecondOrder(outer, inner)` lets the caller pick the inner gradient + # backend and the outer differentiator independently — useful when the + # default Hessian strategy DI picks for a single `adtype` is suboptimal. + # Since `SecondOrder <: AbstractADType`, the existing `order=2` dispatch + # routes it through `DI.prepare_hessian` / `DI.value_gradient_and_hessian` + # without any extension-side changes. + @testset "SecondOrder for order=2" begin + adtype = SecondOrder(AutoForwardDiff(), AutoReverseDiff()) + x = [1.0, 2.0, 3.0] + prep = prepare(adtype, quadratic, zeros(3); order=2) + @test order(prep) == 2 + val, grad, hess = value_gradient_and_hessian!!(prep, x) + @test val ≈ 14.0 + @test grad ≈ [2.0, 4.0, 6.0] + @test hess ≈ [2.0 0 0; 0 2.0 0; 0 0 2.0] + + # `value_and_gradient!!` on a `SecondOrder` order=2 prep routes through + # the inner adtype — the only non-trivial case for `_gradient_adtype`. + val1, grad1 = value_and_gradient!!(prep, x) + @test val1 ≈ 14.0 + @test grad1 ≈ [2.0, 4.0, 6.0] + + # `context=` composes with `SecondOrder` the same way as for a plain `adtype`. + affine(y, a, b) = a * sum(abs2, y) + b + prep_ctx = prepare(adtype, affine, zeros(3); context=(2.0, 1.0), order=2) + val_ctx, grad_ctx, hess_ctx = value_gradient_and_hessian!!(prep_ctx, x) + @test val_ctx ≈ affine(x, 2.0, 1.0) + @test grad_ctx ≈ [4.0, 8.0, 12.0] + @test hess_ctx ≈ [4.0 0 0; 0 4.0 0; 0 0 4.0] + end end diff --git a/test/ext/mooncake/main.jl b/test/ext/mooncake/main.jl index 3c328a1a..855d5f5b 100644 --- a/test/ext/mooncake/main.jl +++ b/test/ext/mooncake/main.jl @@ -18,6 +18,10 @@ using Test run_testcases(Val(:namedtuple); adtype=adtype, atol=1e-6, rtol=1e-6) run_testcases(Val(:cache_reuse); adtype=adtype, atol=1e-6, rtol=1e-6) run_testcases(Val(:edge); adtype=adtype) + # Hessian (`order=2`) is reverse-mode only on the AutoMooncake side; + # AutoMooncakeForward routes through the same generic Hessian path + # since `Mooncake.prepare_hessian_cache` is mode-agnostic. + run_testcases(Val(:hessian); adtype=adtype, atol=1e-6, rtol=1e-6) end end