Skip to content

Add direct ForwardDiff extension#166

Merged
yebai merged 21 commits into
mainfrom
copilot/add-forwarddiff-support
May 25, 2026
Merged

Add direct ForwardDiff extension#166
yebai merged 21 commits into
mainfrom
copilot/add-forwarddiff-support

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented May 25, 2026

Adds AbstractPPLForwardDiffExt using ForwardDiff's public API directly, mirroring the Mooncake extension pattern. This enables ForwardDiff-based gradient computation without requiring DifferentiationInterface as an intermediary.

Changes

  • ext/AbstractPPLForwardDiffExt.jl — Extension dispatching on AutoForwardDiff, using ForwardDiff.gradient!/jacobian!/hessian! with pre-allocated DiffResults buffers and configs. Supports order=1 (gradient/jacobian by output arity), order=2 (hessian), context arguments, chunk size control, and empty-input edge cases.
  • Project.tomlForwardDiff + DiffResults as weakdeps with extension registration
  • test/ext/forwarddiff/ — Dedicated test environment exercising all standard test case groups

Usage

using AbstractPPL, ADTypes, ForwardDiff, DiffResults

f(x) = -0.5 * sum(abs2, x)
x = [1.0, 2.0, 3.0]

prep = prepare(AutoForwardDiff(), f, x)
val, grad = value_and_gradient!!(prep, x)

# With explicit chunk size
prep2 = prepare(AutoForwardDiff(; chunksize=2), f, x)

# Hessian support
prep3 = prepare(AutoForwardDiff(), f, x; order=2)
val, grad, hess = value_gradient_and_hessian!!(prep3, x)

When both this extension and AbstractPPLDifferentiationInterfaceExt are loaded, the direct extension wins for AutoForwardDiff via dispatch specificity (AutoForwardDiff vs AbstractADType).

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • https://api.github.com/repos/FluxML/MacroTools.jl/tarball/1e0228a030642014fe5cfe68c2c0a818f9e3f522
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaArrays/StaticArraysCore.jl/tarball/6ab403037779dae8c514bad259f32a447262455a
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.add("ForwardDiff") (http block)
  • https://api.github.com/repos/JuliaBinaryWrappers/OpenSpecFun_jll.jl/tarball/1346c9208249809840c91b26703912dff463d335
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaCollections/AbstractTrees.jl/tarball/2d9c9a55f9c93e8887ad391fbae72f8ef55e1177
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaCollections/DataStructures.jl/tarball/e86f4a2805f7f19bec5129bc9150c38208e5dc23
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaCollections/LeftChildRightSiblingTrees.jl/tarball/95ba48564903b43b2462318aa243ee79d81135ff
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaCollections/OrderedCollections.jl/tarball/05868e21324cede2207c6f0f466b4bfef6d5e7ee
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaCollections/SortingAlgorithms.jl/tarball/64d974c2e6fdf07f8155b5b2ca2ffa9069b608d9
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaData/DataAPI.jl/tarball/abe83f3a2f1b857aac70ef8b269080af17764bbe
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaData/Missings.jl/tarball/ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaData/Parsers.jl/tarball/5d5e0a78e971354b1c7bff0655d11fdc1b0e12c8
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaDiff/DiffResults.jl/tarball/782dd5f4561f5d267313f23853baaaa4c52ea621
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.add("ForwardDiff") (http block)
  • https://api.github.com/repos/JuliaDiff/DiffRules.jl/tarball/23163d55f885173722d1e4cf0f6110cdbaf7e272
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.add("ForwardDiff") (http block)
  • https://api.github.com/repos/JuliaDiff/DifferentiationInterface.jl/tarball/2147a95a217cc8a78ec96ee03581adf129468e49
    • Triggering command: /usr/bin/julia julia -e using Pkg Pkg.activate(; temp=true) Pkg.develop(; path=".") Pkg.add(["ForwardDiff", "DiffResults", "DifferentiationInterface", "ADTypes"]) using AbstractPPL, ADTypes, ForwardDiff, DiffResults, DifferentiationInterface println("ForwardDiff ext: ", Base.ge --history-file=no --warn-overwrite=yes --color=yes - (http block)
  • https://api.github.com/repos/JuliaDiff/ForwardDiff.jl/tarball/cddeab6487248a39dae1a960fff0ac17b2a28888
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.add("ForwardDiff") (http block)
  • https://api.github.com/repos/JuliaDocs/DocStringExtensions.jl/tarball/7442a5dfe1ebb773c29cc2962a8980f47221d76c
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaFolds/InitialValues.jl/tarball/4da0f88e9a39111c2fa3add390ab15f3a44f3ca3
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaFolds2/BangBang.jl/tarball/cceb62468025be98d42a5dc581b163c20896b040
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaFunctional/CompositionsBase.jl/tarball/802bb88cd69dfd1509f6670416bd4434015693ad
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaIO/JSON.jl/tarball/f76f7560267b840e492180f9899b472f30b88450
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate(); using ADTypes; ad = AutoForwardDiff(); println(fieldnames(typeof(ad))); println(typeof(ad)) (http block)
  • https://api.github.com/repos/JuliaLang/PrecompileTools.jl/tarball/edbeefc7a4889f528644251bdb5fc9ab5348bc2c
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaLogging/LoggingExtras.jl/tarball/f00544d95982ea270145636c181ceda21c4e2575
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaLogging/ProgressLogging.jl/tarball/f0803bc1171e455a04124affa9c21bba5ac4db32
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaLogging/TerminalLoggers.jl/tarball/f133fab380933d042f6796eda4e130272ba520ca
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaMath/DensityInterface.jl/tarball/80c3e8639e3353e5d2912fb3a1916b8455e2494b
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaMath/InverseFunctions.jl/tarball/a779299d77cd080bf77b97535acecd73e1c5e5cb
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaMath/IrrationalConstants.jl/tarball/b2d91fe939cae05960e760110b328288867b5758
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate(); using ADTypes; ad = AutoForwardDiff(); println(fieldnames(typeof(ad))); println(typeof(ad)) (http block)
  • https://api.github.com/repos/JuliaMath/NaNMath.jl/tarball/9b8215b1ee9e78a293f99797cd31375471b2bcae
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.add("ForwardDiff") (http block)
  • https://api.github.com/repos/JuliaMath/SpecialFunctions.jl/tarball/2700b235561b0335d5bef7097a111dc513b8655e
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.add("ForwardDiff") (http block)
  • https://api.github.com/repos/JuliaObjects/Accessors.jl/tarball/2eeb2c9bef11013efc6f8f97f32ee59b146b09fb
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate(); using ADTypes; ad = AutoForwardDiff(); println(fieldnames(typeof(ad))); println(typeof(ad)) (http block)
  • https://api.github.com/repos/JuliaObjects/ConstructionBase.jl/tarball/b4b092499347b18a015186eae3042f72267106cb
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaPackaging/JLLWrappers.jl/tarball/7204148362dafe5fe6a273f855b8ccbe4df8173e
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.add("ForwardDiff") (http block)
  • https://api.github.com/repos/JuliaPackaging/Preferences.jl/tarball/8b770b60760d4451834fe79dd483e318eee709c4
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate(); using ADTypes; ad = AutoForwardDiff(); println(fieldnames(typeof(ad))); println(typeof(ad)) (http block)
  • https://api.github.com/repos/JuliaServices/StructUtils.jl/tarball/82bee338d650aa515f31866c460cb7e3bcef90b8
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaStats/LogExpFunctions.jl/tarball/13ca9e2586b89836fd20cccf56e57e2b9ae7f38f
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaStats/Statistics.jl/tarball/ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaStats/StatsAPI.jl/tarball/178ed29fd5b2a2cfc3bd31c13375ae925623ff36
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/JuliaStats/StatsBase.jl/tarball/aceda6f4e598d331548e04cc6b2124a6148138e3
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/LilithHafner/AliasTables.jl/tarball/9876e1e164b144ca45e9e3198d0b689cadfed9ff
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/LilithHafner/PtrArrays.jl/tarball/4fbbafbc6251b883f4d2705356f3641f3652a7fe
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate(); using ADTypes; ad = AutoForwardDiff(); println(fieldnames(typeof(ad))); println(typeof(ad)) (http block)
  • https://api.github.com/repos/SciML/ADTypes.jl/tarball/bbc22a9a08a0ef6460041086d8a7b27940ed4ffd
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/TuringLang/AbstractMCMC.jl/tarball/8ac6182431567907e0d5170bcac6dd48fa541f78
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate(); using ADTypes; ad = AutoForwardDiff(); println(fieldnames(typeof(ad))); println(typeof(ad)) (http block)
  • https://api.github.com/repos/jw3126/ArgCheck.jl/tarball/f9e9a66c9b7be1ad7372bbd9b062d9230c30c5ce
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate(); using ADTypes; ad = AutoForwardDiff(); println(fieldnames(typeof(ad))); println(typeof(ad)) (http block)
  • https://api.github.com/repos/rdeits/CommonSubexpressions.jl/tarball/cda2cfaebb4be89c9084adaca7dd7333369715c5
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.add("ForwardDiff") (http block)
  • https://api.github.com/repos/timholy/ProgressMeter.jl/tarball/fbb92c6c56b34e1a2c4c36058f68f332bec840e7
    • Triggering command: REDACTED, pid is -1 (http block)
  • https://api.github.com/repos/tkf/ConsoleProgressMonitor.jl/tarball/3ab7b2136722890b9af903859afcf457fa3059e8
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate(); using ADTypes; ad = AutoForwardDiff(); println(fieldnames(typeof(ad))); println(typeof(ad)) (http block)
  • https://api.github.com/repos/tpapp/LogDensityProblems.jl/tarball/d9625f27ded4ad726ceca7819394a4cc77ed25b3
    • Triggering command: REDACTED, pid is -1 (http block)
  • pkg.julialang.org
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.instantiate(); using ADTypes; ad = AutoForwardDiff(); println(fieldnames(typeof(ad))); println(typeof(ad)) (dns block)
    • Triggering command: /usr/bin/julia julia --project=. -e using Pkg; Pkg.add("ForwardDiff") (dns block)
    • Triggering command: `/usr/bin/julia julia --project=. -e
      using Pkg; Pkg.add("DiffResults")
      using ForwardDiff, DiffResults

x = [1.0, 2.0, 3.0]
f(x) = sum(abs2, x)

For gradient: value + gradient

result = DiffResults.MutableDiffResult(zero(eltype(x)), (similar(x),))
cfg = ForwardDiff.GradientConfig(f, x)
Forw --warn-overwrite=yes --color=yes -` (dns block)

If you need me to access, download, or install something from one of these locations, you can either:

Copilot AI linked an issue May 25, 2026 that may be closed by this pull request
Adds `AbstractPPLForwardDiffExt` that directly uses ForwardDiff's public
API (gradient!, jacobian!, hessian! with DiffResults and pre-allocated
configs), mirroring the Mooncake extension pattern.

- ext/AbstractPPLForwardDiffExt.jl: full extension with prepare,
  value_and_gradient!!, value_and_jacobian!!, and
  value_gradient_and_hessian!! implementations
- Project.toml: ForwardDiff + DiffResults as weakdeps with extension
  registration and compat entries
- test/ext/forwarddiff/: dedicated test environment running all standard
  test cases plus context and empty-input tests

Agent-Logs-Url: https://github.com/TuringLang/AbstractPPL.jl/sessions/2f9552bb-c72d-4891-a973-8ecc68959e06

Co-authored-by: yebai <3279477+yebai@users.noreply.github.com>
Copilot AI changed the title [WIP] Add ForwardDiff support for external use Add direct ForwardDiff extension May 25, 2026
Copilot finished work on behalf of yebai May 25, 2026 13:29
Copilot AI requested a review from yebai May 25, 2026 13:29
@github-actions
Copy link
Copy Markdown
Contributor

AbstractPPL.jl documentation for PR #166 is available at:
https://TuringLang.github.io/AbstractPPL.jl/previews/PR166/

@codecov
Copy link
Copy Markdown

codecov Bot commented May 25, 2026

Codecov Report

❌ Patch coverage is 95.32164% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 89.39%. Comparing base (2d35e51) to head (d84606a).

Files with missing lines Patch % Lines
ext/AbstractPPLTestExt.jl 92.52% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #166      +/-   ##
==========================================
+ Coverage   88.79%   89.39%   +0.59%     
==========================================
  Files          15       16       +1     
  Lines         982     1056      +74     
==========================================
+ Hits          872      944      +72     
- Misses        110      112       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

yebai and others added 11 commits May 25, 2026 14:55
…utoReverseDiff

Run JuliaFormatter on ext/AbstractPPLForwardDiffExt.jl and
test/ext/forwarddiff/main.jl to satisfy the Format CI job.

The "DI cache encodes the call mode as a type parameter" testset
asserted `DIGradientCache{0}` and `DIGradientCache{2}` cache types for
`AutoForwardDiff`, but the new direct `AbstractPPLForwardDiffExt` path
now takes precedence over DI when both extensions are loaded.
`AutoReverseDiff()` (non-compiled) exercises the same DI constants
path and keeps the structural assertion meaningful.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Drop section banners and WHAT-comments in the extension; keep WHYs
  (chunk-size dispatch, fresh `Fix2` per call is Tag-type-stable,
  separate gradient cache on the order=2 prep).
- Tighten the `prepare` docstring's second paragraph.
- Remove the "empty input" testset: `run_testcases(Val(:vector))` and
  `run_testcases(Val(:hessian))` already cover zero-length input for
  every arity / order combination via `AbstractPPLTestExt`.
- Remove the trailing arity-mismatch `@test_throws` from the
  "context-lowered gradient" testset: `run_testcases(Val(:edge))`
  already covers "jacobian of scalar output".
- Drop now-unused imports (`value_and_jacobian!!`,
  `value_gradient_and_hessian!!`, `order`, `DiffResults`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add `ext/forwarddiff` to `VALID_LABELS` in `test/run_extras.jl` and to
the CI ext matrix so the chunk-size and context tests this branch
introduces actually run on CI (they were silently skipped before —
`AutoForwardDiff` was only exercised via the DI test env).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace `FDGradientCache`, `FDJacobianCache`, and `FDHessianCache` with
one parametric `FDCache{A,R,C,GR,GC}` keyed on an arity/order `Symbol`
`A ∈ (:scalar, :vector, :hessian)`, mirroring the `MooncakeCache{A}`
pattern. Hot paths and arity-mismatch rejections dispatch on the tag at
compile time exactly as before; `result::Nothing` remains the
empty-input sentinel. Verified type-stable on all four `!!` entries.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`ext/AbstractPPLForwardDiffExt.jl`:

* Thread `adtype.tag` into every `*Config` constructor via a small
  `_fd_tag` helper; `nothing` (the ADTypes default) reproduces
  ForwardDiff's per-constructor default of `Tag(target, eltype(x))`,
  so callers can now use `AutoForwardDiff(; tag=…)` for nested
  differentiation through AbstractPPL.

* Hoist the arity-probe `evaluator(x)` to a single `y_probe` local and
  reuse it as the Jacobian-result prototype on the vector branch. The
  base `prepare` contract promises one prep-time call into `problem`;
  the vector path was invoking it twice.

* Cache `target = _fd_target(evaluator)` once locally rather than
  reconstructing the `Fix2` per config.

`test/ext/forwarddiff/main.jl`: add a regression test asserting the
user-supplied tag flows into the stored config's first type parameter.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Drop the unused `where {T<:Real}` binding on the empty-input
  Jacobian method; the non-empty sibling already uses
  `x::AbstractVector{<:Real}` directly.
* Pass two `nothing`s to `FDCache{:hessian}(nothing, nothing)` for the
  empty-input order=2 cache instead of four — the constructor defaults
  `gradient_result` and `gradient_config` to `nothing`, so the
  resulting type is identical and the line is consistent with the
  `:scalar` / `:vector` empty-input shortcuts above.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`_fd_target(e)` was just `Base.Fix2(_fd_call, e)` — inline the five
call sites and drop the helper. `_fd_call` stays as a top-level named
function: ForwardDiff's `*Config` keys its `Tag` on the target type,
and a closure built inside one method would have a different type
from one built inside another, desyncing the per-call target from the
config captured at prep time. Reworded the comment to make that
constraint (not the harmless cost of per-call `Fix2`) the WHY.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The pending changes in `test/ext/forwarddiff/main.jl` and
`test/ext/mooncake/main.jl` had duplicated the same helper functions
and testset bodies for "allocation-free hot paths" and "type-stable
hot paths". Lift the shared logic into `AbstractPPLTestExt`:

* `IdentityProblem`: allocation-free vector-output problem (avoids
  `VectorValuedProblem`'s result allocation masking AD-path allocations).
* `_inferred_*` helpers wrap `@inferred` so it can be marked broken via
  `@test_broken`.
* `run_testcases(Val(:allocations); ...)`: `@allocated == 0` checks on
  scalar gradient and vector Jacobian, with `gradient_broken` /
  `jacobian_broken` kwargs for backends with known regressions.
* `run_testcases(Val(:type_stability); ...)`: `@inferred` checks on
  gradient/Jacobian/Hessian hot paths, with matching `*_broken` kwargs.

Both extension test files now invoke the shared groups; Mooncake passes
`jacobian_broken=true` for `:allocations` (both modes) and for
`:type_stability` only on `AutoMooncakeForward` (`Tuple{Any,
Union{Array{T,3}, Matrix}}` inference).

Docstring on `generate_testcases` updated to list the new keys.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Generalize the `*_broken` comment: previously cited only "Mooncake's
  forward-mode Jacobian", but the kwargs cover other regressions too
  (Mooncake's `value_and_jacobian!!` allocates on every call across
  both modes; only the forward-mode Jacobian *inference* is broken).
* Unify the `:allocations` vs `:type_stability` branch style — both
  now use the same `if/else` form rather than the ternary the former
  was using inconsistently.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… group

* Format: match the CI JuliaFormatter v1.0.62 baseline (the local one I
  was using is on 2.x and disagrees on `return`-keyword placement).
* `:allocations` group: Julia 1.10 heap-allocates `Fix2`/closure
  captures that 1.11+ elides. Mark `gradient_broken=VERSION < v"1.11"`
  (and `jacobian_broken=VERSION < v"1.11"` on FD) so min CI doesn't
  flag the older runtime as a regression.
* New `:context` group: lifts the inline "context-lowered gradient"
  testset from the FD test into `AbstractPPLTestExt`. Verifies
  `prepare(adtype, f, x; context=(c,))` lowers the context out of the
  gradient. FD calls it in place of the inline testset; Mooncake adds
  it alongside its richer Mooncake-specific context testset (forward
  parity, vector arity rejection, empty input with context).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@yebai yebai marked this pull request as ready for review May 25, 2026 16:10
yebai and others added 7 commits May 25, 2026 17:14
The custom-tag path was only structurally tested (the tag flowed into
the config type parameter), not exercised through an actual AD call.
DynamicPPL's downstream tests caught the gap: `AutoForwardDiff(;
tag=Tag{DynamicPPLTag,Float64}())` carries a sentinel tag whose first
type parameter is *not* `typeof(target)`, so ForwardDiff's default
`checktag` errors when the hot path calls `ForwardDiff.gradient!`.

Pass `Val(false)` to skip `checktag` in all four hot-path calls (this
is what DifferentiationInterface does). The tag is purely a label for
the outer Dual scope; the config built at prep time already encodes
the right tag, so the check is redundant and harmful in the
custom-tag case.

Strengthen the regression test to actually run `value_and_gradient!!`
on a prep built with a custom sentinel tag and assert the gradient
matches the analytic value — would have caught the original bug.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The gradient comparison was hardcoding `atol = 1e-10` while the value
comparison above it (and every other group in this file) used
`atol = atol`. The hardcoded value silently overrode the caller's
kwarg — Mooncake calls with `atol = 1e-6` were getting the tighter
1e-10. Use `atol = atol` to match the surrounding pattern.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Collapse the four case structs (`ValueCase`, `HessianCase`, `ErrorCase`,
`CacheReuseCase`) and seven `run_testcases(Val(:group); ...)` methods
into a single tagged `TestCase` and a single `run_testcase(case; ...)`
that dispatches on `case.tag` via `Val`.

Each backend's test script is now a single uniform loop:

```julia
for case in generate_testcases()
    run_testcase(case; adtype=AutoForwardDiff(), atol=1e-6, rtol=1e-6,
                 allocations=:test, type_stability=:test)
end
```

Tags are `:vector`, `:hessian`, `:context`, `:edge`, `:cache_reuse`,
`:namedtuple`. NamedTuple-input cases live in
`generate_namedtuple_testcases()` so backends that don't support that
input shape don't need to filter at the call site.

`allocations` / `type_stability` accept `:skip` / `:test` / `:broken`
(`:broken` wraps as `@test_broken` for known regressions). Per-case
`allocations_safe::Bool` defaults to `true`; cases with allocating
primals (`VectorValuedProblem` result vector, empty-input shortcuts,
hessian scratch, cache-reuse loops) opt out so the runner skips the
alloc check regardless of caller intent.

Case types and stubs (`TestCase`, `generate_testcases`,
`generate_namedtuple_testcases`, `run_testcase`) live in
`src/AbstractPPL.jl`; the generators and dispatched runner live in
`ext/AbstractPPLTestExt.jl`. The old `Val{group}` API and the standalone
`IdentityProblem` fixture are gone.

Backend-specific broken predicates (`_mooncake_alloc`,
`_mooncake_inferred`) sit next to the loop they drive — they encode
Mooncake's known issues (allocating Jacobian, forward-mode Jacobian and
context inference) without touching the shared harness.

Local: FD 111/111, Mooncake 149 pass + 3 broken, DI 115/115.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`src/AbstractPPL.jl` keeps only the two function stubs
`generate_testcases` and `run_testcase`. The `TestCase` struct (and its
keyword-arg constructor) moves into `ext/AbstractPPLTestExt.jl` — test
scripts only access `case.tag` via field access, so the type itself
doesn't need to live in main.

Collapse the two separate generators into a single Val-dispatched
function:

  generate_testcases(Val(:vector))     — all vector-input cases
  generate_testcases(Val(:namedtuple)) — NamedTuple-input cases

Backends iterate `generate_testcases(Val(:vector))` (and Mooncake also
`Val(:namedtuple)`).

Local: FD 111/111, Mooncake 149 pass + 3 broken, DI 115/115 — no
behavioural change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The two `_run(Val{...})` methods differed only by passing
`context=case.context` to `prepare` (no-op for `:vector` cases since
`case.context` defaults to `()`). Collapse into one method with
`Union{Val{:vector},Val{:context}}` dispatch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The empty-input hessian case and both cache-reuse cases set
`allocations_safe=false` without an inline reason, while the other
four instances do. Add brief explanations to match.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CI reported `Unexpected Pass` on Julia 1.10 for `quadratic (scalar
output)` (FD + Mooncake) and `scalar gradient with context` (FD): the
recent FD-ext tweaks (skip-checktag, hoisted target/tag locals) made
these paths alloc-free on 1.10 too. Drop the `VERSION < v"1.11"`
gating; the per-case `allocations_safe=false` still filters the
genuinely-allocating paths (vector jacobian, empty-input shortcuts,
hessian, cache-reuse loops).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CI reported back-to-back inconsistent results on Julia 1.10 for the
same code: one run had Mooncake's scalar-gradient `@allocated` come
out 0 (an Unexpected Pass when marked `:broken`), the next run had it
at 256 (a Test Failed when marked `:test`). The dependency resolver
picks slightly different Mooncake versions between runs, and Mooncake
0.5.x's allocation behaviour on 1.10 isn't stable across them.

Set `_mooncake_alloc` to return `:skip` on Julia 1.10 instead of
either `:test` or `:broken` — that way the check doesn't fire on min,
regardless of which Mooncake version the resolver picked. Latest-Julia
coverage is unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@yebai yebai merged commit 9ebf18f into main May 25, 2026
17 checks passed
@yebai yebai deleted the copilot/add-forwarddiff-support branch May 25, 2026 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add ForwardDiff support

2 participants