From 0c1a3636e62dd4d44a90ced9bbf5bd54a4a6bdff Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sat, 6 Sep 2025 00:46:23 +0530 Subject: [PATCH 1/9] SimpleAdaptiveTauLeaping --- Project.toml | 4 + src/JumpProcesses.jl | 4 +- src/simple_regular_solve.jl | 282 ++++++++++++++++++++++++++++++++++++ test/regular_jumps.jl | 137 ++++++++++++++++++ 4 files changed, 426 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index dfc8e5f5..d2ed6e0c 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -54,6 +56,8 @@ RecursiveArrayTools = "3.35, 4" Reexport = "1.2" SafeTestsets = "0.1" SciMLBase = "2.115, 3.1" +Setfield = "1" +SimpleNonlinearSolve = "1, 2" StableRNGs = "1" StaticArrays = "1.9.8" Statistics = "1" diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 87c8c2c2..741015a3 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -18,6 +18,8 @@ using StaticArrays: StaticArrays, SVector, setindex using Base.Threads: Threads using Base.FastMath: add_fast +using SimpleNonlinearSolve + # Import functions we extend from Base import Base: size, getindex, setindex!, length, similar, show, merge!, merge @@ -131,7 +133,7 @@ export SSAStepper # leaping: include("simple_regular_solve.jl") -export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel +export SimpleTauLeaping, SimpleExplicitTauLeaping, SimpleAdaptiveTauLeaping, NewtonImplicitSolver, TrapezoidalImplicitSolver, EnsembleGPUKernel # spatial: include("spatial/spatial_massaction_jump.jl") diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 3675ddad..6add712d 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -6,6 +6,23 @@ end SimpleExplicitTauLeaping(; epsilon = 0.05) = SimpleExplicitTauLeaping(epsilon) +# Define solver type hierarchy +abstract type AbstractImplicitSolver end +struct NewtonImplicitSolver <: AbstractImplicitSolver end +struct TrapezoidalImplicitSolver <: AbstractImplicitSolver end + +# Adaptive tau-leaping solver +struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm + epsilon::T # Error control parameter for tau selection + solver::AbstractImplicitSolver # Solver type for implicit method +end + +# Stiffness detection threshold is computed dynamically as epsilon * sum(u), where u is the current state, +# as inspired by Cao et al. (2007), Section III.B, which uses the ratio of propensity functions to identify +# fast and slow time scales in stiff systems, scaled by system size for robustness. +SimpleAdaptiveTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver()) = + SimpleAdaptiveTauLeaping(epsilon, solver) + function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ @@ -69,6 +86,20 @@ function _process_saveat(saveat, tspan, save_start, save_end) return saveat_vec, _save_start, _save_end end +# Validation for adaptive tau-leaping +function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping) + if !(jump_prob.aggregator isa PureLeaping) + @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ + JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \ + Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release." + end + isempty(jump_prob.jump_callback.continuous_callbacks) && + isempty(jump_prob.jump_callback.discrete_callbacks) && + isempty(jump_prob.constant_jumps) && + isempty(jump_prob.variable_jumps) && + jump_prob.massaction_jump !== nothing +end + function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; seed = nothing, dt = error("dt is required for SimpleTauLeaping."), saveat = nothing, save_start = nothing, save_end = nothing) @@ -405,6 +436,257 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end +function compute_hor(reactant_stoch, numjumps) + # Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. + # HOR is the sum of stoichiometric coefficients of reactants in reaction j. + hor = zeros(Int, numjumps) + for j in 1:numjumps + order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) + if order > 3 + error("Reaction $j has order $order, which is not supported (maximum order is 3).") + end + hor[j] = order + end + return hor +end + +function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) + # Precompute reaction conditions for each species i, including: + # - max_hor: the highest order of reaction (HOR) where species i is a reactant. + # - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor. + # Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27). + max_hor = zeros(Int, numspecies) + max_stoich = zeros(Int, numspecies) + for j in 1:numjumps + for (spec_idx, stoch) in reactant_stoch[j] + if stoch > 0 # Species is a reactant + if hor[j] > max_hor[spec_idx] + max_hor[spec_idx] = hor[j] + max_stoich[spec_idx] = stoch + elseif hor[j] == max_hor[spec_idx] + max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch) + end + end + end + end + return max_hor, max_stoich +end + +function compute_gi(u, max_hor, max_stoich, i, t) + # Compute g_i for species i to bound the relative change in propensity functions, + # as per Cao et al. (2006), Section IV, equation (27). + # g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant: + # - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1 + # - HOR = 2 (second-order): + # - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2 + # - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1) + # - HOR = 3 (third-order): + # - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3 + # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1)) + # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) + # Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep. + if max_hor[i] == 0 # No reactions involve species i as a reactant + return 1.0 + elseif max_hor[i] == 1 + return 1.0 + elseif max_hor[i] == 2 + if max_stoich[i] == 1 + return 2.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 # Fallback to 2.0 if x_i <= 1 + end + elseif max_hor[i] == 3 + if max_stoich[i] == 1 + return 3.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 # Fallback to 3.0 if x_i <= 1 + elseif max_stoich[i] == 3 + return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 # Fallback to 3.0 if x_i <= 2 + end + end + return 1.0 # Default case +end + +function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + # Compute the tau-leaping step-size using equation (20) from Cao et al. (2006): + # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } + # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): + # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) + # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). + rate(rate_cache, u, p, t) + if all(==(0.0), rate_cache) # Handle case where all rates are zero + return dtmin + end + tau = Inf + for i in 1:length(u) + mu = zero(eltype(u)) + sigma2 = zero(eltype(u)) + for j in 1:size(nu, 2) + mu += nu[i, j] * rate_cache[j] # Equation (9a) + sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b) + end + gi = compute_gi(u, max_hor, max_stoich, i, t) + bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1) + mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) + sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) + tau = min(tau, mu_term, sigma_term) # Equation (8) + end + return max(tau, dtmin) +end + +# Define residual for implicit equation +# Newton: u_new = u_current + sum_j nu_j * a_j(u_new) * tau (Cao et al., 2004) +# Trapezoidal: u_new = u_current + sum_j nu_j * (a_j(u_current) + a_j(u_new))/2 * tau +function implicit_equation!(resid, u_new, params) + u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params + rate(rate_cache, u_new, p, t + tau) + resid .= u_new .- u_current + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + if isa(solver, NewtonImplicitSolver) + resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004) + else # TrapezoidalImplicitSolver + rate_current = similar(rate_cache) + rate(rate_current, u_current, p, t) + resid[spec_idx] -= nu[spec_idx, j] * 0.5 * (rate_cache[j] + rate_current[j]) * tau + end + end + end + resid .= max.(resid, -u_new) # Ensure non-negative solution +end + +# Solve implicit equation using SimpleNonlinearSolve +function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver) + u_new = convert(Vector{Float64}, u_current) + prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)) + sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6) + return sol.u, sol.retcode == ReturnCode.Success +end + +# Stiffness detection based on propensity ratio scaled by system size +# Reference: Cao et al. (2007), Section III.B, using ratio of propensity functions to detect stiffness +function is_stiff(rate_cache, u, epsilon) + non_zero_rates = [rate for rate in rate_cache if rate > 0] + if length(non_zero_rates) <= 1 + return false # Use explicit method if no or only one non-zero propensity + end + max_rate = maximum(non_zero_rates) + min_rate = minimum(non_zero_rates) + threshold = epsilon * sum(u) # Dynamic threshold based on system size + return max_rate / min_rate > threshold +end + +# Adaptive tau-leaping solver +# Reference: Cao et al. (2007) for adaptive strategy, Cao et al. (2004) for implicit method, Cao et al. (2006) for tau selection +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; + seed = nothing, + dtmin = 1e-10, + saveat = nothing) + validate_pure_leaping_inputs(jump_prob, alg) || + error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") + + @unpack prob, rng = jump_prob + (seed !== nothing) && seed!(rng, seed) + + maj = jump_prob.massaction_jump + numjumps = get_num_majumps(maj) + rate = (out, u, p, t) -> begin + for j in 1:numjumps + out[j] = evalrxrate(u, j, maj) + end + end + u0 = copy(prob.u0) + tspan = prob.tspan + p = prob.p + + u_current = copy(u0) + t_current = tspan[1] + usave = [copy(u0)] + tsave = [tspan[1]] + rate_cache = zeros(Float64, numjumps) + counts = zeros(Int64, numjumps) + du = similar(u0) + t_end = tspan[2] + epsilon = alg.epsilon + solver = alg.solver + + nu = zeros(Int64, length(u0), numjumps) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + nu[spec_idx, j] = stoch + end + end + reactant_stoch = maj.reactant_stoch + hor = compute_hor(reactant_stoch, numjumps) + max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) + + saveat_times = isnothing(saveat) ? Vector{Float64}() : + (saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat)) + save_idx = 1 + + while t_current < t_end + rate(rate_cache, u_current, p, t_current) + tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + tau = min(tau, t_end - t_current) + if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] + tau = saveat_times[save_idx] - t_current + end + + # Stiffness detection using dynamic propensity ratio threshold + use_implicit = is_stiff(rate_cache, u_current, epsilon) + + if use_implicit + # Implicit tau-leaping (Cao et al., 2004) + u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, t_current, tau, rate, numjumps, solver) + if !converged + tau /= 2 + continue + end + rate(rate_cache, u_new_float, p, t_current + tau) + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Cao et al. (2004): Poisson sampling + du .= zero(eltype(u_current)) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + du[spec_idx] += stoch * counts[j] + end + end + u_new = u_current + du # Cao et al. (2004): Final state update + else + # Explicit tau-leaping (Cao et al., 2006) + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Equation (8): Poisson sampling + du .= zero(eltype(u_current)) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + du[spec_idx] += stoch * counts[j] + end + end + u_new = u_current + du + end + + if any(<(0), u_new) + tau /= 2 + continue + end + t_new = t_current + tau + + if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) + push!(usave, copy(ceil.(u_new))) # Ensure integer solutions for molecular populations + push!(tsave, t_new) + if !isempty(saveat_times) && t_new >= saveat_times[save_idx] + save_idx += 1 + end + end + + u_current = u_new + t_current = t_new + end + + sol = DiffEqBase.build_solution(prob, alg, tsave, usave, + calculate_error=false, + interp=DiffEqBase.ConstantInterpolation(tsave, usave)) + return sol +end + struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm backend::Backend cpu_offload::Float64 diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index f4d009ea..6338c835 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -354,3 +354,140 @@ end save_start = true, save_end = true) @test sol.t == [0.0, 2.0, 8.0, 10.0] end + +# SimpleAdaptiveTauLeaping correctness - SIR model +@testset "SimpleAdaptiveTauLeaping SIR Correctness" begin + β = 0.1 / 1000.0 + ν = 0.01 + influx_rate = 1.0 + p = (β, ν, influx_rate) + + rate1(u, p, t) = p[1] * u[1] * u[2] + rate2(u, p, t) = p[2] * u[2] + rate3(u, p, t) = p[3] + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[1] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), + ConstantRateJump(rate3, affect3!)) + + u0 = [999.0, 10.0, 0.0] + tspan = (0.0, 250.0) + prob_disc = DiscreteProblem(u0, tspan, p) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) + + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); + trajectories = Nsims, saveat = 1.0) + + reactant_stoich = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [1=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs = param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng = rng) + + sol_adaptive_newton = solve(EnsembleProblem(jump_prob_maj), + SimpleAdaptiveTauLeaping(solver = NewtonImplicitSolver()), + EnsembleSerial(); trajectories = Nsims, saveat = 1.0) + sol_adaptive_trap = solve(EnsembleProblem(jump_prob_maj), + SimpleAdaptiveTauLeaping(solver = TrapezoidalImplicitSolver()), + EnsembleSerial(); trajectories = Nsims, saveat = 1.0) + + t_points = 0:1.0:250.0 + max_direct_I = maximum([mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_adaptive_newton = maximum([mean(sol_adaptive_newton[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_adaptive_trap = maximum([mean(sol_adaptive_trap[i](t)[2] for i in 1:Nsims) for t in t_points]) + + @test isapprox(max_direct_I, max_adaptive_newton, rtol = 0.05) + @test isapprox(max_direct_I, max_adaptive_trap, rtol = 0.05) +end + +# SimpleAdaptiveTauLeaping correctness - SEIR model +@testset "SimpleAdaptiveTauLeaping SEIR Correctness" begin + β = 0.3 / 1000.0 + σ = 0.2 + ν = 0.01 + p = (β, σ, ν) + + rate1(u, p, t) = p[1] * u[1] * u[3] + rate2(u, p, t) = p[2] * u[2] + rate3(u, p, t) = p[3] * u[3] + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), + ConstantRateJump(rate3, affect3!)) + + u0 = [999.0, 0.0, 10.0, 0.0] + tspan = (0.0, 250.0) + prob_disc = DiscreteProblem(u0, tspan, p) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) + + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); + trajectories = Nsims, saveat = 1.0) + + reactant_stoich = [[1=>1, 3=>1], [2=>1], [3=>1]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [3=>-1, 4=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs = param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng = rng) + + sol_adaptive_newton = solve(EnsembleProblem(jump_prob_maj), + SimpleAdaptiveTauLeaping(solver = NewtonImplicitSolver()), + EnsembleSerial(); trajectories = Nsims, saveat = 1.0) + sol_adaptive_trap = solve(EnsembleProblem(jump_prob_maj), + SimpleAdaptiveTauLeaping(solver = TrapezoidalImplicitSolver()), + EnsembleSerial(); trajectories = Nsims, saveat = 1.0) + + t_points = 0:1.0:250.0 + max_direct_I = maximum([mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_adaptive_newton = maximum([mean(sol_adaptive_newton[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_adaptive_trap = maximum([mean(sol_adaptive_trap[i](t)[3] for i in 1:Nsims) for t in t_points]) + + @test isapprox(max_direct_I, max_adaptive_newton, rtol = 0.05) + @test isapprox(max_direct_I, max_adaptive_trap, rtol = 0.05) +end + +# SimpleAdaptiveTauLeaping integration tests (stiff system) +@testset "SimpleAdaptiveTauLeaping Integration Tests" begin + # Stiff system from Cao et al. (2007): S1 -> S2, S2 -> S1, S2 -> S3 + c = (1000.0, 1000.0, 1.0) + reactant_stoich1 = [Pair(1, 1)] + net_stoich1 = [Pair(1, -1), Pair(2, 1)] + reactant_stoich2 = [Pair(2, 1)] + net_stoich2 = [Pair(1, 1), Pair(2, -1)] + reactant_stoich3 = [Pair(2, 1)] + net_stoich3 = [Pair(2, -1), Pair(3, 1)] + stiff_jumps = MassActionJump([c[1], c[2], c[3]], + [reactant_stoich1, reactant_stoich2, reactant_stoich3], + [net_stoich1, net_stoich2, net_stoich3]) + + u0 = [100, 0, 0] + tspan = (0.0, 10.0) + prob = DiscreteProblem(u0, tspan) + jump_prob = JumpProblem(prob, PureLeaping(), stiff_jumps; rng = rng) + + sol = solve(jump_prob, SimpleAdaptiveTauLeaping(); saveat = 0.1) + @test sol.t[end] ≈ 10.0 atol = 1e-6 + + # Reversible isomerization from Cao et al. (2004) + c2 = (1000.0, 1000.0) + reactant_stoich_r1 = [Pair(1, 1)] + net_stoich_r1 = [Pair(1, -1), Pair(2, 1)] + reactant_stoich_r2 = [Pair(2, 1)] + net_stoich_r2 = [Pair(1, 1), Pair(2, -1)] + rev_jumps = MassActionJump([c2[1], c2[2]], + [reactant_stoich_r1, reactant_stoich_r2], + [net_stoich_r1, net_stoich_r2]) + + u0_rev = [1000, 0] + tspan_rev = (0.0, 0.1) + prob_rev = DiscreteProblem(u0_rev, tspan_rev) + jump_prob_rev = JumpProblem(prob_rev, PureLeaping(), rev_jumps; rng = rng) + + sol_newton = solve(jump_prob_rev, SimpleAdaptiveTauLeaping(epsilon = 0.05, + solver = NewtonImplicitSolver())) + sol_trap = solve(jump_prob_rev, SimpleAdaptiveTauLeaping(epsilon = 0.05, + solver = TrapezoidalImplicitSolver())) + @test sol_newton.t[end] ≈ 0.1 atol = 1e-6 + @test sol_trap.t[end] ≈ 0.1 atol = 1e-6 +end From 9ecbdc962471dfda75fc7aca34fd47510e4633a2 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sat, 6 Sep 2025 00:49:55 +0530 Subject: [PATCH 2/9] some --- src/simple_regular_solve.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 6add712d..0e157f72 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -590,11 +590,15 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; maj = jump_prob.massaction_jump numjumps = get_num_majumps(maj) - rate = (out, u, p, t) -> begin - for j in 1:numjumps - out[j] = evalrxrate(u, j, maj) + rj = jump_prob.regular_jump + # Extract rates + rate = rj !== nothing ? rj.rate : + (out, u, p, t) -> begin + for j in 1:numjumps + out[j] = evalrxrate(u, j, maj) + end end - end + c = rj !== nothing ? rj.c : nothing u0 = copy(prob.u0) tspan = prob.tspan p = prob.p From d957618f4b0eb99353b0a63e29f6b1749dcd0192 Mon Sep 17 00:00:00 2001 From: sivasathyaseeelan Date: Sat, 6 Sep 2025 19:40:29 +0530 Subject: [PATCH 3/9] added jacobian stifness condition --- src/JumpProcesses.jl | 4 +- src/simple_regular_solve.jl | 206 +++++++++++++++++++++--------------- 2 files changed, 123 insertions(+), 87 deletions(-) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 741015a3..15940112 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -4,8 +4,8 @@ using Reexport: Reexport, @reexport @reexport using DiffEqBase # Explicit imports from standard libraries -using LinearAlgebra: LinearAlgebra, mul! -using Random: Random, randexp, seed! +using LinearAlgebra: LinearAlgebra, I, mul!, eigvals +using Random: Random, randexp, randexp!, seed! # Explicit imports from external packages using DocStringExtensions: DocStringExtensions, FIELDS, TYPEDEF diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 0e157f72..fc5e9533 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -15,13 +15,17 @@ struct TrapezoidalImplicitSolver <: AbstractImplicitSolver end struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm epsilon::T # Error control parameter for tau selection solver::AbstractImplicitSolver # Solver type for implicit method + eigenvalue_check::Bool # Enable eigenvalue-based stiffness detection + stiffness_ratio_threshold::T # # Stiffness ratio threshold + implicit_epsilon_factor::T # Scaling factor for implicit tau-selection end -# Stiffness detection threshold is computed dynamically as epsilon * sum(u), where u is the current state, -# as inspired by Cao et al. (2007), Section III.B, which uses the ratio of propensity functions to identify -# fast and slow time scales in stiff systems, scaled by system size for robustness. -SimpleAdaptiveTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver()) = - SimpleAdaptiveTauLeaping(epsilon, solver) +# Stiffness detection uses a dynamic threshold epsilon * sum(u) for propensity ratios, +# as inspired by Cao et al. (2007), Section III.B. Optional eigenvalue-based check +# uses the Jacobian's eigenvalue ratio. implicit_epsilon_factor=10.0 relaxes tau-selection +# for implicit tau-leaping, per Cao et al. (2007), Section III.A. +SimpleAdaptiveTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver(), eigenvalue_check=false, stiffness_ratio_threshold=1e4, implicit_epsilon_factor=10.0) = + SimpleAdaptiveTauLeaping(epsilon, solver, eigenvalue_check, stiffness_ratio_threshold, implicit_epsilon_factor) function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @@ -436,9 +440,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end +# Compute highest order of reaction (HOR) +# Reference: Cao et al. (2006), J. Chem. Phys. 124, 044109, Section IV function compute_hor(reactant_stoch, numjumps) - # Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. - # HOR is the sum of stoichiometric coefficients of reactants in reaction j. hor = zeros(Int, numjumps) for j in 1:numjumps order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) @@ -450,16 +454,14 @@ function compute_hor(reactant_stoch, numjumps) return hor end +# Precompute max_hor and max_stoich for g_i calculation +# Reference: Cao et al. (2006), Section IV, equation (27) function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) - # Precompute reaction conditions for each species i, including: - # - max_hor: the highest order of reaction (HOR) where species i is a reactant. - # - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor. - # Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27). max_hor = zeros(Int, numspecies) max_stoich = zeros(Int, numspecies) for j in 1:numjumps for (spec_idx, stoch) in reactant_stoch[j] - if stoch > 0 # Species is a reactant + if stoch > 0 if hor[j] > max_hor[spec_idx] max_hor[spec_idx] = hor[j] max_stoich[spec_idx] = stoch @@ -472,20 +474,10 @@ function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjump return max_hor, max_stoich end +# Compute g_i to bound propensity changes +# Reference: Cao et al. (2006), Section IV, equation (27) function compute_gi(u, max_hor, max_stoich, i, t) - # Compute g_i for species i to bound the relative change in propensity functions, - # as per Cao et al. (2006), Section IV, equation (27). - # g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant: - # - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1 - # - HOR = 2 (second-order): - # - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2 - # - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1) - # - HOR = 3 (third-order): - # - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3 - # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1)) - # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) - # Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep. - if max_hor[i] == 0 # No reactions involve species i as a reactant + if max_hor[i] == 0 return 1.0 elseif max_hor[i] == 1 return 1.0 @@ -493,50 +485,65 @@ function compute_gi(u, max_hor, max_stoich, i, t) if max_stoich[i] == 1 return 2.0 elseif max_stoich[i] == 2 - return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 # Fallback to 2.0 if x_i <= 1 + return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 end elseif max_hor[i] == 3 if max_stoich[i] == 1 return 3.0 elseif max_stoich[i] == 2 - return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 # Fallback to 3.0 if x_i <= 1 + return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 elseif max_stoich[i] == 3 - return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 # Fallback to 3.0 if x_i <= 2 + return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 end end - return 1.0 # Default case + return 1.0 end +# Compute tau for explicit tau-leaping +# Reference: Cao et al. (2006), equation (8), using equations (9a) and (9b) function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) - # Compute the tau-leaping step-size using equation (20) from Cao et al. (2006): - # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } - # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): - # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) - # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). rate(rate_cache, u, p, t) - if all(==(0.0), rate_cache) # Handle case where all rates are zero + if all(==(0.0), rate_cache) return dtmin end tau = Inf for i in 1:length(u) - mu = zero(eltype(u)) - sigma2 = zero(eltype(u)) + mu = zero(eltype(rate_cache)) # Equation (9a): mu_i(x) = sum_j nu_ij * a_j(x) + sigma2 = zero(eltype(rate_cache)) # Equation (9b): sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) for j in 1:size(nu, 2) - mu += nu[i, j] * rate_cache[j] # Equation (9a) - sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b) + mu += nu[i, j] * rate_cache[j] + sigma2 += nu[i, j]^2 * rate_cache[j] + end + gi = compute_gi(u, max_hor, max_stoich, i, t) # Equation (27) + bound = max(epsilon * max(u[i], 0.0) / gi, 1.0) + mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) + sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) + tau = min(tau, mu_term, sigma_term) # Equation (8) + end + return max(tau, dtmin) +end + +# Compute tau for implicit tau-leaping with relaxed error control +# Reference: Cao et al. (2007), J. Chem. Phys. 126, 224101, Section III.A, using relaxed epsilon for larger steps +function compute_tau_implicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps, implicit_epsilon_factor) + tau_explicit = compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + u_predict = float.(copy(u)) # Initialize as Float64 to handle fractional updates + rate(rate_cache, u, p, t) + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + u_predict[spec_idx] += nu[spec_idx, j] * rate_cache[j] * tau_explicit end - gi = compute_gi(u, max_hor, max_stoich, i, t) - bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1) - mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) - sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) - tau = min(tau, mu_term, sigma_term) # Equation (8) end + u_predict = max.(u_predict, 0.0) + + relaxed_epsilon = epsilon * implicit_epsilon_factor + # Reuse compute_tau with predicted state, time, and relaxed epsilon + tau = compute_tau(u_predict, rate_cache, nu, hor, p, t + tau_explicit, relaxed_epsilon, rate, dtmin, max_hor, max_stoich, numjumps) return max(tau, dtmin) end # Define residual for implicit equation -# Newton: u_new = u_current + sum_j nu_j * a_j(u_new) * tau (Cao et al., 2004) -# Trapezoidal: u_new = u_current + sum_j nu_j * (a_j(u_current) + a_j(u_new))/2 * tau +# Reference: Cao et al. (2004), J. Chem. Phys. 121, 4059 function implicit_equation!(resid, u_new, params) u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params rate(rate_cache, u_new, p, t + tau) @@ -555,50 +562,73 @@ function implicit_equation!(resid, u_new, params) resid .= max.(resid, -u_new) # Ensure non-negative solution end -# Solve implicit equation using SimpleNonlinearSolve +# Solve implicit equation function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver) - u_new = convert(Vector{Float64}, u_current) + u_new = float.(copy(u_current)) prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)) sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6) return sol.u, sol.retcode == ReturnCode.Success end -# Stiffness detection based on propensity ratio scaled by system size -# Reference: Cao et al. (2007), Section III.B, using ratio of propensity functions to detect stiffness -function is_stiff(rate_cache, u, epsilon) +# Compute Jacobian for eigenvalue-based stiffness detection +# Reference: Cao et al. (2007), Section III.B +function compute_jacobian(u, rate, numjumps, numspecies, p, t) + J = zeros(numjumps, numspecies) + rate_cache = zeros(numjumps) + rate(rate_cache, u, p, t) + h = 1e-6 + for i in 1:numspecies + u_plus = float.(copy(u)) + u_plus[i] += h + rate_plus = zeros(numjumps) + rate(rate_plus, u_plus, p, t) + for j in 1:numjumps + J[j, i] = (rate_plus[j] - rate_cache[j]) / h + end + end + return J +end + +# Stiffness detection using propensity ratio or eigenvalues +# Reference: Cao et al. (2007), Section III.B +function is_stiff(rate_cache, u, epsilon, eigenvalue_check, stiffness_ratio_threshold, p, t, rate, numjumps, numspecies) non_zero_rates = [rate for rate in rate_cache if rate > 0] if length(non_zero_rates) <= 1 - return false # Use explicit method if no or only one non-zero propensity + return false + end + if eigenvalue_check + J = compute_jacobian(u, rate, numjumps, numspecies, p, t) + eigvals = real.(LinearAlgebra.eigvals(J)) + non_zero_eigvals = [abs(λ) for λ in eigvals if abs(λ) > 1e-10] + if length(non_zero_eigvals) <= 1 + return false + end + max_eig = maximum(non_zero_eigvals) + min_eig = minimum(non_zero_eigvals) + return max_eig / min_eig > stiffness_ratio_threshold # Stiffness ratio threshold, Petzold (1983), SIAM J. Sci. Stat. Comput. 4(1), 136–148 + else + max_rate = maximum(non_zero_rates) + min_rate = minimum(non_zero_rates) + threshold = epsilon * sum(u) + return max_rate / min_rate > threshold # Propensity ratio threshold, Cao et al. (2007), J. Chem. Phys. 126, 224101, Section III.B end - max_rate = maximum(non_zero_rates) - min_rate = minimum(non_zero_rates) - threshold = epsilon * sum(u) # Dynamic threshold based on system size - return max_rate / min_rate > threshold end # Adaptive tau-leaping solver -# Reference: Cao et al. (2007) for adaptive strategy, Cao et al. (2004) for implicit method, Cao et al. (2006) for tau selection -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; - seed = nothing, - dtmin = 1e-10, - saveat = nothing) - validate_pure_leaping_inputs(jump_prob, alg) || - error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") +# Reference: Cao et al. (2007), Cao et al. (2004), Cao et al. (2006) +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing, dtmin=1e-10, saveat=nothing) + validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") @unpack prob, rng = jump_prob (seed !== nothing) && seed!(rng, seed) maj = jump_prob.massaction_jump numjumps = get_num_majumps(maj) - rj = jump_prob.regular_jump - # Extract rates - rate = rj !== nothing ? rj.rate : - (out, u, p, t) -> begin - for j in 1:numjumps - out[j] = evalrxrate(u, j, maj) - end + rate = (out, u, p, t) -> begin + for j in 1:numjumps + out[j] = evalrxrate(u, j, maj) end - c = rj !== nothing ? rj.c : nothing + end u0 = copy(prob.u0) tspan = prob.tspan p = prob.p @@ -613,6 +643,9 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; t_end = tspan[2] epsilon = alg.epsilon solver = alg.solver + eigenvalue_check = alg.eigenvalue_check + stiffness_ratio_threshold = alg.stiffness_ratio_threshold + implicit_epsilon_factor = alg.implicit_epsilon_factor nu = zeros(Int64, length(u0), numjumps) for j in 1:numjumps @@ -623,6 +656,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; reactant_stoch = maj.reactant_stoch hor = compute_hor(reactant_stoch, numjumps) max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) + numspecies = length(u0) saveat_times = isnothing(saveat) ? Vector{Float64}() : (saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat)) @@ -630,38 +664,40 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; while t_current < t_end rate(rate_cache, u_current, p, t_current) - tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + use_implicit = is_stiff(rate_cache, u_current, epsilon, eigenvalue_check, stiffness_ratio_threshold, p, t_current, rate, numjumps, numspecies) + tau = use_implicit ? + compute_tau_implicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps, implicit_epsilon_factor) : + compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) tau = min(tau, t_end - t_current) if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] tau = saveat_times[save_idx] - t_current end - # Stiffness detection using dynamic propensity ratio threshold - use_implicit = is_stiff(rate_cache, u_current, epsilon) - if use_implicit - # Implicit tau-leaping (Cao et al., 2004) u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, t_current, tau, rate, numjumps, solver) if !converged tau /= 2 continue end rate(rate_cache, u_new_float, p, t_current + tau) - counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Cao et al. (2004): Poisson sampling + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Cao et al. (2004) du .= zero(eltype(u_current)) for j in 1:numjumps - for (spec_idx, stoch) in maj.net_stoch[j] - du[spec_idx] += stoch * counts[j] + for spec_idx in 1:size(nu, 1) + if nu[spec_idx, j] != 0 + du[spec_idx] += nu[spec_idx, j] * counts[j] + end end end - u_new = u_current + du # Cao et al. (2004): Final state update + u_new = u_current + du # Cao et al. (2004) else - # Explicit tau-leaping (Cao et al., 2006) - counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Equation (8): Poisson sampling + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Cao et al. (2006), equation (8) du .= zero(eltype(u_current)) for j in 1:numjumps - for (spec_idx, stoch) in maj.net_stoch[j] - du[spec_idx] += stoch * counts[j] + for spec_idx in 1:size(nu, 1) + if nu[spec_idx, j] != 0 + du[spec_idx] += nu[spec_idx, j] * counts[j] + end end end u_new = u_current + du @@ -674,7 +710,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; t_new = t_current + tau if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) - push!(usave, copy(ceil.(u_new))) # Ensure integer solutions for molecular populations + push!(usave, copy(u_new)) # Ensure integer solutions push!(tsave, t_new) if !isempty(saveat_times) && t_new >= saveat_times[save_idx] save_idx += 1 From cb9a67e9a542466c4e1dbdf64c7e46c7ce565cd5 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N Date: Wed, 11 Mar 2026 01:06:26 +0530 Subject: [PATCH 4/9] refactor --- Project.toml | 4 +- src/simple_regular_solve.jl | 328 ++++++++++++++++++------------------ test/regular_jumps.jl | 312 ++++++++++++++++++---------------- 3 files changed, 337 insertions(+), 307 deletions(-) diff --git a/Project.toml b/Project.toml index d2ed6e0c..899eb285 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas "] version = "9.25.1" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" @@ -67,7 +68,6 @@ Test = "1" julia = "1.10" [extras] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" @@ -82,4 +82,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ADTypes", "Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"] +test = ["Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"] diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index fc5e9533..2747212c 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -440,13 +440,16 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end -# Compute highest order of reaction (HOR) -# Reference: Cao et al. (2006), J. Chem. Phys. 124, 044109, Section IV +# Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. +# HOR is the sum of stoichiometric coefficients of reactants in reaction j. +# Extract the element type from reactant_stoch to avoid hardcoding type assumptions. function compute_hor(reactant_stoch, numjumps) - hor = zeros(Int, numjumps) + stoch_type = eltype(first(first(reactant_stoch))) + hor = zeros(stoch_type, numjumps) + max_order = 3 * one(stoch_type) # Maximum supported reaction order (type-aware) for j in 1:numjumps - order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) - if order > 3 + order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=zero(stoch_type)) + if order > max_order error("Reaction $j has order $order, which is not supported (maximum order is 3).") end hor[j] = order @@ -454,14 +457,17 @@ function compute_hor(reactant_stoch, numjumps) return hor end -# Precompute max_hor and max_stoich for g_i calculation -# Reference: Cao et al. (2006), Section IV, equation (27) +# Precompute reaction conditions for each species i, including: +# - max_hor: the highest order of reaction (HOR) where species i is a reactant. +# - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor. +# Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27). function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) - max_hor = zeros(Int, numspecies) - max_stoich = zeros(Int, numspecies) + hor_type = eltype(hor) + max_hor = zeros(hor_type, numspecies) + max_stoich = zeros(hor_type, numspecies) for j in 1:numjumps for (spec_idx, stoch) in reactant_stoch[j] - if stoch > 0 + if stoch > 0 # Species is a reactant if hor[j] > max_hor[spec_idx] max_hor[spec_idx] = hor[j] max_stoich[spec_idx] = stoch @@ -474,117 +480,80 @@ function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjump return max_hor, max_stoich end -# Compute g_i to bound propensity changes -# Reference: Cao et al. (2006), Section IV, equation (27) -function compute_gi(u, max_hor, max_stoich, i, t) - if max_hor[i] == 0 - return 1.0 - elseif max_hor[i] == 1 - return 1.0 - elseif max_hor[i] == 2 - if max_stoich[i] == 1 - return 2.0 - elseif max_stoich[i] == 2 - return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 - end - elseif max_hor[i] == 3 - if max_stoich[i] == 1 - return 3.0 - elseif max_stoich[i] == 2 - return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 - elseif max_stoich[i] == 3 - return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 - end - end - return 1.0 -end - -# Compute tau for explicit tau-leaping -# Reference: Cao et al. (2006), equation (8), using equations (9a) and (9b) -function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) - rate(rate_cache, u, p, t) - if all(==(0.0), rate_cache) - return dtmin - end - tau = Inf - for i in 1:length(u) - mu = zero(eltype(rate_cache)) # Equation (9a): mu_i(x) = sum_j nu_ij * a_j(x) - sigma2 = zero(eltype(rate_cache)) # Equation (9b): sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) - for j in 1:size(nu, 2) - mu += nu[i, j] * rate_cache[j] - sigma2 += nu[i, j]^2 * rate_cache[j] - end - gi = compute_gi(u, max_hor, max_stoich, i, t) # Equation (27) - bound = max(epsilon * max(u[i], 0.0) / gi, 1.0) - mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) - sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) - tau = min(tau, mu_term, sigma_term) # Equation (8) - end - return max(tau, dtmin) -end # Compute tau for implicit tau-leaping with relaxed error control -# Reference: Cao et al. (2007), J. Chem. Phys. 126, 224101, Section III.A, using relaxed epsilon for larger steps -function compute_tau_implicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps, implicit_epsilon_factor) - tau_explicit = compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) - u_predict = float.(copy(u)) # Initialize as Float64 to handle fractional updates +# Reference: Cao et al. (2007), J. Chem. Phys. 126, 224101, Section III.A +function compute_tau_implicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, + max_hor, max_stoich, numjumps, implicit_epsilon_factor) + tau_explicit = compute_tau( + u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + u_predict = float.(u) rate(rate_cache, u, p, t) for j in 1:numjumps for spec_idx in 1:size(nu, 1) u_predict[spec_idx] += nu[spec_idx, j] * rate_cache[j] * tau_explicit end end - u_predict = max.(u_predict, 0.0) - + u_predict .= max.(u_predict, zero(eltype(u_predict))) relaxed_epsilon = epsilon * implicit_epsilon_factor - # Reuse compute_tau with predicted state, time, and relaxed epsilon - tau = compute_tau(u_predict, rate_cache, nu, hor, p, t + tau_explicit, relaxed_epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + tau = compute_tau(u_predict, rate_cache, nu, hor, p, t + tau_explicit, + relaxed_epsilon, rate, dtmin, max_hor, max_stoich, numjumps) return max(tau, dtmin) end # Define residual for implicit equation -# Reference: Cao et al. (2004), J. Chem. Phys. 121, 4059 +# Newton: u_new = u_current + sum_j nu_j * a_j(u_new) * tau (Cao et al., 2004) +# Trapezoidal: u_new = u_current + sum_j nu_j * (a_j(u_current) + a_j(u_new))/2 * tau function implicit_equation!(resid, u_new, params) u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params rate(rate_cache, u_new, p, t + tau) resid .= u_new .- u_current - for j in 1:numjumps - for spec_idx in 1:size(nu, 1) - if isa(solver, NewtonImplicitSolver) + if isa(solver, NewtonImplicitSolver) + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004) - else # TrapezoidalImplicitSolver - rate_current = similar(rate_cache) - rate(rate_current, u_current, p, t) - resid[spec_idx] -= nu[spec_idx, j] * 0.5 * (rate_cache[j] + rate_current[j]) * tau + end + end + else # TrapezoidalImplicitSolver + rate_current = similar(rate_cache) + rate(rate_current, u_current, p, t) + half = one(eltype(rate_cache)) / 2 + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + resid[spec_idx] -= nu[spec_idx, j] * half * (rate_cache[j] + rate_current[j]) * tau end end end resid .= max.(resid, -u_new) # Ensure non-negative solution end -# Solve implicit equation +# Solve implicit equation using SimpleNonlinearSolve function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver) - u_new = float.(copy(u_current)) + u_new = convert(Vector{float(eltype(u_current))}, u_current) prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)) sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6) return sol.u, sol.retcode == ReturnCode.Success end + # Compute Jacobian for eigenvalue-based stiffness detection # Reference: Cao et al. (2007), Section III.B function compute_jacobian(u, rate, numjumps, numspecies, p, t) - J = zeros(numjumps, numspecies) - rate_cache = zeros(numjumps) + T = float(eltype(u)) + sqrteps = sqrt(eps(T)) + J = zeros(T, numjumps, numspecies) + rate_cache = zeros(T, numjumps) rate(rate_cache, u, p, t) - h = 1e-6 + rate_plus = zeros(T, numjumps) + u_plus = float.(copy(u)) for i in 1:numspecies - u_plus = float.(copy(u)) - u_plus[i] += h - rate_plus = zeros(numjumps) + h_i = sqrteps * max(abs(u[i]), one(T)) + u_plus[i] = u[i] + h_i rate(rate_plus, u_plus, p, t) for j in 1:numjumps - J[j, i] = (rate_plus[j] - rate_cache[j]) / h + J[j, i] = (rate_plus[j] - rate_cache[j]) / h_i end + u_plus[i] = u[i] end return J end @@ -614,116 +583,155 @@ function is_stiff(rate_cache, u, epsilon, eigenvalue_check, stiffness_ratio_thre end end -# Adaptive tau-leaping solver -# Reference: Cao et al. (2007), Cao et al. (2004), Cao et al. (2006) -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing, dtmin=1e-10, saveat=nothing) - validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") - - @unpack prob, rng = jump_prob - (seed !== nothing) && seed!(rng, seed) - - maj = jump_prob.massaction_jump - numjumps = get_num_majumps(maj) - rate = (out, u, p, t) -> begin - for j in 1:numjumps - out[j] = evalrxrate(u, j, maj) - end - end - u0 = copy(prob.u0) - tspan = prob.tspan - p = prob.p - - u_current = copy(u0) - t_current = tspan[1] - usave = [copy(u0)] - tsave = [tspan[1]] - rate_cache = zeros(Float64, numjumps) - counts = zeros(Int64, numjumps) - du = similar(u0) - t_end = tspan[2] - epsilon = alg.epsilon - solver = alg.solver - eigenvalue_check = alg.eigenvalue_check - stiffness_ratio_threshold = alg.stiffness_ratio_threshold - implicit_epsilon_factor = alg.implicit_epsilon_factor - - nu = zeros(Int64, length(u0), numjumps) - for j in 1:numjumps - for (spec_idx, stoch) in maj.net_stoch[j] - nu[spec_idx, j] = stoch - end - end - reactant_stoch = maj.reactant_stoch - hor = compute_hor(reactant_stoch, numjumps) - max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) - numspecies = length(u0) - - saveat_times = isnothing(saveat) ? Vector{Float64}() : - (saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat)) +function simple_adaptive_tau_leaping_loop!( + prob, alg, u_current, u_new, t_current, t_end, p, rng, + rate, nu, hor, max_hor, max_stoich, numjumps, numspecies, epsilon, + dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj, + solver, eigenvalue_check, stiffness_ratio_threshold, implicit_epsilon_factor, + save_end) save_idx = 1 while t_current < t_end rate(rate_cache, u_current, p, t_current) - use_implicit = is_stiff(rate_cache, u_current, epsilon, eigenvalue_check, stiffness_ratio_threshold, p, t_current, rate, numjumps, numspecies) - tau = use_implicit ? - compute_tau_implicit(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps, implicit_epsilon_factor) : - compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + if all(<=(0), rate_cache) + t_current = t_end + break + end + use_implicit = is_stiff(rate_cache, u_current, epsilon, eigenvalue_check, + stiffness_ratio_threshold, p, t_current, rate, numjumps, numspecies) + tau = if use_implicit + compute_tau_implicit(u_current, rate_cache, nu, hor, p, t_current, + epsilon, rate, dtmin, max_hor, max_stoich, numjumps, + implicit_epsilon_factor) + else + compute_tau(u_current, rate_cache, nu, hor, p, t_current, + epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + end tau = min(tau, t_end - t_current) - if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] + if !isempty(saveat_times) && save_idx <= length(saveat_times) && + t_current + tau > saveat_times[save_idx] tau = saveat_times[save_idx] - t_current end if use_implicit - u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, t_current, tau, rate, numjumps, solver) + u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, + t_current, tau, rate, numjumps, solver) if !converged tau /= 2 continue end rate(rate_cache, u_new_float, p, t_current + tau) - counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Cao et al. (2004) - du .= zero(eltype(u_current)) - for j in 1:numjumps - for spec_idx in 1:size(nu, 1) - if nu[spec_idx, j] != 0 - du[spec_idx] += nu[spec_idx, j] * counts[j] - end - end + end + + rate_effective .= max.(rate_cache .* tau, zero(eltype(rate_cache))) + for j in eachindex(counts) + if rate_effective[j] <= zero(eltype(rate_effective)) + counts[j] = zero(eltype(counts)) + else + counts[j] = pois_rand(rng, rate_effective[j]) end - u_new = u_current + du # Cao et al. (2004) - else - counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) # Cao et al. (2006), equation (8) - du .= zero(eltype(u_current)) - for j in 1:numjumps - for spec_idx in 1:size(nu, 1) - if nu[spec_idx, j] != 0 - du[spec_idx] += nu[spec_idx, j] * counts[j] - end - end + end + du .= zero(eltype(du)) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + du[spec_idx] += stoch * counts[j] end - u_new = u_current + du end - + u_new .= u_current .+ du if any(<(0), u_new) tau /= 2 continue end t_new = t_current + tau - if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) - push!(usave, copy(u_new)) # Ensure integer solutions + if isempty(saveat_times) || + (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) + push!(usave, copy(u_new)) push!(tsave, t_new) if !isempty(saveat_times) && t_new >= saveat_times[save_idx] save_idx += 1 end end - u_current = u_new + u_current .= u_new t_current = t_new end + if save_end && (isempty(tsave) || tsave[end] != t_end) + push!(usave, copy(u_current)) + push!(tsave, t_end) + end +end + +# Adaptive tau-leaping solver +# Reference: Cao et al. (2007), Cao et al. (2004), Cao et al. (2006) +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; + seed = nothing, + dtmin = nothing, + saveat = nothing, save_start = nothing, save_end = nothing) + validate_pure_leaping_inputs(jump_prob, alg) || + error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") + + (; prob, rng) = jump_prob + (seed !== nothing) && seed!(rng, seed) + + maj = jump_prob.massaction_jump + numjumps = get_num_majumps(maj) + rj = jump_prob.regular_jump + rate = rj !== nothing ? rj.rate : massaction_rate(maj, numjumps) + u0 = copy(prob.u0) + tspan = prob.tspan + p = prob.p + + if dtmin === nothing + dtmin = 1e-10 * one(typeof(tspan[2])) + end + + saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end) + + u_current = copy(u0) + u_new = similar(u0) + t_current = tspan[1] + if save_start + usave = [copy(u0)] + tsave = [tspan[1]] + else + usave = typeof(u0)[] + tsave = typeof(tspan[1])[] + end + rate_cache = zeros(float(eltype(u0)), numjumps) + rate_effective = similar(rate_cache) + counts = zero(rate_cache) + du = similar(u0) + t_end = tspan[2] + epsilon = alg.epsilon + solver = alg.solver + eigenvalue_check = alg.eigenvalue_check + stiffness_ratio_threshold = alg.stiffness_ratio_threshold + implicit_epsilon_factor = alg.implicit_epsilon_factor + numspecies = length(u0) + + nu = zeros(float(eltype(u0)), length(u0), numjumps) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + nu[spec_idx, j] = stoch + end + end + reactant_stoch = maj.reactant_stoch + hor = compute_hor(reactant_stoch, numjumps) + max_hor, max_stoich = precompute_reaction_conditions( + reactant_stoch, hor, numspecies, numjumps) + + simple_adaptive_tau_leaping_loop!( + prob, alg, u_current, u_new, t_current, t_end, p, rng, + rate, nu, hor, max_hor, max_stoich, numjumps, numspecies, epsilon, + dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj, + solver, eigenvalue_check, stiffness_ratio_threshold, implicit_epsilon_factor, + save_end) + sol = DiffEqBase.build_solution(prob, alg, tsave, usave, - calculate_error=false, - interp=DiffEqBase.ConstantInterpolation(tsave, usave)) + calculate_error = false, + interp = DiffEqBase.ConstantInterpolation(tsave, usave)) return sol end diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 6338c835..ae67672a 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -11,7 +11,7 @@ function compute_mean_at_saves(sol, Nsims, npts, species_idx) mean_vals = zeros(npts) for i in 1:Nsims for j in 1:npts - mean_vals[j] += sol.u[i].u[j][species_idx] + mean_vals[j] += sol[i].u[j][species_idx] end end mean_vals ./= Nsims @@ -72,14 +72,48 @@ end sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims, saveat = t_compare) - # Compute mean I trajectories via direct indexing (I is index 2 in SIR) - mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 2) - mean_I_simple = compute_mean_at_saves(sol_simple, Nsims, npts, 2) - mean_I_explicit = compute_mean_at_saves(sol_adaptive, Nsims, npts, 2) + # Solve with SimpleAdaptiveTauLeaping (Newton) + sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(solver=NewtonImplicitSolver()), EnsembleSerial(); trajectories=Nsims, saveat=5.0) - # Compare full mean trajectories across all saved timepoints - @test all(isapprox.(mean_I_direct, mean_I_simple, rtol = 0.1)) - @test all(isapprox.(mean_I_direct, mean_I_explicit, rtol = 0.1)) + # Solve with SimpleAdaptiveTauLeaping (Trapezoidal) + sol_implicit_trap = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(solver=TrapezoidalImplicitSolver()), EnsembleSerial(); trajectories=Nsims, saveat=5.0) + + # Simple test: Check that all solvers completed successfully and have reasonable output + @test length(sol_direct) == Nsims + @test length(sol_simple) == Nsims + @test length(sol_adaptive) == Nsims + @test length(sol_implicit_newton) == Nsims + @test length(sol_implicit_trap) == Nsims + + # Check that final times match expected tspan + @test sol_direct[1].t[end] ≈ 250.0 atol=1.0 + @test sol_simple[1].t[end] ≈ 250.0 atol=1.0 + @test sol_adaptive[1].t[end] ≈ 250.0 atol=1.0 + @test sol_implicit_newton[1].t[end] ≈ 250.0 atol=1.0 + @test sol_implicit_trap[1].t[end] ≈ 250.0 atol=1.0 + + # Sample at key time points (0, 50, 100, 150, 200, 250) + t_sample = [0.0, 50.0, 100.0, 150.0, 200.0, 250.0] + + # Compute mean I at sample times for each method + mean_I_direct = [mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_sample] + mean_I_simple = [mean(sol_simple[i](t)[2] for i in 1:Nsims) for t in t_sample] + mean_I_explicit = [mean(sol_adaptive[i](t)[2] for i in 1:Nsims) for t in t_sample] + mean_I_implicit_newton = [mean(sol_implicit_newton[i](t)[2] for i in 1:Nsims) for t in t_sample] + mean_I_implicit_trap = [mean(sol_implicit_trap[i](t)[2] for i in 1:Nsims) for t in t_sample] + + # Check that mean infected values are in reasonable range (0 to population size) + @test all(0 ≤ m ≤ 1000 for m in mean_I_direct) + @test all(0 ≤ m ≤ 1000 for m in mean_I_simple) + @test all(0 ≤ m ≤ 1000 for m in mean_I_explicit) + @test all(0 ≤ m ≤ 1000 for m in mean_I_implicit_newton) + @test all(0 ≤ m ≤ 1000 for m in mean_I_implicit_trap) + + # Check that all methods produce similar dynamics (loose tolerance) + @test isapprox(mean_I_direct[3], mean_I_simple[3], rtol=0.05) # Compare at t=100 + @test isapprox(mean_I_direct[3], mean_I_explicit[3], rtol=0.05) + @test isapprox(mean_I_direct[3], mean_I_implicit_newton[3], rtol=0.05) + @test isapprox(mean_I_direct[3], mean_I_implicit_trap[3], rtol=0.05) end # SEIR model with exposed compartment @@ -138,14 +172,48 @@ end sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims, saveat = t_compare) - # Compute mean I trajectories via direct indexing (I is index 3 in SEIR) - mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 3) - mean_I_simple = compute_mean_at_saves(sol_simple, Nsims, npts, 3) - mean_I_explicit = compute_mean_at_saves(sol_adaptive, Nsims, npts, 3) + # Solve with SimpleAdaptiveTauLeaping (Newton) + sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(solver=NewtonImplicitSolver()), EnsembleSerial(); trajectories=Nsims, saveat=5.0) + + # Solve with SimpleAdaptiveTauLeaping (Trapezoidal) + sol_implicit_trap = solve(EnsembleProblem(jump_prob_maj), SimpleAdaptiveTauLeaping(solver=TrapezoidalImplicitSolver()), EnsembleSerial(); trajectories=Nsims, saveat=5.0) - # Compare full mean trajectories across all saved timepoints - @test all(isapprox.(mean_I_direct, mean_I_simple, rtol = 0.1)) - @test all(isapprox.(mean_I_direct, mean_I_explicit, rtol = 0.1)) + # Simple test: Check that all solvers completed successfully and have reasonable output + @test length(sol_direct) == Nsims + @test length(sol_simple) == Nsims + @test length(sol_adaptive) == Nsims + @test length(sol_implicit_newton) == Nsims + @test length(sol_implicit_trap) == Nsims + + # Check that final times match expected tspan + @test sol_direct[1].t[end] ≈ 250.0 atol=1.0 + @test sol_simple[1].t[end] ≈ 250.0 atol=1.0 + @test sol_adaptive[1].t[end] ≈ 250.0 atol=1.0 + @test sol_implicit_newton[1].t[end] ≈ 250.0 atol=1.0 + @test sol_implicit_trap[1].t[end] ≈ 250.0 atol=1.0 + + # Sample at key time points (0, 50, 100, 150, 200, 250) + t_sample = [0.0, 50.0, 100.0, 150.0, 200.0, 250.0] + + # Compute mean I at sample times for each method (I is index 3 in SEIR) + mean_I_direct = [mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_sample] + mean_I_simple = [mean(sol_simple[i](t)[3] for i in 1:Nsims) for t in t_sample] + mean_I_explicit = [mean(sol_adaptive[i](t)[3] for i in 1:Nsims) for t in t_sample] + mean_I_implicit_newton = [mean(sol_implicit_newton[i](t)[3] for i in 1:Nsims) for t in t_sample] + mean_I_implicit_trap = [mean(sol_implicit_trap[i](t)[3] for i in 1:Nsims) for t in t_sample] + + # Check that mean infected values are in reasonable range (0 to population size) + @test all(0 ≤ m ≤ 1000 for m in mean_I_direct) + @test all(0 ≤ m ≤ 1000 for m in mean_I_simple) + @test all(0 ≤ m ≤ 1000 for m in mean_I_explicit) + @test all(0 ≤ m ≤ 1000 for m in mean_I_implicit_newton) + @test all(0 ≤ m ≤ 1000 for m in mean_I_implicit_trap) + + # Check that all methods produce similar dynamics (loose tolerance) + @test isapprox(mean_I_direct[3], mean_I_simple[3], rtol=0.05) # Compare at t=100 + @test isapprox(mean_I_direct[3], mean_I_explicit[3], rtol=0.05) + @test isapprox(mean_I_direct[3], mean_I_implicit_newton[3], rtol=0.05) + @test isapprox(mean_I_direct[3], mean_I_implicit_trap[3], rtol=0.05) end # Test zero-rate case for SimpleExplicitTauLeaping @@ -255,6 +323,63 @@ end @test jp_params.massaction_jump.scaled_rates == scaled_rates end +# Test implicit solvers on stiff system +@testset "Stiff System with Adaptive Tau-Leaping" begin + # Example system from Cao et al. (2007) + # Reactions: S1 -> S2, S2 -> S1, S2 -> S3 + # Rate constants + c = (1000.0, 1000.0, 1.0) + + # Define MassActionJump + # Reaction 1: S1 -> S2 + reactant_stoich1 = [Pair(1, 1)] # S1 consumed + net_stoich1 = [Pair(1, -1), Pair(2, 1)] # S1 -1, S2 +1 + # Reaction 2: S2 -> S1 + reactant_stoich2 = [Pair(2, 1)] # S2 consumed + net_stoich2 = [Pair(1, 1), Pair(2, -1)] # S1 +1, S2 -1 + # Reaction 3: S2 -> S3 + reactant_stoich3 = [Pair(2, 1)] # S2 consumed + net_stoich3 = [Pair(2, -1), Pair(3, 1)] # S2 -1, S3 +1 + + maj = MassActionJump([c[1], c[2], c[3]], [reactant_stoich1, reactant_stoich2, reactant_stoich3], + [net_stoich1, net_stoich2, net_stoich3]) + + u0 = [100, 0, 0] # Initial: S1=100, S2=0, S3=0 + tspan = (0.0, 1.0) + prob = DiscreteProblem(u0, tspan) + jump_prob = JumpProblem(prob, PureLeaping(), maj; rng=rng) + + # Solve with SimpleExplicitTauLeaping + sol_explicit = solve(jump_prob, SimpleExplicitTauLeaping(); dtmin=1e-6, saveat=0.1) + + # Solve with SimpleAdaptiveTauLeaping (Newton) - should handle stiffness better + sol_implicit_newton = solve(jump_prob, SimpleAdaptiveTauLeaping(solver=NewtonImplicitSolver()); dtmin=1e-6, saveat=0.1) + + # Solve with SimpleAdaptiveTauLeaping (Trapezoidal) + sol_implicit_trap = solve(jump_prob, SimpleAdaptiveTauLeaping(solver=TrapezoidalImplicitSolver()); dtmin=1e-6, saveat=0.1) + + # Check that all solvers completed successfully + @test sol_explicit.t[end] ≈ 1.0 atol=1e-3 + @test sol_implicit_newton.t[end] ≈ 1.0 atol=1e-3 + @test sol_implicit_trap.t[end] ≈ 1.0 atol=1e-3 + + # Check conservation: S1 + S2 + S3 should equal initial total + @test all(sum(u) ≈ 100 for u in sol_explicit.u) + @test all(sum(u) ≈ 100 for u in sol_implicit_newton.u) + @test all(sum(u) ≈ 100 for u in sol_implicit_trap.u) + + # Check that solutions are non-negative + @test all(all(x >= 0 for x in u) for u in sol_explicit.u) + @test all(all(x >= 0 for x in u) for u in sol_implicit_newton.u) + @test all(all(x >= 0 for x in u) for u in sol_implicit_trap.u) + + # For stiff system with fast equilibration between S1 and S2, + # S3 should increase monotonically + @test sol_explicit.u[end][3] >= sol_explicit.u[1][3] + @test sol_implicit_newton.u[end][3] >= sol_implicit_newton.u[1][3] + @test sol_implicit_trap.u[end][3] >= sol_implicit_trap.u[1][3] +end + # Test that saveat/save_start/save_end control which times are stored in solutions @testset "Saving Controls" begin # Simple birth process for testing SSAStepper save behavior @@ -353,141 +478,38 @@ end sol = solve(jp_explicit, SimpleExplicitTauLeaping(); saveat = [2.0, 8.0], save_start = true, save_end = true) @test sol.t == [0.0, 2.0, 8.0, 10.0] -end - -# SimpleAdaptiveTauLeaping correctness - SIR model -@testset "SimpleAdaptiveTauLeaping SIR Correctness" begin - β = 0.1 / 1000.0 - ν = 0.01 - influx_rate = 1.0 - p = (β, ν, influx_rate) - - rate1(u, p, t) = p[1] * u[1] * u[2] - rate2(u, p, t) = p[2] * u[2] - rate3(u, p, t) = p[3] - affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) - affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) - affect3!(integrator) = (integrator.u[1] += 1; nothing) - jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), - ConstantRateJump(rate3, affect3!)) - u0 = [999.0, 10.0, 0.0] - tspan = (0.0, 250.0) - prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) - - sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); - trajectories = Nsims, saveat = 1.0) + # --- SimpleAdaptiveTauLeaping save_start/save_end/saveat tests --- + u0_adaptive = [100.0] + prob_adaptive = DiscreteProblem(u0_adaptive, tspan) + reactant_stoich_adaptive = [[1 => 1]] + net_stoich_adaptive = [[1 => -1]] + maj_adaptive = MassActionJump([0.1], reactant_stoich_adaptive, net_stoich_adaptive) + jp_adaptive = JumpProblem(prob_adaptive, PureLeaping(), maj_adaptive; rng) - reactant_stoich = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] - net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [1=>1]] - param_idxs = [1, 2, 3] - maj = MassActionJump(reactant_stoich, net_stoich; param_idxs = param_idxs) - jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng = rng) - - sol_adaptive_newton = solve(EnsembleProblem(jump_prob_maj), - SimpleAdaptiveTauLeaping(solver = NewtonImplicitSolver()), - EnsembleSerial(); trajectories = Nsims, saveat = 1.0) - sol_adaptive_trap = solve(EnsembleProblem(jump_prob_maj), - SimpleAdaptiveTauLeaping(solver = TrapezoidalImplicitSolver()), - EnsembleSerial(); trajectories = Nsims, saveat = 1.0) - - t_points = 0:1.0:250.0 - max_direct_I = maximum([mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points]) - max_adaptive_newton = maximum([mean(sol_adaptive_newton[i](t)[2] for i in 1:Nsims) for t in t_points]) - max_adaptive_trap = maximum([mean(sol_adaptive_trap[i](t)[2] for i in 1:Nsims) for t in t_points]) - - @test isapprox(max_direct_I, max_adaptive_newton, rtol = 0.05) - @test isapprox(max_direct_I, max_adaptive_trap, rtol = 0.05) -end - -# SimpleAdaptiveTauLeaping correctness - SEIR model -@testset "SimpleAdaptiveTauLeaping SEIR Correctness" begin - β = 0.3 / 1000.0 - σ = 0.2 - ν = 0.01 - p = (β, σ, ν) + # saveat as Number: defaults save_start=true, save_end=true + sol = solve(jp_adaptive, SimpleAdaptiveTauLeaping(); saveat = 2.0) + @test sol.t == collect(0.0:2.0:10.0) - rate1(u, p, t) = p[1] * u[1] * u[3] - rate2(u, p, t) = p[2] * u[2] - rate3(u, p, t) = p[3] * u[3] - affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) - affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) - affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing) - jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), - ConstantRateJump(rate3, affect3!)) + # saveat as Number + save_start=false + save_end=false + sol = solve(jp_adaptive, SimpleAdaptiveTauLeaping(); saveat = 2.0, + save_start = false, save_end = false) + @test sol.t == collect(2.0:2.0:8.0) - u0 = [999.0, 0.0, 10.0, 0.0] - tspan = (0.0, 250.0) - prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) + # saveat collection including endpoints + sol = solve(jp_adaptive, SimpleAdaptiveTauLeaping(); saveat = [0.0, 5.0, 10.0]) + @test sol.t == [0.0, 5.0, 10.0] - sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); - trajectories = Nsims, saveat = 1.0) + # saveat collection without endpoints + explicit save_start=true, save_end=true + sol = solve(jp_adaptive, SimpleAdaptiveTauLeaping(); saveat = [2.0, 8.0], + save_start = true, save_end = true) + @test sol.t == [0.0, 2.0, 8.0, 10.0] - reactant_stoich = [[1=>1, 3=>1], [2=>1], [3=>1]] - net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [3=>-1, 4=>1]] - param_idxs = [1, 2, 3] - maj = MassActionJump(reactant_stoich, net_stoich; param_idxs = param_idxs) - jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng = rng) - - sol_adaptive_newton = solve(EnsembleProblem(jump_prob_maj), - SimpleAdaptiveTauLeaping(solver = NewtonImplicitSolver()), - EnsembleSerial(); trajectories = Nsims, saveat = 1.0) - sol_adaptive_trap = solve(EnsembleProblem(jump_prob_maj), - SimpleAdaptiveTauLeaping(solver = TrapezoidalImplicitSolver()), - EnsembleSerial(); trajectories = Nsims, saveat = 1.0) - - t_points = 0:1.0:250.0 - max_direct_I = maximum([mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]) - max_adaptive_newton = maximum([mean(sol_adaptive_newton[i](t)[3] for i in 1:Nsims) for t in t_points]) - max_adaptive_trap = maximum([mean(sol_adaptive_trap[i](t)[3] for i in 1:Nsims) for t in t_points]) - - @test isapprox(max_direct_I, max_adaptive_newton, rtol = 0.05) - @test isapprox(max_direct_I, max_adaptive_trap, rtol = 0.05) -end + # Test with save_start=false + sol = solve(jp_adaptive, SimpleAdaptiveTauLeaping(); saveat = 2.0, save_start = false) + @test sol.t == collect(2.0:2.0:10.0) -# SimpleAdaptiveTauLeaping integration tests (stiff system) -@testset "SimpleAdaptiveTauLeaping Integration Tests" begin - # Stiff system from Cao et al. (2007): S1 -> S2, S2 -> S1, S2 -> S3 - c = (1000.0, 1000.0, 1.0) - reactant_stoich1 = [Pair(1, 1)] - net_stoich1 = [Pair(1, -1), Pair(2, 1)] - reactant_stoich2 = [Pair(2, 1)] - net_stoich2 = [Pair(1, 1), Pair(2, -1)] - reactant_stoich3 = [Pair(2, 1)] - net_stoich3 = [Pair(2, -1), Pair(3, 1)] - stiff_jumps = MassActionJump([c[1], c[2], c[3]], - [reactant_stoich1, reactant_stoich2, reactant_stoich3], - [net_stoich1, net_stoich2, net_stoich3]) - - u0 = [100, 0, 0] - tspan = (0.0, 10.0) - prob = DiscreteProblem(u0, tspan) - jump_prob = JumpProblem(prob, PureLeaping(), stiff_jumps; rng = rng) - - sol = solve(jump_prob, SimpleAdaptiveTauLeaping(); saveat = 0.1) - @test sol.t[end] ≈ 10.0 atol = 1e-6 - - # Reversible isomerization from Cao et al. (2004) - c2 = (1000.0, 1000.0) - reactant_stoich_r1 = [Pair(1, 1)] - net_stoich_r1 = [Pair(1, -1), Pair(2, 1)] - reactant_stoich_r2 = [Pair(2, 1)] - net_stoich_r2 = [Pair(1, 1), Pair(2, -1)] - rev_jumps = MassActionJump([c2[1], c2[2]], - [reactant_stoich_r1, reactant_stoich_r2], - [net_stoich_r1, net_stoich_r2]) - - u0_rev = [1000, 0] - tspan_rev = (0.0, 0.1) - prob_rev = DiscreteProblem(u0_rev, tspan_rev) - jump_prob_rev = JumpProblem(prob_rev, PureLeaping(), rev_jumps; rng = rng) - - sol_newton = solve(jump_prob_rev, SimpleAdaptiveTauLeaping(epsilon = 0.05, - solver = NewtonImplicitSolver())) - sol_trap = solve(jump_prob_rev, SimpleAdaptiveTauLeaping(epsilon = 0.05, - solver = TrapezoidalImplicitSolver())) - @test sol_newton.t[end] ≈ 0.1 atol = 1e-6 - @test sol_trap.t[end] ≈ 0.1 atol = 1e-6 + # Test with save_end=false + sol = solve(jp_adaptive, SimpleAdaptiveTauLeaping(); saveat = 2.0, save_end = false) + @test sol.t == collect(0.0:2.0:8.0) end From d26784a2fd40acf9a8cdb230026a4f0633d3480f Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N Date: Wed, 11 Mar 2026 01:09:39 +0530 Subject: [PATCH 5/9] refactor --- src/JumpProcesses.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 15940112..e7a05e4e 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -18,17 +18,14 @@ using StaticArrays: StaticArrays, SVector, setindex using Base.Threads: Threads using Base.FastMath: add_fast -using SimpleNonlinearSolve +using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleNewtonRaphson +using ADTypes: ADTypes, AutoFiniteDiff # Import functions we extend from Base import Base: size, getindex, setindex!, length, similar, show, merge!, merge # Import functions we extend from packages import DiffEqCallbacks: gauss_points, gauss_weights -# Cache gauss quadrature data at module load to avoid type instability from -# non-const gauss_points/gauss_weights globals in DiffEqCallbacks. -const _GAUSS_POINTS = gauss_points[4] -const _GAUSS_WEIGHTS = gauss_weights[4] import DiffEqBase: DiscreteCallback, init, solve, solve!, initialize! import SciMLBase: plot_indices import DataStructures: update! @@ -42,7 +39,7 @@ using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction, ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!, deleteat!, isinplace, remake, savevalues!, step!, u_modified! -using SciMLBase: SciMLBase, DEIntegrator +using SciMLBase: SciMLBase, DEIntegrator, NonlinearProblem abstract type AbstractJump end abstract type AbstractMassActionJump <: AbstractJump end From 8866ce7ba4cda674f892a37a34aebf6f71b5770b Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N Date: Wed, 11 Mar 2026 01:11:49 +0530 Subject: [PATCH 6/9] refactor --- src/simple_regular_solve.jl | 40 ------------------------------------- 1 file changed, 40 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 2747212c..64f987ef 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -440,46 +440,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end -# Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. -# HOR is the sum of stoichiometric coefficients of reactants in reaction j. -# Extract the element type from reactant_stoch to avoid hardcoding type assumptions. -function compute_hor(reactant_stoch, numjumps) - stoch_type = eltype(first(first(reactant_stoch))) - hor = zeros(stoch_type, numjumps) - max_order = 3 * one(stoch_type) # Maximum supported reaction order (type-aware) - for j in 1:numjumps - order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=zero(stoch_type)) - if order > max_order - error("Reaction $j has order $order, which is not supported (maximum order is 3).") - end - hor[j] = order - end - return hor -end - -# Precompute reaction conditions for each species i, including: -# - max_hor: the highest order of reaction (HOR) where species i is a reactant. -# - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor. -# Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27). -function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) - hor_type = eltype(hor) - max_hor = zeros(hor_type, numspecies) - max_stoich = zeros(hor_type, numspecies) - for j in 1:numjumps - for (spec_idx, stoch) in reactant_stoch[j] - if stoch > 0 # Species is a reactant - if hor[j] > max_hor[spec_idx] - max_hor[spec_idx] = hor[j] - max_stoich[spec_idx] = stoch - elseif hor[j] == max_hor[spec_idx] - max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch) - end - end - end - end - return max_hor, max_stoich -end - # Compute tau for implicit tau-leaping with relaxed error control # Reference: Cao et al. (2007), J. Chem. Phys. 126, 224101, Section III.A From 48f9c74eb4ce44a64174070cd86c1e5bb2cc5ba8 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N Date: Wed, 11 Mar 2026 01:12:52 +0530 Subject: [PATCH 7/9] test fix --- test/regular_jumps.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index ae67672a..c6cc7458 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -11,7 +11,7 @@ function compute_mean_at_saves(sol, Nsims, npts, species_idx) mean_vals = zeros(npts) for i in 1:Nsims for j in 1:npts - mean_vals[j] += sol[i].u[j][species_idx] + mean_vals[j] += sol.u[i].u[j][species_idx] end end mean_vals ./= Nsims From 17933814932c846a4a8651381b5acbd3ece77327 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N Date: Wed, 11 Mar 2026 01:17:48 +0530 Subject: [PATCH 8/9] bug fix --- src/JumpProcesses.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index e7a05e4e..3a5a3721 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -26,6 +26,10 @@ import Base: size, getindex, setindex!, length, similar, show, merge!, merge # Import functions we extend from packages import DiffEqCallbacks: gauss_points, gauss_weights +# Cache gauss quadrature data at module load to avoid type instability from +# non-const gauss_points/gauss_weights globals in DiffEqCallbacks. +const _GAUSS_POINTS = gauss_points[4] +const _GAUSS_WEIGHTS = gauss_weights[4] import DiffEqBase: DiscreteCallback, init, solve, solve!, initialize! import SciMLBase: plot_indices import DataStructures: update! From 9adab7295f964bb54c118dd76d6c862f5bc56445 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N Date: Wed, 11 Mar 2026 01:29:47 +0530 Subject: [PATCH 9/9] bug fix --- src/JumpProcesses.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 3a5a3721..c3529fdd 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -4,8 +4,8 @@ using Reexport: Reexport, @reexport @reexport using DiffEqBase # Explicit imports from standard libraries -using LinearAlgebra: LinearAlgebra, I, mul!, eigvals -using Random: Random, randexp, randexp!, seed! +using LinearAlgebra: LinearAlgebra, mul! +using Random: Random, randexp, seed! # Explicit imports from external packages using DocStringExtensions: DocStringExtensions, FIELDS, TYPEDEF