Add Mooncake extension; extend AD conformance suite with NamedTuple and cache-reuse coverage#160
Conversation
|
AbstractPPL.jl documentation for PR #160 is available at: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #160 +/- ##
==========================================
+ Coverage 87.11% 88.14% +1.03%
==========================================
Files 14 15 +1
Lines 784 886 +102
==========================================
+ Hits 683 781 +98
- Misses 101 105 +4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Mooncake AD-backend extension built on the evaluator interface, with the
shared conformance suite extended to cover NamedTuple inputs and empty-input
arity errors. Squashed from prior incremental commits:
- AbstractPPLMooncakeExt: cache reuse, scalar/vector dispatch, NamedTuple
inputs via VectorEvaluator/NamedTupleEvaluator wrappers; integration test
in test/ext/mooncake.
- Evaluators._ad_output_arity: lift the duplicated `Union{Number,
AbstractVector}` output check from both extensions into one helper that
returns `:scalar` / `:vector` for downstream dispatch.
- Empty-input arity tagging (`Val(:scalar)` / `Val(:vector)`) so the
empty-input fast path raises the same "requires a scalar/vector-valued
function" error as the DI path instead of silently succeeding.
- AbstractPPLTestExt: add `Val(:namedtuple)` group (one ValueCase + one
ErrorCase); tighten regex assertions on the existing arity-mismatch cases.
- check_dims threaded through the inner `prepare` call so AD hot paths can
skip per-call shape checks.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`_check_ad_input(evaluator, x)` in `Evaluators` replaces the duplicated
`T <: Integer` rejection plus length check that appeared at six AD entry
points (two in the DI extension, four in Mooncake). Compile-time `T`
elision is preserved.
Move `generate_testcases(::Val{:namedtuple})` and
`run_testcases(::Val{:namedtuple})` to sit alongside the `:vector` and
`:edge` definitions so the file reads generate-then-run for all three
groups.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
7232d5b to
f4847f1
Compare
f4847f1 to
771c2e6
Compare
- `:cache_reuse` conformance group in `AbstractPPLTestExt` drives
`value_and_{gradient,jacobian}!!` three times per case against a single
`prepared` evaluator to catch backend cache corruption between calls.
- DI ext sub-environment now also loads `ReverseDiff` and exercises
`AutoReverseDiff(compile=true)` against the conformance suite, covering
the `_prepare_di(::AutoReverseDiff{true}, …)` compiled-tape path.
- Lift the duplicated `value_and_{gradient,jacobian}!!` arity-mismatch
`ArgumentError` strings into shared `Evaluators._throw_*` helpers used by
both the DI and Mooncake extensions.
- `generate_testcases` docstring lists `:namedtuple` and `:cache_reuse`
alongside `:vector` / `:edge` as reserved group keys.
- Trim verbose `check_dims` clarifications in docstrings and
`docs/src/evaluators.md` to one sentence each.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
771c2e6 to
1f8b157
Compare
|
@shravanngoswamii, this should be the final PR of the sequel, which adds a native Mooncake backend. EDIT: There are some minor changes to DI and Mooncake extensions motivated by the DynamicPPL needs. |
Stale manifests cause subtle resolution and loading issues; document the expected `Pkg.update()` step alongside the existing test commands. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two regressions visible on tiny-model gradients went through the new
AbstractPPL evaluator interface:
- `_check_ad_input` always ran on `value_and_{gradient,jacobian}!!`
entry, even when the evaluator was prepared with `check_dims=false`.
Now dispatch-gated on `VectorEvaluator{CheckInput}`: the `{false}`
overload is a no-op, so the `DimensionMismatch` and integer-rejection
paths are elided from the LLVM IR of the AD hot path.
- `DICache` stored `use_context::Bool` as a runtime field, leaving a
branch in the compiled call selecting the context vs no-context DI
form. `UseContext` is now a type parameter and the branch is resolved
by dispatch via `_di_value_and_{gradient,jacobian}` helpers.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`Mooncake.value_and_gradient!!(cache, evaluator, x)` reset the evaluator's tangent buffer on every call, even though AbstractPPL discards `∂f` and only surfaces `∂x`. For evaluators that wrap a model with large fields (e.g. a 128-tuple of `Float64`), the zeroing was the dominant per-call overhead at tiny model sizes. Pass `args_to_zero=(false, true)` to the reverse-mode `Mooncake.Cache` path to skip the `∂f` reset while still zeroing the `∂x` buffer. The forward-mode `Mooncake.ForwardCache` doesn't accept the kwarg, so the branch is `isa`-dispatched on the concrete cache type and constant-folds at compile time. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Concise inline notes on:
- `VectorEvaluator{true|false}` callable bodies (shared `T <: Integer`
compile-time elision, and the `{false}` skip of `_check_vector_length`).
- Mooncake ext empty-input and arity-mismatch methods (compile-time dispatch
via `MooncakeCache{…,Nothing}` and `MooncakeCache{:scalar|:vector}` to
avoid runtime branches).
- `args_to_zero=(false, true)` at both Mooncake gradient call sites
(skipping the evaluator's tangent re-zeroing per call — `∂f` is discarded).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Mooncake was deriving a nested `Tangent{NamedTuple{f::Tangent{...}}}` for
every `VectorEvaluator`/`NamedTupleEvaluator` it received, then walking
that structure on every backward pass. The evaluators are AbstractPPL's
own wrapper types and never appear as a downstream gradient target — the
public API only returns `(value, ∂x)`.
Register `Mooncake.tangent_type(::Type{<:VectorEvaluator}) = NoTangent`
(and the same for `NamedTupleEvaluator`) so the cache carries no tangent
for the user's problem fields. The `args_to_zero=(false, true)` mitigation
and the `_ConstantEvaluator` wrapper from the prior pass are both no
longer needed; the call sites pass `p.evaluator` directly.
Verified on the MWE setup: `Mooncake.Tangent{` count in the prepared cache
type is 0; value and gradient match a direct `logdensity_at(x, state, …)`
call.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Callers who know an equivalent raw `f(x, contexts...) ≡ problem(x)` can pass it via `prepare(AutoMooncake(), problem, x; raw_gradient_target=(f, contexts))`. Mooncake then compiles the tape on the raw call shape with `args_to_zero= (false, true, false, …)` instead of the generic `evaluator(x)` wrapper — sidestepping the fixed-overhead seen on tiny scalar-vector problems. `prepared(x)` still calls `problem(x)`; only the AD entry uses the lowered cache (a new `MooncakeLoweredCache` carries `cache`, `f`, `contexts`, and `args_to_zero`). Scoped strictly to reverse-mode `AutoMooncake` and scalar arity with non-empty input — anything else errors at prepare time. Jacobian on a lowered cache surfaces the existing arity-mismatch error. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…rget on all AD prepare methods
- Collapsed `MooncakeLoweredCache` into `MooncakeCache{A,C,F,CT,AZ}`. The
three new type params default to `Nothing` via the existing constructor;
the lowered-path constructor populates them. Dispatch on `CT<:Tuple`
(excluding the `Nothing` default) picks the lowered AD entry. No new
type, no runtime branching.
- DI extension's `prepare(::AbstractADType, ...)` now accepts
`raw_gradient_target=nothing` and silently ignores it. Same for the
Mooncake NamedTuple `prepare`. Generic user code that passes the kwarg
to non-Mooncake backends (or to the Mooncake NamedTuple path) no longer
hits a MethodError.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- `args_to_zero` was a derived value (`(false, true, false×length(contexts))`) stored as a struct field plus a 5th type parameter. Moved the construction to the AD entry; the tuple constant-folds for any concrete `contexts` arity. Saves one type parameter and one field. - Dropped two trailing comments on `raw_gradient_target=nothing` kwargs (the comment didn't explain WHY — the kwarg name and surrounding context already convey "this is a backend-specific optimization that defaults off"). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
shravanngoswamii
left a comment
There was a problem hiding this comment.
Other than these comments, PR looks good to me. Happy to approve!
…egation and raw_gradient_target as unsafe Addresses PR #160 review comments: - Throw a clear `ArgumentError` for non-`DenseVector` inputs instead of letting Mooncake return a shape-incorrect tangent (reverse) or crash inside Mooncake (forward/Jacobian). - Document that NamedTuple input-shape validation is intentionally delegated to Mooncake's `PreparedCacheSpec` to avoid duplicating checks on every AD call. - Add a docstring on the vector `prepare` method describing `raw_gradient_target` as an unsafe escape hatch that bypasses evaluator indirection and shape checks. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Lift the lowered-AD escape hatch into a first-class API: every vector
`prepare` now accepts `context::Tuple=()`, the prepared evaluator computes
`problem(x, context...)`, and AD differentiates only `x`. `VectorEvaluator`
carries the context as a third type parameter so callers can recover it
from the evaluator without going through the kwarg again.
Mooncake ext
- Compile every scalar gradient cache on the raw `evaluator.f` /
`evaluator.context` (not the raw `problem`/`context` kwargs), so a
downstream override of structural `prepare` that returns a different
`f`/`context` doesn't desync from the hot path.
- Forward-mode `AutoMooncakeForward` now also accepts non-empty `context`.
`_mooncake_value_and_gradient` dispatches reverse-mode to the
`args_to_zero` kwarg form and forward-mode to the splat-only form
(`ForwardCache` rejects `args_to_zero`).
- Vector-jacobian path now also runs on `evaluator.f` for the same reason.
- Empty input with non-empty `context` is supported (was rejected).
The `MooncakeCache{arity,Nothing}` empty-input shortcut already
evaluates `evaluator(x)` without invoking Mooncake.
DI ext
- `DICache{Mode}`: `Mode == :closure` for compiled-tape ReverseDiff,
`Mode::Int == length(evaluator.context)` for the constants path.
The `Int` doubles as documentation of how many `DI.Constant`s the AD
call passes (`N + 1`, including `f`).
- Single shared AD target `_call_evaluator(x, f, ctx::Vararg{Any,N}) where {N}`.
`Vararg{Any,N}` forces specialization on the trailing arity.
Docs/tests
- New `Constant context arguments` section in `docs/src/evaluators.md`.
- New regression tests covering: context threading via structural
`prepare`, forward-mode Mooncake context, empty-input + non-empty
context shortcut, and `DICache` mode-tag pinning.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@shravanngoswamii I made some further improvements; this should be ready for merging if you are happy too. |
…omments
Three single-call helpers folded into their unique call sites:
- 4-arg `_mooncake_gradient_cache(adtype, f, x, context::Tuple)` and the
`_mooncake_jacobian_cache` pair → one `if adtype isa AutoMooncake ... end`
branch in vector `prepare`. The `isa` is compile-folded since `adtype`'s
concrete type lives in the method's specialization.
- `_mooncake_value_and_gradient(::Auto*, ...)` → an inlined branch in the
scalar-gradient hot path's `value_and_gradient!!` body, using the same
compile-folded `isa`. Kept in a single `<:_MooncakeAD` method so the
empty-input shortcut's cache specificity (`MooncakeCache{:scalar,Nothing}`)
doesn't clash with a per-AD-type method.
The 3-arg `_mooncake_gradient_cache(::Auto*, f, x)` pair (NamedTuple path)
stays factored — it's reused by the NamedTuple `prepare` and has a distinct
call shape (evaluator + values, not raw `f` + context splat).
Comment cleanup:
- `tangent_type` defensive guard now describes the current state rather
than the refactor that produced it ("after the raw-target merge").
- Cache-prep block trimmed from two paragraphs to one — kept the
evaluator-as-source-of-truth rationale and the forward-mode unification
note, dropped the redundant splat-no-op elaboration.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Let me go through it again. Just need to test few more things to be sure! It was fun to try both these PRs aggressively😅! |
|
Thanks for being patient, @yebai. I think this PR is ready to merge! I suppose now we can experiment with Mooncake on our own accord without much hassle. |
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
AbstractPPLMooncakeExt: gradient/jacobian via Mooncake with cache reuse,scalar/vector dispatch, and
NamedTupleinputs (wrappingVectorEvaluatorandNamedTupleEvaluator). BumpsAbstractPPLto0.15.AbstractPPLTestExtconformance suite grows two reusable groups:Val(:namedtuple)— value + gradient overNamedTupleinputs, plus anErrorCasefor structure mismatch.Val(:cache_reuse)— three sequentialvalue_and_{gradient,jacobian}!!calls against a single
Preparedto catch backend cache corruption.Evaluators._ad_output_arity(:scalar/:vector) andEvaluators._check_ad_inputfactor out the duplicated output-arity checkand input validation that appeared across both AD extensions; compile-time
T <: Integerelision is preserved.Val(:scalar)/Val(:vector)) so thegradient_prep === nothingarity check still fires for length-0 inputs.