From 1ebe83e082fe50e2c6ed484969c0b3f7bc0d79ba Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Sat, 9 May 2026 00:29:52 +0100 Subject: [PATCH 01/15] Add Mooncake extension Mooncake AD-backend extension built on the evaluator interface, with the shared conformance suite extended to cover NamedTuple inputs and empty-input arity errors. Squashed from prior incremental commits: - AbstractPPLMooncakeExt: cache reuse, scalar/vector dispatch, NamedTuple inputs via VectorEvaluator/NamedTupleEvaluator wrappers; integration test in test/ext/mooncake. - Evaluators._ad_output_arity: lift the duplicated `Union{Number, AbstractVector}` output check from both extensions into one helper that returns `:scalar` / `:vector` for downstream dispatch. - Empty-input arity tagging (`Val(:scalar)` / `Val(:vector)`) so the empty-input fast path raises the same "requires a scalar/vector-valued function" error as the DI path instead of silently succeeding. - AbstractPPLTestExt: add `Val(:namedtuple)` group (one ValueCase + one ErrorCase); tighten regex assertions on the existing arity-mismatch cases. - check_dims threaded through the inner `prepare` call so AD hot paths can skip per-call shape checks. Co-Authored-By: Claude Opus 4.7 (1M context) --- .github/workflows/CI.yml | 1 + Project.toml | 5 +- ext/AbstractPPLDifferentiationInterfaceExt.jl | 13 +- ext/AbstractPPLMooncakeExt.jl | 133 ++++++++++++++++++ ext/AbstractPPLTestExt.jl | 56 +++++++- src/evaluators/Evaluators.jl | 14 ++ test/Project.toml | 2 +- test/ext/mooncake/Project.toml | 11 ++ test/ext/mooncake/main.jl | 22 +++ test/run_extras.jl | 3 +- 10 files changed, 247 insertions(+), 13 deletions(-) create mode 100644 ext/AbstractPPLMooncakeExt.jl create mode 100644 test/ext/mooncake/Project.toml create mode 100644 test/ext/mooncake/main.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6a4375b1..6a4189a8 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -61,6 +61,7 @@ jobs: matrix: label: - ext/differentiationinterface + - ext/mooncake version: - '1' - 'min' diff --git a/Project.toml b/Project.toml index e767cf77..0e7ce54d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probabilistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.14.3" +version = "0.15" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -21,11 +21,13 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [weakdeps] DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [extensions] AbstractPPLDifferentiationInterfaceExt = ["DifferentiationInterface"] AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"] +AbstractPPLMooncakeExt = ["Mooncake"] AbstractPPLTestExt = ["Test"] [compat] @@ -39,6 +41,7 @@ Distributions = "0.25" JSON = "0.19 - 0.21, 1" LinearAlgebra = "<0.0.1, 1" MacroTools = "0.5" +Mooncake = "0.5.27" OrderedCollections = "1.8.1" Random = "1.6" StatsBase = "0.32, 0.33, 0.34" diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 80f95ee1..d2f35d0c 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -1,7 +1,7 @@ module AbstractPPLDifferentiationInterfaceExt using AbstractPPL: AbstractPPL -using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator +using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_arity using ADTypes: AbstractADType, AutoReverseDiff using DifferentiationInterface: DifferentiationInterface as DI @@ -32,19 +32,14 @@ function AbstractPPL.prepare( adtype::AbstractADType, problem, x::AbstractVector{<:Real}; check_dims::Bool=true ) evaluator = AbstractPPL.prepare(problem, x; check_dims)::VectorEvaluator - y = evaluator(x) - y isa Union{Number,AbstractVector} || throw( - ArgumentError( - "A prepared AD evaluator must return a scalar or AbstractVector; got $(typeof(y)).", - ), - ) + arity = _ad_output_arity(evaluator(x)) if length(x) == 0 # DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`); the # `Val(0)` sentinel keeps the `gradient_prep === nothing` arity check meaningful. - gp, jp = y isa Number ? (Val(0), nothing) : (nothing, Val(0)) + gp, jp = arity === :scalar ? (Val(0), nothing) : (nothing, Val(0)) return Prepared(adtype, evaluator, DICache(_call_evaluator, gp, jp, true)) end - if y isa Number + if arity === :scalar target, gradient_prep, use_context = _prepare_di( DI.prepare_gradient, adtype, x, evaluator ) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl new file mode 100644 index 00000000..ac48df63 --- /dev/null +++ b/ext/AbstractPPLMooncakeExt.jl @@ -0,0 +1,133 @@ +module AbstractPPLMooncakeExt + +using AbstractPPL: AbstractPPL +using AbstractPPL.Evaluators: + Evaluators, + Prepared, + VectorEvaluator, + NamedTupleEvaluator, + _ad_output_arity, + _assert_namedtuple_shape +using ADTypes: AutoMooncake, AutoMooncakeForward +using Mooncake: Mooncake + +const _MooncakeAD = Union{AutoMooncake,AutoMooncakeForward} + +# Tag a Mooncake cache with the prepared evaluator's output arity (`:scalar` +# or `:vector`) so `value_and_gradient!!` / `value_and_jacobian!!` can raise +# helpful arity-mismatch errors instead of failing inside Mooncake. The inner +# `cache` is `Nothing` for the empty-input shortcut path. +struct MooncakeCache{A,C} + cache::C +end +MooncakeCache{A}(cache::C) where {A,C} = MooncakeCache{A,C}(cache) + +_mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config + +# `value_and_gradient!!` accepts either a reverse-mode gradient cache +# (AutoMooncake) or a forward-mode derivative cache (AutoMooncakeForward). +function _mooncake_gradient_cache(::AutoMooncake, f, x; config) + return Mooncake.prepare_gradient_cache(f, x; config=config) +end +function _mooncake_gradient_cache(::AutoMooncakeForward, f, x; config) + return Mooncake.prepare_derivative_cache(f, x; config=config) +end + +# `value_and_jacobian!!`: reverse mode wants a pullback cache, forward mode +# wants a derivative cache. +function _mooncake_jacobian_cache(::AutoMooncake, f, x; config) + return Mooncake.prepare_pullback_cache(f, x; config=config) +end +function _mooncake_jacobian_cache(::AutoMooncakeForward, f, x; config) + return Mooncake.prepare_derivative_cache(f, x; config=config) +end + +function AbstractPPL.prepare( + adtype::_MooncakeAD, problem, values::NamedTuple; check_dims::Bool=true +) + evaluator = AbstractPPL.prepare(problem, values; check_dims)::NamedTupleEvaluator + config = _mooncake_config(adtype) + cache = _mooncake_gradient_cache(adtype, evaluator, values; config) + return Prepared(adtype, evaluator, cache) +end + +function AbstractPPL.prepare( + adtype::_MooncakeAD, problem, x::AbstractVector{<:Real}; check_dims::Bool=true +) + evaluator = AbstractPPL.prepare(problem, x; check_dims)::VectorEvaluator + arity = _ad_output_arity(evaluator(x)) + # Mooncake builds no tape for length-zero `x`; tag with `Nothing` so the + # empty-input methods below shortcut without invoking Mooncake. + length(x) == 0 && return Prepared(adtype, evaluator, MooncakeCache{arity}(nothing)) + config = _mooncake_config(adtype) + cache = if arity === :scalar + _mooncake_gradient_cache(adtype, evaluator, x; config) + else + _mooncake_jacobian_cache(adtype, evaluator, x; config) + end + return Prepared(adtype, evaluator, MooncakeCache{arity}(cache)) +end + +# `Mooncake.value_and_gradient!!` returns `(val, (∂f, ∂x))`; we discard the +# function tangent `∂f` and surface only `∂x` as the user-facing gradient. +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{<:_MooncakeAD,<:NamedTupleEvaluator}, values::NamedTuple +) + _assert_namedtuple_shape(p.evaluator, values) + val, (_, grad) = Mooncake.value_and_gradient!!(p.cache, p.evaluator, values) + return (val, grad) +end + +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar,Nothing}}, + x::AbstractVector{T}, +) where {T<:Real} + T <: Integer && Evaluators._reject_integer_input(x) + Evaluators._check_vector_length(p.evaluator.dim, x) + return (p.evaluator(x), T[]) +end + +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, + x::AbstractVector{T}, +) where {T<:Real} + T <: Integer && Evaluators._reject_integer_input(x) + Evaluators._check_vector_length(p.evaluator.dim, x) + val, (_, grad) = Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x) + return (val, grad) +end + +@inline function AbstractPPL.value_and_gradient!!( + ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector}}, + ::AbstractVector{<:Real}, +) + throw(ArgumentError("`value_and_gradient!!` requires a scalar-valued function.")) +end + +@inline function AbstractPPL.value_and_jacobian!!( + ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, + ::AbstractVector{<:Real}, +) + throw(ArgumentError("`value_and_jacobian!!` requires a vector-valued function.")) +end + +@inline function AbstractPPL.value_and_jacobian!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector,Nothing}}, + x::AbstractVector{T}, +) where {T<:Real} + T <: Integer && Evaluators._reject_integer_input(x) + Evaluators._check_vector_length(p.evaluator.dim, x) + val = p.evaluator(x) + return (val, similar(x, length(val), 0)) +end + +@inline function AbstractPPL.value_and_jacobian!!( + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector}}, + x::AbstractVector{T}, +) where {T<:Real} + T <: Integer && Evaluators._reject_integer_input(x) + Evaluators._check_vector_length(p.evaluator.dim, x) + return Mooncake.value_and_jacobian!!(p.cache.cache, p.evaluator, x) +end + +end # module diff --git a/ext/AbstractPPLTestExt.jl b/ext/AbstractPPLTestExt.jl index 6f7463da..3bd3f86a 100644 --- a/ext/AbstractPPLTestExt.jl +++ b/ext/AbstractPPLTestExt.jl @@ -93,7 +93,15 @@ function AbstractPPL.generate_testcases(::Val{:edge}) zeros(3), [2.0, 3.0, 4.0], (prepared, x) -> AbstractPPL.value_and_gradient!!(prepared, x), - ArgumentError, + r"scalar-valued", + ), + ErrorCase( + "jacobian of scalar output", + QuadraticProblem(), + zeros(3), + [3.0, 1.0, 2.0], + (prepared, x) -> AbstractPPL.value_and_jacobian!!(prepared, x), + r"vector-valued", ), ErrorCase( "gradient of vector-valued output, empty input", @@ -178,4 +186,50 @@ function AbstractPPL.run_testcases(::Val{:edge}, prepare_fn=AbstractPPL.prepare; return nothing end +function AbstractPPL.generate_testcases(::Val{:namedtuple}) + return ( + ValueCase( + "scalar output over (x::Real, y::Vector)", + vs -> vs.x^2 + sum(abs2, vs.y), + (x=0.0, y=zeros(2)), + (x=3.0, y=[1.0, 2.0]), + 14.0, + (x=6.0, y=[2.0, 4.0]), + nothing, + ), + ErrorCase( + "wrong NamedTuple structure", + vs -> vs.x^2 + sum(abs2, vs.y), + (x=0.0, y=zeros(2)), + (x=3.0, z=[1.0, 2.0]), + (prepared, x) -> AbstractPPL.value_and_gradient!!(prepared, x), + r"same NamedTuple structure", + ), + ) +end + +function AbstractPPL.run_testcases( + ::Val{:namedtuple}, prepare_fn=AbstractPPL.prepare; adtype, atol=0, rtol=1e-10 +) + for case in generate_testcases(Val(:namedtuple)) + @testset "$(case.name)" begin + prepared = prepare_fn(adtype, case.f, case.x_proto) + if case isa ErrorCase + @test_throws case.exception case.op(prepared, case.x) + continue + end + @test prepared(case.x) ≈ case.value atol = atol rtol = rtol + if case.gradient !== nothing + val, grad = AbstractPPL.value_and_gradient!!(prepared, case.x) + @test val ≈ case.value atol = atol rtol = rtol + for k in keys(case.gradient) + @test getproperty(grad, k) ≈ getproperty(case.gradient, k) atol = atol rtol = + rtol + end + end + end + end + return nothing +end + end # module diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index a88f2910..aff28353 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -218,6 +218,20 @@ function _assert_namedtuple_shape(e::NamedTupleEvaluator{true}, values) end _assert_namedtuple_shape(::NamedTupleEvaluator{false}, _) = nothing +# Classify the output of a probe `evaluator(x)` call into the two arities the +# AD interface supports — `:scalar` routes to gradient prep, `:vector` to +# jacobian prep. Shared by the DI and Mooncake extensions so both surface the +# same error message for unsupported output types. +function _ad_output_arity(y) + y isa Number && return :scalar + y isa AbstractVector && return :vector + throw( + ArgumentError( + "A prepared AD evaluator must return a scalar or AbstractVector; got $(typeof(y)).", + ), + ) +end + # Complements the `typeof` check above: same-typed arrays can differ in `size`. # Arrays with non-`Real`/`Complex` eltype are walked element-wise to catch # inner mismatches. Unknown leaves throw, mirroring the supported-leaves diff --git a/test/Project.toml b/test/Project.toml index 1bd57d0d..122f7e4e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -16,7 +16,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] ADTypes = "1" -AbstractPPL = "0.14" +AbstractPPL = "0.15" Accessors = "0.1" Aqua = "0.8" DimensionalData = "0.29, 0.30" diff --git a/test/ext/mooncake/Project.toml b/test/ext/mooncake/Project.toml new file mode 100644 index 00000000..6a5c2039 --- /dev/null +++ b/test/ext/mooncake/Project.toml @@ -0,0 +1,11 @@ +[deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[compat] +ADTypes = "1" +Mooncake = "0.5.27" +julia = "1.10" diff --git a/test/ext/mooncake/main.jl b/test/ext/mooncake/main.jl new file mode 100644 index 00000000..55723c7f --- /dev/null +++ b/test/ext/mooncake/main.jl @@ -0,0 +1,22 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) +Pkg.instantiate() + +using AbstractPPL: AbstractPPL, prepare, run_testcases +using ADTypes: AutoMooncake, AutoMooncakeForward +using Mooncake +using Test + +@testset "AbstractPPLMooncakeExt" begin + for (label, adtype) in ( + ("Mooncake (reverse)", AutoMooncake()), + ("Mooncake (forward)", AutoMooncakeForward()), + ) + @testset "$label" begin + run_testcases(Val(:vector); adtype=adtype, atol=1e-6, rtol=1e-6) + run_testcases(Val(:namedtuple); adtype=adtype, atol=1e-6, rtol=1e-6) + run_testcases(Val(:edge); adtype=adtype) + end + end +end diff --git a/test/run_extras.jl b/test/run_extras.jl index e557f363..cd2c157e 100644 --- a/test/run_extras.jl +++ b/test/run_extras.jl @@ -2,8 +2,9 @@ # # Usage (from the repo root): # LABEL=ext/differentiationinterface julia test/run_extras.jl +# LABEL=ext/mooncake julia test/run_extras.jl -const VALID_LABELS = ("ext/differentiationinterface",) +const VALID_LABELS = ("ext/differentiationinterface", "ext/mooncake") label = get(ENV, "LABEL", nothing) label in VALID_LABELS || From d3d94935c20591b70ccef66977720086cb656fcd Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Sat, 9 May 2026 00:37:30 +0100 Subject: [PATCH 02/15] Factor AD-input validation; reorder NamedTuple test group `_check_ad_input(evaluator, x)` in `Evaluators` replaces the duplicated `T <: Integer` rejection plus length check that appeared at six AD entry points (two in the DI extension, four in Mooncake). Compile-time `T` elision is preserved. Move `generate_testcases(::Val{:namedtuple})` and `run_testcases(::Val{:namedtuple})` to sit alongside the `:vector` and `:edge` definitions so the file reads generate-then-run for all three groups. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLDifferentiationInterfaceExt.jl | 6 +-- ext/AbstractPPLMooncakeExt.jl | 15 +++---- ext/AbstractPPLTestExt.jl | 44 +++++++++---------- src/evaluators/Evaluators.jl | 8 ++++ 4 files changed, 37 insertions(+), 36 deletions(-) diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index d2f35d0c..b5ee8b72 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -58,8 +58,7 @@ end ) where {T<:Real} p.cache.gradient_prep === nothing && throw(ArgumentError("`value_and_gradient!!` requires a scalar-valued function.")) - T <: Integer && Evaluators._reject_integer_input(x) - Evaluators._check_vector_length(p.evaluator.dim, x) + Evaluators._check_ad_input(p.evaluator, x) # Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff # `BoundsError`); typed `T[]` matches the caller's element type. length(x) == 0 && return (p.evaluator(x), T[]) @@ -77,8 +76,7 @@ end ) where {T<:Real} p.cache.jacobian_prep === nothing && throw(ArgumentError("`value_and_jacobian!!` requires a vector-valued function.")) - T <: Integer && Evaluators._reject_integer_input(x) - Evaluators._check_vector_length(p.evaluator.dim, x) + Evaluators._check_ad_input(p.evaluator, x) if length(x) == 0 val = p.evaluator(x) return (val, similar(x, length(val), 0)) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index ac48df63..d483ec46 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -15,8 +15,7 @@ const _MooncakeAD = Union{AutoMooncake,AutoMooncakeForward} # Tag a Mooncake cache with the prepared evaluator's output arity (`:scalar` # or `:vector`) so `value_and_gradient!!` / `value_and_jacobian!!` can raise -# helpful arity-mismatch errors instead of failing inside Mooncake. The inner -# `cache` is `Nothing` for the empty-input shortcut path. +# helpful arity-mismatch errors instead of failing inside Mooncake. struct MooncakeCache{A,C} cache::C end @@ -82,8 +81,7 @@ end p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar,Nothing}}, x::AbstractVector{T}, ) where {T<:Real} - T <: Integer && Evaluators._reject_integer_input(x) - Evaluators._check_vector_length(p.evaluator.dim, x) + Evaluators._check_ad_input(p.evaluator, x) return (p.evaluator(x), T[]) end @@ -91,8 +89,7 @@ end p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, x::AbstractVector{T}, ) where {T<:Real} - T <: Integer && Evaluators._reject_integer_input(x) - Evaluators._check_vector_length(p.evaluator.dim, x) + Evaluators._check_ad_input(p.evaluator, x) val, (_, grad) = Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x) return (val, grad) end @@ -115,8 +112,7 @@ end p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector,Nothing}}, x::AbstractVector{T}, ) where {T<:Real} - T <: Integer && Evaluators._reject_integer_input(x) - Evaluators._check_vector_length(p.evaluator.dim, x) + Evaluators._check_ad_input(p.evaluator, x) val = p.evaluator(x) return (val, similar(x, length(val), 0)) end @@ -125,8 +121,7 @@ end p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector}}, x::AbstractVector{T}, ) where {T<:Real} - T <: Integer && Evaluators._reject_integer_input(x) - Evaluators._check_vector_length(p.evaluator.dim, x) + Evaluators._check_ad_input(p.evaluator, x) return Mooncake.value_and_jacobian!!(p.cache.cache, p.evaluator, x) end diff --git a/ext/AbstractPPLTestExt.jl b/ext/AbstractPPLTestExt.jl index 3bd3f86a..07bf45ff 100644 --- a/ext/AbstractPPLTestExt.jl +++ b/ext/AbstractPPLTestExt.jl @@ -154,6 +154,28 @@ function AbstractPPL.generate_testcases(::Val{:edge}) ) end +function AbstractPPL.generate_testcases(::Val{:namedtuple}) + return ( + ValueCase( + "scalar output over (x::Real, y::Vector)", + vs -> vs.x^2 + sum(abs2, vs.y), + (x=0.0, y=zeros(2)), + (x=3.0, y=[1.0, 2.0]), + 14.0, + (x=6.0, y=[2.0, 4.0]), + nothing, + ), + ErrorCase( + "wrong NamedTuple structure", + vs -> vs.x^2 + sum(abs2, vs.y), + (x=0.0, y=zeros(2)), + (x=3.0, z=[1.0, 2.0]), + (prepared, x) -> AbstractPPL.value_and_gradient!!(prepared, x), + r"same NamedTuple structure", + ), + ) +end + function AbstractPPL.run_testcases( ::Val{:vector}, prepare_fn=AbstractPPL.prepare; adtype, atol=0, rtol=1e-10 ) @@ -186,28 +208,6 @@ function AbstractPPL.run_testcases(::Val{:edge}, prepare_fn=AbstractPPL.prepare; return nothing end -function AbstractPPL.generate_testcases(::Val{:namedtuple}) - return ( - ValueCase( - "scalar output over (x::Real, y::Vector)", - vs -> vs.x^2 + sum(abs2, vs.y), - (x=0.0, y=zeros(2)), - (x=3.0, y=[1.0, 2.0]), - 14.0, - (x=6.0, y=[2.0, 4.0]), - nothing, - ), - ErrorCase( - "wrong NamedTuple structure", - vs -> vs.x^2 + sum(abs2, vs.y), - (x=0.0, y=zeros(2)), - (x=3.0, z=[1.0, 2.0]), - (prepared, x) -> AbstractPPL.value_and_gradient!!(prepared, x), - r"same NamedTuple structure", - ), - ) -end - function AbstractPPL.run_testcases( ::Val{:namedtuple}, prepare_fn=AbstractPPL.prepare; adtype, atol=0, rtol=1e-10 ) diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index aff28353..d3b36e69 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -177,6 +177,14 @@ function _check_vector_length(dim::Int, x) return nothing end +# Shared input validation for AD-backend `value_and_{gradient,jacobian}!!` entry +# points. Same compile-time `T <: Integer` elision as the `VectorEvaluator` body. +function _check_ad_input(e::VectorEvaluator, x::AbstractVector{T}) where {T} + T <: Integer && _reject_integer_input(x) + _check_vector_length(e.dim, x) + return nothing +end + function (e::VectorEvaluator{true})(x::AbstractVector{T}) where {T} T <: Integer && _reject_integer_input(x) _check_vector_length(e.dim, x) From 1f8b15783c68d9057bd3442a252ec10e3ff6e23a Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Sat, 9 May 2026 00:57:04 +0100 Subject: [PATCH 03/15] Add cache-reuse tests; tidy AD arity errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `:cache_reuse` conformance group in `AbstractPPLTestExt` drives `value_and_{gradient,jacobian}!!` three times per case against a single `prepared` evaluator to catch backend cache corruption between calls. - DI ext sub-environment now also loads `ReverseDiff` and exercises `AutoReverseDiff(compile=true)` against the conformance suite, covering the `_prepare_di(::AutoReverseDiff{true}, …)` compiled-tape path. - Lift the duplicated `value_and_{gradient,jacobian}!!` arity-mismatch `ArgumentError` strings into shared `Evaluators._throw_*` helpers used by both the DI and Mooncake extensions. - `generate_testcases` docstring lists `:namedtuple` and `:cache_reuse` alongside `:vector` / `:edge` as reserved group keys. - Trim verbose `check_dims` clarifications in docstrings and `docs/src/evaluators.md` to one sentence each. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLDifferentiationInterfaceExt.jl | 6 +-- ext/AbstractPPLMooncakeExt.jl | 22 ++++----- ext/AbstractPPLTestExt.jl | 45 ++++++++++++++----- src/AbstractPPL.jl | 5 ++- src/evaluators/Evaluators.jl | 18 ++++++-- test/ext/differentiationinterface/main.jl | 30 ++++++------- test/ext/mooncake/main.jl | 3 +- 7 files changed, 79 insertions(+), 50 deletions(-) diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index b5ee8b72..732fc796 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -56,8 +56,7 @@ end @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} ) where {T<:Real} - p.cache.gradient_prep === nothing && - throw(ArgumentError("`value_and_gradient!!` requires a scalar-valued function.")) + p.cache.gradient_prep === nothing && Evaluators._throw_gradient_needs_scalar() Evaluators._check_ad_input(p.evaluator, x) # Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff # `BoundsError`); typed `T[]` matches the caller's element type. @@ -74,8 +73,7 @@ end @inline function AbstractPPL.value_and_jacobian!!( p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} ) where {T<:Real} - p.cache.jacobian_prep === nothing && - throw(ArgumentError("`value_and_jacobian!!` requires a vector-valued function.")) + p.cache.jacobian_prep === nothing && Evaluators._throw_jacobian_needs_vector() Evaluators._check_ad_input(p.evaluator, x) if length(x) == 0 val = p.evaluator(x) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index d483ec46..3180650b 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -2,12 +2,7 @@ module AbstractPPLMooncakeExt using AbstractPPL: AbstractPPL using AbstractPPL.Evaluators: - Evaluators, - Prepared, - VectorEvaluator, - NamedTupleEvaluator, - _ad_output_arity, - _assert_namedtuple_shape + Evaluators, Prepared, VectorEvaluator, NamedTupleEvaluator, _ad_output_arity using ADTypes: AutoMooncake, AutoMooncakeForward using Mooncake: Mooncake @@ -26,19 +21,19 @@ _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtyp # `value_and_gradient!!` accepts either a reverse-mode gradient cache # (AutoMooncake) or a forward-mode derivative cache (AutoMooncakeForward). function _mooncake_gradient_cache(::AutoMooncake, f, x; config) - return Mooncake.prepare_gradient_cache(f, x; config=config) + return Mooncake.prepare_gradient_cache(f, x; config) end function _mooncake_gradient_cache(::AutoMooncakeForward, f, x; config) - return Mooncake.prepare_derivative_cache(f, x; config=config) + return Mooncake.prepare_derivative_cache(f, x; config) end # `value_and_jacobian!!`: reverse mode wants a pullback cache, forward mode # wants a derivative cache. function _mooncake_jacobian_cache(::AutoMooncake, f, x; config) - return Mooncake.prepare_pullback_cache(f, x; config=config) + return Mooncake.prepare_pullback_cache(f, x; config) end function _mooncake_jacobian_cache(::AutoMooncakeForward, f, x; config) - return Mooncake.prepare_derivative_cache(f, x; config=config) + return Mooncake.prepare_derivative_cache(f, x; config) end function AbstractPPL.prepare( @@ -69,10 +64,11 @@ end # `Mooncake.value_and_gradient!!` returns `(val, (∂f, ∂x))`; we discard the # function tangent `∂f` and surface only `∂x` as the user-facing gradient. +# Shape validation is delegated to the inner `NamedTupleEvaluator{CheckInput}` +# callable Mooncake invokes — gated by the user's `check_dims` choice. @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:_MooncakeAD,<:NamedTupleEvaluator}, values::NamedTuple ) - _assert_namedtuple_shape(p.evaluator, values) val, (_, grad) = Mooncake.value_and_gradient!!(p.cache, p.evaluator, values) return (val, grad) end @@ -98,14 +94,14 @@ end ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector}}, ::AbstractVector{<:Real}, ) - throw(ArgumentError("`value_and_gradient!!` requires a scalar-valued function.")) + return Evaluators._throw_gradient_needs_scalar() end @inline function AbstractPPL.value_and_jacobian!!( ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, ::AbstractVector{<:Real}, ) - throw(ArgumentError("`value_and_jacobian!!` requires a vector-valued function.")) + return Evaluators._throw_jacobian_needs_vector() end @inline function AbstractPPL.value_and_jacobian!!( diff --git a/ext/AbstractPPLTestExt.jl b/ext/AbstractPPLTestExt.jl index 07bf45ff..3ac01060 100644 --- a/ext/AbstractPPLTestExt.jl +++ b/ext/AbstractPPLTestExt.jl @@ -165,14 +165,6 @@ function AbstractPPL.generate_testcases(::Val{:namedtuple}) (x=6.0, y=[2.0, 4.0]), nothing, ), - ErrorCase( - "wrong NamedTuple structure", - vs -> vs.x^2 + sum(abs2, vs.y), - (x=0.0, y=zeros(2)), - (x=3.0, z=[1.0, 2.0]), - (prepared, x) -> AbstractPPL.value_and_gradient!!(prepared, x), - r"same NamedTuple structure", - ), ) end @@ -214,10 +206,6 @@ function AbstractPPL.run_testcases( for case in generate_testcases(Val(:namedtuple)) @testset "$(case.name)" begin prepared = prepare_fn(adtype, case.f, case.x_proto) - if case isa ErrorCase - @test_throws case.exception case.op(prepared, case.x) - continue - end @test prepared(case.x) ≈ case.value atol = atol rtol = rtol if case.gradient !== nothing val, grad = AbstractPPL.value_and_gradient!!(prepared, case.x) @@ -232,4 +220,37 @@ function AbstractPPL.run_testcases( return nothing end +# Drive `value_and_{gradient,jacobian}!!` twice with different inputs against +# the same `prepared` evaluator to exercise cache reuse — catches backends +# whose cache state is corrupted by a prior call. +function AbstractPPL.run_testcases( + ::Val{:cache_reuse}, prepare_fn=AbstractPPL.prepare; adtype, atol=0, rtol=1e-10 +) + @testset "scalar output, repeated calls" begin + prepared = prepare_fn(adtype, QuadraticProblem(), zeros(3)) + for (x, value, gradient) in ( + ([1.0, 2.0, 3.0], 14.0, [2.0, 4.0, 6.0]), + ([4.0, 5.0, 6.0], 77.0, [8.0, 10.0, 12.0]), + ([0.5, -1.0, 2.0], 5.25, [1.0, -2.0, 4.0]), + ) + val, grad = AbstractPPL.value_and_gradient!!(prepared, x) + @test val ≈ value atol = atol rtol = rtol + @test grad ≈ gradient atol = atol rtol = rtol + end + end + @testset "vector output, repeated calls" begin + prepared = prepare_fn(adtype, VectorValuedProblem(), zeros(3)) + for (x, value, jacobian) in ( + ([2.0, 3.0, 4.0], [6.0, 7.0], [3.0 2.0 0.0; 0.0 1.0 1.0]), + ([5.0, 1.0, 7.0], [5.0, 8.0], [1.0 5.0 0.0; 0.0 1.0 1.0]), + ([0.0, 4.0, -2.0], [0.0, 2.0], [4.0 0.0 0.0; 0.0 1.0 1.0]), + ) + val, jac = AbstractPPL.value_and_jacobian!!(prepared, x) + @test val ≈ value atol = atol rtol = rtol + @test jac ≈ jacobian atol = atol rtol = rtol + end + end + return nothing +end + end # module diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 0b50ca9a..c70b349d 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -19,8 +19,9 @@ using .Evaluators: prepare, value_and_gradient!!, value_and_jacobian!! Return a tuple of test cases for the conformance `group`. Implemented by the `Test` extension (`AbstractPPLTestExt`). Reserved group keys (extensions must not redefine these): `:vector` for value/gradient/jacobian round-trips on -vector-input evaluators; `:edge` for error-path cases. Downstream packages may -add other keys. +vector-input evaluators; `:namedtuple` for `NamedTuple`-input evaluators; +`:edge` for error-path cases; `:cache_reuse` for repeated calls against a +single prepared evaluator. Downstream packages may add other keys. """ function generate_testcases end diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index d3b36e69..728e99b8 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -203,13 +203,15 @@ end (e::NamedTupleEvaluator{false})(values::NamedTuple) = e.f(values) """ - _assert_namedtuple_shape(e::NamedTupleEvaluator, values) + _assert_namedtuple_shape(e::NamedTupleEvaluator{true}, values) Throw `ArgumentError` unless `values` has the same type as the prototype captured during preparation, including matching `size` for any nested `AbstractArray` leaves. Also throws if the prototype contains a leaf type outside the supported -set (`Real`, `Complex`, `AbstractArray`, `Tuple`, `NamedTuple`). No-op when `e` -was constructed with `CheckInput=false`. +set (`Real`, `Complex`, `AbstractArray`, `Tuple`, `NamedTuple`). + +Gated by `CheckInput`: the `{false}` overload is a no-op so AD hot paths and +other opt-out callers pay nothing. """ function _assert_namedtuple_shape(e::NamedTupleEvaluator{true}, values) typeof(values) === typeof(e.inputspec) || throw( @@ -240,6 +242,16 @@ function _ad_output_arity(y) ) end +# Arity-mismatch errors shared by the DI and Mooncake extensions; kept here so +# the `:edge` testcase regexes (`r"scalar-valued"`, `r"vector-valued"`) pin a +# single error string instead of one per backend. +function _throw_gradient_needs_scalar() + throw(ArgumentError("`value_and_gradient!!` requires a scalar-valued function.")) +end +function _throw_jacobian_needs_vector() + throw(ArgumentError("`value_and_jacobian!!` requires a vector-valued function.")) +end + # Complements the `typeof` check above: same-typed arrays can differ in `size`. # Arrays with non-`Real`/`Complex` eltype are walked element-wise to catch # inner mismatches. Unknown leaves throw, mirroring the supported-leaves diff --git a/test/ext/differentiationinterface/main.jl b/test/ext/differentiationinterface/main.jl index 4e19a8ce..cd02fb6b 100644 --- a/test/ext/differentiationinterface/main.jl +++ b/test/ext/differentiationinterface/main.jl @@ -3,27 +3,27 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) Pkg.instantiate() -using AbstractPPL: AbstractPPL, run_testcases +using AbstractPPL: run_testcases using ADTypes: AutoForwardDiff, AutoReverseDiff -using DifferentiationInterface +using DifferentiationInterface: DifferentiationInterface as DI using ForwardDiff using ReverseDiff using Test @testset "AbstractPPLDifferentiationInterfaceExt" begin - run_testcases(Val(:vector); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) - run_testcases(Val(:edge); adtype=AutoForwardDiff()) - - @testset "AutoReverseDiff compiled tape (no-context path)" begin - ad = AutoReverseDiff(; compile=true) - p_scalar = AbstractPPL.prepare(ad, x -> sum(abs2, x), zeros(3)) - p_vector = AbstractPPL.prepare(ad, x -> [x[1] * x[2], x[2] + x[3]], zeros(3)) - - @test !p_scalar.cache.use_context - @test !isnothing(p_scalar.cache.gradient_prep.tape) - @test !p_vector.cache.use_context - @test !isnothing(p_vector.cache.jacobian_prep.tape) + @testset "ForwardDiff" begin + run_testcases(Val(:vector); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) + run_testcases(Val(:cache_reuse); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) + run_testcases(Val(:edge); adtype=AutoForwardDiff()) + end - run_testcases(Val(:vector); adtype=ad, atol=1e-6, rtol=1e-6) + # Compiled-tape ReverseDiff goes through the `_prepare_di(::AutoReverseDiff{true}, …)` + # specialisation that closes the evaluator into a `Base.Fix2` target — the + # `:cache_reuse` group exercises that path across multiple inputs. + @testset "ReverseDiff (compiled tape)" begin + adtype = AutoReverseDiff(; compile=true) + run_testcases(Val(:vector); adtype=adtype, atol=1e-6, rtol=1e-6) + run_testcases(Val(:cache_reuse); adtype=adtype, atol=1e-6, rtol=1e-6) + run_testcases(Val(:edge); adtype=adtype) end end diff --git a/test/ext/mooncake/main.jl b/test/ext/mooncake/main.jl index 55723c7f..527a219e 100644 --- a/test/ext/mooncake/main.jl +++ b/test/ext/mooncake/main.jl @@ -3,7 +3,7 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) Pkg.instantiate() -using AbstractPPL: AbstractPPL, prepare, run_testcases +using AbstractPPL: run_testcases using ADTypes: AutoMooncake, AutoMooncakeForward using Mooncake using Test @@ -16,6 +16,7 @@ using Test @testset "$label" begin run_testcases(Val(:vector); adtype=adtype, atol=1e-6, rtol=1e-6) run_testcases(Val(:namedtuple); adtype=adtype, atol=1e-6, rtol=1e-6) + run_testcases(Val(:cache_reuse); adtype=adtype, atol=1e-6, rtol=1e-6) run_testcases(Val(:edge); adtype=adtype) end end From bb45ccab605b38c0296b16c2c63e1d393e32d9ad Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 13 May 2026 11:45:58 +0100 Subject: [PATCH 04/15] AGENTS.md: refresh env before tests / doc builds Stale manifests cause subtle resolution and loading issues; document the expected `Pkg.update()` step alongside the existing test commands. Co-Authored-By: Claude Opus 4.7 (1M context) --- AGENTS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/AGENTS.md b/AGENTS.md index 139496fe..8a03e58a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -32,6 +32,8 @@ AbstractPPL.jl provides shared interfaces and utilities for probabilistic progra - Full package tests: `julia --project=. -e 'using Pkg; Pkg.test()'` - Docs: `julia --project=docs docs/make.jl` +Always refresh each environment (`Pkg.update()` / `up`) before tests or doc builds — a stale manifest can cause subtle resolution and loading issues. + Run the smallest relevant test first, then broaden when changing public interfaces, extensions, or downstream-facing behaviour. Do not weaken tests just to make CI pass. ## Documentation From cf42151168ad0961d870f17d22ee747a0b458fb2 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 13 May 2026 14:57:42 +0100 Subject: [PATCH 05/15] Cut AD hot-path overhead in VectorEvaluator/DICache Two regressions visible on tiny-model gradients went through the new AbstractPPL evaluator interface: - `_check_ad_input` always ran on `value_and_{gradient,jacobian}!!` entry, even when the evaluator was prepared with `check_dims=false`. Now dispatch-gated on `VectorEvaluator{CheckInput}`: the `{false}` overload is a no-op, so the `DimensionMismatch` and integer-rejection paths are elided from the LLVM IR of the AD hot path. - `DICache` stored `use_context::Bool` as a runtime field, leaving a branch in the compiled call selecting the context vs no-context DI form. `UseContext` is now a type parameter and the branch is resolved by dispatch via `_di_value_and_{gradient,jacobian}` helpers. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLDifferentiationInterfaceExt.jl | 70 +++++++++++-------- src/evaluators/Evaluators.jl | 6 +- test/evaluators/Evaluators.jl | 15 ++++ test/ext/differentiationinterface/main.jl | 23 +++++- 4 files changed, 81 insertions(+), 33 deletions(-) diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 732fc796..8ab744d1 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -9,25 +9,37 @@ using DifferentiationInterface: DifferentiationInterface as DI # that in DynamicPPL the model and other evaluator state stay constant. @inline _call_evaluator(x, evaluator) = evaluator(x) -struct DICache{F,GP,JP} +# `UseContext` is type-encoded so the dispatch between the context and +# no-context DI call is resolved at compile time; on tiny problems the runtime +# branch would otherwise show up as fixed overhead in the AD hot path. +struct DICache{UseContext,F,GP,JP} target::F gradient_prep::GP jacobian_prep::JP - use_context::Bool + function DICache{UseContext}(target::F, gp::GP, jp::JP) where {UseContext,F,GP,JP} + UseContext isa Bool || throw(ArgumentError("`UseContext` must be a Bool.")) + return new{UseContext,F,GP,JP}(target, gp, jp) + end end # Compiled ReverseDiff only reuses a compiled tape on the one-argument path; # `DI.Constant` deactivates tape recording, so close the evaluator into the -# target and call DI without contexts. +# target and call DI without contexts. The trailing `Val(false)`/`Val(true)` +# carries `UseContext` to the `DICache` constructor at compile time. function _prepare_di(prep::F, adtype::AutoReverseDiff{true}, x, evaluator) where {F} target = Base.Fix2(_call_evaluator, evaluator) - return target, prep(target, adtype, x), false + return target, prep(target, adtype, x), Val(false) end function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F} - return _call_evaluator, prep(_call_evaluator, adtype, x, DI.Constant(evaluator)), true + return ( + _call_evaluator, prep(_call_evaluator, adtype, x, DI.Constant(evaluator)), Val(true) + ) end +@inline _wrap_cache(target, gp, jp, ::Val{UseContext}) where {UseContext} = + DICache{UseContext}(target, gp, jp) + function AbstractPPL.prepare( adtype::AbstractADType, problem, x::AbstractVector{<:Real}; check_dims::Bool=true ) @@ -35,24 +47,32 @@ function AbstractPPL.prepare( arity = _ad_output_arity(evaluator(x)) if length(x) == 0 # DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`); the - # `Val(0)` sentinel keeps the `gradient_prep === nothing` arity check meaningful. + # `Val(0)` sentinel keeps the `gradient_prep === nothing` arity check + # meaningful. `UseContext` is irrelevant on this shortcut path — the AD + # entry returns `(p.evaluator(x), T[])` before any DI call. gp, jp = arity === :scalar ? (Val(0), nothing) : (nothing, Val(0)) - return Prepared(adtype, evaluator, DICache(_call_evaluator, gp, jp, true)) + return Prepared(adtype, evaluator, DICache{true}(_call_evaluator, gp, jp)) end if arity === :scalar - target, gradient_prep, use_context = _prepare_di( - DI.prepare_gradient, adtype, x, evaluator - ) - return Prepared( - adtype, evaluator, DICache(target, gradient_prep, nothing, use_context) - ) + target, gradient_prep, ctx = _prepare_di(DI.prepare_gradient, adtype, x, evaluator) + return Prepared(adtype, evaluator, _wrap_cache(target, gradient_prep, nothing, ctx)) end - target, jacobian_prep, use_context = _prepare_di( - DI.prepare_jacobian, adtype, x, evaluator - ) - return Prepared(adtype, evaluator, DICache(target, nothing, jacobian_prep, use_context)) + target, jacobian_prep, ctx = _prepare_di(DI.prepare_jacobian, adtype, x, evaluator) + return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, ctx)) end +# Compile-time dispatch on the `UseContext` type parameter eliminates the +# context-vs-no-context branch from the AD hot path. +@inline _di_value_and_gradient(c::DICache{true}, ad, x, eval) = + DI.value_and_gradient(c.target, c.gradient_prep, ad, x, DI.Constant(eval)) +@inline _di_value_and_gradient(c::DICache{false}, ad, x, _) = + DI.value_and_gradient(c.target, c.gradient_prep, ad, x) + +@inline _di_value_and_jacobian(c::DICache{true}, ad, x, eval) = + DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x, DI.Constant(eval)) +@inline _di_value_and_jacobian(c::DICache{false}, ad, x, _) = + DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x) + @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} ) where {T<:Real} @@ -61,13 +81,7 @@ end # Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff # `BoundsError`); typed `T[]` matches the caller's element type. length(x) == 0 && return (p.evaluator(x), T[]) - return if p.cache.use_context - DI.value_and_gradient( - p.cache.target, p.cache.gradient_prep, p.adtype, x, DI.Constant(p.evaluator) - ) - else - DI.value_and_gradient(p.cache.target, p.cache.gradient_prep, p.adtype, x) - end + return _di_value_and_gradient(p.cache, p.adtype, x, p.evaluator) end @inline function AbstractPPL.value_and_jacobian!!( @@ -79,13 +93,7 @@ end val = p.evaluator(x) return (val, similar(x, length(val), 0)) end - return if p.cache.use_context - DI.value_and_jacobian( - p.cache.target, p.cache.jacobian_prep, p.adtype, x, DI.Constant(p.evaluator) - ) - else - DI.value_and_jacobian(p.cache.target, p.cache.jacobian_prep, p.adtype, x) - end + return _di_value_and_jacobian(p.cache, p.adtype, x, p.evaluator) end end # module diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index 728e99b8..fd9d0fd1 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -179,11 +179,15 @@ end # Shared input validation for AD-backend `value_and_{gradient,jacobian}!!` entry # points. Same compile-time `T <: Integer` elision as the `VectorEvaluator` body. -function _check_ad_input(e::VectorEvaluator, x::AbstractVector{T}) where {T} +# Gated by `CheckInput`: the `{false}` overload is a no-op so the AD hot path +# pays nothing when the caller has already validated the input (e.g. via +# `prepare(...; check_dims=false)`). +function _check_ad_input(e::VectorEvaluator{true}, x::AbstractVector{T}) where {T} T <: Integer && _reject_integer_input(x) _check_vector_length(e.dim, x) return nothing end +_check_ad_input(::VectorEvaluator{false}, ::AbstractVector) = nothing function (e::VectorEvaluator{true})(x::AbstractVector{T}) where {T} T <: Integer && _reject_integer_input(x) diff --git a/test/evaluators/Evaluators.jl b/test/evaluators/Evaluators.jl index c8454438..6aaab96c 100644 --- a/test/evaluators/Evaluators.jl +++ b/test/evaluators/Evaluators.jl @@ -72,6 +72,21 @@ end # Unsupported leaf types are rejected rather than silently passing. ne_string = AbstractPPL.Evaluators.NamedTupleEvaluator(x -> length(x.s), (s="abc",)) @test_throws r"Supported leaves" ne_string((s="abcde",)) + + # `_check_ad_input` is dispatch-gated by `CheckInput` so the AD hot + # path pays nothing when the evaluator was prepared with + # `check_dims=false`. + ve_checked = AbstractPPL.Evaluators.VectorEvaluator{true}(sum, 3) + @test AbstractPPL.Evaluators._check_ad_input(ve_checked, [1.0, 2.0, 3.0]) === + nothing + @test_throws DimensionMismatch AbstractPPL.Evaluators._check_ad_input( + ve_checked, [1.0, 2.0] + ) + @test_throws r"floating-point" AbstractPPL.Evaluators._check_ad_input( + ve_checked, [1, 2, 3] + ) + @test AbstractPPL.Evaluators._check_ad_input(ve_unchecked, [1.0, 2.0]) === nothing + @test AbstractPPL.Evaluators._check_ad_input(ve_unchecked, [1, 2, 3]) === nothing end @testset "prepare (structural)" begin diff --git a/test/ext/differentiationinterface/main.jl b/test/ext/differentiationinterface/main.jl index cd02fb6b..9788eda7 100644 --- a/test/ext/differentiationinterface/main.jl +++ b/test/ext/differentiationinterface/main.jl @@ -3,13 +3,17 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) Pkg.instantiate() -using AbstractPPL: run_testcases +using AbstractPPL: AbstractPPL, prepare, run_testcases, value_and_gradient!! using ADTypes: AutoForwardDiff, AutoReverseDiff using DifferentiationInterface: DifferentiationInterface as DI using ForwardDiff using ReverseDiff using Test +const DIExt = Base.get_extension(AbstractPPL, :AbstractPPLDifferentiationInterfaceExt) + +quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) + @testset "AbstractPPLDifferentiationInterfaceExt" begin @testset "ForwardDiff" begin run_testcases(Val(:vector); adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6) @@ -26,4 +30,21 @@ using Test run_testcases(Val(:cache_reuse); adtype=adtype, atol=1e-6, rtol=1e-6) run_testcases(Val(:edge); adtype=adtype) end + + # `DICache` encodes `UseContext` as a type parameter so the + # context-vs-no-context DI call is resolved by dispatch, not a runtime + # `Bool` branch in the AD hot path. + @testset "DICache encodes UseContext as a type parameter" begin + x = [1.0, 2.0, 3.0] + prep_ctx = prepare(AutoForwardDiff(), quadratic, x) + prep_noctx = prepare(AutoReverseDiff(; compile=true), quadratic, x) + + @test prep_ctx.cache isa DIExt.DICache{true} + @test prep_noctx.cache isa DIExt.DICache{false} + @test !hasfield(typeof(prep_ctx.cache), :use_context) + + # Hot path is type-stable on both branches. + @inferred value_and_gradient!!(prep_ctx, x) + @inferred value_and_gradient!!(prep_noctx, x) + end end From 82f8731fd453d937529a33f5f403ad1ee685392c Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 13 May 2026 15:31:01 +0100 Subject: [PATCH 06/15] Mooncake ext: skip re-zeroing the evaluator tangent each gradient call MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `Mooncake.value_and_gradient!!(cache, evaluator, x)` reset the evaluator's tangent buffer on every call, even though AbstractPPL discards `∂f` and only surfaces `∂x`. For evaluators that wrap a model with large fields (e.g. a 128-tuple of `Float64`), the zeroing was the dominant per-call overhead at tiny model sizes. Pass `args_to_zero=(false, true)` to the reverse-mode `Mooncake.Cache` path to skip the `∂f` reset while still zeroing the `∂x` buffer. The forward-mode `Mooncake.ForwardCache` doesn't accept the kwarg, so the branch is `isa`-dispatched on the concrete cache type and constant-folds at compile time. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLMooncakeExt.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index 3180650b..d859db63 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -64,12 +64,21 @@ end # `Mooncake.value_and_gradient!!` returns `(val, (∂f, ∂x))`; we discard the # function tangent `∂f` and surface only `∂x` as the user-facing gradient. +# Reverse-mode caches accept `args_to_zero` to skip re-zeroing the evaluator's +# tangent buffer each call (the dominant overhead when the user's problem +# carries large fields whose gradient we never consume); forward-mode caches +# don't take the kwarg, so the branch is `isa`-dispatched on the concrete +# `p.cache` type and constant-folds away. # Shape validation is delegated to the inner `NamedTupleEvaluator{CheckInput}` # callable Mooncake invokes — gated by the user's `check_dims` choice. @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:_MooncakeAD,<:NamedTupleEvaluator}, values::NamedTuple ) - val, (_, grad) = Mooncake.value_and_gradient!!(p.cache, p.evaluator, values) + val, (_, grad) = if p.cache isa Mooncake.Cache + Mooncake.value_and_gradient!!(p.cache, p.evaluator, values; args_to_zero=(false, true)) + else + Mooncake.value_and_gradient!!(p.cache, p.evaluator, values) + end return (val, grad) end @@ -86,7 +95,11 @@ end x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) - val, (_, grad) = Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x) + val, (_, grad) = if p.cache.cache isa Mooncake.Cache + Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x; args_to_zero=(false, true)) + else + Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x) + end return (val, grad) end From 5b5f39f16b7f4bb350bc1a1f8f6fdd272fa9763f Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 13 May 2026 16:28:18 +0100 Subject: [PATCH 07/15] Comment perf-critical dispatch and tangent-zero sites MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Concise inline notes on: - `VectorEvaluator{true|false}` callable bodies (shared `T <: Integer` compile-time elision, and the `{false}` skip of `_check_vector_length`). - Mooncake ext empty-input and arity-mismatch methods (compile-time dispatch via `MooncakeCache{…,Nothing}` and `MooncakeCache{:scalar|:vector}` to avoid runtime branches). - `args_to_zero=(false, true)` at both Mooncake gradient call sites (skipping the evaluator's tangent re-zeroing per call — `∂f` is discarded). Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLMooncakeExt.jl | 10 ++++++++++ src/evaluators/Evaluators.jl | 3 +++ 2 files changed, 13 insertions(+) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index d859db63..6a0f9bd0 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -75,6 +75,7 @@ end p::Prepared{<:_MooncakeAD,<:NamedTupleEvaluator}, values::NamedTuple ) val, (_, grad) = if p.cache isa Mooncake.Cache + # Skip re-zeroing the evaluator's tangent buffer; we discard `∂f`. Mooncake.value_and_gradient!!(p.cache, p.evaluator, values; args_to_zero=(false, true)) else Mooncake.value_and_gradient!!(p.cache, p.evaluator, values) @@ -82,6 +83,9 @@ end return (val, grad) end +# Empty-input shortcut: tagged with `MooncakeCache{…,Nothing}` at prepare time +# so dispatch resolves the no-Mooncake path at compile time — no runtime +# `isnothing(cache)` branch in the hot path. @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar,Nothing}}, x::AbstractVector{T}, @@ -96,6 +100,7 @@ end ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) val, (_, grad) = if p.cache.cache isa Mooncake.Cache + # Skip re-zeroing the evaluator's tangent buffer; we discard `∂f`. Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x; args_to_zero=(false, true)) else Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x) @@ -103,6 +108,9 @@ end return (val, grad) end +# Arity-mismatch errors as dedicated methods so dispatch on +# `MooncakeCache{:scalar}` vs `{:vector}` resolves at compile time instead of +# a runtime check on the cache contents. @inline function AbstractPPL.value_and_gradient!!( ::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector}}, ::AbstractVector{<:Real}, @@ -117,6 +125,8 @@ end return Evaluators._throw_jacobian_needs_vector() end +# Empty-input jacobian shortcut — same compile-time dispatch trick as the +# scalar Nothing-tagged case; skips Mooncake entirely. @inline function AbstractPPL.value_and_jacobian!!( p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector,Nothing}}, x::AbstractVector{T}, diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index fd9d0fd1..e4c8b977 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -189,6 +189,9 @@ function _check_ad_input(e::VectorEvaluator{true}, x::AbstractVector{T}) where { end _check_ad_input(::VectorEvaluator{false}, ::AbstractVector) = nothing +# Both bodies rely on `T <: Integer` being a static check so the AD hot path +# (Float/dual `T`) elides the branch; the `{false}` callable additionally skips +# `_check_vector_length` since AD libraries pass length-matching dual inputs. function (e::VectorEvaluator{true})(x::AbstractVector{T}) where {T} T <: Integer && _reject_integer_input(x) _check_vector_length(e.dim, x) From 4b1ed3527798c7ce747c0bc39cbbecdebef809ce Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 13 May 2026 17:00:32 +0100 Subject: [PATCH 08/15] Mooncake ext: hide evaluator from tangent derivation via tangent_type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mooncake was deriving a nested `Tangent{NamedTuple{f::Tangent{...}}}` for every `VectorEvaluator`/`NamedTupleEvaluator` it received, then walking that structure on every backward pass. The evaluators are AbstractPPL's own wrapper types and never appear as a downstream gradient target — the public API only returns `(value, ∂x)`. Register `Mooncake.tangent_type(::Type{<:VectorEvaluator}) = NoTangent` (and the same for `NamedTupleEvaluator`) so the cache carries no tangent for the user's problem fields. The `args_to_zero=(false, true)` mitigation and the `_ConstantEvaluator` wrapper from the prior pass are both no longer needed; the call sites pass `p.evaluator` directly. Verified on the MWE setup: `Mooncake.Tangent{` count in the prepared cache type is 0; value and gradient match a direct `logdensity_at(x, state, …)` call. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLMooncakeExt.jl | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index 6a0f9bd0..ecfa656b 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -8,6 +8,15 @@ using Mooncake: Mooncake const _MooncakeAD = Union{AutoMooncake,AutoMooncakeForward} +# Tell Mooncake that the evaluator wrappers are constants from its +# perspective: their fields hold the user's problem state, which Mooncake +# would otherwise derive a nested `Tangent{NamedTuple{f::Tangent{...}}}` for +# and walk on every backward pass. The evaluators are AbstractPPL's own +# types and only ever appear as the callable argument to Mooncake — no +# downstream caller asks for a gradient w.r.t. them. +Mooncake.tangent_type(::Type{<:VectorEvaluator}) = Mooncake.NoTangent +Mooncake.tangent_type(::Type{<:NamedTupleEvaluator}) = Mooncake.NoTangent + # Tag a Mooncake cache with the prepared evaluator's output arity (`:scalar` # or `:vector`) so `value_and_gradient!!` / `value_and_jacobian!!` can raise # helpful arity-mismatch errors instead of failing inside Mooncake. @@ -62,24 +71,15 @@ function AbstractPPL.prepare( return Prepared(adtype, evaluator, MooncakeCache{arity}(cache)) end -# `Mooncake.value_and_gradient!!` returns `(val, (∂f, ∂x))`; we discard the -# function tangent `∂f` and surface only `∂x` as the user-facing gradient. -# Reverse-mode caches accept `args_to_zero` to skip re-zeroing the evaluator's -# tangent buffer each call (the dominant overhead when the user's problem -# carries large fields whose gradient we never consume); forward-mode caches -# don't take the kwarg, so the branch is `isa`-dispatched on the concrete -# `p.cache` type and constant-folds away. +# `Mooncake.value_and_gradient!!` returns `(val, (∂f, ∂x))`; `∂f` is `NoTangent` +# because we registered `tangent_type(::Type{<:NamedTupleEvaluator}) = NoTangent` +# above, so the cache never carries a tangent for the user's problem. # Shape validation is delegated to the inner `NamedTupleEvaluator{CheckInput}` # callable Mooncake invokes — gated by the user's `check_dims` choice. @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:_MooncakeAD,<:NamedTupleEvaluator}, values::NamedTuple ) - val, (_, grad) = if p.cache isa Mooncake.Cache - # Skip re-zeroing the evaluator's tangent buffer; we discard `∂f`. - Mooncake.value_and_gradient!!(p.cache, p.evaluator, values; args_to_zero=(false, true)) - else - Mooncake.value_and_gradient!!(p.cache, p.evaluator, values) - end + val, (_, grad) = Mooncake.value_and_gradient!!(p.cache, p.evaluator, values) return (val, grad) end @@ -99,12 +99,7 @@ end x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) - val, (_, grad) = if p.cache.cache isa Mooncake.Cache - # Skip re-zeroing the evaluator's tangent buffer; we discard `∂f`. - Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x; args_to_zero=(false, true)) - else - Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x) - end + val, (_, grad) = Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x) return (val, grad) end From e05f547d9ed284497b3c82784c89fe27c50ae21e Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 13 May 2026 17:47:43 +0100 Subject: [PATCH 09/15] Mooncake ext: add `raw_gradient_target` for lowered AD entry shape MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Callers who know an equivalent raw `f(x, contexts...) ≡ problem(x)` can pass it via `prepare(AutoMooncake(), problem, x; raw_gradient_target=(f, contexts))`. Mooncake then compiles the tape on the raw call shape with `args_to_zero= (false, true, false, …)` instead of the generic `evaluator(x)` wrapper — sidestepping the fixed-overhead seen on tiny scalar-vector problems. `prepared(x)` still calls `problem(x)`; only the AD entry uses the lowered cache (a new `MooncakeLoweredCache` carries `cache`, `f`, `contexts`, and `args_to_zero`). Scoped strictly to reverse-mode `AutoMooncake` and scalar arity with non-empty input — anything else errors at prepare time. Jacobian on a lowered cache surfaces the existing arity-mismatch error. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLMooncakeExt.jl | 71 ++++++++++++++++++++++++++++++++++- test/ext/mooncake/main.jl | 57 +++++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 3 deletions(-) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index ecfa656b..2c867a49 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -25,6 +25,20 @@ struct MooncakeCache{A,C} end MooncakeCache{A}(cache::C) where {A,C} = MooncakeCache{A,C}(cache) +# Opt-in lowered-target cache for reverse-mode `AutoMooncake`: callers who +# know a raw `f(x, contexts...)` equivalent to `problem(x)` can hand it in +# via `raw_gradient_target=(f, contexts)`. Mooncake then compiles a tape on +# `(f, x, contexts...)` rather than the generic `evaluator(x)` shape — the +# inactive `contexts` ride along as plain positional args with +# `args_to_zero=false`. `prepared(x)` still calls `problem(x)`; only the AD +# entry point uses the lowered cache. +struct MooncakeLoweredCache{C,F,CT<:Tuple,AZ<:Tuple} + cache::C + f::F + contexts::CT + args_to_zero::AZ +end + _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config # `value_and_gradient!!` accepts either a reverse-mode gradient cache @@ -55,14 +69,43 @@ function AbstractPPL.prepare( end function AbstractPPL.prepare( - adtype::_MooncakeAD, problem, x::AbstractVector{<:Real}; check_dims::Bool=true + adtype::_MooncakeAD, + problem, + x::AbstractVector{<:Real}; + check_dims::Bool=true, + raw_gradient_target=nothing, ) + # Validate `raw_gradient_target` preconditions that don't need an arity + # probe, so the probe `evaluator(x)` below cannot crash on user code that + # assumes non-empty `x`. + if raw_gradient_target !== nothing + adtype isa AutoMooncake || throw( + ArgumentError( + "`raw_gradient_target` is only supported with reverse-mode `AutoMooncake`.", + ), + ) + length(x) > 0 || + throw(ArgumentError("`raw_gradient_target` is not supported for empty input.")) + end evaluator = AbstractPPL.prepare(problem, x; check_dims)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) + config = _mooncake_config(adtype) + if raw_gradient_target !== nothing + arity === :scalar || throw( + ArgumentError( + "`raw_gradient_target` is only supported for scalar-valued problems." + ), + ) + f, contexts = raw_gradient_target + cache = Mooncake.prepare_gradient_cache(f, x, contexts...; config) + args_to_zero = (false, true, map(_ -> false, contexts)...) + return Prepared( + adtype, evaluator, MooncakeLoweredCache(cache, f, contexts, args_to_zero) + ) + end # Mooncake builds no tape for length-zero `x`; tag with `Nothing` so the # empty-input methods below shortcut without invoking Mooncake. length(x) == 0 && return Prepared(adtype, evaluator, MooncakeCache{arity}(nothing)) - config = _mooncake_config(adtype) cache = if arity === :scalar _mooncake_gradient_cache(adtype, evaluator, x; config) else @@ -103,6 +146,21 @@ end return (val, grad) end +# Lowered raw-target gradient — `p.cache.f(x, p.cache.contexts...) ≡ p.evaluator(x)` +# by the `raw_gradient_target` contract. Mooncake's tape was compiled on the +# raw shape, sidestepping the fixed `evaluator(x)` overhead. +@inline function AbstractPPL.value_and_gradient!!( + p::Prepared{<:AutoMooncake,<:VectorEvaluator,<:MooncakeLoweredCache}, + x::AbstractVector{T}, +) where {T<:Real} + Evaluators._check_ad_input(p.evaluator, x) + c = p.cache + val, tangents = Mooncake.value_and_gradient!!( + c.cache, c.f, x, c.contexts...; args_to_zero=c.args_to_zero + ) + return (val, tangents[2]) +end + # Arity-mismatch errors as dedicated methods so dispatch on # `MooncakeCache{:scalar}` vs `{:vector}` resolves at compile time instead of # a runtime check on the cache contents. @@ -120,6 +178,15 @@ end return Evaluators._throw_jacobian_needs_vector() end +# `raw_gradient_target` is a scalar-only fast path; jacobians must use the +# generic preparation. +@inline function AbstractPPL.value_and_jacobian!!( + ::Prepared{<:AutoMooncake,<:VectorEvaluator,<:MooncakeLoweredCache}, + ::AbstractVector{<:Real}, +) + return Evaluators._throw_jacobian_needs_vector() +end + # Empty-input jacobian shortcut — same compile-time dispatch trick as the # scalar Nothing-tagged case; skips Mooncake entirely. @inline function AbstractPPL.value_and_jacobian!!( diff --git a/test/ext/mooncake/main.jl b/test/ext/mooncake/main.jl index 527a219e..d306ddea 100644 --- a/test/ext/mooncake/main.jl +++ b/test/ext/mooncake/main.jl @@ -3,7 +3,7 @@ Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) Pkg.instantiate() -using AbstractPPL: run_testcases +using AbstractPPL: AbstractPPL, prepare, run_testcases, value_and_gradient!! using ADTypes: AutoMooncake, AutoMooncakeForward using Mooncake using Test @@ -20,4 +20,59 @@ using Test run_testcases(Val(:edge); adtype=adtype) end end + + @testset "raw_gradient_target" begin + struct TinyProblem{T} + offset::T + end + raw_logdensity(x::AbstractVector{<:Real}, offset) = -0.5 * (x[1] - offset)^2 + (p::TinyProblem)(x::AbstractVector{<:Real}) = raw_logdensity(x, p.offset) + + x = [0.3] + problem = TinyProblem(0.1) + ad = AutoMooncake(; config=nothing) + + generic = prepare(ad, problem, x; check_dims=false) + lowered = prepare( + ad, + problem, + x; + check_dims=false, + raw_gradient_target=(raw_logdensity, (problem.offset,)), + ) + + # `prepared(x)` still calls `problem(x)` on both paths. + @test generic(x) == problem(x) + @test lowered(x) == problem(x) + + # Same value and gradient as the generic path. + @test value_and_gradient!!(generic, x) == value_and_gradient!!(lowered, x) + + # Rejects on forward mode, vector-valued problems, and empty input. + vec_problem = x -> [x[1]^2, x[1] + 1.0] + @test_throws ArgumentError prepare( + AutoMooncakeForward(; config=nothing), + problem, + x; + check_dims=false, + raw_gradient_target=(raw_logdensity, (problem.offset,)), + ) + @test_throws ArgumentError prepare( + ad, + vec_problem, + x; + check_dims=false, + raw_gradient_target=((y, c) -> [y[1] * c], (1.0,)), + ) + @test_throws ArgumentError prepare( + ad, + problem, + Float64[]; + check_dims=false, + raw_gradient_target=(raw_logdensity, (problem.offset,)), + ) + + # Jacobian on a scalar-only lowered cache surfaces our arity-mismatch error. + @test_throws r"vector-valued" AbstractPPL.value_and_jacobian!!(lowered, x) + end end From b23398d81586b05395c020f504d948abae961800 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 13 May 2026 17:53:41 +0100 Subject: [PATCH 10/15] Merge MooncakeLoweredCache into MooncakeCache; accept raw_gradient_target on all AD prepare methods - Collapsed `MooncakeLoweredCache` into `MooncakeCache{A,C,F,CT,AZ}`. The three new type params default to `Nothing` via the existing constructor; the lowered-path constructor populates them. Dispatch on `CT<:Tuple` (excluding the `Nothing` default) picks the lowered AD entry. No new type, no runtime branching. - DI extension's `prepare(::AbstractADType, ...)` now accepts `raw_gradient_target=nothing` and silently ignores it. Same for the Mooncake NamedTuple `prepare`. Generic user code that passes the kwarg to non-Mooncake backends (or to the Mooncake NamedTuple path) no longer hits a MethodError. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLDifferentiationInterfaceExt.jl | 6 ++- ext/AbstractPPLMooncakeExt.jl | 54 ++++++++++--------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 8ab744d1..7461aab4 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -41,7 +41,11 @@ end DICache{UseContext}(target, gp, jp) function AbstractPPL.prepare( - adtype::AbstractADType, problem, x::AbstractVector{<:Real}; check_dims::Bool=true + adtype::AbstractADType, + problem, + x::AbstractVector{<:Real}; + check_dims::Bool=true, + raw_gradient_target=nothing, # Mooncake-only optimization; ignored here. ) evaluator = AbstractPPL.prepare(problem, x; check_dims)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index 2c867a49..15255f76 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -20,24 +20,26 @@ Mooncake.tangent_type(::Type{<:NamedTupleEvaluator}) = Mooncake.NoTangent # Tag a Mooncake cache with the prepared evaluator's output arity (`:scalar` # or `:vector`) so `value_and_gradient!!` / `value_and_jacobian!!` can raise # helpful arity-mismatch errors instead of failing inside Mooncake. -struct MooncakeCache{A,C} - cache::C -end -MooncakeCache{A}(cache::C) where {A,C} = MooncakeCache{A,C}(cache) - -# Opt-in lowered-target cache for reverse-mode `AutoMooncake`: callers who -# know a raw `f(x, contexts...)` equivalent to `problem(x)` can hand it in -# via `raw_gradient_target=(f, contexts)`. Mooncake then compiles a tape on -# `(f, x, contexts...)` rather than the generic `evaluator(x)` shape — the -# inactive `contexts` ride along as plain positional args with -# `args_to_zero=false`. `prepared(x)` still calls `problem(x)`; only the AD -# entry point uses the lowered cache. -struct MooncakeLoweredCache{C,F,CT<:Tuple,AZ<:Tuple} +# +# The optional `f` / `contexts` / `args_to_zero` fields carry the lowered +# raw-target path opted into by `raw_gradient_target=(f, contexts)` on +# reverse-mode `AutoMooncake`. They default to `nothing`; dispatch on +# `CT<:Tuple` (vs `Nothing`) picks the lowered AD entry. `prepared(x)` still +# calls `problem(x)` — only the AD entry consults the lowered fields. +struct MooncakeCache{A,C,F,CT,AZ} cache::C f::F contexts::CT args_to_zero::AZ end +function MooncakeCache{A}(cache::C) where {A,C} + return MooncakeCache{A,C,Nothing,Nothing,Nothing}(cache, nothing, nothing, nothing) +end +function MooncakeCache{A}( + cache::C, f::F, contexts::CT, args_to_zero::AZ +) where {A,C,F,CT<:Tuple,AZ<:Tuple} + return MooncakeCache{A,C,F,CT,AZ}(cache, f, contexts, args_to_zero) +end _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config @@ -60,7 +62,11 @@ function _mooncake_jacobian_cache(::AutoMooncakeForward, f, x; config) end function AbstractPPL.prepare( - adtype::_MooncakeAD, problem, values::NamedTuple; check_dims::Bool=true + adtype::_MooncakeAD, + problem, + values::NamedTuple; + check_dims::Bool=true, + raw_gradient_target=nothing, # vector-only optimization; ignored here. ) evaluator = AbstractPPL.prepare(problem, values; check_dims)::NamedTupleEvaluator config = _mooncake_config(adtype) @@ -100,7 +106,7 @@ function AbstractPPL.prepare( cache = Mooncake.prepare_gradient_cache(f, x, contexts...; config) args_to_zero = (false, true, map(_ -> false, contexts)...) return Prepared( - adtype, evaluator, MooncakeLoweredCache(cache, f, contexts, args_to_zero) + adtype, evaluator, MooncakeCache{:scalar}(cache, f, contexts, args_to_zero) ) end # Mooncake builds no tape for length-zero `x`; tag with `Nothing` so the @@ -148,9 +154,14 @@ end # Lowered raw-target gradient — `p.cache.f(x, p.cache.contexts...) ≡ p.evaluator(x)` # by the `raw_gradient_target` contract. Mooncake's tape was compiled on the -# raw shape, sidestepping the fixed `evaluator(x)` overhead. +# raw shape, sidestepping the fixed `evaluator(x)` overhead. `CT<:Tuple` +# distinguishes the lowered cache from the generic one (where `CT=Nothing`). @inline function AbstractPPL.value_and_gradient!!( - p::Prepared{<:AutoMooncake,<:VectorEvaluator,<:MooncakeLoweredCache}, + p::Prepared{ + <:AutoMooncake, + <:VectorEvaluator, + <:MooncakeCache{:scalar,<:Any,<:Any,<:Tuple,<:Tuple}, + }, x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) @@ -178,15 +189,6 @@ end return Evaluators._throw_jacobian_needs_vector() end -# `raw_gradient_target` is a scalar-only fast path; jacobians must use the -# generic preparation. -@inline function AbstractPPL.value_and_jacobian!!( - ::Prepared{<:AutoMooncake,<:VectorEvaluator,<:MooncakeLoweredCache}, - ::AbstractVector{<:Real}, -) - return Evaluators._throw_jacobian_needs_vector() -end - # Empty-input jacobian shortcut — same compile-time dispatch trick as the # scalar Nothing-tagged case; skips Mooncake entirely. @inline function AbstractPPL.value_and_jacobian!!( From 8c219d1272ed73088ad5bab2532ac519449c26f8 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Wed, 13 May 2026 17:58:25 +0100 Subject: [PATCH 11/15] Drop `args_to_zero` field from MooncakeCache; trim noise comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `args_to_zero` was a derived value (`(false, true, false×length(contexts))`) stored as a struct field plus a 5th type parameter. Moved the construction to the AD entry; the tuple constant-folds for any concrete `contexts` arity. Saves one type parameter and one field. - Dropped two trailing comments on `raw_gradient_target=nothing` kwargs (the comment didn't explain WHY — the kwarg name and surrounding context already convey "this is a backend-specific optimization that defaults off"). Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLDifferentiationInterfaceExt.jl | 2 +- ext/AbstractPPLMooncakeExt.jl | 39 +++++++++---------- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 7461aab4..3dc70d26 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -45,7 +45,7 @@ function AbstractPPL.prepare( problem, x::AbstractVector{<:Real}; check_dims::Bool=true, - raw_gradient_target=nothing, # Mooncake-only optimization; ignored here. + raw_gradient_target=nothing, ) evaluator = AbstractPPL.prepare(problem, x; check_dims)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index 15255f76..5bf6abe7 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -21,24 +21,23 @@ Mooncake.tangent_type(::Type{<:NamedTupleEvaluator}) = Mooncake.NoTangent # or `:vector`) so `value_and_gradient!!` / `value_and_jacobian!!` can raise # helpful arity-mismatch errors instead of failing inside Mooncake. # -# The optional `f` / `contexts` / `args_to_zero` fields carry the lowered -# raw-target path opted into by `raw_gradient_target=(f, contexts)` on -# reverse-mode `AutoMooncake`. They default to `nothing`; dispatch on -# `CT<:Tuple` (vs `Nothing`) picks the lowered AD entry. `prepared(x)` still +# The optional `f` / `contexts` fields carry the lowered raw-target path +# opted into by `raw_gradient_target=(f, contexts)` on reverse-mode +# `AutoMooncake`. They default to `nothing`; dispatch on `CT<:Tuple` (vs +# `Nothing`) picks the lowered AD entry. `args_to_zero` is derived from +# `contexts` at the AD entry — it's a `false, true, false…` literal that +# constant-folds for any concrete `contexts` arity. `prepared(x)` still # calls `problem(x)` — only the AD entry consults the lowered fields. -struct MooncakeCache{A,C,F,CT,AZ} +struct MooncakeCache{A,C,F,CT} cache::C f::F contexts::CT - args_to_zero::AZ end function MooncakeCache{A}(cache::C) where {A,C} - return MooncakeCache{A,C,Nothing,Nothing,Nothing}(cache, nothing, nothing, nothing) + return MooncakeCache{A,C,Nothing,Nothing}(cache, nothing, nothing) end -function MooncakeCache{A}( - cache::C, f::F, contexts::CT, args_to_zero::AZ -) where {A,C,F,CT<:Tuple,AZ<:Tuple} - return MooncakeCache{A,C,F,CT,AZ}(cache, f, contexts, args_to_zero) +function MooncakeCache{A}(cache::C, f::F, contexts::CT) where {A,C,F,CT<:Tuple} + return MooncakeCache{A,C,F,CT}(cache, f, contexts) end _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config @@ -66,7 +65,7 @@ function AbstractPPL.prepare( problem, values::NamedTuple; check_dims::Bool=true, - raw_gradient_target=nothing, # vector-only optimization; ignored here. + raw_gradient_target=nothing, ) evaluator = AbstractPPL.prepare(problem, values; check_dims)::NamedTupleEvaluator config = _mooncake_config(adtype) @@ -104,10 +103,7 @@ function AbstractPPL.prepare( ) f, contexts = raw_gradient_target cache = Mooncake.prepare_gradient_cache(f, x, contexts...; config) - args_to_zero = (false, true, map(_ -> false, contexts)...) - return Prepared( - adtype, evaluator, MooncakeCache{:scalar}(cache, f, contexts, args_to_zero) - ) + return Prepared(adtype, evaluator, MooncakeCache{:scalar}(cache, f, contexts)) end # Mooncake builds no tape for length-zero `x`; tag with `Nothing` so the # empty-input methods below shortcut without invoking Mooncake. @@ -156,18 +152,21 @@ end # by the `raw_gradient_target` contract. Mooncake's tape was compiled on the # raw shape, sidestepping the fixed `evaluator(x)` overhead. `CT<:Tuple` # distinguishes the lowered cache from the generic one (where `CT=Nothing`). +# `args_to_zero` is constant-folded from `c.contexts`'s arity at compile time. @inline function AbstractPPL.value_and_gradient!!( p::Prepared{ - <:AutoMooncake, - <:VectorEvaluator, - <:MooncakeCache{:scalar,<:Any,<:Any,<:Tuple,<:Tuple}, + <:AutoMooncake,<:VectorEvaluator,<:MooncakeCache{:scalar,<:Any,<:Any,<:Tuple} }, x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) c = p.cache val, tangents = Mooncake.value_and_gradient!!( - c.cache, c.f, x, c.contexts...; args_to_zero=c.args_to_zero + c.cache, + c.f, + x, + c.contexts...; + args_to_zero=(false, true, map(_ -> false, c.contexts)...), ) return (val, tangents[2]) end From 88eea2a48c36e8367f061cfafb774f6695a8f229 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Thu, 14 May 2026 12:18:30 +0100 Subject: [PATCH 12/15] Mooncake ext: reject non-dense vectors; document NamedTuple check delegation and raw_gradient_target as unsafe Addresses PR #160 review comments: - Throw a clear `ArgumentError` for non-`DenseVector` inputs instead of letting Mooncake return a shape-incorrect tangent (reverse) or crash inside Mooncake (forward/Jacobian). - Document that NamedTuple input-shape validation is intentionally delegated to Mooncake's `PreparedCacheSpec` to avoid duplicating checks on every AD call. - Add a docstring on the vector `prepare` method describing `raw_gradient_target` as an unsafe escape hatch that bypasses evaluator indirection and shape checks. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLDifferentiationInterfaceExt.jl | 9 ++- ext/AbstractPPLMooncakeExt.jl | 74 ++++++++++++++----- test/ext/mooncake/main.jl | 10 +++ 3 files changed, 72 insertions(+), 21 deletions(-) diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 3dc70d26..3f0cc48f 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -17,7 +17,6 @@ struct DICache{UseContext,F,GP,JP} gradient_prep::GP jacobian_prep::JP function DICache{UseContext}(target::F, gp::GP, jp::JP) where {UseContext,F,GP,JP} - UseContext isa Bool || throw(ArgumentError("`UseContext` must be a Bool.")) return new{UseContext,F,GP,JP}(target, gp, jp) end end @@ -40,6 +39,9 @@ end @inline _wrap_cache(target, gp, jp, ::Val{UseContext}) where {UseContext} = DICache{UseContext}(target, gp, jp) +# `raw_gradient_target` is accepted for signature parity with the Mooncake +# extension's vector `prepare`, but DI has no equivalent context-lowering +# entry — only `nothing` is supported here. function AbstractPPL.prepare( adtype::AbstractADType, problem, @@ -47,6 +49,11 @@ function AbstractPPL.prepare( check_dims::Bool=true, raw_gradient_target=nothing, ) + raw_gradient_target === nothing || throw( + ArgumentError( + "`raw_gradient_target` is not supported by the DifferentiationInterface extension.", + ), + ) evaluator = AbstractPPL.prepare(problem, x; check_dims)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) if length(x) == 0 diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index 5bf6abe7..31157ef5 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -17,17 +17,11 @@ const _MooncakeAD = Union{AutoMooncake,AutoMooncakeForward} Mooncake.tangent_type(::Type{<:VectorEvaluator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{<:NamedTupleEvaluator}) = Mooncake.NoTangent -# Tag a Mooncake cache with the prepared evaluator's output arity (`:scalar` -# or `:vector`) so `value_and_gradient!!` / `value_and_jacobian!!` can raise -# helpful arity-mismatch errors instead of failing inside Mooncake. -# -# The optional `f` / `contexts` fields carry the lowered raw-target path -# opted into by `raw_gradient_target=(f, contexts)` on reverse-mode -# `AutoMooncake`. They default to `nothing`; dispatch on `CT<:Tuple` (vs -# `Nothing`) picks the lowered AD entry. `args_to_zero` is derived from -# `contexts` at the AD entry — it's a `false, true, false…` literal that -# constant-folds for any concrete `contexts` arity. `prepared(x)` still -# calls `problem(x)` — only the AD entry consults the lowered fields. +# `A` tags the evaluator's output arity (`:scalar`/`:vector`) so arity +# mismatches dispatch to a dedicated error method instead of failing inside +# Mooncake. `f`/`contexts` are `nothing` on the generic path; on the +# `raw_gradient_target` path they carry the lowered target so `CT<:Tuple` +# selects the lowered AD entry by dispatch. struct MooncakeCache{A,C,F,CT} cache::C f::F @@ -42,8 +36,6 @@ end _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config -# `value_and_gradient!!` accepts either a reverse-mode gradient cache -# (AutoMooncake) or a forward-mode derivative cache (AutoMooncakeForward). function _mooncake_gradient_cache(::AutoMooncake, f, x; config) return Mooncake.prepare_gradient_cache(f, x; config) end @@ -51,8 +43,6 @@ function _mooncake_gradient_cache(::AutoMooncakeForward, f, x; config) return Mooncake.prepare_derivative_cache(f, x; config) end -# `value_and_jacobian!!`: reverse mode wants a pullback cache, forward mode -# wants a derivative cache. function _mooncake_jacobian_cache(::AutoMooncake, f, x; config) return Mooncake.prepare_pullback_cache(f, x; config) end @@ -67,12 +57,48 @@ function AbstractPPL.prepare( check_dims::Bool=true, raw_gradient_target=nothing, ) + raw_gradient_target === nothing || throw( + ArgumentError( + "`raw_gradient_target` is only supported on the vector `prepare` path." + ), + ) evaluator = AbstractPPL.prepare(problem, values; check_dims)::NamedTupleEvaluator config = _mooncake_config(adtype) cache = _mooncake_gradient_cache(adtype, evaluator, values; config) return Prepared(adtype, evaluator, cache) end +""" + prepare(adtype::AutoMooncake, problem, x; check_dims=true, raw_gradient_target=nothing) + prepare(adtype::AutoMooncakeForward, problem, x; check_dims=true) + +Prepare a Mooncake gradient/Jacobian evaluator for a dense vector input. + +Non-`DenseVector` inputs (views, strided slices) are rejected: Mooncake +assumes a contiguous primal and otherwise returns shape-incorrect tangents +on reverse mode and crashes on forward/Jacobian paths. + +# `raw_gradient_target` (unsafe) + +Optional reverse-mode kwarg of the form `(f, contexts::Tuple)`. When +supplied, Mooncake compiles its tape against `f(x, contexts...)` instead of +the wrapping `VectorEvaluator`, which avoids the per-call indirection +through the evaluator on the AD hot path. + +This is an **unsafe escape hatch**: + + - The caller asserts `f(x, contexts...) ≡ evaluator(x)` for every `x` + Mooncake will see — AbstractPPL does not (and cannot) verify this. + - The AD pass calls `f(x, contexts...)` directly; the `VectorEvaluator` + wrapper is bypassed. Input shape is still validated up front by + `_check_ad_input` on the user-facing call. + - The `(f, contexts)` shape is destructured directly; malformed values + (e.g. a bare function, or `contexts` that isn't a tuple) will raise + `MethodError`/`BoundsError` rather than a structured `ArgumentError`. + +Use only when the indirection cost is measured and the equivalence is +known to hold. +""" function AbstractPPL.prepare( adtype::_MooncakeAD, problem, @@ -80,6 +106,13 @@ function AbstractPPL.prepare( check_dims::Bool=true, raw_gradient_target=nothing, ) + x isa DenseVector || throw( + ArgumentError( + "AutoMooncake / AutoMooncakeForward require a dense vector input " * + "(e.g. `Vector{<:Real}`); got $(typeof(x)). Wrap non-dense inputs " * + "(views, strided slices) with `collect` before calling `prepare`.", + ), + ) # Validate `raw_gradient_target` preconditions that don't need an arity # probe, so the probe `evaluator(x)` below cannot crash on user code that # assumes non-empty `x`. @@ -116,11 +149,12 @@ function AbstractPPL.prepare( return Prepared(adtype, evaluator, MooncakeCache{arity}(cache)) end -# `Mooncake.value_and_gradient!!` returns `(val, (∂f, ∂x))`; `∂f` is `NoTangent` -# because we registered `tangent_type(::Type{<:NamedTupleEvaluator}) = NoTangent` -# above, so the cache never carries a tangent for the user's problem. -# Shape validation is delegated to the inner `NamedTupleEvaluator{CheckInput}` -# callable Mooncake invokes — gated by the user's `check_dims` choice. +# Input-shape validation is delegated to the AD backend: Mooncake catches +# top-level NamedTuple-type mismatches, and the inner +# `NamedTupleEvaluator{CheckInput}` callable catches nested-array size +# mismatches (gated by `check_dims`). Running `_assert_namedtuple_shape` +# again here would duplicate the second check on every AD call. +# (`∂f` is `NoTangent` thanks to the `tangent_type` overload above.) @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:_MooncakeAD,<:NamedTupleEvaluator}, values::NamedTuple ) diff --git a/test/ext/mooncake/main.jl b/test/ext/mooncake/main.jl index d306ddea..76a4aaf5 100644 --- a/test/ext/mooncake/main.jl +++ b/test/ext/mooncake/main.jl @@ -75,4 +75,14 @@ using Test # Jacobian on a scalar-only lowered cache surfaces our arity-mismatch error. @test_throws r"vector-valued" AbstractPPL.value_and_jacobian!!(lowered, x) end + + @testset "dense vector requirement" begin + # Non-dense AbstractVectors (e.g. `view`s) are rejected up front rather + # than reaching Mooncake, where reverse-mode silently returns a + # `Mooncake.Tangent` and forward/Jacobian paths crash. + problem = x -> sum(abs2, x) + v = view([1.0, 2.0, 3.0], :) + @test_throws r"dense vector" prepare(AutoMooncake(), problem, v) + @test_throws r"dense vector" prepare(AutoMooncakeForward(), problem, v) + end end From c638d5c47482e1965656948aa92f27be3d93c24d Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 18 May 2026 12:45:08 +0100 Subject: [PATCH 13/15] Replace raw_gradient_target with context::Tuple=() on prepare Lift the lowered-AD escape hatch into a first-class API: every vector `prepare` now accepts `context::Tuple=()`, the prepared evaluator computes `problem(x, context...)`, and AD differentiates only `x`. `VectorEvaluator` carries the context as a third type parameter so callers can recover it from the evaluator without going through the kwarg again. Mooncake ext - Compile every scalar gradient cache on the raw `evaluator.f` / `evaluator.context` (not the raw `problem`/`context` kwargs), so a downstream override of structural `prepare` that returns a different `f`/`context` doesn't desync from the hot path. - Forward-mode `AutoMooncakeForward` now also accepts non-empty `context`. `_mooncake_value_and_gradient` dispatches reverse-mode to the `args_to_zero` kwarg form and forward-mode to the splat-only form (`ForwardCache` rejects `args_to_zero`). - Vector-jacobian path now also runs on `evaluator.f` for the same reason. - Empty input with non-empty `context` is supported (was rejected). The `MooncakeCache{arity,Nothing}` empty-input shortcut already evaluates `evaluator(x)` without invoking Mooncake. DI ext - `DICache{Mode}`: `Mode == :closure` for compiled-tape ReverseDiff, `Mode::Int == length(evaluator.context)` for the constants path. The `Int` doubles as documentation of how many `DI.Constant`s the AD call passes (`N + 1`, including `f`). - Single shared AD target `_call_evaluator(x, f, ctx::Vararg{Any,N}) where {N}`. `Vararg{Any,N}` forces specialization on the trailing arity. Docs/tests - New `Constant context arguments` section in `docs/src/evaluators.md`. - New regression tests covering: context threading via structural `prepare`, forward-mode Mooncake context, empty-input + non-empty context shortcut, and `DICache` mode-tag pinning. Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/src/evaluators.md | 17 ++ ext/AbstractPPLDifferentiationInterfaceExt.jl | 105 ++++++---- ext/AbstractPPLMooncakeExt.jl | 180 ++++++++---------- src/evaluators/Evaluators.jl | 39 ++-- test/evaluators/Evaluators.jl | 9 + test/ext/differentiationinterface/main.jl | 36 ++-- test/ext/mooncake/main.jl | 58 +++--- 7 files changed, 247 insertions(+), 197 deletions(-) diff --git a/docs/src/evaluators.md b/docs/src/evaluators.md index 4f1e1512..9dd6f8af 100644 --- a/docs/src/evaluators.md +++ b/docs/src/evaluators.md @@ -138,6 +138,23 @@ library invokes the inner callable many times with same-length dual arrays derived from a single user-supplied `x`; re-validating on each invocation would be redundant work in the hot path. +## Constant context arguments + +When the underlying callable naturally takes the form `f(x, context...)` — +where everything after `x` is constant state — pass `context` as a tuple to +the vector form of `prepare`. AD differentiates only w.r.t. `x`; every +value in `context` is treated as inactive: + +```julia +affine(x, scale, offset) = scale * sum(x) + offset +prepared = prepare(adtype, affine, zeros(3); context=(2.0, 1.0)) +val, grad = value_and_gradient!!(prepared, [1.0, 2.0, 3.0]) +# val == 2.0 * 6.0 + 1.0; grad == [2.0, 2.0, 2.0] +``` + +`prepared(x)` evaluates `f(x, context...)`, and `context=()` (the default) +preserves the unary `f(x)` shape. + ## Without an AD backend The two-argument form `prepare(problem, x)` is available without any AD diff --git a/ext/AbstractPPLDifferentiationInterfaceExt.jl b/ext/AbstractPPLDifferentiationInterfaceExt.jl index 3f0cc48f..1f4cfe8b 100644 --- a/ext/AbstractPPLDifferentiationInterfaceExt.jl +++ b/ext/AbstractPPLDifferentiationInterfaceExt.jl @@ -5,84 +5,105 @@ using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_ using ADTypes: AbstractADType, AutoReverseDiff using DifferentiationInterface: DifferentiationInterface as DI -# Differentiate only `x`; the evaluator is passed as a `DI.Constant` context so -# that in DynamicPPL the model and other evaluator state stay constant. -@inline _call_evaluator(x, evaluator) = evaluator(x) +# AD target used by both `DICache` modes. `Vararg{Any,N}` with a free `N` +# forces specialization on the trailing arity (a bare `Vararg{Any}` would +# skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the +# constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on +# the closure path — empty `ctx` then makes the splat a no-op. +@inline _call_evaluator(x, f::F, ctx::Vararg{Any,N}) where {F,N} = f(x, ctx...) -# `UseContext` is type-encoded so the dispatch between the context and -# no-context DI call is resolved at compile time; on tiny problems the runtime -# branch would otherwise show up as fixed overhead in the AD hot path. -struct DICache{UseContext,F,GP,JP} +# `Mode` tags the cache shape: +# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure, +# the AD call passes **0** `DI.Constant`s. +# * `N::Int` — constants path: `N == length(evaluator.context)`, the +# AD call passes **N + 1** `DI.Constant`s (`f` plus the +# `N` context values). +# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*` +# at compile time without a runtime branch. +struct DICache{Mode,F,GP,JP} target::F gradient_prep::GP jacobian_prep::JP - function DICache{UseContext}(target::F, gp::GP, jp::JP) where {UseContext,F,GP,JP} - return new{UseContext,F,GP,JP}(target, gp, jp) + function DICache{Mode}(target::F, gp::GP, jp::JP) where {Mode,F,GP,JP} + return new{Mode,F,GP,JP}(target, gp, jp) end end # Compiled ReverseDiff only reuses a compiled tape on the one-argument path; # `DI.Constant` deactivates tape recording, so close the evaluator into the -# target and call DI without contexts. The trailing `Val(false)`/`Val(true)` -# carries `UseContext` to the `DICache` constructor at compile time. +# target and call DI without constants. Context (if any) is captured inside +# the evaluator closure rather than lowered out — the lowered path would also +# require a closure here, so the wrapper cost is unavoidable for compiled tapes. function _prepare_di(prep::F, adtype::AutoReverseDiff{true}, x, evaluator) where {F} target = Base.Fix2(_call_evaluator, evaluator) - return target, prep(target, adtype, x), Val(false) + return target, prep(target, adtype, x), Val(:closure) end function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F} + constants = (DI.Constant(evaluator.f), map(DI.Constant, evaluator.context)...) return ( - _call_evaluator, prep(_call_evaluator, adtype, x, DI.Constant(evaluator)), Val(true) + _call_evaluator, + prep(_call_evaluator, adtype, x, constants...), + Val(length(evaluator.context)), ) end -@inline _wrap_cache(target, gp, jp, ::Val{UseContext}) where {UseContext} = - DICache{UseContext}(target, gp, jp) +@inline _wrap_cache(target, gp, jp, ::Val{Mode}) where {Mode} = + DICache{Mode}(target, gp, jp) -# `raw_gradient_target` is accepted for signature parity with the Mooncake -# extension's vector `prepare`, but DI has no equivalent context-lowering -# entry — only `nothing` is supported here. function AbstractPPL.prepare( adtype::AbstractADType, problem, x::AbstractVector{<:Real}; check_dims::Bool=true, - raw_gradient_target=nothing, + context::Tuple=(), ) - raw_gradient_target === nothing || throw( - ArgumentError( - "`raw_gradient_target` is not supported by the DifferentiationInterface extension.", - ), - ) - evaluator = AbstractPPL.prepare(problem, x; check_dims)::VectorEvaluator + evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) if length(x) == 0 - # DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`); the - # `Val(0)` sentinel keeps the `gradient_prep === nothing` arity check - # meaningful. `UseContext` is irrelevant on this shortcut path — the AD - # entry returns `(p.evaluator(x), T[])` before any DI call. + # DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`). + # `Val(0)` is an arity sentinel for the `gradient_prep === nothing` + # check below; the AD entry short-circuits before any DI call. gp, jp = arity === :scalar ? (Val(0), nothing) : (nothing, Val(0)) - return Prepared(adtype, evaluator, DICache{true}(_call_evaluator, gp, jp)) + cache = _wrap_cache(_call_evaluator, gp, jp, Val(length(context))) + return Prepared(adtype, evaluator, cache) end if arity === :scalar - target, gradient_prep, ctx = _prepare_di(DI.prepare_gradient, adtype, x, evaluator) - return Prepared(adtype, evaluator, _wrap_cache(target, gradient_prep, nothing, ctx)) + target, gradient_prep, mode = _prepare_di(DI.prepare_gradient, adtype, x, evaluator) + return Prepared( + adtype, evaluator, _wrap_cache(target, gradient_prep, nothing, mode) + ) end - target, jacobian_prep, ctx = _prepare_di(DI.prepare_jacobian, adtype, x, evaluator) - return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, ctx)) + target, jacobian_prep, mode = _prepare_di(DI.prepare_jacobian, adtype, x, evaluator) + return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, mode)) end -# Compile-time dispatch on the `UseContext` type parameter eliminates the -# context-vs-no-context branch from the AD hot path. -@inline _di_value_and_gradient(c::DICache{true}, ad, x, eval) = - DI.value_and_gradient(c.target, c.gradient_prep, ad, x, DI.Constant(eval)) -@inline _di_value_and_gradient(c::DICache{false}, ad, x, _) = +# Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile +# time. The unconstrained method matches every non-`:closure` `Mode` (i.e. +# any `Int N`); `:closure` is strictly more specific and wins for compiled +# tapes. On the constants path we always pass `DI.Constant(eval.f)` plus the +# `N` context constants — `N == 0` collapses the `map` splat to nothing. +@inline _di_value_and_gradient(c::DICache{:closure}, ad, x, _) = DI.value_and_gradient(c.target, c.gradient_prep, ad, x) +@inline _di_value_and_gradient(c::DICache, ad, x, eval) = DI.value_and_gradient( + c.target, + c.gradient_prep, + ad, + x, + DI.Constant(eval.f), + map(DI.Constant, eval.context)..., +) -@inline _di_value_and_jacobian(c::DICache{true}, ad, x, eval) = - DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x, DI.Constant(eval)) -@inline _di_value_and_jacobian(c::DICache{false}, ad, x, _) = +@inline _di_value_and_jacobian(c::DICache{:closure}, ad, x, _) = DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x) +@inline _di_value_and_jacobian(c::DICache, ad, x, eval) = DI.value_and_jacobian( + c.target, + c.jacobian_prep, + ad, + x, + DI.Constant(eval.f), + map(DI.Constant, eval.context)..., +) @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T} diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index 31157ef5..554d36aa 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -8,31 +8,23 @@ using Mooncake: Mooncake const _MooncakeAD = Union{AutoMooncake,AutoMooncakeForward} -# Tell Mooncake that the evaluator wrappers are constants from its -# perspective: their fields hold the user's problem state, which Mooncake -# would otherwise derive a nested `Tangent{NamedTuple{f::Tangent{...}}}` for -# and walk on every backward pass. The evaluators are AbstractPPL's own -# types and only ever appear as the callable argument to Mooncake — no -# downstream caller asks for a gradient w.r.t. them. +# `NamedTupleEvaluator` is the callable on the NamedTuple path; `NoTangent` +# stops Mooncake from deriving a `Tangent{NamedTuple{...}}` for its fields +# on every backward pass. `VectorEvaluator` is no longer passed to Mooncake +# after the raw-target merge; the override is kept as a defensive guard. Mooncake.tangent_type(::Type{<:VectorEvaluator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{<:NamedTupleEvaluator}) = Mooncake.NoTangent -# `A` tags the evaluator's output arity (`:scalar`/`:vector`) so arity -# mismatches dispatch to a dedicated error method instead of failing inside -# Mooncake. `f`/`contexts` are `nothing` on the generic path; on the -# `raw_gradient_target` path they carry the lowered target so `CT<:Tuple` -# selects the lowered AD entry by dispatch. -struct MooncakeCache{A,C,F,CT} +# Type parameters: +# +# * `A::Symbol` — output arity, `:scalar` or `:vector`. Drives the +# gradient/jacobian dispatch and the arity-mismatch errors. +# * `C` — the underlying Mooncake cache, or `Nothing` for the +# empty-input shortcut. +struct MooncakeCache{A,C} cache::C - f::F - contexts::CT -end -function MooncakeCache{A}(cache::C) where {A,C} - return MooncakeCache{A,C,Nothing,Nothing}(cache, nothing, nothing) -end -function MooncakeCache{A}(cache::C, f::F, contexts::CT) where {A,C,F,CT<:Tuple} - return MooncakeCache{A,C,F,CT}(cache, f, contexts) end +MooncakeCache{A}(cache::C) where {A,C} = MooncakeCache{A,C}(cache) _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config @@ -43,6 +35,15 @@ function _mooncake_gradient_cache(::AutoMooncakeForward, f, x; config) return Mooncake.prepare_derivative_cache(f, x; config) end +# Vector-scalar overloads — splat `context` into the underlying Mooncake +# prep call (empty tuple is a no-op). +function _mooncake_gradient_cache(::AutoMooncake, f, x, context::Tuple; config) + return Mooncake.prepare_gradient_cache(f, x, context...; config) +end +function _mooncake_gradient_cache(::AutoMooncakeForward, f, x, context::Tuple; config) + return Mooncake.prepare_derivative_cache(f, x, context...; config) +end + function _mooncake_jacobian_cache(::AutoMooncake, f, x; config) return Mooncake.prepare_pullback_cache(f, x; config) end @@ -51,17 +52,8 @@ function _mooncake_jacobian_cache(::AutoMooncakeForward, f, x; config) end function AbstractPPL.prepare( - adtype::_MooncakeAD, - problem, - values::NamedTuple; - check_dims::Bool=true, - raw_gradient_target=nothing, + adtype::_MooncakeAD, problem, values::NamedTuple; check_dims::Bool=true ) - raw_gradient_target === nothing || throw( - ArgumentError( - "`raw_gradient_target` is only supported on the vector `prepare` path." - ), - ) evaluator = AbstractPPL.prepare(problem, values; check_dims)::NamedTupleEvaluator config = _mooncake_config(adtype) cache = _mooncake_gradient_cache(adtype, evaluator, values; config) @@ -69,8 +61,8 @@ function AbstractPPL.prepare( end """ - prepare(adtype::AutoMooncake, problem, x; check_dims=true, raw_gradient_target=nothing) - prepare(adtype::AutoMooncakeForward, problem, x; check_dims=true) + prepare(adtype::AutoMooncake, problem, x; check_dims=true, context::Tuple=()) + prepare(adtype::AutoMooncakeForward, problem, x; check_dims=true, context::Tuple=()) Prepare a Mooncake gradient/Jacobian evaluator for a dense vector input. @@ -78,33 +70,21 @@ Non-`DenseVector` inputs (views, strided slices) are rejected: Mooncake assumes a contiguous primal and otherwise returns shape-incorrect tangents on reverse mode and crashes on forward/Jacobian paths. -# `raw_gradient_target` (unsafe) - -Optional reverse-mode kwarg of the form `(f, contexts::Tuple)`. When -supplied, Mooncake compiles its tape against `f(x, contexts...)` instead of -the wrapping `VectorEvaluator`, which avoids the per-call indirection -through the evaluator on the AD hot path. - -This is an **unsafe escape hatch**: - - - The caller asserts `f(x, contexts...) ≡ evaluator(x)` for every `x` - Mooncake will see — AbstractPPL does not (and cannot) verify this. - - The AD pass calls `f(x, contexts...)` directly; the `VectorEvaluator` - wrapper is bypassed. Input shape is still validated up front by - `_check_ad_input` on the user-facing call. - - The `(f, contexts)` shape is destructured directly; malformed values - (e.g. a bare function, or `contexts` that isn't a tuple) will raise - `MethodError`/`BoundsError` rather than a structured `ArgumentError`. +`context` follows the base `prepare` contract — the prepared evaluator +computes `problem(x, context...)` with AD differentiating only `x`. One +Mooncake-specific restriction: vector-valued problems require `context=()`. -Use only when the indirection cost is measured and the equivalence is -known to hold. +Empty input (`length(x) == 0`) is supported with any `context`; Mooncake +builds no tape for zero-length `x`, so the prepared evaluator's AD entry +short-circuits to `(problem(x, context...), eltype(x)[])` without invoking +Mooncake. """ function AbstractPPL.prepare( adtype::_MooncakeAD, problem, x::AbstractVector{<:Real}; check_dims::Bool=true, - raw_gradient_target=nothing, + context::Tuple=(), ) x isa DenseVector || throw( ArgumentError( @@ -113,38 +93,30 @@ function AbstractPPL.prepare( "(views, strided slices) with `collect` before calling `prepare`.", ), ) - # Validate `raw_gradient_target` preconditions that don't need an arity - # probe, so the probe `evaluator(x)` below cannot crash on user code that - # assumes non-empty `x`. - if raw_gradient_target !== nothing - adtype isa AutoMooncake || throw( - ArgumentError( - "`raw_gradient_target` is only supported with reverse-mode `AutoMooncake`.", - ), - ) - length(x) > 0 || - throw(ArgumentError("`raw_gradient_target` is not supported for empty input.")) - end - evaluator = AbstractPPL.prepare(problem, x; check_dims)::VectorEvaluator + evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator arity = _ad_output_arity(evaluator(x)) config = _mooncake_config(adtype) - if raw_gradient_target !== nothing - arity === :scalar || throw( + if !isempty(evaluator.context) && arity !== :scalar + throw( ArgumentError( - "`raw_gradient_target` is only supported for scalar-valued problems." + "Non-empty `context` is only supported for scalar-valued problems." ), ) - f, contexts = raw_gradient_target - cache = Mooncake.prepare_gradient_cache(f, x, contexts...; config) - return Prepared(adtype, evaluator, MooncakeCache{:scalar}(cache, f, contexts)) end # Mooncake builds no tape for length-zero `x`; tag with `Nothing` so the - # empty-input methods below shortcut without invoking Mooncake. + # empty-input methods below shortcut without invoking Mooncake. Empty `x` + # with non-empty context also routes here — the hot-path shortcut just + # calls `p.evaluator(x)` which already does `f([], context...)`. length(x) == 0 && return Prepared(adtype, evaluator, MooncakeCache{arity}(nothing)) + # Compile the tape on the evaluator's `f` and `context` (not the raw + # `problem` passed in): a downstream override of structural `prepare` + # may return a `VectorEvaluator` whose `.f`/`.context` differ from the + # caller-supplied `problem`/`context`. The hot path uses `evaluator.f` + # / `evaluator.context`, so the cache must agree. cache = if arity === :scalar - _mooncake_gradient_cache(adtype, evaluator, x; config) + _mooncake_gradient_cache(adtype, evaluator.f, x, evaluator.context; config) else - _mooncake_jacobian_cache(adtype, evaluator, x; config) + _mooncake_jacobian_cache(adtype, evaluator.f, x; config) end return Prepared(adtype, evaluator, MooncakeCache{arity}(cache)) end @@ -162,9 +134,10 @@ end return (val, grad) end -# Empty-input shortcut: tagged with `MooncakeCache{…,Nothing}` at prepare time -# so dispatch resolves the no-Mooncake path at compile time — no runtime -# `isnothing(cache)` branch in the hot path. +# Empty-input shortcut. `MooncakeCache{:scalar,Nothing}` is strictly more +# specific than `MooncakeCache{:scalar}` on `C`, so dispatch unambiguously +# selects this method over the general scalar-gradient hot path below for +# zero-length `x`. @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar,Nothing}}, x::AbstractVector{T}, @@ -173,35 +146,30 @@ end return (p.evaluator(x), T[]) end -@inline function AbstractPPL.value_and_gradient!!( - p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, - x::AbstractVector{T}, -) where {T<:Real} - Evaluators._check_ad_input(p.evaluator, x) - val, (_, grad) = Mooncake.value_and_gradient!!(p.cache.cache, p.evaluator, x) - return (val, grad) -end +# Reverse-mode `Mooncake.Cache` needs `args_to_zero` to mark `x` as the lone +# active input (`false` on `f`, `true` on `x`, `false` on each context value). +# Forward-mode `ForwardCache` derives activity from its seeded argument and +# rejects the kwarg. Dispatching on the AD type keeps each call mode-specific +# without a runtime branch. +@inline _mooncake_value_and_gradient( + ::AutoMooncake, cache, f::F, x, context::Tuple +) where {F} = Mooncake.value_and_gradient!!( + cache, f, x, context...; args_to_zero=(false, true, map(_ -> false, context)...) +) +@inline _mooncake_value_and_gradient( + ::AutoMooncakeForward, cache, f::F, x, context::Tuple +) where {F} = Mooncake.value_and_gradient!!(cache, f, x, context...) -# Lowered raw-target gradient — `p.cache.f(x, p.cache.contexts...) ≡ p.evaluator(x)` -# by the `raw_gradient_target` contract. Mooncake's tape was compiled on the -# raw shape, sidestepping the fixed `evaluator(x)` overhead. `CT<:Tuple` -# distinguishes the lowered cache from the generic one (where `CT=Nothing`). -# `args_to_zero` is constant-folded from `c.contexts`'s arity at compile time. +# Scalar-gradient hot path. Empty `context` collapses the splat and reduces +# `args_to_zero` to `(false, true)`. `tangents[2]` is the `x`-gradient — the +# trailing entries (one per context value) are zeroed and discarded. @inline function AbstractPPL.value_and_gradient!!( - p::Prepared{ - <:AutoMooncake,<:VectorEvaluator,<:MooncakeCache{:scalar,<:Any,<:Any,<:Tuple} - }, + p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) - c = p.cache - val, tangents = Mooncake.value_and_gradient!!( - c.cache, - c.f, - x, - c.contexts...; - args_to_zero=(false, true, map(_ -> false, c.contexts)...), - ) + e = p.evaluator + val, tangents = _mooncake_value_and_gradient(p.adtype, p.cache.cache, e.f, x, e.context) return (val, tangents[2]) end @@ -222,8 +190,8 @@ end return Evaluators._throw_jacobian_needs_vector() end -# Empty-input jacobian shortcut — same compile-time dispatch trick as the -# scalar Nothing-tagged case; skips Mooncake entirely. +# Empty-input jacobian shortcut. Same `Nothing` specificity trick as the +# scalar case above; skips Mooncake entirely. @inline function AbstractPPL.value_and_jacobian!!( p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:vector,Nothing}}, x::AbstractVector{T}, @@ -238,7 +206,11 @@ end x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) - return Mooncake.value_and_jacobian!!(p.cache.cache, p.evaluator, x) + # Vector arity rejects non-empty `context` at prepare time, so the tape + # is compiled on `problem(x)` and there is no splat or `args_to_zero` to + # propagate. Mooncake's `value_and_jacobian!!` returns `(val, jac)` + # directly with `x` as the only active argument. + return Mooncake.value_and_jacobian!!(p.cache.cache, p.evaluator.f, x) end end # module diff --git a/src/evaluators/Evaluators.jl b/src/evaluators/Evaluators.jl index e4c8b977..dfae89f3 100644 --- a/src/evaluators/Evaluators.jl +++ b/src/evaluators/Evaluators.jl @@ -41,8 +41,8 @@ Prepared(adtype::AbstractADType, evaluator) = Prepared(adtype, evaluator, nothin """ prepare(problem, values::NamedTuple; check_dims::Bool=true) - prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true) - prepare(adtype, problem, x::AbstractVector{<:Real}; check_dims::Bool=true) + prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=()) + prepare(adtype, problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=()) Prepare a callable evaluator for `problem`. @@ -56,6 +56,11 @@ the input shape on each call. Pass `check_dims=false` to skip the per-call check, e.g. inside an AD backend's hot path where the input shape is already guaranteed. +The vector-input forms accept a `context::Tuple` of constant arguments threaded +through to `problem`: the prepared evaluator computes `problem(x, context...)`, +and AD backends differentiate only with respect to `x`. `context=()` (the +default) preserves the unary `problem(x)` contract. + The three-argument AD-aware form may invoke `problem` once during preparation to detect output arity (scalar vs vector) and select gradient or jacobian machinery accordingly. Avoid `prepare` calls when `problem` has side effects @@ -69,8 +74,10 @@ function prepare end function prepare(problem, values::NamedTuple; check_dims::Bool=true) return NamedTupleEvaluator{check_dims}(problem, values) end -function prepare(problem, x::AbstractVector{<:Real}; check_dims::Bool=true) - return VectorEvaluator{check_dims}(problem, length(x)) +function prepare( + problem, x::AbstractVector{<:Real}; check_dims::Bool=true, context::Tuple=() +) + return VectorEvaluator{check_dims}(problem, length(x), context) end """ @@ -93,8 +100,8 @@ The Jacobian has shape `(length(value), length(x))`. function value_and_jacobian!! end """ - VectorEvaluator{CheckInput}(f, dim) - VectorEvaluator(f, dim) # equivalent to `VectorEvaluator{true}(f, dim)` + VectorEvaluator{CheckInput}(f, dim, context::Tuple=()) + VectorEvaluator(f, dim, context::Tuple=()) # equivalent to `VectorEvaluator{true}(f, dim, context)` Evaluator shape for scalar functions of a vector input. Part of the extension author API; end users interact with the wrapping `Prepared` instead. @@ -105,20 +112,28 @@ author API; end users interact with the wrapping `Prepared` instead. where input shape is already guaranteed and the runtime check would persist in the dual/shadow hot path. +`context` is a tuple of constant arguments threaded through to `f`: +`evaluator(x)` computes `f(x, context...)`. AD backends treat every value in +`context` as inactive and differentiate only with respect to `x`. The default +empty tuple keeps the unary `f(x)` contract. + A bare `VectorEvaluator` is *not* differentiable; gradient capability is the contract of the wrapping `Prepared` returned by `prepare(adtype, ...)`. """ -struct VectorEvaluator{CheckInput,F} +struct VectorEvaluator{CheckInput,F,C<:Tuple} f::F dim::Int - function VectorEvaluator{CheckInput}(f::F, dim::Int) where {CheckInput,F} + context::C + function VectorEvaluator{CheckInput}( + f::F, dim::Int, context::C=() + ) where {CheckInput,F,C<:Tuple} CheckInput isa Bool || throw(ArgumentError("`CheckInput` must be a Bool.")) dim >= 0 || throw(ArgumentError("`dim` must be non-negative, got $dim.")) - return new{CheckInput,F}(f, dim) + return new{CheckInput,F,C}(f, dim, context) end end -VectorEvaluator(f, dim::Int) = VectorEvaluator{true}(f, dim) +VectorEvaluator(f, dim::Int, context::Tuple=()) = VectorEvaluator{true}(f, dim, context) """ NamedTupleEvaluator{CheckInput}(f, inputspec) @@ -195,12 +210,12 @@ _check_ad_input(::VectorEvaluator{false}, ::AbstractVector) = nothing function (e::VectorEvaluator{true})(x::AbstractVector{T}) where {T} T <: Integer && _reject_integer_input(x) _check_vector_length(e.dim, x) - return e.f(x) + return e.f(x, e.context...) end function (e::VectorEvaluator{false})(x::AbstractVector{T}) where {T} T <: Integer && _reject_integer_input(x) - return e.f(x) + return e.f(x, e.context...) end function (e::NamedTupleEvaluator{true})(values::NamedTuple) diff --git a/test/evaluators/Evaluators.jl b/test/evaluators/Evaluators.jl index 6aaab96c..8b96d445 100644 --- a/test/evaluators/Evaluators.jl +++ b/test/evaluators/Evaluators.jl @@ -117,6 +117,15 @@ end pv_unchecked = prepare(sum, zeros(3); check_dims=false) @test pv_unchecked isa VectorEvaluator{false} @test pv_unchecked([1.0, 2.0]) == 3.0 # wrong length, no error + + # `context` threads constant args through to the callable; AD-unaware + # `prepare` constructs the `VectorEvaluator` with the same shape and + # `prepared(x)` evaluates `f(x, context...)`. + affine(x, a, b) = sum(x) * a + b + pv_ctx = prepare(affine, zeros(2); context=(2.0, 1.0)) + @test pv_ctx isa VectorEvaluator{true} + @test pv_ctx.context === (2.0, 1.0) + @test pv_ctx([3.0, 4.0]) == 15.0 end @testset "prepare (AD-aware)" begin diff --git a/test/ext/differentiationinterface/main.jl b/test/ext/differentiationinterface/main.jl index 9788eda7..636f237e 100644 --- a/test/ext/differentiationinterface/main.jl +++ b/test/ext/differentiationinterface/main.jl @@ -31,20 +31,30 @@ quadratic(x::AbstractVector{<:Real}) = sum(xi -> xi^2, x) run_testcases(Val(:edge); adtype=adtype) end - # `DICache` encodes `UseContext` as a type parameter so the - # context-vs-no-context DI call is resolved by dispatch, not a runtime - # `Bool` branch in the AD hot path. - @testset "DICache encodes UseContext as a type parameter" begin + # `DICache`'s `Mode` parameter is either `:closure` (compiled-tape + # ReverseDiff) or the integer context length on the constants path. The + # constants-path integer also documents how many `DI.Constant`s the AD + # call passes. + @testset "DICache encodes the call mode as a type parameter" begin x = [1.0, 2.0, 3.0] - prep_ctx = prepare(AutoForwardDiff(), quadratic, x) - prep_noctx = prepare(AutoReverseDiff(; compile=true), quadratic, x) - - @test prep_ctx.cache isa DIExt.DICache{true} - @test prep_noctx.cache isa DIExt.DICache{false} - @test !hasfield(typeof(prep_ctx.cache), :use_context) - - # Hot path is type-stable on both branches. - @inferred value_and_gradient!!(prep_ctx, x) + prep_noctx = prepare(AutoForwardDiff(), quadratic, x) + prep_closure = prepare(AutoReverseDiff(; compile=true), quadratic, x) + affine(y, a, b) = a * sum(abs2, y) + b + prep_ctx = prepare(AutoForwardDiff(), affine, x; context=(2.0, 1.0)) + + @test prep_noctx.cache isa DIExt.DICache{0} + @test prep_closure.cache isa DIExt.DICache{:closure} + @test prep_ctx.cache isa DIExt.DICache{2} + + # Non-empty-context primal matches the underlying `f(x, context...)`. + @test prep_ctx(x) == affine(x, 2.0, 1.0) + val, grad = value_and_gradient!!(prep_ctx, x) + @test val == affine(x, 2.0, 1.0) + @test grad ≈ [4.0, 8.0, 12.0] # 2 * 2x + + # Hot path is type-stable on all three preps. @inferred value_and_gradient!!(prep_noctx, x) + @inferred value_and_gradient!!(prep_closure, x) + @inferred value_and_gradient!!(prep_ctx, x) end end diff --git a/test/ext/mooncake/main.jl b/test/ext/mooncake/main.jl index 76a4aaf5..3c328a1a 100644 --- a/test/ext/mooncake/main.jl +++ b/test/ext/mooncake/main.jl @@ -21,7 +21,7 @@ using Test end end - @testset "raw_gradient_target" begin + @testset "context-lowered gradient" begin struct TinyProblem{T} offset::T end @@ -34,43 +34,49 @@ using Test generic = prepare(ad, problem, x; check_dims=false) lowered = prepare( - ad, - problem, - x; - check_dims=false, - raw_gradient_target=(raw_logdensity, (problem.offset,)), + ad, raw_logdensity, x; check_dims=false, context=(problem.offset,) ) - # `prepared(x)` still calls `problem(x)` on both paths. + # `prepared(x)` evaluates `problem(x)` on the generic path and + # `raw_logdensity(x, context...)` on the lowered path; both should + # produce the same scalar. @test generic(x) == problem(x) @test lowered(x) == problem(x) # Same value and gradient as the generic path. @test value_and_gradient!!(generic, x) == value_and_gradient!!(lowered, x) - # Rejects on forward mode, vector-valued problems, and empty input. - vec_problem = x -> [x[1]^2, x[1] + 1.0] - @test_throws ArgumentError prepare( - AutoMooncakeForward(; config=nothing), - problem, - x; - check_dims=false, - raw_gradient_target=(raw_logdensity, (problem.offset,)), + # Forward mode supports context too — same primal and (approximately) + # the same derivative as the reverse-mode lowered path on this scalar + # problem. Use `≈` because forward and reverse may differ in the last + # ULPs. + ad_fwd = AutoMooncakeForward(; config=nothing) + lowered_fwd = prepare( + ad_fwd, raw_logdensity, x; check_dims=false, context=(problem.offset,) ) + @test lowered_fwd(x) == problem(x) + val_fwd, grad_fwd = value_and_gradient!!(lowered_fwd, x) + val_rev, grad_rev = value_and_gradient!!(lowered, x) + @test val_fwd ≈ val_rev atol = 1e-12 + @test grad_fwd ≈ grad_rev atol = 1e-12 + + # Rejects on vector-valued problems with non-empty context. + vec_problem(y, c) = [y[1] * c, y[1] + c] @test_throws ArgumentError prepare( - ad, - vec_problem, - x; - check_dims=false, - raw_gradient_target=((y, c) -> [y[1] * c], (1.0,)), + ad, vec_problem, x; check_dims=false, context=(1.0,) ) - @test_throws ArgumentError prepare( - ad, - problem, - Float64[]; - check_dims=false, - raw_gradient_target=(raw_logdensity, (problem.offset,)), + + # Empty input with non-empty context is supported — the empty-input + # shortcut bypasses Mooncake and just calls `f([], context...)`. Use + # a `sum(...; init=0.0)`-based `f` since `raw_logdensity` indexes `x[1]`. + empty_logdensity(y::AbstractVector{<:Real}, offset) = + sum(y; init=zero(eltype(y))) + offset + empty_lowered = prepare( + ad, empty_logdensity, Float64[]; check_dims=false, context=(0.5,) ) + val0, grad0 = value_and_gradient!!(empty_lowered, Float64[]) + @test val0 == empty_logdensity(Float64[], 0.5) + @test grad0 == Float64[] # Jacobian on a scalar-only lowered cache surfaces our arity-mismatch error. @test_throws r"vector-valued" AbstractPPL.value_and_jacobian!!(lowered, x) From 08a611d3fbdc622b1c36747eb43b8e8dc1a3416c Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 18 May 2026 14:56:49 +0100 Subject: [PATCH 14/15] Mooncake ext: inline single-call helpers, trim stale design-history comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three single-call helpers folded into their unique call sites: - 4-arg `_mooncake_gradient_cache(adtype, f, x, context::Tuple)` and the `_mooncake_jacobian_cache` pair → one `if adtype isa AutoMooncake ... end` branch in vector `prepare`. The `isa` is compile-folded since `adtype`'s concrete type lives in the method's specialization. - `_mooncake_value_and_gradient(::Auto*, ...)` → an inlined branch in the scalar-gradient hot path's `value_and_gradient!!` body, using the same compile-folded `isa`. Kept in a single `<:_MooncakeAD` method so the empty-input shortcut's cache specificity (`MooncakeCache{:scalar,Nothing}`) doesn't clash with a per-AD-type method. The 3-arg `_mooncake_gradient_cache(::Auto*, f, x)` pair (NamedTuple path) stays factored — it's reused by the NamedTuple `prepare` and has a distinct call shape (evaluator + values, not raw `f` + context splat). Comment cleanup: - `tangent_type` defensive guard now describes the current state rather than the refactor that produced it ("after the raw-target merge"). - Cache-prep block trimmed from two paragraphs to one — kept the evaluator-as-source-of-truth rationale and the forward-mode unification note, dropped the redundant splat-no-op elaboration. Co-Authored-By: Claude Opus 4.7 (1M context) --- ext/AbstractPPLMooncakeExt.jl | 79 ++++++++++++++++------------------- 1 file changed, 36 insertions(+), 43 deletions(-) diff --git a/ext/AbstractPPLMooncakeExt.jl b/ext/AbstractPPLMooncakeExt.jl index 554d36aa..b07af8ae 100644 --- a/ext/AbstractPPLMooncakeExt.jl +++ b/ext/AbstractPPLMooncakeExt.jl @@ -10,8 +10,8 @@ const _MooncakeAD = Union{AutoMooncake,AutoMooncakeForward} # `NamedTupleEvaluator` is the callable on the NamedTuple path; `NoTangent` # stops Mooncake from deriving a `Tangent{NamedTuple{...}}` for its fields -# on every backward pass. `VectorEvaluator` is no longer passed to Mooncake -# after the raw-target merge; the override is kept as a defensive guard. +# on every backward pass. The `VectorEvaluator` override is a defensive +# guard — vector preps no longer pass the evaluator wrapper to Mooncake. Mooncake.tangent_type(::Type{<:VectorEvaluator}) = Mooncake.NoTangent Mooncake.tangent_type(::Type{<:NamedTupleEvaluator}) = Mooncake.NoTangent @@ -28,6 +28,8 @@ MooncakeCache{A}(cache::C) where {A,C} = MooncakeCache{A,C}(cache) _mooncake_config(adtype) = adtype.config === nothing ? Mooncake.Config() : adtype.config +# NamedTuple-path helper: Mooncake exposes separate `prepare_*_cache` +# entries per AD mode but the call shape (target + values) is the same. function _mooncake_gradient_cache(::AutoMooncake, f, x; config) return Mooncake.prepare_gradient_cache(f, x; config) end @@ -35,22 +37,6 @@ function _mooncake_gradient_cache(::AutoMooncakeForward, f, x; config) return Mooncake.prepare_derivative_cache(f, x; config) end -# Vector-scalar overloads — splat `context` into the underlying Mooncake -# prep call (empty tuple is a no-op). -function _mooncake_gradient_cache(::AutoMooncake, f, x, context::Tuple; config) - return Mooncake.prepare_gradient_cache(f, x, context...; config) -end -function _mooncake_gradient_cache(::AutoMooncakeForward, f, x, context::Tuple; config) - return Mooncake.prepare_derivative_cache(f, x, context...; config) -end - -function _mooncake_jacobian_cache(::AutoMooncake, f, x; config) - return Mooncake.prepare_pullback_cache(f, x; config) -end -function _mooncake_jacobian_cache(::AutoMooncakeForward, f, x; config) - return Mooncake.prepare_derivative_cache(f, x; config) -end - function AbstractPPL.prepare( adtype::_MooncakeAD, problem, values::NamedTuple; check_dims::Bool=true ) @@ -109,14 +95,19 @@ function AbstractPPL.prepare( # calls `p.evaluator(x)` which already does `f([], context...)`. length(x) == 0 && return Prepared(adtype, evaluator, MooncakeCache{arity}(nothing)) # Compile the tape on the evaluator's `f` and `context` (not the raw - # `problem` passed in): a downstream override of structural `prepare` - # may return a `VectorEvaluator` whose `.f`/`.context` differ from the - # caller-supplied `problem`/`context`. The hot path uses `evaluator.f` - # / `evaluator.context`, so the cache must agree. - cache = if arity === :scalar - _mooncake_gradient_cache(adtype, evaluator.f, x, evaluator.context; config) + # `problem` / `context` kwargs): a downstream override of structural + # `prepare` may return a `VectorEvaluator` whose `.f`/`.context` differ + # from the caller-supplied values, and the hot path reads them off the + # evaluator. Forward mode uses `prepare_derivative_cache` for both + # arities; the splat is a no-op for vector arity (empty `context`). + cache = if adtype isa AutoMooncake + if arity === :scalar + Mooncake.prepare_gradient_cache(evaluator.f, x, evaluator.context...; config) + else + Mooncake.prepare_pullback_cache(evaluator.f, x; config) + end else - _mooncake_jacobian_cache(adtype, evaluator.f, x; config) + Mooncake.prepare_derivative_cache(evaluator.f, x, evaluator.context...; config) end return Prepared(adtype, evaluator, MooncakeCache{arity}(cache)) end @@ -146,30 +137,32 @@ end return (p.evaluator(x), T[]) end -# Reverse-mode `Mooncake.Cache` needs `args_to_zero` to mark `x` as the lone -# active input (`false` on `f`, `true` on `x`, `false` on each context value). -# Forward-mode `ForwardCache` derives activity from its seeded argument and -# rejects the kwarg. Dispatching on the AD type keeps each call mode-specific -# without a runtime branch. -@inline _mooncake_value_and_gradient( - ::AutoMooncake, cache, f::F, x, context::Tuple -) where {F} = Mooncake.value_and_gradient!!( - cache, f, x, context...; args_to_zero=(false, true, map(_ -> false, context)...) -) -@inline _mooncake_value_and_gradient( - ::AutoMooncakeForward, cache, f::F, x, context::Tuple -) where {F} = Mooncake.value_and_gradient!!(cache, f, x, context...) - -# Scalar-gradient hot path. Empty `context` collapses the splat and reduces -# `args_to_zero` to `(false, true)`. `tangents[2]` is the `x`-gradient — the -# trailing entries (one per context value) are zeroed and discarded. +# Scalar-gradient hot path. Reverse mode (`Mooncake.Cache`) needs +# `args_to_zero` to mark `x` as the lone active input (`false` on `f`, +# `true` on `x`, `false` on each context value); forward mode +# (`ForwardCache`) derives activity from its seeded argument and rejects +# the kwarg. The `p.adtype isa AutoMooncake` branch is compile-folded +# since `adtype`'s concrete type lives in `Prepared`'s type parameters. +# Empty `context` collapses the splat and reduces `args_to_zero` to +# `(false, true)`. `tangents[2]` is the `x`-gradient; trailing entries +# (one per context value) are inactive and discarded. @inline function AbstractPPL.value_and_gradient!!( p::Prepared{<:_MooncakeAD,<:VectorEvaluator,<:MooncakeCache{:scalar}}, x::AbstractVector{T}, ) where {T<:Real} Evaluators._check_ad_input(p.evaluator, x) e = p.evaluator - val, tangents = _mooncake_value_and_gradient(p.adtype, p.cache.cache, e.f, x, e.context) + val, tangents = if p.adtype isa AutoMooncake + Mooncake.value_and_gradient!!( + p.cache.cache, + e.f, + x, + e.context...; + args_to_zero=(false, true, map(_ -> false, e.context)...), + ) + else + Mooncake.value_and_gradient!!(p.cache.cache, e.f, x, e.context...) + end return (val, tangents[2]) end From d7194aff3edd616e11de2f6f2169fac90a718997 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Tue, 19 May 2026 12:20:57 +0100 Subject: [PATCH 15/15] HISTORY.md: add 0.15.0 entry for evaluator and AD interface Co-Authored-By: Claude Opus 4.7 (1M context) --- HISTORY.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 8d8e0e49..4cc01a81 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,19 @@ +## 0.15.0 + +New evaluator-preparation and AD interface: `prepare` binds a callable to a sample input (vector or `NamedTuple`); `value_and_gradient!!` / `value_and_jacobian!!` return value-and-derivative pairs from the resulting `Prepared` wrapper. The `!!` suffix signals the returned derivative may alias the cache — copy if you need to keep it. + +```julia +using ADTypes, Mooncake # or DifferentiationInterface + ForwardDiff +using AbstractPPL: prepare, value_and_gradient!! +prepared = prepare(AutoMooncake(), x -> -0.5 * sum(abs2, x), zeros(3)) +val, grad = value_and_gradient!!(prepared, [1.0, 2.0, 3.0]) +# val == -7.0; grad == [-1.0, -2.0, -3.0] +``` + +Two new AD-backend extensions ship with it: `AbstractPPLDifferentiationInterfaceExt` (any DI backend) and `AbstractPPLMooncakeExt` (`AutoMooncake`, `AutoMooncakeForward`). `AbstractPPLTestExt` gains a conformance harness via `generate_testcases` / `run_testcases` (reserved groups: `:vector`, `:namedtuple`, `:edge`, `:cache_reuse`). + +See [`docs/src/evaluators.md`](docs/src/evaluators.md) for the full interface, the `check_dims` and `context::Tuple` options, the `NamedTuple` input path, and extension-author guidance. + ## 0.14.2 Fix string serialisation of VarNames such that the order of keyword arguments is preserved (this was previously guaranteed, but JSON.jl v1.5.0 introduced a change that caused the keyword arguments to always be sorted.)