From 2cc3790d5a6d7ad5ccca9d0630f67410cd15b56d Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 19 May 2026 21:51:52 +0100 Subject: [PATCH 1/5] Add order=2 Hessian preparation via value_gradient_and_hessian!! Extends `prepare(adtype, problem, x; order=2)` to build Hessian machinery for scalar-valued problems on the DI and Mooncake extensions, returning `(value, gradient, hessian)` from a new `value_gradient_and_hessian!!` generic. Unifies the per-extension caches (`DICache`, `MooncakeCache`) so one struct carries every derivative order, with explicit cross-arity error messages replacing prior `MethodError`s. DI uses the in-place `DI.value_gradient_and_hessian!` with caller-owned buffers; Mooncake uses its native `prepare_hessian_cache` API. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/src/evaluators.md | 51 ++++++++ ext/AbstractPPLDifferentiationInterfaceExt.jl | 95 ++++++++++++-- ext/AbstractPPLMooncakeExt.jl | 120 ++++++++++++++---- ext/AbstractPPLTestExt.jl | 63 +++++++++ src/AbstractPPL.jl | 13 +- src/evaluators/Evaluators.jl | 61 +++++++-- test/ext/differentiationinterface/main.jl | 30 ++++- test/ext/mooncake/main.jl | 4 + 8 files changed, 387 insertions(+), 50 deletions(-) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index 9dd6f8af..5b03b7b0 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -138,6 +138,56 @@ library invokes the inner callable many times with same-length dual arrays derived from a single user-supplied `x`; re-validating on each invocation would be redundant work in the hot path. +## Hessian (`order=2`) + +Pass `order=2` to `prepare` to build a Hessian-capable evaluator. The +returned object answers `value_gradient_and_hessian!!`, which returns +`(value, gradient, hessian)` in a single call. `order=2` requires +`problem` to be scalar-valued; a vector-valued probe throws at preparation +time. + +```julia +using AbstractPPL: prepare, value_gradient_and_hessian!! +using ADTypes: AutoForwardDiff +using ForwardDiff, DifferentiationInterface + +quadratic(x) = sum(abs2, x) +prepared = prepare(AutoForwardDiff(), quadratic, zeros(3); order=2) +val, grad, hess = value_gradient_and_hessian!!(prepared, [1.0, 2.0, 3.0]) +# val == 14.0 +# grad == [2.0, 4.0, 6.0] +# hess == [2 0 0; 0 2 0; 0 0 2] +``` + +Both `context=` and `check_dims=` apply to `order=2` preps with the same +semantics as for `order=1`. The `!!` aliasing contract also extends: the +returned gradient and Hessian may alias internal cache buffers of +`prepared`, so copy before retaining them past the next call. NamedTuple +inputs are not supported at `order=2`. + +For DifferentiationInterface, `adtype` can be either a single backend +(letting DI pick its own Hessian strategy) or a +[`DifferentiationInterface.SecondOrder(outer, inner)`](https://juliadiff.org/DifferentiationInterface.jl/stable/api/#DifferentiationInterface.SecondOrder) +composition that selects the outer differentiator and the inner gradient +backend independently — typically forward-over-reverse: + +```julia +using DifferentiationInterface: SecondOrder +using ADTypes: AutoForwardDiff, AutoReverseDiff + +adtype = SecondOrder(AutoForwardDiff(), AutoReverseDiff()) +prepared = prepare(adtype, quadratic, zeros(3); order=2) +``` + +`SecondOrder <: AbstractADType`, so the same `prepare(adtype, problem, x; order=2)` +entry handles it. + +Calling `value_gradient_and_hessian!!` on an `order=1` prep throws an +`ArgumentError` — re-prepare with `order=2` instead. Likewise, calling +`value_and_gradient!!` or `value_and_jacobian!!` on an `order=2` prep is +unsupported; use `value_gradient_and_hessian!!` and discard the unused +return value. + ## Constant context arguments When the underlying callable naturally takes the form `f(x, context...)` — @@ -177,4 +227,5 @@ p([1.0, 2.0, 3.0]) AbstractPPL.prepare AbstractPPL.value_and_gradient!! AbstractPPL.value_and_jacobian!! +AbstractPPL.value_gradient_and_hessian!! ``` diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 1f4cfe8b..3a8f993c 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -5,7 +5,7 @@ using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_ using ADTypes: AbstractADType, AutoReverseDiff using DifferentiationInterface: DifferentiationInterface as DI -# AD target used by both `DICache` modes. `Vararg{Any,N}` with a free `N` +# AD target used by every `DICache` mode. `Vararg{Any,N}` with a free `N` # forces specialization on the trailing arity (a bare `Vararg{Any}` would # skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the # constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on @@ -13,19 +13,31 @@ using DifferentiationInterface: DifferentiationInterface as DI @inline _call_evaluator(x, f::F, ctx::Vararg{Any,N}) where {F,N} = f(x, ctx...) # `Mode` tags the cache shape: -# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure, -# the AD call passes **0** `DI.Constant`s. -# * `N::Int` — constants path: `N == length(evaluator.context)`, the -# AD call passes **N + 1** `DI.Constant`s (`f` plus the -# `N` context values). -# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*` -# at compile time without a runtime branch. -struct DICache{Mode,F,GP,JP} +# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure, the +# AD call passes **0** `DI.Constant`s. +# * `N::Int` — constants path: `N == length(evaluator.context)`, the AD +# call passes **N + 1** `DI.Constant`s (`f` plus the `N` +# context values). +# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*` at +# compile time without a runtime branch. +# +# Single cache for every derivative order. At most one of `gradient_prep`, +# `jacobian_prep`, `hessian_prep` is non-`Nothing` at any time; the hot-path +# methods discriminate via `=== nothing` checks (folded at compile time since +# field types are concrete in each instantiation). `grad_buf` / `hess_buf` are +# non-`Nothing` only for order=2 — caller-owned output buffers handed to +# `DI.value_gradient_and_hessian!`. Returned arrays alias them (`!!` contract). +struct DICache{Mode,F,GP,JP,HP,G,H} target::F gradient_prep::GP jacobian_prep::JP - function DICache{Mode}(target::F, gp::GP, jp::JP) where {Mode,F,GP,JP} - return new{Mode,F,GP,JP}(target, gp, jp) + hessian_prep::HP + grad_buf::G + hess_buf::H + function DICache{Mode}( + target::F, gp::GP, jp::JP, hp::HP, g::G, h::H + ) where {Mode,F,GP,JP,HP,G,H} + return new{Mode,F,GP,JP,HP,G,H}(target, gp, jp, hp, g, h) end end @@ -49,7 +61,7 @@ function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F} end @inline _wrap_cache(target, gp, jp, ::Val{Mode}) where {Mode} = - DICache{Mode}(target, gp, jp) + DICache{Mode}(target, gp, jp, nothing, nothing, nothing) function AbstractPPL.prepare( adtype::AbstractADType, @@ -57,9 +69,32 @@ function AbstractPPL.prepare( x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=(), + order::Int=1, ) evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) + if order == 2 + arity === :scalar || Evaluators._throw_hessian_needs_scalar() + if length(x) == 0 + # DI Hessian prep crashes on length-0 input; the AD entry + # short-circuits before any DI call. `Val(0)` is a non-`Nothing` + # sentinel for `hessian_prep` so dispatch recognises this as an + # order=2 prep (mirrors the order=1 empty-input pattern below). + cache = _wrap_hessian_cache( + _call_evaluator, Val(0), nothing, nothing, Val(length(context)) + ) + return Prepared(adtype, evaluator, cache) + end + target, hessian_prep, mode = _prepare_di(DI.prepare_hessian, adtype, x, evaluator) + # Buffers pre-allocated from `x` (shape and eltype): the hot path is + # zero-allocation on the gradient/Hessian outputs, and the returned + # arrays alias these slots — copy if you need to retain them. + grad_buf = similar(x) + hess_buf = similar(x, length(x), length(x)) + cache = _wrap_hessian_cache(target, hessian_prep, grad_buf, hess_buf, mode) + return Prepared(adtype, evaluator, cache) + end + order == 1 || throw(ArgumentError("`order` must be 1 or 2, got $order.")) if length(x) == 0 # DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`). # `Val(0)` is an arity sentinel for the `gradient_prep === nothing` @@ -78,6 +113,9 @@ function AbstractPPL.prepare( return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, mode)) end +@inline _wrap_hessian_cache(target, hp, g, h, ::Val{Mode}) where {Mode} = + DICache{Mode}(target, nothing, nothing, hp, g, h) + # Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile # time. The unconstrained method matches every non-`:closure` `Mode` (i.e. # any `Int N`); `:closure` is strictly more specific and wins for compiled @@ -108,6 +146,9 @@ end @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} ) where {T<:Real} + # Both `=== nothing` branches fold at compile time: each instantiation + # has concrete field types, so only the relevant branch survives. + p.cache.hessian_prep === nothing || Evaluators._throw_use_value_gradient_and_hessian() p.cache.gradient_prep === nothing && Evaluators._throw_gradient_needs_scalar() Evaluators._check_ad_input(p.evaluator, x) # Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff @@ -119,6 +160,7 @@ end @inline function AbstractPPL.value_and_jacobian!!( p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} ) where {T<:Real} + p.cache.hessian_prep === nothing || Evaluators._throw_use_value_gradient_and_hessian() p.cache.jacobian_prep === nothing && Evaluators._throw_jacobian_needs_vector() Evaluators._check_ad_input(p.evaluator, x) if length(x) == 0 @@ -128,4 +170,33 @@ end return _di_value_and_jacobian(p.cache, p.adtype, x, p.evaluator) end +# Hessian hot-path dispatch mirrors the gradient/jacobian helpers above: +# `:closure` (compiled-tape) vs constants `Mode`, resolved at compile time. +# Uses DI's in-place variant `value_gradient_and_hessian!` with caller-owned +# buffers; the returned `(val, grad, hess)` aliases `c.grad_buf` / `c.hess_buf`. +@inline _di_value_gradient_and_hessian(c::DICache{:closure}, ad, x, _) = + DI.value_gradient_and_hessian!(c.target, c.grad_buf, c.hess_buf, c.hessian_prep, ad, x) +@inline _di_value_gradient_and_hessian(c::DICache, ad, x, eval) = + DI.value_gradient_and_hessian!( + c.target, + c.grad_buf, + c.hess_buf, + c.hessian_prep, + ad, + x, + DI.Constant(eval.f), + map(DI.Constant, eval.context)..., + ) + +@inline function AbstractPPL.value_gradient_and_hessian!!( + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} +) where {T<:Real} + # Order=1 preps have `hessian_prep === nothing` (compile-folded check). + p.cache.hessian_prep === nothing && Evaluators._throw_hessian_needs_order_2_prep() + Evaluators._check_ad_input(p.evaluator, x) + # Empty-input shortcut — same reasoning as the order=1 path. + length(x) == 0 && return (p.evaluator(x), T[], similar(x, 0, 0)) + return _di_value_gradient_and_hessian(p.cache, p.adtype, x, p.evaluator) +end + end # module diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index b07af8ae..7fceadfc 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -8,23 +8,37 @@ using Mooncake: Mooncake const _MooncakeAD = Union{AutoMooncake,AutoMooncakeForward} -# `NamedTupleEvaluator` is the callable on the NamedTuple path; `NoTangent` -# stops Mooncake from deriving a `Tangent{NamedTuple{...}}` for its fields -# on every backward pass. The `VectorEvaluator` override is a defensive -# guard — vector preps no longer pass the evaluator wrapper to Mooncake. +# `NoTangent` stops Mooncake from deriving a `Tangent{...}` for the evaluator +# wrapper's fields on each backward pass. Load-bearing for both: +# * `NamedTupleEvaluator` — passed directly to Mooncake on the NamedTuple +# gradient path. +# * `VectorEvaluator` — wrapped by the order=2 path (Mooncake's Hessian +# API accepts only `AbstractVector` arguments, so +# context is closed over via `VectorEvaluator{false}`). Mooncake.tangent_type(::Type{<:VectorEvaluator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{<:NamedTupleEvaluator}) = Mooncake.NoTangent # Type parameters: # -# * `A::Symbol` — output arity, `:scalar` or `:vector`. Drives the -# gradient/jacobian dispatch and the arity-mismatch errors. +# * `A::Symbol` — `:scalar` / `:vector` for order=1 (output arity), `:hessian` +# for order=2. Drives every dispatch decision below. +# * `Target` — `Nothing` for order=1 (Mooncake's gradient/Jacobian API +# takes `f` and `context` as separate args). For `:hessian`, +# a `VectorEvaluator{false}` that closes over `f` and +# `context`, since Mooncake's Hessian API accepts only +# `AbstractVector` arguments. The evaluator's `NoTangent` +# tangent type prevents differentiation of its fields. # * `C` — the underlying Mooncake cache, or `Nothing` for the # empty-input shortcut. -struct MooncakeCache{A,C} +struct MooncakeCache{A,Target,C} + target::Target cache::C + function MooncakeCache{A}(target::Target, cache::C) where {A,Target,C} + return new{A,Target,C}(target, cache) + end end -MooncakeCache{A}(cache::C) where {A,C} = MooncakeCache{A,C}(cache) +# Order=1 convenience: no target wrapper. +MooncakeCache{A}(cache) where {A} = MooncakeCache{A}(nothing, cache) _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config @@ -47,10 +61,13 @@ function AbstractPPL.prepare( end """ - prepare(adtype::AutoMooncake, problem, x; check_dims=true, context::Tuple=()) - prepare(adtype::AutoMooncakeForward, problem, x; check_dims=true, context::Tuple=()) + prepare(adtype::AutoMooncake, problem, x; check_dims=true, context::Tuple=(), order=1) + prepare(adtype::AutoMooncakeForward, problem, x; check_dims=true, context::Tuple=(), order=1) -Prepare a Mooncake gradient/Jacobian evaluator for a dense vector input. +Prepare a Mooncake gradient, Jacobian, or Hessian evaluator for a dense vector +input. `order=1` (default) picks gradient/Jacobian by output arity; +`order=2` builds Hessian machinery (`value_gradient_and_hessian!!`) and +requires a scalar-valued problem. Non-`DenseVector` inputs (views, strided slices) are rejected: Mooncake assumes a contiguous primal and otherwise returns shape-incorrect tangents @@ -58,12 +75,12 @@ on reverse mode and crashes on forward/Jacobian paths. `context` follows the base `prepare` contract — the prepared evaluator computes `problem(x, context...)` with AD differentiating only `x`. One -Mooncake-specific restriction: vector-valued problems require `context=()`. +Mooncake-specific restriction for `order=1`: vector-valued problems require +`context=()`. `order=2` accepts any `context`. Empty input (`length(x) == 0`) is supported with any `context`; Mooncake builds no tape for zero-length `x`, so the prepared evaluator's AD entry -short-circuits to `(problem(x, context...), eltype(x)[])` without invoking -Mooncake. +short-circuits without invoking Mooncake. """ function AbstractPPL.prepare( adtype::_MooncakeAD, @@ -71,6 +88,7 @@ function AbstractPPL.prepare( x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=(), + order::Int=1, ) x isa DenseVector || throw( ArgumentError( @@ -82,6 +100,17 @@ function AbstractPPL.prepare( evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) config = _mooncake_config(adtype) + if order == 2 + arity === :scalar || Evaluators._throw_hessian_needs_scalar() + # `{false}` skips the per-call shape check — `_check_ad_input` on the + # AD entry already validates `x`. `dim` is unused for `{false}`. + target = VectorEvaluator{false}(evaluator.f, 0, evaluator.context) + length(x) == 0 && + return Prepared(adtype, evaluator, MooncakeCache{:hessian}(target, nothing)) + cache = Mooncake.prepare_hessian_cache(target, x; config) + return Prepared(adtype, evaluator, MooncakeCache{:hessian}(target, cache)) + end + order == 1 || throw(ArgumentError("`order` must be 1 or 2, got $order.")) if !isempty(evaluator.context) && arity !== :scalar throw( ArgumentError( @@ -125,12 +154,12 @@ end return (val, grad) end -# Empty-input shortcut. `MooncakeCache{:scalar,Nothing}` is strictly more -# specific than `MooncakeCache{:scalar}` on `C`, so dispatch unambiguously -# selects this method over the general scalar-gradient hot path below for -# zero-length `x`. +# Empty-input shortcut. `MooncakeCache{:scalar,Nothing,Nothing}` is strictly +# more specific than `MooncakeCache{:scalar}`, so dispatch unambiguously selects +# this method over the general scalar-gradient hot path below for zero-length +# `x`. (`Target=Nothing` always holds for order=1 — see struct comment.) @inline function AbstractPPL.value_and_gradient!!( - p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar,Nothing}}, + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar,Nothing,Nothing}}, x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) @@ -167,8 +196,8 @@ end end # Arity-mismatch errors as dedicated methods so dispatch on -# `MooncakeCache{:scalar}` vs `{:vector}` resolves at compile time instead of -# a runtime check on the cache contents. +# `MooncakeCache{:scalar}` vs `{:vector}` vs `{:hessian}` resolves at compile +# time instead of a runtime check on the cache contents. @inline function AbstractPPL.value_and_gradient!!( ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector}}, ::AbstractVector{<:Real}, @@ -176,6 +205,13 @@ end return Evaluators._throw_gradient_needs_scalar() end +@inline function AbstractPPL.value_and_gradient!!( + ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian}}, + ::AbstractVector{<:Real}, +) + return Evaluators._throw_use_value_gradient_and_hessian() +end + @inline function AbstractPPL.value_and_jacobian!!( ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, ::AbstractVector{<:Real}, @@ -183,10 +219,17 @@ end return Evaluators._throw_jacobian_needs_vector() end +@inline function AbstractPPL.value_and_jacobian!!( + ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian}}, + ::AbstractVector{<:Real}, +) + return Evaluators._throw_use_value_gradient_and_hessian() +end + # Empty-input jacobian shortcut. Same `Nothing` specificity trick as the # scalar case above; skips Mooncake entirely. @inline function AbstractPPL.value_and_jacobian!!( - p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector,Nothing}}, + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector,Nothing,Nothing}}, x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) @@ -206,4 +249,37 @@ end return Mooncake.value_and_jacobian!!(p.cache.cache, p.evaluator.f, x) end +# Order=1 prep rejected for Hessian. `MooncakeCache{:hessian}` has dedicated +# methods below that are strictly more specific, so this catch-all only fires +# for `:scalar` / `:vector`. +@inline function AbstractPPL.value_gradient_and_hessian!!( + ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache}, ::AbstractVector{<:Real} +) + return Evaluators._throw_hessian_needs_order_2_prep() +end + +# Empty-input shortcut — Mooncake builds no tape for length-zero `x`. +@inline function AbstractPPL.value_gradient_and_hessian!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian,<:Any,Nothing}}, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + return (p.evaluator(x), T[], similar(x, 0, 0)) +end + +@inline function AbstractPPL.value_gradient_and_hessian!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian}}, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + # Mooncake's `value_gradient_and_hessian!!` currently allocates fresh + # gradient and Hessian arrays per call despite the `!!` name (unlike its + # `value_and_gradient!!`, which does alias `HVPCache` storage). The + # AbstractPPL `!!` contract permits aliasing rather than requiring it, so + # this is conformant; once Mooncake's `HVPCache` is updated to reuse + # output buffers, the returned arrays here will alias automatically with + # no extension change needed. + return Mooncake.value_gradient_and_hessian!!(p.cache.cache, p.cache.target, x) +end + end # module diff --git a/ext/AbstractPPLTestExt.jl b/ext/AbstractPPLTestExt.jl index 3ac01060..16e360c3 100644 --- a/ext/AbstractPPLTestExt.jl +++ b/ext/AbstractPPLTestExt.jl @@ -19,6 +19,19 @@ struct ValueCase jacobian::Any end +# Mirror of `ValueCase` for `order=2` prep + `value_gradient_and_hessian!!`. +# A separate type keeps the order=1 cases narrow and lets `run_testcases` +# dispatch on the prep order without an extra field. +struct HessianCase + name::String + f::Any + x_proto::Any + x::Any + value::Any + gradient::Any + hessian::Any +end + struct ErrorCase name::String f::Any @@ -69,6 +82,29 @@ function AbstractPPL.generate_testcases(::Val{:vector}) ) end +function AbstractPPL.generate_testcases(::Val{:hessian}) + return ( + HessianCase( + "quadratic (scalar output)", + QuadraticProblem(), + zeros(3), + [3.0, 1.0, 2.0], + 14.0, + [6.0, 2.0, 4.0], + [2.0 0.0 0.0; 0.0 2.0 0.0; 0.0 0.0 2.0], + ), + HessianCase( + "empty input, scalar output", + x -> 7.5, + Float64[], + Float64[], + 7.5, + Float64[], + zeros(0, 0), + ), + ) +end + function AbstractPPL.generate_testcases(::Val{:edge}) return ( ErrorCase( @@ -151,6 +187,17 @@ function AbstractPPL.generate_testcases(::Val{:edge}) (prepared, x) -> AbstractPPL.value_and_jacobian!!(prepared, x), r"floating-point", ), + # `value_gradient_and_hessian!!` rejects order=1 preps regardless of + # the underlying problem arity — both paths share the same dispatch + # so one case suffices. + ErrorCase( + "value_gradient_and_hessian!! on order=1 prep", + QuadraticProblem(), + zeros(3), + [3.0, 1.0, 2.0], + (prepared, x) -> AbstractPPL.value_gradient_and_hessian!!(prepared, x), + r"order=2", + ), ) end @@ -190,6 +237,22 @@ function AbstractPPL.run_testcases( return nothing end +function AbstractPPL.run_testcases( + ::Val{:hessian}, prepare_fn=AbstractPPL.prepare; adtype, atol=0, rtol=1e-10 +) + for case in generate_testcases(Val(:hessian)) + @testset "$(case.name)" begin + prepared = prepare_fn(adtype, case.f, case.x_proto; order=2) + @test prepared(case.x) ≈ case.value atol = atol rtol = rtol + val, grad, hess = AbstractPPL.value_gradient_and_hessian!!(prepared, case.x) + @test val ≈ case.value atol = atol rtol = rtol + @test grad ≈ case.gradient atol = atol rtol = rtol + @test hess ≈ case.hessian atol = atol rtol = rtol + end + end + return nothing +end + function AbstractPPL.run_testcases(::Val{:edge}, prepare_fn=AbstractPPL.prepare; adtype) for case in generate_testcases(Val(:edge)) @testset "$(case.name)" begin diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index c70b349d..f9f81247 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -11,7 +11,8 @@ include("abstractmodeltrace.jl") include("abstractprobprog.jl") include("evaluate.jl") include("evaluators/Evaluators.jl") -using .Evaluators: prepare, value_and_gradient!!, value_and_jacobian!! +using .Evaluators: + prepare, value_and_gradient!!, value_and_jacobian!!, value_gradient_and_hessian!! """ generate_testcases(::Val{group}) @@ -19,9 +20,11 @@ using .Evaluators: prepare, value_and_gradient!!, value_and_jacobian!! Return a tuple of test cases for the conformance `group`. Implemented by the `Test` extension (`AbstractPPLTestExt`). Reserved group keys (extensions must not redefine these): `:vector` for value/gradient/jacobian round-trips on -vector-input evaluators; `:namedtuple` for `NamedTuple`-input evaluators; -`:edge` for error-path cases; `:cache_reuse` for repeated calls against a -single prepared evaluator. Downstream packages may add other keys. +vector-input evaluators; `:hessian` for `order=2` value/gradient/Hessian +round-trips on vector-input scalar-output evaluators; `:namedtuple` for +`NamedTuple`-input evaluators; `:edge` for error-path cases; `:cache_reuse` +for repeated calls against a single prepared evaluator. Downstream packages +may add other keys. """ function generate_testcases end @@ -38,7 +41,7 @@ function run_testcases end @static if VERSION >= v"1.11.0" eval( Meta.parse( - "public prepare, value_and_gradient!!, value_and_jacobian!!, generate_testcases, run_testcases", + "public prepare, value_and_gradient!!, value_and_jacobian!!, value_gradient_and_hessian!!, generate_testcases, run_testcases", ), ) end diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index dfae89f3..7e1e32a4 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -42,14 +42,14 @@ Prepared(adtype::AbstractADType, evaluator) = Prepared(adtype, evaluator, nothin """ prepare(problem, values::NamedTuple; check_dims::Bool=true) prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=()) - prepare(adtype, problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=()) + prepare(adtype, problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=(), order::Int=1) Prepare a callable evaluator for `problem`. Use the two-argument form with a `NamedTuple` when the evaluator works with named inputs, or with a vector when it works with vector inputs. The three-argument form, contributed by AD-backend extensions, additionally -prepares gradient or jacobian machinery for vector inputs. +prepares gradient, jacobian, or Hessian machinery for vector inputs. `check_dims` (default `true`) controls whether the returned evaluator validates the input shape on each call. Pass `check_dims=false` to skip the per-call @@ -61,9 +61,15 @@ through to `problem`: the prepared evaluator computes `problem(x, context...)`, and AD backends differentiate only with respect to `x`. `context=()` (the default) preserves the unary `problem(x)` contract. +`order` selects the derivative order to prepare for on the AD-aware form. The +default `order=1` prepares gradient (scalar output) or jacobian (vector output) +machinery. `order=2` prepares Hessian machinery via `value_gradient_and_hessian!!` +and requires `problem` to be scalar-valued — vector-valued problems will throw +during preparation. + The three-argument AD-aware form may invoke `problem` once during preparation -to detect output arity (scalar vs vector) and select gradient or jacobian -machinery accordingly. Avoid `prepare` calls when `problem` has side effects +to detect output arity (scalar vs vector) and select the appropriate +derivative machinery. Avoid `prepare` calls when `problem` has side effects that should fire only on user-driven evaluations. """ function prepare end @@ -99,6 +105,17 @@ The Jacobian has shape `(length(value), length(x))`. """ function value_and_jacobian!! end +""" + value_gradient_and_hessian!!(prepared, x::AbstractVector{<:Real}) + +Return `(value, gradient::AbstractVector, hessian::AbstractMatrix)` for a +scalar-valued evaluator prepared with `order=2`, potentially reusing internal +cache buffers. The returned gradient and Hessian may alias `prepared`'s +internal storage; copy if you need to retain them past the next call. +The Hessian has shape `(length(x), length(x))`. +""" +function value_gradient_and_hessian!! end + """ VectorEvaluator{CheckInput}(f, dim, context::Tuple=()) VectorEvaluator(f, dim, context::Tuple=()) # equivalent to `VectorEvaluator{true}(f, dim, context)` @@ -264,15 +281,38 @@ function _ad_output_arity(y) ) end -# Arity-mismatch errors shared by the DI and Mooncake extensions; kept here so -# the `:edge` testcase regexes (`r"scalar-valued"`, `r"vector-valued"`) pin a -# single error string instead of one per backend. +# Error helpers shared by the DI and Mooncake extensions; kept here so the +# `:edge` testcase regexes (`r"scalar-valued"`, `r"vector-valued"`, `r"order=2"`) +# pin a single error string instead of one per backend. Covers two failure +# modes: +# * arity mismatch — function output doesn't match the requested derivative; +# * wrong prep order — caller asked for the derivative the prep wasn't built +# for, in either direction. function _throw_gradient_needs_scalar() throw(ArgumentError("`value_and_gradient!!` requires a scalar-valued function.")) end function _throw_jacobian_needs_vector() throw(ArgumentError("`value_and_jacobian!!` requires a vector-valued function.")) end +function _throw_hessian_needs_scalar() + throw( + ArgumentError("`value_gradient_and_hessian!!` requires a scalar-valued function.") + ) +end +function _throw_hessian_needs_order_2_prep() + throw( + ArgumentError( + "`value_gradient_and_hessian!!` requires an evaluator prepared with `order=2`." + ), + ) +end +function _throw_use_value_gradient_and_hessian() + throw( + ArgumentError( + "This evaluator was prepared with `order=2`; use `value_gradient_and_hessian!!` to compute its derivatives.", + ), + ) +end # Complements the `typeof` check above: same-typed arrays can differ in `size`. # Arrays with non-`Real`/`Complex` eltype are walked element-wise to catch @@ -324,11 +364,14 @@ function __init__() end # Same fire-only-when-no-backend-loaded logic as the `prepare` hint above. Base.Experimental.register_error_hint(MethodError) do io, exc, args, kwargs - exc.f === value_and_gradient!! || exc.f === value_and_jacobian!! || return nothing + exc.f === value_and_gradient!! || + exc.f === value_and_jacobian!! || + exc.f === value_gradient_and_hessian!! || + return nothing isempty(methods(exc.f)) || return nothing print( io, - "\nNo AD backend extension is loaded. Load `DifferentiationInterface` (with a backend like `ForwardDiff`) or `Mooncake` to enable gradient/jacobian computation.", + "\nNo AD backend extension is loaded. Load `DifferentiationInterface` (with a backend like `ForwardDiff`) or `Mooncake` to enable gradient/jacobian/Hessian computation.", ) end end diff --git a/test/ext/differentiationinterface/main.jl b/test/ext/differentiationinterface/main.jl index 636f237e..9fc41639 100644 --- a/test/ext/differentiationinterface/main.jl +++ b/test/ext/differentiationinterface/main.jl @@ -3,9 +3,10 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) Pkg.instantiate() -using AbstractPPL: AbstractPPL, prepare, run_testcases, value_and_gradient!! +using AbstractPPL: + AbstractPPL, prepare, run_testcases, value_and_gradient!!, value_gradient_and_hessian!! using ADTypes: AutoForwardDiff, AutoReverseDiff -using DifferentiationInterface: DifferentiationInterface as DI +using DifferentiationInterface: DifferentiationInterface as DI, SecondOrder using ForwardDiff using ReverseDiff using Test @@ -17,6 +18,7 @@ quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) @testset "AbstractPPLDifferentiationInterfaceExt" begin @testset "ForwardDiff" begin run_testcases(Val(:vector); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) + run_testcases(Val(:hessian); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) run_testcases(Val(:cache_reuse); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) run_testcases(Val(:edge); adtype=AutoForwardDiff()) end @@ -57,4 +59,28 @@ quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) @inferred value_and_gradient!!(prep_closure, x) @inferred value_and_gradient!!(prep_ctx, x) end + + # `SecondOrder(outer, inner)` lets the caller pick the inner gradient + # backend and the outer differentiator independently — useful when the + # default Hessian strategy DI picks for a single `adtype` is suboptimal. + # Since `SecondOrder <: AbstractADType`, the existing `order=2` dispatch + # routes it through `DI.prepare_hessian` / `DI.value_gradient_and_hessian` + # without any extension-side changes. + @testset "SecondOrder for order=2" begin + adtype = SecondOrder(AutoForwardDiff(), AutoReverseDiff()) + x = [1.0, 2.0, 3.0] + prep = prepare(adtype, quadratic, zeros(3); order=2) + val, grad, hess = value_gradient_and_hessian!!(prep, x) + @test val ≈ 14.0 + @test grad ≈ [2.0, 4.0, 6.0] + @test hess ≈ [2.0 0 0; 0 2.0 0; 0 0 2.0] + + # `context=` composes with `SecondOrder` the same way as for a plain `adtype`. + affine(y, a, b) = a * sum(abs2, y) + b + prep_ctx = prepare(adtype, affine, zeros(3); context=(2.0, 1.0), order=2) + val_ctx, grad_ctx, hess_ctx = value_gradient_and_hessian!!(prep_ctx, x) + @test val_ctx ≈ affine(x, 2.0, 1.0) + @test grad_ctx ≈ [4.0, 8.0, 12.0] + @test hess_ctx ≈ [4.0 0 0; 0 4.0 0; 0 0 4.0] + end end diff --git a/test/ext/mooncake/main.jl b/test/ext/mooncake/main.jl index 3c328a1a..855d5f5b 100644 --- a/test/ext/mooncake/main.jl +++ b/test/ext/mooncake/main.jl @@ -18,6 +18,10 @@ using Test run_testcases(Val(:namedtuple); adtype=adtype, atol=1e-6, rtol=1e-6) run_testcases(Val(:cache_reuse); adtype=adtype, atol=1e-6, rtol=1e-6) run_testcases(Val(:edge); adtype=adtype) + # Hessian (`order=2`) is reverse-mode only on the AutoMooncake side; + # AutoMooncakeForward routes through the same generic Hessian path + # since `Mooncake.prepare_hessian_cache` is mode-agnostic. + run_testcases(Val(:hessian); adtype=adtype, atol=1e-6, rtol=1e-6) end end From ccd76b93c2c0b2cee84f840bb7fc94e621afdb0e Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 19 May 2026 22:56:18 +0100 Subject: [PATCH 2/5] Split Hessian edge case into separate :hessian_edge testcase group Move the order=1-prep error case for value_gradient_and_hessian!! out of :edge and into a new :hessian_edge group so Hessian-specific edge checks are only exercised by preparations that support order=2. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLTestExt.jl | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/ext/AbstractPPLTestExt.jl b/ext/AbstractPPLTestExt.jl index 16e360c3..4322fbc0 100644 --- a/ext/AbstractPPLTestExt.jl +++ b/ext/AbstractPPLTestExt.jl @@ -105,6 +105,22 @@ function AbstractPPL.generate_testcases(::Val{:hessian}) ) end +function AbstractPPL.generate_testcases(::Val{:hessian_edge}) + return ( + # `value_gradient_and_hessian!!` rejects order=1 preps regardless of + # the underlying problem arity — both paths share the same dispatch + # so one case suffices. + ErrorCase( + "value_gradient_and_hessian!! on order=1 prep", + QuadraticProblem(), + zeros(3), + [3.0, 1.0, 2.0], + (prepared, x) -> AbstractPPL.value_gradient_and_hessian!!(prepared, x), + r"order=2", + ), + ) +end + function AbstractPPL.generate_testcases(::Val{:edge}) return ( ErrorCase( @@ -187,17 +203,6 @@ function AbstractPPL.generate_testcases(::Val{:edge}) (prepared, x) -> AbstractPPL.value_and_jacobian!!(prepared, x), r"floating-point", ), - # `value_gradient_and_hessian!!` rejects order=1 preps regardless of - # the underlying problem arity — both paths share the same dispatch - # so one case suffices. - ErrorCase( - "value_gradient_and_hessian!! on order=1 prep", - QuadraticProblem(), - zeros(3), - [3.0, 1.0, 2.0], - (prepared, x) -> AbstractPPL.value_gradient_and_hessian!!(prepared, x), - r"order=2", - ), ) end @@ -250,6 +255,12 @@ function AbstractPPL.run_testcases( @test hess ≈ case.hessian atol = atol rtol = rtol end end + for case in generate_testcases(Val(:hessian_edge)) + @testset "$(case.name)" begin + prepared = prepare_fn(adtype, case.f, case.x_proto) + @test_throws case.exception case.op(prepared, case.x) + end + end return nothing end From ab74b6e0b2bfcc12e3f41c4a0bd57f0621c5d9c8 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Tue, 19 May 2026 22:58:26 +0100 Subject: [PATCH 3/5] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0e7ce54d..e0b3fa4c 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probabilistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.15" +version = "0.15.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From b98f027f1050bdbef163e711d27324bfcaaa3a9e Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 21 May 2026 12:30:55 +0100 Subject: [PATCH 4/5] Encode derivative order on Prepared; split DICache into per-arity types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Prepared` gains an `Order` type parameter (`1` for gradient/jacobian, `2` for Hessian) with an `order(::Prepared)` accessor, so the prep order can be retrieved reliably without inspecting the backend-specific cache type. `value_and_gradient!!` on an order=2 prep now returns `(value, gradient)` via a dedicated gradient prep built alongside the Hessian prep — no O(n²) Hessian work for a gradient-only call. For DI's `SecondOrder` backend the gradient prep uses `DI.inner(adtype)` per DI's convention; the same unwrap runs on the hot path so prep and call use matching adtypes. `order` is now validated up-front via `Evaluators._validate_ad_order` (was duplicated across both extensions and fired only after the structural prep had already called `problem` once). DI: `DICache` is replaced by three concrete types — `DIGradientCache`, `DIJacobianCache`, `DIHessianCache` — eliminating the 6-nullable-field struct and runtime `=== nothing` checks. `_di_call_shape` is the shared target-and-constants helper used by both `_prepare_di` (order=1) and the order=2 path; the two preps share one target instance so compiled-tape ReverseDiff sees a consistent `Fix2` closure. Mooncake: `MooncakeCache` gains a `gradient_cache` field populated only at order=2; `_mooncake_gradient_cache` is now used by the NamedTuple path, the order=1 scalar branch, and the order=2 gradient prep. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/src/evaluators.md | 9 +- ext/AbstractPPLDifferentiationInterfaceExt.jl | 282 +++++++++++------- ext/AbstractPPLMooncakeExt.jl | 124 +++++--- ext/AbstractPPLTestExt.jl | 9 +- src/AbstractPPL.jl | 4 +- src/evaluators/Evaluators.jl | 54 ++-- test/ext/differentiationinterface/main.jl | 26 +- 7 files changed, 316 insertions(+), 192 deletions(-) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index 5b03b7b0..7e1b1282 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -183,10 +183,11 @@ prepared = prepare(adtype, quadratic, zeros(3); order=2) entry handles it. Calling `value_gradient_and_hessian!!` on an `order=1` prep throws an -`ArgumentError` — re-prepare with `order=2` instead. Likewise, calling -`value_and_gradient!!` or `value_and_jacobian!!` on an `order=2` prep is -unsupported; use `value_gradient_and_hessian!!` and discard the unused -return value. +`ArgumentError` — re-prepare with `order=2` instead. The reverse is allowed: +`value_and_gradient!!` on an `order=2` prep returns `(value, gradient)` +without paying the Hessian cost, since `prepare` builds a dedicated +gradient prep alongside the Hessian one. `value_and_jacobian!!` is rejected +because `order=2` requires a scalar-valued problem. ## Constant context arguments diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 3a8f993c..989090d8 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -5,39 +5,55 @@ using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_ using ADTypes: AbstractADType, AutoReverseDiff using DifferentiationInterface: DifferentiationInterface as DI -# AD target used by every `DICache` mode. `Vararg{Any,N}` with a free `N` -# forces specialization on the trailing arity (a bare `Vararg{Any}` would -# skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the -# constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on -# the closure path — empty `ctx` then makes the splat a no-op. +# AD target used by every DI cache. `Vararg{Any,N}` with a free `N` forces +# specialization on the trailing arity (a bare `Vararg{Any}` would skip it). +# DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the constants path, +# and as `_call_evaluator(x, evaluator)` (via `Fix2`) on the closure path — +# empty `ctx` then makes the splat a no-op. @inline _call_evaluator(x, f::F, ctx::Vararg{Any,N}) where {F,N} = f(x, ctx...) -# `Mode` tags the cache shape: -# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure, the +# `Mode` tags the call shape: +# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure; the # AD call passes **0** `DI.Constant`s. -# * `N::Int` — constants path: `N == length(evaluator.context)`, the AD +# * `N::Int` — constants path: `N == length(evaluator.context)`; the AD # call passes **N + 1** `DI.Constant`s (`f` plus the `N` # context values). -# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*` at -# compile time without a runtime branch. -# -# Single cache for every derivative order. At most one of `gradient_prep`, -# `jacobian_prep`, `hessian_prep` is non-`Nothing` at any time; the hot-path -# methods discriminate via `=== nothing` checks (folded at compile time since -# field types are concrete in each instantiation). `grad_buf` / `hess_buf` are -# non-`Nothing` only for order=2 — caller-owned output buffers handed to -# `DI.value_gradient_and_hessian!`. Returned arrays alias them (`!!` contract). -struct DICache{Mode,F,GP,JP,HP,G,H} +# Encoding `Mode` in each cache type resolves the closure-vs-constants dispatch +# in `_di_value_and_*` at compile time without a runtime branch. + +# `Nothing` in the prep slot flags the empty-input cache (DI prep paths fail +# on length-0 input, e.g. ForwardDiff `BoundsError`). Hot paths dispatch on the +# `Nothing` parameter to short-circuit before any DI call. Same convention for +# `DIJacobianCache` and `DIHessianCache` below. +struct DIGradientCache{Mode,F,GP} target::F gradient_prep::GP + function DIGradientCache(target::F, gp::GP, ::Val{Mode}) where {Mode,F,GP} + return new{Mode,F,GP}(target, gp) + end +end + +struct DIJacobianCache{Mode,F,JP} + target::F jacobian_prep::JP + function DIJacobianCache(target::F, jp::JP, ::Val{Mode}) where {Mode,F,JP} + return new{Mode,F,JP}(target, jp) + end +end + +# Order=2 (scalar-output). `grad_buf` / `hess_buf` are caller-owned output +# buffers handed to `DI.value_gradient_and_hessian!`; the returned arrays alias +# them (`!!` contract). +struct DIHessianCache{Mode,F,GP,HP,G,H} + target::F + gradient_prep::GP hessian_prep::HP grad_buf::G hess_buf::H - function DICache{Mode}( - target::F, gp::GP, jp::JP, hp::HP, g::G, h::H - ) where {Mode,F,GP,JP,HP,G,H} - return new{Mode,F,GP,JP,HP,G,H}(target, gp, jp, hp, g, h) + function DIHessianCache( + target::F, gp::GP, hp::HP, g::G, h::H, ::Val{Mode} + ) where {Mode,F,GP,HP,G,H} + return new{Mode,F,GP,HP,G,H}(target, gp, hp, g, h) end end @@ -46,22 +62,28 @@ end # target and call DI without constants. Context (if any) is captured inside # the evaluator closure rather than lowered out — the lowered path would also # require a closure here, so the wrapper cost is unavoidable for compiled tapes. -function _prepare_di(prep::F, adtype::AutoReverseDiff{true}, x, evaluator) where {F} - target = Base.Fix2(_call_evaluator, evaluator) - return target, prep(target, adtype, x), Val(:closure) +# +# `_di_call_shape` returns `(target, mode, constants)`. For the closure path +# `constants == ()` and the splat at every prep/call site collapses to nothing, +# letting prep and call sites share one shape regardless of mode. +function _di_call_shape(::AutoReverseDiff{true}, evaluator) + return Base.Fix2(_call_evaluator, evaluator), Val(:closure), () end - -function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F} - constants = (DI.Constant(evaluator.f), map(DI.Constant, evaluator.context)...) - return ( - _call_evaluator, - prep(_call_evaluator, adtype, x, constants...), - Val(length(evaluator.context)), - ) +function _di_call_shape(::AbstractADType, evaluator) + return _call_evaluator, + Val(length(evaluator.context)), + (DI.Constant(evaluator.f), map(DI.Constant, evaluator.context)...) end -@inline _wrap_cache(target, gp, jp, ::Val{Mode}) where {Mode} = - DICache{Mode}(target, gp, jp, nothing, nothing, nothing) +# `SecondOrder` doesn't define gradient prep; per DI's contract the inner +# adtype is the one used for the first derivative. +@inline _gradient_adtype(adtype::AbstractADType) = adtype +@inline _gradient_adtype(adtype::DI.SecondOrder) = DI.inner(adtype) + +function _prepare_di(prep::F, adtype, x, evaluator) where {F} + target, mode, constants = _di_call_shape(adtype, evaluator) + return target, prep(target, adtype, x, constants...), mode +end function AbstractPPL.prepare( adtype::AbstractADType, @@ -71,132 +93,168 @@ function AbstractPPL.prepare( context::Tuple=(), order::Int=1, ) + Evaluators._validate_ad_order(order) evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) + mode_empty = Val(length(context)) if order == 2 arity === :scalar || Evaluators._throw_hessian_needs_scalar() if length(x) == 0 - # DI Hessian prep crashes on length-0 input; the AD entry - # short-circuits before any DI call. `Val(0)` is a non-`Nothing` - # sentinel for `hessian_prep` so dispatch recognises this as an - # order=2 prep (mirrors the order=1 empty-input pattern below). - cache = _wrap_hessian_cache( - _call_evaluator, Val(0), nothing, nothing, Val(length(context)) + cache = DIHessianCache( + _call_evaluator, nothing, nothing, nothing, nothing, mode_empty ) - return Prepared(adtype, evaluator, cache) + return Prepared(adtype, evaluator, cache, Val(2)) end - target, hessian_prep, mode = _prepare_di(DI.prepare_hessian, adtype, x, evaluator) - # Buffers pre-allocated from `x` (shape and eltype): the hot path is - # zero-allocation on the gradient/Hessian outputs, and the returned - # arrays alias these slots — copy if you need to retain them. - grad_buf = similar(x) - hess_buf = similar(x, length(x), length(x)) - cache = _wrap_hessian_cache(target, hessian_prep, grad_buf, hess_buf, mode) - return Prepared(adtype, evaluator, cache) + # Build both gradient and Hessian preps against the same target so + # `value_and_gradient!!` on the order=2 prep skips the O(n²) Hessian + # cost. Sharing the target matters for compiled-tape ReverseDiff — + # two `Fix2` instances may not be interchangeable in DI. + target, mode, constants = _di_call_shape(adtype, evaluator) + gradient_prep = DI.prepare_gradient( + target, _gradient_adtype(adtype), x, constants... + ) + hessian_prep = DI.prepare_hessian(target, adtype, x, constants...) + # Buffers pre-allocated from `x`: hot path is zero-allocation on the + # gradient/Hessian outputs, returned arrays alias these slots. + cache = DIHessianCache( + target, + gradient_prep, + hessian_prep, + similar(x), + similar(x, length(x), length(x)), + mode, + ) + return Prepared(adtype, evaluator, cache, Val(2)) end - order == 1 || throw(ArgumentError("`order` must be 1 or 2, got $order.")) if length(x) == 0 - # DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`). - # `Val(0)` is an arity sentinel for the `gradient_prep === nothing` - # check below; the AD entry short-circuits before any DI call. - gp, jp = arity === :scalar ? (Val(0), nothing) : (nothing, Val(0)) - cache = _wrap_cache(_call_evaluator, gp, jp, Val(length(context))) + cache = if arity === :scalar + DIGradientCache(_call_evaluator, nothing, mode_empty) + else + DIJacobianCache(_call_evaluator, nothing, mode_empty) + end return Prepared(adtype, evaluator, cache) end if arity === :scalar target, gradient_prep, mode = _prepare_di(DI.prepare_gradient, adtype, x, evaluator) - return Prepared( - adtype, evaluator, _wrap_cache(target, gradient_prep, nothing, mode) - ) + return Prepared(adtype, evaluator, DIGradientCache(target, gradient_prep, mode)) end target, jacobian_prep, mode = _prepare_di(DI.prepare_jacobian, adtype, x, evaluator) - return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, mode)) + return Prepared(adtype, evaluator, DIJacobianCache(target, jacobian_prep, mode)) end -@inline _wrap_hessian_cache(target, hp, g, h, ::Val{Mode}) where {Mode} = - DICache{Mode}(target, nothing, nothing, hp, g, h) +# Hot-path dispatch is by cache type + `Mode` (closure vs constants), both +# resolved at compile time. On the constants path we always pass +# `DI.Constant(eval.f)` plus the `N` context constants — `N == 0` collapses +# the `map` splat to nothing. +const _GradientCapable = Union{DIGradientCache,DIHessianCache} -# Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile -# time. The unconstrained method matches every non-`:closure` `Mode` (i.e. -# any `Int N`); `:closure` is strictly more specific and wins for compiled -# tapes. On the constants path we always pass `DI.Constant(eval.f)` plus the -# `N` context constants — `N == 0` collapses the `map` splat to nothing. -@inline _di_value_and_gradient(c::DICache{:closure}, ad, x, _) = - DI.value_and_gradient(c.target, c.gradient_prep, ad, x) -@inline _di_value_and_gradient(c::DICache, ad, x, eval) = DI.value_and_gradient( +@inline _di_value_and_gradient(c::Union{DIGradientCache{:closure},DIHessianCache{:closure}}, ad, x, _) = DI.value_and_gradient( + c.target, c.gradient_prep, _gradient_adtype(ad), x +) +@inline _di_value_and_gradient(c::_GradientCapable, ad, x, eval) = DI.value_and_gradient( c.target, c.gradient_prep, - ad, + _gradient_adtype(ad), x, DI.Constant(eval.f), map(DI.Constant, eval.context)..., ) -@inline _di_value_and_jacobian(c::DICache{:closure}, ad, x, _) = - DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x) -@inline _di_value_and_jacobian(c::DICache, ad, x, eval) = DI.value_and_jacobian( +@inline _di_value_and_jacobian(c::DIJacobianCache{:closure}, ad, x, _) = DI.value_and_jacobian( + c.target, c.jacobian_prep, ad, x +) +@inline _di_value_and_jacobian(c::DIJacobianCache, ad, x, eval) = DI.value_and_jacobian( + c.target, c.jacobian_prep, ad, x, DI.Constant(eval.f), map(DI.Constant, eval.context)... +) + +@inline _di_value_gradient_and_hessian(c::DIHessianCache{:closure}, ad, x, _) = DI.value_gradient_and_hessian!( + c.target, c.grad_buf, c.hess_buf, c.hessian_prep, ad, x +) +@inline _di_value_gradient_and_hessian(c::DIHessianCache, ad, x, eval) = DI.value_gradient_and_hessian!( c.target, - c.jacobian_prep, + c.grad_buf, + c.hess_buf, + c.hessian_prep, ad, x, DI.Constant(eval.f), map(DI.Constant, eval.context)..., ) +# `value_and_gradient!!`: works on both `DIGradientCache` (order=1 scalar) and +# `DIHessianCache` (order=2). Empty-input caches carry `gradient_prep::Nothing` +# and dispatch to the short-circuit method below; vector-output caches reject. +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{ + <:AbstractADType, + <:VectorEvaluator, + <:Union{DIGradientCache{<:Any,<:Any,Nothing},DIHessianCache{<:Any,<:Any,Nothing}}, + }, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + return (p.evaluator(x), T[]) +end + @inline function AbstractPPL.value_and_gradient!!( - p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:_GradientCapable}, x::AbstractVector{T} ) where {T<:Real} - # Both `=== nothing` branches fold at compile time: each instantiation - # has concrete field types, so only the relevant branch survives. - p.cache.hessian_prep === nothing || Evaluators._throw_use_value_gradient_and_hessian() - p.cache.gradient_prep === nothing && Evaluators._throw_gradient_needs_scalar() Evaluators._check_ad_input(p.evaluator, x) - # Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff - # `BoundsError`); typed `T[]` matches the caller's element type. - length(x) == 0 && return (p.evaluator(x), T[]) return _di_value_and_gradient(p.cache, p.adtype, x, p.evaluator) end +@inline function AbstractPPL.value_and_gradient!!( + ::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIJacobianCache}, + ::AbstractVector{<:Real}, +) + return Evaluators._throw_gradient_needs_scalar() +end + @inline function AbstractPPL.value_and_jacobian!!( - p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIJacobianCache{<:Any,<:Any,Nothing}}, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + val = p.evaluator(x) + return (val, similar(x, length(val), 0)) +end + +@inline function AbstractPPL.value_and_jacobian!!( + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIJacobianCache}, x::AbstractVector{T} ) where {T<:Real} - p.cache.hessian_prep === nothing || Evaluators._throw_use_value_gradient_and_hessian() - p.cache.jacobian_prep === nothing && Evaluators._throw_jacobian_needs_vector() Evaluators._check_ad_input(p.evaluator, x) - if length(x) == 0 - val = p.evaluator(x) - return (val, similar(x, length(val), 0)) - end return _di_value_and_jacobian(p.cache, p.adtype, x, p.evaluator) end -# Hessian hot-path dispatch mirrors the gradient/jacobian helpers above: -# `:closure` (compiled-tape) vs constants `Mode`, resolved at compile time. -# Uses DI's in-place variant `value_gradient_and_hessian!` with caller-owned -# buffers; the returned `(val, grad, hess)` aliases `c.grad_buf` / `c.hess_buf`. -@inline _di_value_gradient_and_hessian(c::DICache{:closure}, ad, x, _) = - DI.value_gradient_and_hessian!(c.target, c.grad_buf, c.hess_buf, c.hessian_prep, ad, x) -@inline _di_value_gradient_and_hessian(c::DICache, ad, x, eval) = - DI.value_gradient_and_hessian!( - c.target, - c.grad_buf, - c.hess_buf, - c.hessian_prep, - ad, - x, - DI.Constant(eval.f), - map(DI.Constant, eval.context)..., - ) +@inline function AbstractPPL.value_and_jacobian!!( + ::Prepared{<:AbstractADType,<:VectorEvaluator,<:_GradientCapable}, + ::AbstractVector{<:Real}, +) + return Evaluators._throw_jacobian_needs_vector() +end @inline function AbstractPPL.value_gradient_and_hessian!!( - p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} + p::Prepared{ + <:AbstractADType,<:VectorEvaluator,<:DIHessianCache{<:Any,<:Any,<:Any,Nothing} + }, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + return (p.evaluator(x), T[], similar(x, 0, 0)) +end + +@inline function AbstractPPL.value_gradient_and_hessian!!( + p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DIHessianCache}, x::AbstractVector{T} ) where {T<:Real} - # Order=1 preps have `hessian_prep === nothing` (compile-folded check). - p.cache.hessian_prep === nothing && Evaluators._throw_hessian_needs_order_2_prep() Evaluators._check_ad_input(p.evaluator, x) - # Empty-input shortcut — same reasoning as the order=1 path. - length(x) == 0 && return (p.evaluator(x), T[], similar(x, 0, 0)) return _di_value_gradient_and_hessian(p.cache, p.adtype, x, p.evaluator) end +@inline function AbstractPPL.value_gradient_and_hessian!!( + ::Prepared{<:AbstractADType,<:VectorEvaluator,<:Union{DIGradientCache,DIJacobianCache}}, + ::AbstractVector{<:Real}, +) + return Evaluators._throw_hessian_needs_order_2_prep() +end + end # module diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index 7fceadfc..41e49e1d 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -30,25 +30,31 @@ Mooncake.tangent_type(::Type{<:NamedTupleEvaluator}) = Mooncake.NoTangent # tangent type prevents differentiation of its fields. # * `C` — the underlying Mooncake cache, or `Nothing` for the # empty-input shortcut. -struct MooncakeCache{A,Target,C} +# * `G` — gradient cache populated only at order=2 so the order=1 +# `value_and_gradient!!` entry on a Hessian prep skips the +# Hessian work. `Nothing` for every order=1 path. +struct MooncakeCache{A,Target,C,G} target::Target cache::C - function MooncakeCache{A}(target::Target, cache::C) where {A,Target,C} - return new{A,Target,C}(target, cache) + gradient_cache::G + function MooncakeCache{A}( + target::Target, cache::C, gradient_cache::G=nothing + ) where {A,Target,C,G} + return new{A,Target,C,G}(target, cache, gradient_cache) end end -# Order=1 convenience: no target wrapper. MooncakeCache{A}(cache) where {A} = MooncakeCache{A}(nothing, cache) _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config -# NamedTuple-path helper: Mooncake exposes separate `prepare_*_cache` -# entries per AD mode but the call shape (target + values) is the same. -function _mooncake_gradient_cache(::AutoMooncake, f, x; config) - return Mooncake.prepare_gradient_cache(f, x; config) +# Mooncake exposes separate `prepare_*_cache` entries per AD mode; the call +# shape (callable + active arg + extra args) is the same. Used by the +# NamedTuple path, the order=1 scalar branch, and the order=2 gradient prep. +function _mooncake_gradient_cache(::AutoMooncake, f, x, args...; config) + Mooncake.prepare_gradient_cache(f, x, args...; config) end -function _mooncake_gradient_cache(::AutoMooncakeForward, f, x; config) - return Mooncake.prepare_derivative_cache(f, x; config) +function _mooncake_gradient_cache(::AutoMooncakeForward, f, x, args...; config) + Mooncake.prepare_derivative_cache(f, x, args...; config) end function AbstractPPL.prepare( @@ -90,6 +96,7 @@ function AbstractPPL.prepare( context::Tuple=(), order::Int=1, ) + Evaluators._validate_ad_order(order) x isa DenseVector || throw( ArgumentError( "AutoMooncake / AutoMooncakeForward require a dense vector input " * @@ -105,12 +112,24 @@ function AbstractPPL.prepare( # `{false}` skips the per-call shape check — `_check_ad_input` on the # AD entry already validates `x`. `dim` is unused for `{false}`. target = VectorEvaluator{false}(evaluator.f, 0, evaluator.context) - length(x) == 0 && - return Prepared(adtype, evaluator, MooncakeCache{:hessian}(target, nothing)) - cache = Mooncake.prepare_hessian_cache(target, x; config) - return Prepared(adtype, evaluator, MooncakeCache{:hessian}(target, cache)) + length(x) == 0 && return Prepared( + adtype, evaluator, MooncakeCache{:hessian}(target, nothing), Val(2) + ) + hess_cache = Mooncake.prepare_hessian_cache(target, x; config) + # Order=1 gradient cache so `value_and_gradient!!` on the same prep + # skips the Hessian work. Mooncake's `value_and_gradient!!` runs on + # `evaluator.f` with context-as-extra-args, distinct from the wrapped + # `target` used by the Hessian API. + grad_cache = _mooncake_gradient_cache( + adtype, evaluator.f, x, evaluator.context...; config + ) + return Prepared( + adtype, + evaluator, + MooncakeCache{:hessian}(target, hess_cache, grad_cache), + Val(2), + ) end - order == 1 || throw(ArgumentError("`order` must be 1 or 2, got $order.")) if !isempty(evaluator.context) && arity !== :scalar throw( ArgumentError( @@ -127,16 +146,12 @@ function AbstractPPL.prepare( # `problem` / `context` kwargs): a downstream override of structural # `prepare` may return a `VectorEvaluator` whose `.f`/`.context` differ # from the caller-supplied values, and the hot path reads them off the - # evaluator. Forward mode uses `prepare_derivative_cache` for both - # arities; the splat is a no-op for vector arity (empty `context`). - cache = if adtype isa AutoMooncake - if arity === :scalar - Mooncake.prepare_gradient_cache(evaluator.f, x, evaluator.context...; config) - else - Mooncake.prepare_pullback_cache(evaluator.f, x; config) - end + # evaluator. The reverse-mode vector branch is the only one that can't + # share `_mooncake_gradient_cache` — it needs `prepare_pullback_cache`. + cache = if arity !== :scalar && adtype isa AutoMooncake + Mooncake.prepare_pullback_cache(evaluator.f, x; config) else - Mooncake.prepare_derivative_cache(evaluator.f, x, evaluator.context...; config) + _mooncake_gradient_cache(adtype, evaluator.f, x, evaluator.context...; config) end return Prepared(adtype, evaluator, MooncakeCache{arity}(cache)) end @@ -157,7 +172,7 @@ end # Empty-input shortcut. `MooncakeCache{:scalar,Nothing,Nothing}` is strictly # more specific than `MooncakeCache{:scalar}`, so dispatch unambiguously selects # this method over the general scalar-gradient hot path below for zero-length -# `x`. (`Target=Nothing` always holds for order=1 — see struct comment.) +# `x`. @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar,Nothing,Nothing}}, x::AbstractVector{T}, @@ -170,31 +185,34 @@ end # `args_to_zero` to mark `x` as the lone active input (`false` on `f`, # `true` on `x`, `false` on each context value); forward mode # (`ForwardCache`) derives activity from its seeded argument and rejects -# the kwarg. The `p.adtype isa AutoMooncake` branch is compile-folded -# since `adtype`'s concrete type lives in `Prepared`'s type parameters. -# Empty `context` collapses the splat and reduces `args_to_zero` to +# the kwarg. The `adtype isa AutoMooncake` branch is compile-folded +# since the concrete type lives in `Prepared`'s type parameters. Empty +# `context` collapses the splat and reduces `args_to_zero` to # `(false, true)`. `tangents[2]` is the `x`-gradient; trailing entries # (one per context value) are inactive and discarded. -@inline function AbstractPPL.value_and_gradient!!( - p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, - x::AbstractVector{T}, -) where {T<:Real} - Evaluators._check_ad_input(p.evaluator, x) - e = p.evaluator - val, tangents = if p.adtype isa AutoMooncake +@inline function _mooncake_value_and_gradient(adtype, gcache, e::VectorEvaluator, x) + val, tangents = if adtype isa AutoMooncake Mooncake.value_and_gradient!!( - p.cache.cache, + gcache, e.f, x, e.context...; args_to_zero=(false, true, map(_ -> false, e.context)...), ) else - Mooncake.value_and_gradient!!(p.cache.cache, e.f, x, e.context...) + Mooncake.value_and_gradient!!(gcache, e.f, x, e.context...) end return (val, tangents[2]) end +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, + x::AbstractVector{<:Real}, +) + Evaluators._check_ad_input(p.evaluator, x) + return _mooncake_value_and_gradient(p.adtype, p.cache.cache, p.evaluator, x) +end + # Arity-mismatch errors as dedicated methods so dispatch on # `MooncakeCache{:scalar}` vs `{:vector}` vs `{:hessian}` resolves at compile # time instead of a runtime check on the cache contents. @@ -205,25 +223,37 @@ end return Evaluators._throw_gradient_needs_scalar() end +# Empty-input shortcut for order=2 preps: same `Nothing` specificity trick +# as the scalar case. `gradient_cache` is `Nothing` only on the empty-x prep. @inline function AbstractPPL.value_and_gradient!!( - ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian}}, - ::AbstractVector{<:Real}, -) - return Evaluators._throw_use_value_gradient_and_hessian() + p::Prepared{ + <:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian,<:Any,Nothing,Nothing} + }, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + return (p.evaluator(x), T[]) end -@inline function AbstractPPL.value_and_jacobian!!( - ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, - ::AbstractVector{<:Real}, +# Order=2 prep also satisfies the order=1 gradient contract via the dedicated +# gradient cache built at prep time — skips the O(n²) Hessian work. +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian}}, + x::AbstractVector{<:Real}, ) - return Evaluators._throw_jacobian_needs_vector() + Evaluators._check_ad_input(p.evaluator, x) + return _mooncake_value_and_gradient(p.adtype, p.cache.gradient_cache, p.evaluator, x) end @inline function AbstractPPL.value_and_jacobian!!( - ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:hessian}}, + ::Prepared{ + <:_MooncakeAD, + <:VectorEvaluator, + <:Union{MooncakeCache{:scalar},MooncakeCache{:hessian}}, + }, ::AbstractVector{<:Real}, ) - return Evaluators._throw_use_value_gradient_and_hessian() + return Evaluators._throw_jacobian_needs_vector() end # Empty-input jacobian shortcut. Same `Nothing` specificity trick as the diff --git a/ext/AbstractPPLTestExt.jl b/ext/AbstractPPLTestExt.jl index 4322fbc0..c5c74727 100644 --- a/ext/AbstractPPLTestExt.jl +++ b/ext/AbstractPPLTestExt.jl @@ -19,9 +19,6 @@ struct ValueCase jacobian::Any end -# Mirror of `ValueCase` for `order=2` prep + `value_gradient_and_hessian!!`. -# A separate type keeps the order=1 cases narrow and lets `run_testcases` -# dispatch on the prep order without an extra field. struct HessianCase name::String f::Any @@ -226,6 +223,7 @@ function AbstractPPL.run_testcases( for case in generate_testcases(Val(:vector)) @testset "$(case.name)" begin prepared = prepare_fn(adtype, case.f, case.x_proto) + @test AbstractPPL.order(prepared) == 1 @test prepared(case.x) ≈ case.value atol = atol rtol = rtol if case.gradient !== nothing val, grad = AbstractPPL.value_and_gradient!!(prepared, case.x) @@ -248,11 +246,16 @@ function AbstractPPL.run_testcases( for case in generate_testcases(Val(:hessian)) @testset "$(case.name)" begin prepared = prepare_fn(adtype, case.f, case.x_proto; order=2) + @test AbstractPPL.order(prepared) == 2 @test prepared(case.x) ≈ case.value atol = atol rtol = rtol val, grad, hess = AbstractPPL.value_gradient_and_hessian!!(prepared, case.x) @test val ≈ case.value atol = atol rtol = rtol @test grad ≈ case.gradient atol = atol rtol = rtol @test hess ≈ case.hessian atol = atol rtol = rtol + # Order=2 prep also satisfies the order=1 gradient contract. + val1, grad1 = AbstractPPL.value_and_gradient!!(prepared, case.x) + @test val1 ≈ case.value atol = atol rtol = rtol + @test grad1 ≈ case.gradient atol = atol rtol = rtol end end for case in generate_testcases(Val(:hessian_edge)) diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index f9f81247..2a32494d 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -12,7 +12,7 @@ include("abstractprobprog.jl") include("evaluate.jl") include("evaluators/Evaluators.jl") using .Evaluators: - prepare, value_and_gradient!!, value_and_jacobian!!, value_gradient_and_hessian!! + prepare, value_and_gradient!!, value_and_jacobian!!, value_gradient_and_hessian!!, order """ generate_testcases(::Val{group}) @@ -41,7 +41,7 @@ function run_testcases end @static if VERSION >= v"1.11.0" eval( Meta.parse( - "public prepare, value_and_gradient!!, value_and_jacobian!!, value_gradient_and_hessian!!, generate_testcases, run_testcases", + "public prepare, value_and_gradient!!, value_and_jacobian!!, value_gradient_and_hessian!!, order, generate_testcases, run_testcases", ), ) end diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index 7e1e32a4..5dd279c0 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -6,10 +6,14 @@ import ..evaluate!! include("utils.jl") """ - Prepared{AD<:AbstractADType,E,C}(adtype, evaluator, cache) - Prepared(adtype, evaluator) # cache defaults to `nothing` + Prepared{AD<:AbstractADType,E,C,Order}(adtype, evaluator, cache) + Prepared(adtype, evaluator, cache, Val(Order)) + Prepared(adtype, evaluator, cache) # defaults `Order` to 1 + Prepared(adtype, evaluator) # cache defaults to `nothing` -AD-prepared evaluator parameterised by backend type `AD`. +AD-prepared evaluator parameterised by backend type `AD` and derivative order +`Order` (`1` for gradient/jacobian, `2` for Hessian). Retrieve `Order` via +[`order`](@ref). - `adtype` — the backend, used for dispatch. - `evaluator` — the user-facing callable (typically a `VectorEvaluator` or @@ -29,16 +33,38 @@ function AbstractPPL.value_and_gradient!!( end ``` """ -struct Prepared{AD<:AbstractADType,E,C} +struct Prepared{AD<:AbstractADType,E,C,Order} adtype::AD evaluator::E cache::C + function Prepared{AD,E,C,Order}( + adtype, evaluator, cache + ) where {AD<:AbstractADType,E,C,Order} + return new{AD,E,C,Order}(adtype, evaluator, cache) + end end -Prepared(adtype::AbstractADType, evaluator) = Prepared(adtype, evaluator, nothing) +function Prepared( + adtype::AD, evaluator::E, cache::C, ::Val{Order} +) where {AD<:AbstractADType,E,C,Order} + return Prepared{AD,E,C,Order}(adtype, evaluator, cache) +end +function Prepared(adtype::AbstractADType, evaluator, cache) + Prepared(adtype, evaluator, cache, Val(1)) +end +Prepared(adtype::AbstractADType, evaluator) = Prepared(adtype, evaluator, nothing, Val(1)) (p::Prepared)(x) = p.evaluator(x) +""" + order(p::Prepared) + +Return the derivative order `p` was prepared for (`1` for gradient/jacobian, +`2` for Hessian). Type-stable — folds to the `Order` type parameter at compile +time. +""" +order(::Prepared{<:Any,<:Any,<:Any,O}) where {O} = O + """ prepare(problem, values::NamedTuple; check_dims::Bool=true) prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=()) @@ -283,11 +309,7 @@ end # Error helpers shared by the DI and Mooncake extensions; kept here so the # `:edge` testcase regexes (`r"scalar-valued"`, `r"vector-valued"`, `r"order=2"`) -# pin a single error string instead of one per backend. Covers two failure -# modes: -# * arity mismatch — function output doesn't match the requested derivative; -# * wrong prep order — caller asked for the derivative the prep wasn't built -# for, in either direction. +# pin a single error string instead of one per backend. function _throw_gradient_needs_scalar() throw(ArgumentError("`value_and_gradient!!` requires a scalar-valued function.")) end @@ -306,13 +328,11 @@ function _throw_hessian_needs_order_2_prep() ), ) end -function _throw_use_value_gradient_and_hessian() - throw( - ArgumentError( - "This evaluator was prepared with `order=2`; use `value_gradient_and_hessian!!` to compute its derivatives.", - ), - ) -end + +# Validate the `order=` kwarg of `prepare(adtype, problem, x; order)`. Shared by +# the DI and Mooncake extensions so the error string is identical. +@inline _validate_ad_order(order::Int) = + order in (1, 2) || throw(ArgumentError("`order` must be 1 or 2, got $order.")) # Complements the `typeof` check above: same-typed arrays can differ in `size`. # Arrays with non-`Real`/`Complex` eltype are walked element-wise to catch diff --git a/test/ext/differentiationinterface/main.jl b/test/ext/differentiationinterface/main.jl index 9fc41639..dff6473e 100644 --- a/test/ext/differentiationinterface/main.jl +++ b/test/ext/differentiationinterface/main.jl @@ -4,7 +4,12 @@ Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) Pkg.instantiate() using AbstractPPL: - AbstractPPL, prepare, run_testcases, value_and_gradient!!, value_gradient_and_hessian!! + AbstractPPL, + prepare, + run_testcases, + value_and_gradient!!, + value_gradient_and_hessian!!, + order using ADTypes: AutoForwardDiff, AutoReverseDiff using DifferentiationInterface: DifferentiationInterface as DI, SecondOrder using ForwardDiff @@ -23,7 +28,7 @@ quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) run_testcases(Val(:edge); adtype=AutoForwardDiff()) end - # Compiled-tape ReverseDiff goes through the `_prepare_di(::AutoReverseDiff{true}, …)` + # Compiled-tape ReverseDiff goes through the `_di_call_shape(::AutoReverseDiff{true}, …)` # specialisation that closes the evaluator into a `Base.Fix2` target — the # `:cache_reuse` group exercises that path across multiple inputs. @testset "ReverseDiff (compiled tape)" begin @@ -33,20 +38,20 @@ quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) run_testcases(Val(:edge); adtype=adtype) end - # `DICache`'s `Mode` parameter is either `:closure` (compiled-tape + # The DI cache types' `Mode` parameter is either `:closure` (compiled-tape # ReverseDiff) or the integer context length on the constants path. The # constants-path integer also documents how many `DI.Constant`s the AD # call passes. - @testset "DICache encodes the call mode as a type parameter" begin + @testset "DI cache encodes the call mode as a type parameter" begin x = [1.0, 2.0, 3.0] prep_noctx = prepare(AutoForwardDiff(), quadratic, x) prep_closure = prepare(AutoReverseDiff(; compile=true), quadratic, x) affine(y, a, b) = a * sum(abs2, y) + b prep_ctx = prepare(AutoForwardDiff(), affine, x; context=(2.0, 1.0)) - @test prep_noctx.cache isa DIExt.DICache{0} - @test prep_closure.cache isa DIExt.DICache{:closure} - @test prep_ctx.cache isa DIExt.DICache{2} + @test prep_noctx.cache isa DIExt.DIGradientCache{0} + @test prep_closure.cache isa DIExt.DIGradientCache{:closure} + @test prep_ctx.cache isa DIExt.DIGradientCache{2} # Non-empty-context primal matches the underlying `f(x, context...)`. @test prep_ctx(x) == affine(x, 2.0, 1.0) @@ -70,11 +75,18 @@ quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) adtype = SecondOrder(AutoForwardDiff(), AutoReverseDiff()) x = [1.0, 2.0, 3.0] prep = prepare(adtype, quadratic, zeros(3); order=2) + @test order(prep) == 2 val, grad, hess = value_gradient_and_hessian!!(prep, x) @test val ≈ 14.0 @test grad ≈ [2.0, 4.0, 6.0] @test hess ≈ [2.0 0 0; 0 2.0 0; 0 0 2.0] + # `value_and_gradient!!` on a `SecondOrder` order=2 prep routes through + # the inner adtype — the only non-trivial case for `_gradient_adtype`. + val1, grad1 = value_and_gradient!!(prep, x) + @test val1 ≈ 14.0 + @test grad1 ≈ [2.0, 4.0, 6.0] + # `context=` composes with `SecondOrder` the same way as for a plain `adtype`. affine(y, a, b) = a * sum(abs2, y) + b prep_ctx = prepare(adtype, affine, zeros(3); context=(2.0, 1.0), order=2) From d70a945b545c1b250afca774e51a3e656f8363ce Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 21 May 2026 12:36:45 +0100 Subject: [PATCH 5/5] Apply JuliaFormatter v1 to changes from previous commit Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLDifferentiationInterfaceExt.jl | 36 ++++++++++--------- ext/AbstractPPLMooncakeExt.jl | 4 +-- src/evaluators/Evaluators.jl | 2 +- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 989090d8..4ba7c344 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -148,9 +148,9 @@ end # the `map` splat to nothing. const _GradientCapable = Union{DIGradientCache,DIHessianCache} -@inline _di_value_and_gradient(c::Union{DIGradientCache{:closure},DIHessianCache{:closure}}, ad, x, _) = DI.value_and_gradient( - c.target, c.gradient_prep, _gradient_adtype(ad), x -) +@inline _di_value_and_gradient( + c::Union{DIGradientCache{:closure},DIHessianCache{:closure}}, ad, x, _ +) = DI.value_and_gradient(c.target, c.gradient_prep, _gradient_adtype(ad), x) @inline _di_value_and_gradient(c::_GradientCapable, ad, x, eval) = DI.value_and_gradient( c.target, c.gradient_prep, @@ -160,27 +160,31 @@ const _GradientCapable = Union{DIGradientCache,DIHessianCache} map(DI.Constant, eval.context)..., ) -@inline _di_value_and_jacobian(c::DIJacobianCache{:closure}, ad, x, _) = DI.value_and_jacobian( - c.target, c.jacobian_prep, ad, x -) +@inline _di_value_and_jacobian(c::DIJacobianCache{:closure}, ad, x, _) = + DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x) @inline _di_value_and_jacobian(c::DIJacobianCache, ad, x, eval) = DI.value_and_jacobian( - c.target, c.jacobian_prep, ad, x, DI.Constant(eval.f), map(DI.Constant, eval.context)... -) - -@inline _di_value_gradient_and_hessian(c::DIHessianCache{:closure}, ad, x, _) = DI.value_gradient_and_hessian!( - c.target, c.grad_buf, c.hess_buf, c.hessian_prep, ad, x -) -@inline _di_value_gradient_and_hessian(c::DIHessianCache, ad, x, eval) = DI.value_gradient_and_hessian!( c.target, - c.grad_buf, - c.hess_buf, - c.hessian_prep, + c.jacobian_prep, ad, x, DI.Constant(eval.f), map(DI.Constant, eval.context)..., ) +@inline _di_value_gradient_and_hessian(c::DIHessianCache{:closure}, ad, x, _) = + DI.value_gradient_and_hessian!(c.target, c.grad_buf, c.hess_buf, c.hessian_prep, ad, x) +@inline _di_value_gradient_and_hessian(c::DIHessianCache, ad, x, eval) = + DI.value_gradient_and_hessian!( + c.target, + c.grad_buf, + c.hess_buf, + c.hessian_prep, + ad, + x, + DI.Constant(eval.f), + map(DI.Constant, eval.context)..., + ) + # `value_and_gradient!!`: works on both `DIGradientCache` (order=1 scalar) and # `DIHessianCache` (order=2). Empty-input caches carry `gradient_prep::Nothing` # and dispatch to the short-circuit method below; vector-output caches reject. diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index 41e49e1d..3ff32313 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -51,10 +51,10 @@ _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtyp # shape (callable + active arg + extra args) is the same. Used by the # NamedTuple path, the order=1 scalar branch, and the order=2 gradient prep. function _mooncake_gradient_cache(::AutoMooncake, f, x, args...; config) - Mooncake.prepare_gradient_cache(f, x, args...; config) + return Mooncake.prepare_gradient_cache(f, x, args...; config) end function _mooncake_gradient_cache(::AutoMooncakeForward, f, x, args...; config) - Mooncake.prepare_derivative_cache(f, x, args...; config) + return Mooncake.prepare_derivative_cache(f, x, args...; config) end function AbstractPPL.prepare( diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index 5dd279c0..124c7134 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -50,7 +50,7 @@ function Prepared( return Prepared{AD,E,C,Order}(adtype, evaluator, cache) end function Prepared(adtype::AbstractADType, evaluator, cache) - Prepared(adtype, evaluator, cache, Val(1)) + return Prepared(adtype, evaluator, cache, Val(1)) end Prepared(adtype::AbstractADType, evaluator) = Prepared(adtype, evaluator, nothing, Val(1))