Skip to content

Potential bug when switching from SDE to SSA using Callbacks #592

@mauricelanghinrichs

Description

@mauricelanghinrichs

Potential bug

I noticed a potential bug when trying to implement a hybrid SSA-SDE solver. In the MRE below, you find a simple birth-death process that starts at 150 counts and slowly decreases (death rate > birth rate). I want to solve the initial regime using a SDE solver, while switching to an accurate Gillespie / SS algorithm when reaching 95 counts or lower. The switching behaviour is implemented using callbacks.

When switching from the SDE to the SSA regime, the subsequently used Direct() method somehow gets stuck. See here:
Image

This happens only when using VariableRateJump. When using ConstantRateJump instead, the algorithm seems to work correctly:
Image

This issue seems to depend on using noise_rate_prototype, in particular using matrix G to specify the noise component in the cle_g! method. When rewriting my problem without using noise_rate_prototype, both versions (VariableRateJump and ConstantRateJump) run fine. I need noise_rate_prototype however to be able to implement multi-compartment birth-death-migration processes, where noise channels may be correlated.

Maybe this is not a bug and my callbacks or so are not implemented correctly. Happy for any input.

Expected behavior
The Direct() method to not get stuck after callbacks (switching from SDE to SSA regime), when using VariableRateJumps and the noise_rate_prototype option.

MRE and versions

  • MRE
using JumpProcesses
using StochasticDiffEq
# using DifferentialEquations
using Random
using Plots

### version with noise_rate_prototype
# model 
u0 = [150.0]
tspan = (0.0, 10.0)

mutable struct Params
    λ::Float64 # birth in X (X -> 2 X)
    μ::Float64 # death in X (X -> )
    threshold_low::Float64
    threshold_high::Float64
    mode::Symbol
end

p = Params(
    0.3, 0.5, # λ and μ
    95.0,    # CLE -> SSA
    105.0,   # SSA -> CLE
    :SSA
)

is_cle(p) = p.mode === :CLE
is_ssa(p) = p.mode === :SSA

# CLE drift and diffusion, active only in CLE mode
function cle_f!(du, u, p, t)
    if is_cle(p)
        du[1] = (p.λ - p.μ) * u[1]
    else
        du .= 0.0
    end
    return nothing
end

function cle_g!(G, u, p, t)
    if is_cle(p)
        x1 = max(u[1], 0.0)
        G[1, 1] = sqrt(p.λ * x1)
        G[1, 2] = - sqrt(p.μ * x1)
    else
        G .= 0.0
    end
    return nothing
end

# SSA jumps, active only in SSA mode
rate_birth1(u, p, t)  = is_ssa(p) ? p.λ * u[1] : 0.0
rate_death1(u, p, t)  = is_ssa(p) ? p.μ * u[1] : 0.0

affect_birth1!(integrator)  = (integrator.u[1] += 1.0)
affect_death1!(integrator)  = (integrator.u[1] -= 1.0)

# callbacks
function cle_to_ssa_condition(u, t, integrator)
    integrator.p.mode === :CLE ? u[1] - integrator.p.threshold_low : 1.0
end

function cle_to_ssa_affect!(integrator)
    integrator.u[1] = max(round(integrator.u[1]), 0.0)

    integrator.p.mode = :SSA

    u_modified!(integrator, true)
    reset_aggregated_jumps!(integrator)
end

cle_to_ssa_cb = ContinuousCallback(
    cle_to_ssa_condition,
    cle_to_ssa_affect!)

function ssa_to_cle_condition(u, t, integrator)
    integrator.p.mode === :SSA && u[1]  integrator.p.threshold_high
end

function ssa_to_cle_affect!(integrator)
    integrator.p.mode = :CLE

    u_modified!(integrator, true)
    reset_aggregated_jumps!(integrator)
end

ssa_to_cle_cb = DiscreteCallback(
    ssa_to_cle_condition,
    ssa_to_cle_affect!)

cb = CallbackSet(cle_to_ssa_cb, ssa_to_cle_cb)

# problem construction
sde_prob = SDEProblem(
    cle_f!,
    cle_g!,
    u0,
    tspan,
    p;
    noise_rate_prototype = zeros(1, 2)
    )

### either use ConstantRateJumps or VariableRateJumps
### CODE FAILS WITH VariableRateJumps
jumps = (
    ConstantRateJump(rate_birth1,  affect_birth1!),
    ConstantRateJump(rate_death1,  affect_death1!),
)

jumps = (
    VariableRateJump(rate_birth1,  affect_birth1!),
    VariableRateJump(rate_death1,  affect_death1!),
)
###

jump_prob = JumpProblem(
    sde_prob,
    Direct(),
    jumps...)

Random.seed!(1)

@time begin 
    p.mode = (u0[1]  p.threshold_high) ? :CLE : :SSA

    sol = solve(
        jump_prob,
        LambaEM();
        callback = cb,
        adaptive = true,
    )
end

plot(sol, labels="X1")
plot!(0:0.1:tspan[2], u0[1] .* exp.((p.λ - p.μ) .* (0:0.1:tspan[2])), label="E(X1)")
  • Output of using Pkg; Pkg.status()
Status `~/.julia/environments/v1.12/Project.toml`
  [6e4b80f9] BenchmarkTools v1.8.0
  [336ed68f] CSV v0.10.16
  [13f3f980] CairoMakie v0.15.10
  [aaaa29a8] Clustering v0.15.8
  [861a8166] Combinatorics v1.1.0
  [a93c6f00] DataFrames v1.8.2
  [b4f34e82] Distances v0.10.12
  [31c24e10] Distributions v0.25.125
  [cc61a311] FLoops v0.2.2
  [1ecd5474] GraphMakie v0.6.3
  [86223c79] Graphs v1.14.0
  [033835bb] JLD2 v0.6.4
  [ccbc3e58] JumpProcesses v9.28.0
  [fa8bd995] MetaGraphsNext v0.8.0
  [7509a0a4] NautyGraphs v0.7.2
  [43a3c2be] PairPlots v3.0.3
  [91a5bcdd] Plots v1.41.6
  [295af30f] Revise v3.14.3
  [2913bbd2] StatsBase v0.34.10
  [f3b207a7] StatsPlots v0.15.8
  [789caeaf] StochasticDiffEq v7.0.0
  [fdbf4ff8] XLSX v0.11.8
  [9a3f8284] Random v1.11.0
  [2f01184e] SparseArrays v1.12.0
  • Output of versioninfo()
Julia Version 1.12.6
Commit 15346901f00 (2026-04-09 19:20 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: macOS (arm64-apple-darwin24.0.0)
  CPU: 16 × Apple M4 Max
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, apple-m4)
  GC: Built with stock GC
Threads: 12 default, 1 interactive, 12 GC (on 12 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_VSCODE_REPL = 1

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions