Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "JumpProcesses"
uuid = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "9.28.0"
version = "9.28.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
4 changes: 4 additions & 0 deletions src/extended_jump_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ function LinearAlgebra.mul!(c::ExtendedJumpArray, A::AbstractVecOrMat, u::Abstra
Nu = length(c.u)
if size(A, 1) == Nu
mul!(c.u, A, u)
# zero c.jump_u so callers (e.g. adaptive SDE step caches that reuse `c`
# as a scratchpad) do not see stale values as a noise contribution on
# the jump-rate-integral state
fill!(c.jump_u, zero(eltype(c.jump_u)))
elseif size(A, 1) == length(c)
mul!(c.u, @view(A[1:Nu, :]), u)
mul!(c.jump_u, @view(A[(Nu + 1):end, :]), u)
Expand Down
31 changes: 29 additions & 2 deletions test/extended_jump_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ rng = StableRNG(123)
# Check that the new broadcast norm gives the same result as the old one
rand_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(rng, 5),
rand(rng, 2))
old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) / max(DiffEqBase.recursive_length(rand_array), 1))
old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) /
max(DiffEqBase.recursive_length(rand_array), 1))
new_norm = DiffEqBase.ODE_DEFAULT_NORM(rand_array, 0.0)
@test old_norm ≈ new_norm

# Check for an ExtendedJumpArray where the types differ (Float64/Int64)
rand_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Int64}}(rand(rng, 5),
rand(rng, 1:1000,
2))
old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) / max(DiffEqBase.recursive_length(rand_array), 1))
old_norm = Base.FastMath.sqrt_fast(DiffEqBase.UNITLESS_ABS2(rand_array) /
max(DiffEqBase.recursive_length(rand_array), 1))
new_norm = DiffEqBase.ODE_DEFAULT_NORM(rand_array, 0.0)
@test old_norm ≈ new_norm

Expand Down Expand Up @@ -119,6 +121,31 @@ let
@test SciMLBase.plottable_indices(sol.u[1]) == 1:length(u₀)
end

# Regression for https://github.com/SciML/JumpProcesses.jl/issues/592:
# mul!(c::ExtendedJumpArray, A, u) must clear c.jump_u when A only addresses
# the c.u portion. Otherwise stale scratchpad values pollute the jump-rate
# integral state in adaptive SDE solvers.
let rng = StableRNG(592)
c = ExtendedJumpArray(rand(rng, 3), [1.0, 2.0, -3.0]) # pre-populated jump_u
A = rand(rng, 3, 4) # noise-rate-prototype-sized
u = rand(rng, 4)
expected_u = A * u
mul!(c, A, u)
@test c.u ≈ expected_u
@test all(iszero, c.jump_u) # jump_u zeroed
end

# Full-state matrix case still scatters into both halves.
let rng = StableRNG(593)
c = ExtendedJumpArray(zeros(3), [9.0, 9.0, 9.0])
A = rand(rng, 6, 4)
u = rand(rng, 4)
full = A * u
mul!(c, A, u)
@test c.u ≈ full[1:3]
@test c.jump_u ≈ full[4:6]
end

# Test ldiv! and lmul! for stiff solver support
let rng = StableRNG(456)
u = rand(rng, 3)
Expand Down
70 changes: 68 additions & 2 deletions test/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ let
integrator.p[3] += 1
nothing
end
birth_jump = VariableRateJump(birth_rate, birth_affect!; save_positions = (false, false))
birth_jump = VariableRateJump(birth_rate, birth_affect!; save_positions = (
false, false))

# Define death jump: X → ∅
death_rate(u, p, t) = 0.5 * u[1]
Expand All @@ -429,7 +430,8 @@ let
integrator.p[3] += 1
nothing
end
death_jump = VariableRateJump(death_rate, death_affect!; save_positions = (false, false))
death_jump = VariableRateJump(death_rate, death_affect!; save_positions = (
false, false))

Nsims = 100
results = Dict()
Expand Down Expand Up @@ -566,3 +568,67 @@ let
@test sol.u[end][1] + sol.u[end][2] ≈ u0[1] + u0[2]
end
end

# Regression for https://github.com/SciML/JumpProcesses.jl/issues/592
# VR_FRM + SDEProblem with `noise_rate_prototype` + adaptive non-diagonal SDE
# solver: ExtendedJumpArray's `mul!` needs to zero `c.jump_u` when the noise
# matrix only addresses the original state, otherwise stale scratchpad values
# from the adaptive error estimate blow up `jump_u` and the VRJ callbacks
# never fire.
let
rng = StableRNG(592)

mutable struct Issue592Params
λ::Float64
μ::Float64
mode::Symbol
end
p = Issue592Params(0.3, 0.5, :CLE)

function f!(du, u, p, t)
du[1] = (p.mode === :CLE) ? (p.λ - p.μ) * u[1] : 0.0
nothing
end
function g!(G, u, p, t)
if p.mode === :CLE
x = max(u[1], 0.0)
G[1, 1] = sqrt(p.λ * x)
G[1, 2] = -sqrt(p.μ * x)
else
G .= 0.0
end
nothing
end

rate_birth(u, p, t) = (p.mode === :SSA) ? p.λ * u[1] : 0.0
rate_death(u, p, t) = (p.mode === :SSA) ? p.μ * u[1] : 0.0
birth_affect!(integ) = (integ.u[1] += 1.0)
death_affect!(integ) = (integ.u[1] -= 1.0)
birth = VariableRateJump(rate_birth, birth_affect!)
death = VariableRateJump(rate_death, death_affect!)

switch_cond(u, t, integ) = (integ.p.mode === :CLE) ? u[1] - 95.0 : 1.0
function switch_affect!(integ)
integ.u[1] = max(round(integ.u[1]), 0.0)
integ.p.mode = :SSA
u_modified!(integ, true)
reset_aggregated_jumps!(integ)
end
switch_cb = ContinuousCallback(switch_cond, switch_affect!)

u0 = [150.0]
tspan = (0.0, 10.0)
sde_prob = SDEProblem(f!, g!, u0, tspan, p; noise_rate_prototype = zeros(1, 2))
jprob = JumpProblem(sde_prob, Direct(), birth, death; vr_aggregator = VR_FRM(), rng)

sol = solve(jprob, LambaEM(); callback = switch_cb, adaptive = true)
@test SciMLBase.successful_retcode(sol)
@test sol.t[end] == tspan[2]
# The bug drove `jump_u` past ±1e15 within a few steps; with the fix it
# stays bounded near its initial -randexp() values.
@test all(isfinite, sol.u[end].jump_u)
@test maximum(abs, sol.u[end].jump_u) < 1e6
# SSA jumps must actually fire after the mode switch — without the fix u
# gets pinned at the switch threshold (95) forever.
@test sol.u[end].u[1] != 95.0
end
Loading