diff --git a/Project.toml b/Project.toml index 620d9597..1430bdab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "JumpProcesses" uuid = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" authors = ["Chris Rackauckas "] -version = "9.28.0" +version = "9.28.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/extended_jump_array.jl b/src/extended_jump_array.jl index 0eacede3..b5bfcf2a 100644 --- a/src/extended_jump_array.jl +++ b/src/extended_jump_array.jl @@ -94,6 +94,10 @@ function LinearAlgebra.mul!(c::ExtendedJumpArray, A::AbstractVecOrMat, u::Abstra Nu = length(c.u) if size(A, 1) == Nu mul!(c.u, A, u) + # zero c.jump_u so callers (e.g. adaptive SDE step caches that reuse `c` + # as a scratchpad) do not see stale values as a noise contribution on + # the jump-rate-integral state + fill!(c.jump_u, zero(eltype(c.jump_u))) elseif size(A, 1) == length(c) mul!(c.u, @view(A[1:Nu, :]), u) mul!(c.jump_u, @view(A[(Nu + 1):end, :]), u) diff --git a/test/extended_jump_array.jl b/test/extended_jump_array.jl index 9f6d59a5..23e42b3a 100644 --- a/test/extended_jump_array.jl +++ b/test/extended_jump_array.jl @@ -7,7 +7,8 @@ rng = StableRNG(123) # Check that the new broadcast norm gives the same result as the old one rand_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(rng, 5), rand(rng, 2)) -old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) / max(DiffEqBase.recursive_length(rand_array), 1)) +old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) / + max(DiffEqBase.recursive_length(rand_array), 1)) new_norm = DiffEqBase.ODE_DEFAULT_NORM(rand_array, 0.0) @test old_norm ≈ new_norm @@ -15,7 +16,8 @@ new_norm = DiffEqBase.ODE_DEFAULT_NORM(rand_array, 0.0) rand_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Int64}}(rand(rng, 5), rand(rng, 1:1000, 2)) -old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) / max(DiffEqBase.recursive_length(rand_array), 1)) +old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) / + max(DiffEqBase.recursive_length(rand_array), 1)) new_norm = DiffEqBase.ODE_DEFAULT_NORM(rand_array, 0.0) @test old_norm ≈ new_norm @@ -119,6 +121,31 @@ let @test SciMLBase.plottable_indices(sol.u[1]) == 1:length(u₀) end +# Regression for https://github.com/SciML/JumpProcesses.jl/issues/592: +# mul!(c::ExtendedJumpArray, A, u) must clear c.jump_u when A only addresses +# the c.u portion. Otherwise stale scratchpad values pollute the jump-rate +# integral state in adaptive SDE solvers. +let rng = StableRNG(592) + c = ExtendedJumpArray(rand(rng, 3), [1.0, 2.0, -3.0]) # pre-populated jump_u + A = rand(rng, 3, 4) # noise-rate-prototype-sized + u = rand(rng, 4) + expected_u = A * u + mul!(c, A, u) + @test c.u ≈ expected_u + @test all(iszero, c.jump_u) # jump_u zeroed +end + +# Full-state matrix case still scatters into both halves. +let rng = StableRNG(593) + c = ExtendedJumpArray(zeros(3), [9.0, 9.0, 9.0]) + A = rand(rng, 6, 4) + u = rand(rng, 4) + full = A * u + mul!(c, A, u) + @test c.u ≈ full[1:3] + @test c.jump_u ≈ full[4:6] +end + # Test ldiv! and lmul! for stiff solver support let rng = StableRNG(456) u = rand(rng, 3) diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 0b81a14b..30edce20 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -420,7 +420,8 @@ let integrator.p[3] += 1 nothing end - birth_jump = VariableRateJump(birth_rate, birth_affect!; save_positions = (false, false)) + birth_jump = VariableRateJump(birth_rate, birth_affect!; save_positions = ( + false, false)) # Define death jump: X → ∅ death_rate(u, p, t) = 0.5 * u[1] @@ -429,7 +430,8 @@ let integrator.p[3] += 1 nothing end - death_jump = VariableRateJump(death_rate, death_affect!; save_positions = (false, false)) + death_jump = VariableRateJump(death_rate, death_affect!; save_positions = ( + false, false)) Nsims = 100 results = Dict() @@ -566,3 +568,67 @@ let @test sol.u[end][1] + sol.u[end][2] ≈ u0[1] + u0[2] end end + +# Regression for https://github.com/SciML/JumpProcesses.jl/issues/592 +# VR_FRM + SDEProblem with `noise_rate_prototype` + adaptive non-diagonal SDE +# solver: ExtendedJumpArray's `mul!` needs to zero `c.jump_u` when the noise +# matrix only addresses the original state, otherwise stale scratchpad values +# from the adaptive error estimate blow up `jump_u` and the VRJ callbacks +# never fire. +let + rng = StableRNG(592) + + mutable struct Issue592Params + λ::Float64 + μ::Float64 + mode::Symbol + end + p = Issue592Params(0.3, 0.5, :CLE) + + function f!(du, u, p, t) + du[1] = (p.mode === :CLE) ? (p.λ - p.μ) * u[1] : 0.0 + nothing + end + function g!(G, u, p, t) + if p.mode === :CLE + x = max(u[1], 0.0) + G[1, 1] = sqrt(p.λ * x) + G[1, 2] = -sqrt(p.μ * x) + else + G .= 0.0 + end + nothing + end + + rate_birth(u, p, t) = (p.mode === :SSA) ? p.λ * u[1] : 0.0 + rate_death(u, p, t) = (p.mode === :SSA) ? p.μ * u[1] : 0.0 + birth_affect!(integ) = (integ.u[1] += 1.0) + death_affect!(integ) = (integ.u[1] -= 1.0) + birth = VariableRateJump(rate_birth, birth_affect!) + death = VariableRateJump(rate_death, death_affect!) + + switch_cond(u, t, integ) = (integ.p.mode === :CLE) ? u[1] - 95.0 : 1.0 + function switch_affect!(integ) + integ.u[1] = max(round(integ.u[1]), 0.0) + integ.p.mode = :SSA + u_modified!(integ, true) + reset_aggregated_jumps!(integ) + end + switch_cb = ContinuousCallback(switch_cond, switch_affect!) + + u0 = [150.0] + tspan = (0.0, 10.0) + sde_prob = SDEProblem(f!, g!, u0, tspan, p; noise_rate_prototype = zeros(1, 2)) + jprob = JumpProblem(sde_prob, Direct(), birth, death; vr_aggregator = VR_FRM(), rng) + + sol = solve(jprob, LambaEM(); callback = switch_cb, adaptive = true) + @test SciMLBase.successful_retcode(sol) + @test sol.t[end] == tspan[2] + # The bug drove `jump_u` past ±1e15 within a few steps; with the fix it + # stays bounded near its initial -randexp() values. + @test all(isfinite, sol.u[end].jump_u) + @test maximum(abs, sol.u[end].jump_u) < 1e6 + # SSA jumps must actually fire after the mode switch — without the fix u + # gets pinned at the switch threshold (95) forever. + @test sol.u[end].u[1] != 95.0 +end