Skip to content

Fix Tracker gradient on RAT v4 by preserving the ODESolution wrapper#3665

Merged
ChrisRackauckas merged 2 commits into
SciML:masterfrom
ChrisRackauckas-Claude:cc/tracker-rat-v4-preserve-wrapper
May 21, 2026
Merged

Fix Tracker gradient on RAT v4 by preserving the ODESolution wrapper#3665
ChrisRackauckas merged 2 commits into
SciML:masterfrom
ChrisRackauckas-Claude:cc/tracker-rat-v4-preserve-wrapper

Conversation

@ChrisRackauckas-Claude
Copy link
Copy Markdown
Contributor

Please ignore until reviewed by @ChrisRackauckas.

Summary

Re-take on the Tracker + RAT v4 fix after the original #3663 was reverted in #3664. Same root cause, different correct shape.

In RecursiveArrayTools v4, AbstractVectorOfArray <: AbstractArray, so the sol isa AbstractArray branch in Tracker.@grad function DiffEqBase.solve_up now matches ODESolution and returns the nested sol.u :: Vector{Vector{Float64}} directly. Tracker tracks the vector-of-vectors, and downstream sum(solve(...)) reduces the outer vector element-wise into a Vector{Float64}, breaking Tracker.gradient(loss, p) callers with "Function output is not scalar".

The previous fix (#3663) stacked the data into a fresh Matrix{Float64} via Array(sol). That changed the return type of solve(...) from a wrapper to a matrix and broke downstream consumers — hence the revert in #3664. This PR returns the ODESolution wrapper itself for AbstractVectorOfArray inputs, so callers reduce through the RAT v4 AbstractArray interface and produce a scalar without losing the solution type.

Verification

Local on Julia 1.12.6 + RAT 4.3.0 + Tracker 0.2.38 + SciMLSensitivity master, with DiffEqBase 7.5.1 from this branch:

Pattern Result
sum(solve(remake(prob; p), Tsit5())) scalar gradient ✓
sum(solve(...; save_idxs = [1])) scalar gradient ✓
sum(solve(...; save_everystep = false)) scalar gradient ✓
sum(solve(...; saveat = 2.3)) scalar gradient ✓

The five @test_broken false guards in test/concrete_solve_derivatives.jl (lines 518, 548, 579, 611, 735) on SciMLSensitivity are slated for removal in SciML/SciMLSensitivity.jl#1452, gated on this PR landing.

Refs

Test plan

  • All 4 Tracker patterns produce scalar gradients locally
  • Comment on solve_up @grad updated to explain RAT v4 semantics
  • CI green

In RecursiveArrayTools v4 `AbstractVectorOfArray <: AbstractArray`, so
the `sol isa AbstractArray` branch in `Tracker.@Grad function
DiffEqBase.solve_up` now matches ODESolution and returns the nested
`sol.u :: Vector{Vector{Float64}}` directly. Tracker tracks the
vector-of-vectors, and downstream `sum(solve(...))` reduces the outer
vector element-wise into a `Vector{Float64}`, breaking
`Tracker.gradient(loss, p)` callers with "Function output is not
scalar".

Return the ODESolution wrapper itself for `AbstractVectorOfArray`
inputs so callers reduce through the RAT v4 AbstractArray interface
and get a scalar as before. Earlier attempt (SciML#3663) stacked into a
fresh matrix via `Array(sol)`; that was reverted in SciML#3664 because it
changed the return type and broke downstream consumers. Preserving the
wrapper keeps the contract.

Refs SciML/SciMLSensitivity.jl#1331.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
@ChrisRackauckas ChrisRackauckas marked this pull request as ready for review May 21, 2026 13:48
@ChrisRackauckas ChrisRackauckas merged commit fb060b3 into SciML:master May 21, 2026
77 of 84 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants