Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ jobs:
matrix:
label:
- ext/differentiationinterface
- ext/mooncake
version:
- '1'
- 'min'
Expand Down
2 changes: 2 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ AbstractPPL.jl provides shared interfaces and utilities for probabilistic progra
- Full package tests: `julia --project=. -e 'using Pkg; Pkg.test()'`
- Docs: `julia --project=docs docs/make.jl`

Always refresh each environment (`Pkg.update()` / `up`) before tests or doc builds — a stale manifest can cause subtle resolution and loading issues.

Run the smallest relevant test first, then broaden when changing public interfaces, extensions, or downstream-facing behaviour. Do not weaken tests just to make CI pass.

## Documentation
Expand Down
16 changes: 16 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
## 0.15.0

New evaluator-preparation and AD interface: `prepare` binds a callable to a sample input (vector or `NamedTuple`); `value_and_gradient!!` / `value_and_jacobian!!` return value-and-derivative pairs from the resulting `Prepared` wrapper. The `!!` suffix signals the returned derivative may alias the cache — copy if you need to keep it.

```julia
using ADTypes, Mooncake # or DifferentiationInterface + ForwardDiff
using AbstractPPL: prepare, value_and_gradient!!
prepared = prepare(AutoMooncake(), x -> -0.5 * sum(abs2, x), zeros(3))
val, grad = value_and_gradient!!(prepared, [1.0, 2.0, 3.0])
# val == -7.0; grad == [-1.0, -2.0, -3.0]
```

Two new AD-backend extensions ship with it: `AbstractPPLDifferentiationInterfaceExt` (any DI backend) and `AbstractPPLMooncakeExt` (`AutoMooncake`, `AutoMooncakeForward`). `AbstractPPLTestExt` gains a conformance harness via `generate_testcases` / `run_testcases` (reserved groups: `:vector`, `:namedtuple`, `:edge`, `:cache_reuse`).

See [`docs/src/evaluators.md`](docs/src/evaluators.md) for the full interface, the `check_dims` and `context::Tuple` options, the `NamedTuple` input path, and extension-author guidance.

## 0.14.2

Fix string serialisation of VarNames such that the order of keyword arguments is preserved (this was previously guaranteed, but JSON.jl v1.5.0 introduced a change that caused the keyword arguments to always be sorted.)
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
keywords = ["probabilistic programming"]
license = "MIT"
desc = "Common interfaces for probabilistic programming"
version = "0.14.3"
version = "0.15"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -21,11 +21,13 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[weakdeps]
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
AbstractPPLDifferentiationInterfaceExt = ["DifferentiationInterface"]
AbstractPPLDistributionsExt = ["Distributions", "LinearAlgebra"]
AbstractPPLMooncakeExt = ["Mooncake"]
AbstractPPLTestExt = ["Test"]

[compat]
Expand All @@ -39,6 +41,7 @@ Distributions = "0.25"
JSON = "0.19 - 0.21, 1"
LinearAlgebra = "<0.0.1, 1"
MacroTools = "0.5"
Mooncake = "0.5.27"
OrderedCollections = "1.8.1"
Random = "1.6"
StatsBase = "0.32, 0.33, 0.34"
Expand Down
17 changes: 17 additions & 0 deletions docs/src/evaluators.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,23 @@ library invokes the inner callable many times with same-length dual arrays
derived from a single user-supplied `x`; re-validating on each invocation
would be redundant work in the hot path.

## Constant context arguments

When the underlying callable naturally takes the form `f(x, context...)` —
where everything after `x` is constant state — pass `context` as a tuple to
the vector form of `prepare`. AD differentiates only w.r.t. `x`; every
value in `context` is treated as inactive:

```julia
affine(x, scale, offset) = scale * sum(x) + offset
prepared = prepare(adtype, affine, zeros(3); context=(2.0, 1.0))
val, grad = value_and_gradient!!(prepared, [1.0, 2.0, 3.0])
# val == 2.0 * 6.0 + 1.0; grad == [2.0, 2.0, 2.0]
```

`prepared(x)` evaluates `f(x, context...)`, and `context=()` (the default)
preserves the unary `f(x)` shape.

## Without an AD backend

The two-argument form `prepare(problem, x)` is available without any AD
Expand Down
135 changes: 83 additions & 52 deletions ext/AbstractPPLDifferentiationInterfaceExt.jl
Original file line number Diff line number Diff line change
@@ -1,100 +1,131 @@
module AbstractPPLDifferentiationInterfaceExt

using AbstractPPL: AbstractPPL
using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator
using AbstractPPL.Evaluators: Evaluators, Prepared, VectorEvaluator, _ad_output_arity
using ADTypes: AbstractADType, AutoReverseDiff
using DifferentiationInterface: DifferentiationInterface as DI

# Differentiate only `x`; the evaluator is passed as a `DI.Constant` context so
# that in DynamicPPL the model and other evaluator state stay constant.
@inline _call_evaluator(x, evaluator) = evaluator(x)
# AD target used by both `DICache` modes. `Vararg{Any,N}` with a free `N`
# forces specialization on the trailing arity (a bare `Vararg{Any}` would
# skip it). DI invokes this as `_call_evaluator(x, f, c1, …, cN)` on the
# constants path, and as `_call_evaluator(x, evaluator)` (via `Fix2`) on
# the closure path — empty `ctx` then makes the splat a no-op.
@inline _call_evaluator(x, f::F, ctx::Vararg{Any,N}) where {F,N} = f(x, ctx...)

struct DICache{F,GP,JP}
# `Mode` tags the cache shape:
# * `:closure` — compiled-tape ReverseDiff: target is a `Fix2` closure,
# the AD call passes **0** `DI.Constant`s.
# * `N::Int` — constants path: `N == length(evaluator.context)`, the
# AD call passes **N + 1** `DI.Constant`s (`f` plus the
# `N` context values).
# Encoding `Mode` in the type resolves the dispatch in `_di_value_and_*`
# at compile time without a runtime branch.
struct DICache{Mode,F,GP,JP}
target::F
gradient_prep::GP
jacobian_prep::JP
use_context::Bool
function DICache{Mode}(target::F, gp::GP, jp::JP) where {Mode,F,GP,JP}
return new{Mode,F,GP,JP}(target, gp, jp)
end
end

# Compiled ReverseDiff only reuses a compiled tape on the one-argument path;
# `DI.Constant` deactivates tape recording, so close the evaluator into the
# target and call DI without contexts.
# target and call DI without constants. Context (if any) is captured inside
# the evaluator closure rather than lowered out — the lowered path would also
# require a closure here, so the wrapper cost is unavoidable for compiled tapes.
function _prepare_di(prep::F, adtype::AutoReverseDiff{true}, x, evaluator) where {F}
target = Base.Fix2(_call_evaluator, evaluator)
return target, prep(target, adtype, x), false
return target, prep(target, adtype, x), Val(:closure)
end

function _prepare_di(prep::F, adtype::AbstractADType, x, evaluator) where {F}
return _call_evaluator, prep(_call_evaluator, adtype, x, DI.Constant(evaluator)), true
constants = (DI.Constant(evaluator.f), map(DI.Constant, evaluator.context)...)
return (
_call_evaluator,
prep(_call_evaluator, adtype, x, constants...),
Val(length(evaluator.context)),
)
end

@inline _wrap_cache(target, gp, jp, ::Val{Mode}) where {Mode} =
DICache{Mode}(target, gp, jp)

function AbstractPPL.prepare(
adtype::AbstractADType, problem, x::AbstractVector{<:Real}; check_dims::Bool=true
adtype::AbstractADType,
problem,
x::AbstractVector{<:Real};
check_dims::Bool=true,
context::Tuple=(),
)
evaluator = AbstractPPL.prepare(problem, x; check_dims)::VectorEvaluator
y = evaluator(x)
y isa Union{Number,AbstractVector} || throw(
ArgumentError(
"A prepared AD evaluator must return a scalar or AbstractVector; got $(typeof(y)).",
),
)
evaluator = AbstractPPL.prepare(problem, x; check_dims, context)::VectorEvaluator
arity = _ad_output_arity(evaluator(x))
if length(x) == 0
# DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`); the
# `Val(0)` sentinel keeps the `gradient_prep === nothing` arity check meaningful.
gp, jp = y isa Number ? (Val(0), nothing) : (nothing, Val(0))
return Prepared(adtype, evaluator, DICache(_call_evaluator, gp, jp, true))
# DI prep crashes on length-0 input (e.g. ForwardDiff `BoundsError`).
# `Val(0)` is an arity sentinel for the `gradient_prep === nothing`
# check below; the AD entry short-circuits before any DI call.
gp, jp = arity === :scalar ? (Val(0), nothing) : (nothing, Val(0))
cache = _wrap_cache(_call_evaluator, gp, jp, Val(length(context)))
return Prepared(adtype, evaluator, cache)
end
if y isa Number
target, gradient_prep, use_context = _prepare_di(
DI.prepare_gradient, adtype, x, evaluator
)
if arity === :scalar
target, gradient_prep, mode = _prepare_di(DI.prepare_gradient, adtype, x, evaluator)
return Prepared(
adtype, evaluator, DICache(target, gradient_prep, nothing, use_context)
adtype, evaluator, _wrap_cache(target, gradient_prep, nothing, mode)
)
end
target, jacobian_prep, use_context = _prepare_di(
DI.prepare_jacobian, adtype, x, evaluator
)
return Prepared(adtype, evaluator, DICache(target, nothing, jacobian_prep, use_context))
target, jacobian_prep, mode = _prepare_di(DI.prepare_jacobian, adtype, x, evaluator)
return Prepared(adtype, evaluator, _wrap_cache(target, nothing, jacobian_prep, mode))
end

# Hot-path dispatch is by `Mode` (closure vs constants), resolved at compile
# time. The unconstrained method matches every non-`:closure` `Mode` (i.e.
# any `Int N`); `:closure` is strictly more specific and wins for compiled
# tapes. On the constants path we always pass `DI.Constant(eval.f)` plus the
# `N` context constants — `N == 0` collapses the `map` splat to nothing.
@inline _di_value_and_gradient(c::DICache{:closure}, ad, x, _) =
DI.value_and_gradient(c.target, c.gradient_prep, ad, x)
@inline _di_value_and_gradient(c::DICache, ad, x, eval) = DI.value_and_gradient(
c.target,
c.gradient_prep,
ad,
x,
DI.Constant(eval.f),
map(DI.Constant, eval.context)...,
)

@inline _di_value_and_jacobian(c::DICache{:closure}, ad, x, _) =
DI.value_and_jacobian(c.target, c.jacobian_prep, ad, x)
@inline _di_value_and_jacobian(c::DICache, ad, x, eval) = DI.value_and_jacobian(
c.target,
c.jacobian_prep,
ad,
x,
DI.Constant(eval.f),
map(DI.Constant, eval.context)...,
)

@inline function AbstractPPL.value_and_gradient!!(
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
) where {T<:Real}
p.cache.gradient_prep === nothing &&
throw(ArgumentError("`value_and_gradient!!` requires a scalar-valued function."))
T <: Integer && Evaluators._reject_integer_input(x)
Evaluators._check_vector_length(p.evaluator.dim, x)
p.cache.gradient_prep === nothing && Evaluators._throw_gradient_needs_scalar()
Evaluators._check_ad_input(p.evaluator, x)
# Bypass DI on length-0 input — DI prep paths fail (e.g. ForwardDiff
# `BoundsError`); typed `T[]` matches the caller's element type.
length(x) == 0 && return (p.evaluator(x), T[])
return if p.cache.use_context
DI.value_and_gradient(
p.cache.target, p.cache.gradient_prep, p.adtype, x, DI.Constant(p.evaluator)
)
else
DI.value_and_gradient(p.cache.target, p.cache.gradient_prep, p.adtype, x)
end
return _di_value_and_gradient(p.cache, p.adtype, x, p.evaluator)
end

@inline function AbstractPPL.value_and_jacobian!!(
p::Prepared{<:AbstractADType,<:VectorEvaluator,<:DICache}, x::AbstractVector{T}
) where {T<:Real}
p.cache.jacobian_prep === nothing &&
throw(ArgumentError("`value_and_jacobian!!` requires a vector-valued function."))
T <: Integer && Evaluators._reject_integer_input(x)
Evaluators._check_vector_length(p.evaluator.dim, x)
p.cache.jacobian_prep === nothing && Evaluators._throw_jacobian_needs_vector()
Evaluators._check_ad_input(p.evaluator, x)
if length(x) == 0
val = p.evaluator(x)
return (val, similar(x, length(val), 0))
end
return if p.cache.use_context
DI.value_and_jacobian(
p.cache.target, p.cache.jacobian_prep, p.adtype, x, DI.Constant(p.evaluator)
)
else
DI.value_and_jacobian(p.cache.target, p.cache.jacobian_prep, p.adtype, x)
end
return _di_value_and_jacobian(p.cache, p.adtype, x, p.evaluator)
end

end # module
Loading
Loading