diff --git a/Project.toml b/Project.toml index dfc8e5f5..899eb285 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas "] version = "9.25.1" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" @@ -17,6 +18,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -54,6 +57,8 @@ RecursiveArrayTools = "3.35, 4" Reexport = "1.2" SafeTestsets = "0.1" SciMLBase = "2.115, 3.1" +Setfield = "1" +SimpleNonlinearSolve = "1, 2" StableRNGs = "1" StaticArrays = "1.9.8" Statistics = "1" @@ -63,7 +68,6 @@ Test = "1" julia = "1.10" [extras] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" @@ -78,4 +82,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ADTypes", "Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"] +test = ["Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"] diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 87c8c2c2..e2f49fcb 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -18,6 +18,9 @@ using StaticArrays: StaticArrays, SVector, setindex using Base.Threads: Threads using Base.FastMath: add_fast +using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleNewtonRaphson +using ADTypes: ADTypes, AutoFiniteDiff + # Import functions we extend from Base import Base: size, getindex, setindex!, length, similar, show, merge!, merge @@ -40,7 +43,7 @@ using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction, ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!, deleteat!, isinplace, remake, savevalues!, step!, u_modified! -using SciMLBase: SciMLBase, DEIntegrator +using SciMLBase: SciMLBase, DEIntegrator, NonlinearProblem abstract type AbstractJump end abstract type AbstractMassActionJump <: AbstractJump end @@ -131,7 +134,7 @@ export SSAStepper # leaping: include("simple_regular_solve.jl") -export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel +export SimpleTauLeaping, SimpleExplicitTauLeaping, SimpleImplicitTauLeaping, NewtonImplicitSolver, TrapezoidalImplicitSolver, EnsembleGPUKernel # spatial: include("spatial/spatial_massaction_jump.jl") diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 3675ddad..fb865da7 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -6,6 +6,18 @@ end SimpleExplicitTauLeaping(; epsilon = 0.05) = SimpleExplicitTauLeaping(epsilon) +# Define solver type hierarchy +abstract type AbstractImplicitSolver end +struct NewtonImplicitSolver <: AbstractImplicitSolver end +struct TrapezoidalImplicitSolver <: AbstractImplicitSolver end + +struct SimpleImplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm + epsilon::T # Error control parameter for tau selection + solver::AbstractImplicitSolver # Solver type: Newton or Trapezoidal +end + +SimpleImplicitTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver()) = SimpleImplicitTauLeaping(epsilon, solver) + function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ @@ -69,6 +81,19 @@ function _process_saveat(saveat, tspan, save_start, save_end) return saveat_vec, _save_start, _save_end end +function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping) + if !(jump_prob.aggregator isa PureLeaping) + @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ + JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \ + Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release." + end + isempty(jump_prob.jump_callback.continuous_callbacks) && + isempty(jump_prob.jump_callback.discrete_callbacks) && + isempty(jump_prob.constant_jumps) && + isempty(jump_prob.variable_jumps) && + jump_prob.massaction_jump !== nothing +end + function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; seed = nothing, dt = error("dt is required for SimpleTauLeaping."), saveat = nothing, save_start = nothing, save_end = nothing) @@ -405,6 +430,291 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end +function compute_hor(reactant_stoch, numjumps) + # Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. + # HOR is the sum of stoichiometric coefficients of reactants in reaction j. + # Extract the element type from reactant_stoch to avoid hardcoding type assumptions. + stoch_type = eltype(first(first(reactant_stoch))) + hor = zeros(stoch_type, numjumps) + max_order = 3 * one(stoch_type) # Maximum supported reaction order (type-aware) + for j in 1:numjumps + order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=zero(stoch_type)) + if order > max_order + error("Reaction $j has order $order, which is not supported (maximum order is 3).") + end + hor[j] = order + end + return hor +end + +function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) + # Precompute reaction conditions for each species i, including: + # - max_hor: the highest order of reaction (HOR) where species i is a reactant. + # - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor. + # Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27). + hor_type = eltype(hor) + max_hor = zeros(hor_type, numspecies) + max_stoich = zeros(hor_type, numspecies) + stoch_type = eltype(first(first(reactant_stoch))) + zero_stoch = zero(stoch_type) + for j in 1:numjumps + for (spec_idx, stoch) in reactant_stoch[j] + if stoch > zero_stoch # Species is a reactant + if hor[j] > max_hor[spec_idx] + max_hor[spec_idx] = hor[j] + max_stoich[spec_idx] = stoch + elseif hor[j] == max_hor[spec_idx] + max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch) + end + end + end + end + return max_hor, max_stoich +end + +function compute_gi(u, max_hor, max_stoich, i, t) + # Compute g_i for species i to bound the relative change in propensity functions, + # as per Cao et al. (2006), Section IV, equation (27). + # g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant: + # - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1 + # - HOR = 2 (second-order): + # - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2 + # - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1) + # - HOR = 3 (third-order): + # - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3 + # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1)) + # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) + # Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep. + one_max_hor = one(1 / one(eltype(u))) + hor_type = eltype(max_hor) + zero_hor = zero(hor_type) + one_hor = one(hor_type) + two_hor = one_hor + one_hor + three_hor = two_hor + one_hor + + if max_hor[i] == zero_hor # No reactions involve species i as a reactant + return one_max_hor + elseif max_hor[i] == one_hor + return one_max_hor + elseif max_hor[i] == two_hor + if max_stoich[i] == one_hor + return 2 * one_max_hor + else # max_stoich[i] == 2 + return u[i] > one_max_hor ? + 2 * one_max_hor + one_max_hor / (u[i] - one_max_hor) : 2 * one_max_hor # Fallback to 2 if x_i <= 1 + end + elseif max_hor[i] == three_hor + if max_stoich[i] == one_hor + return 3 * one_max_hor + elseif max_stoich[i] == two_hor + return u[i] > one_max_hor ? + (3 * one_max_hor / 2) * (2 * one_max_hor + one_max_hor / (u[i] - one_max_hor)) : 3 * one_max_hor # Fallback to 3 if x_i <= 1 + else # max_stoich[i] == 3 + return u[i] > 2 * one_max_hor ? + 3 * one_max_hor + one_max_hor / (u[i] - one_max_hor) + 2 * one_max_hor / (u[i] - 2 * one_max_hor) : 3 * one_max_hor # Fallback to 3 if x_i <= 2 + end + end + return one_max_hor # Default case +end + +function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + # Compute the tau-leaping step-size using equation (20) from Cao et al. (2006): + # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } + # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): + # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) + # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). + rate(rate_cache, u, p, t) + if all(<=(0), rate_cache) # Handle case where all rates are zero or negative + return dtmin + end + tau = typemax(typeof(t)) + for i in 1:length(u) + mu = zero(eltype(u)) + sigma2 = zero(eltype(u)) + for j in 1:size(nu, 2) + mu += nu[i, j] * rate_cache[j] # Equation (9a) + sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b) + end + gi = compute_gi(u, max_hor, max_stoich, i, t) + bound = max(epsilon * u[i] / gi, one(eltype(u))) # max(epsilon * x_i / g_i, 1) + mu_term = abs(mu) > 0 ? bound / abs(mu) : typemax(typeof(t)) # First term in equation (8) + sigma_term = sigma2 > 0 ? bound^2 / sigma2 : typemax(typeof(t)) # Second term in equation (8) + tau = min(tau, mu_term, sigma_term) # Equation (8) + end + return max(tau, dtmin) +end + +# Define residual for implicit equation +# Newton: u_new = u_current + sum_j nu_j * a_j(u_new) * tau (Cao et al., 2004) +# Trapezoidal: u_new = u_current + sum_j nu_j * (a_j(u_current) + a_j(u_new))/2 * tau +function implicit_equation!(resid, u_new, params) + u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params + rate(rate_cache, u_new, p, t + tau) + resid .= u_new .- u_current + if isa(solver, NewtonImplicitSolver) + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004) + end + end + else # TrapezoidalImplicitSolver + rate_current = similar(rate_cache) + rate(rate_current, u_current, p, t) + half = one(eltype(rate_cache)) / 2 + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + resid[spec_idx] -= nu[spec_idx, j] * half * (rate_cache[j] + rate_current[j]) * tau + end + end + end + resid .= max.(resid, -u_new) # Ensure non-negative solution +end + +# Solve implicit equation using SimpleNonlinearSolve +function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver) + u_new = convert(Vector{float(eltype(u_current))}, u_current) + prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)) + sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6) + return sol.u, sol.retcode == ReturnCode.Success +end + +# Function to generate a mass action rate function +function massaction_rate(maj, numjumps) + return (out, u, p, t) -> begin + for j in 1:numjumps + out[j] = evalrxrate(u, j, maj) + end + end +end + +function simple_implicit_tau_leaping_loop!( + prob, alg, u_current, t_current, t_end, p, rng, + rate, nu, hor, max_hor, max_stoich, numjumps, epsilon, + dtmin, saveat_times, usave, tsave, du, counts, rate_cache, + maj, solver, save_end) + save_idx = 1 + + while t_current < t_end + rate(rate_cache, u_current, p, t_current) + tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + tau = min(tau, t_end - t_current) + if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] + tau = saveat_times[save_idx] - t_current + end + + u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, t_current, tau, rate, numjumps, solver) + if !converged + tau /= 2 + continue + end + + rate(rate_cache, u_new_float, p, t_current + tau) + zero_rate = zero(eltype(rate_cache)) + counts .= pois_rand.(rng, max.(rate_cache * tau, zero_rate)) + du .= zero(eltype(du)) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + du[spec_idx] += stoch * counts[j] + end + end + u_new = u_current + du + + zero_pop = zero(eltype(u_new)) + if any(<(zero_pop), u_new) + # Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3 + tau /= 2 + continue + end + # Ensure non-negativity, as per Cao et al. (2006), Section 3.3 + for i in eachindex(u_new) + u_new[i] = max(u_new[i], zero_pop) + end + t_new = t_current + tau + + if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) + push!(usave, copy(u_new)) + push!(tsave, t_new) + if !isempty(saveat_times) && t_new >= saveat_times[save_idx] + save_idx += 1 + end + end + + u_current = u_new + t_current = t_new + end + + # Save endpoint if requested and not already saved + if save_end && (isempty(tsave) || tsave[end] != t_end) + push!(usave, copy(u_current)) + push!(tsave, t_end) + end +end + +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; + seed = nothing, + dtmin = nothing, + saveat = nothing, save_start = nothing, save_end = nothing) + validate_pure_leaping_inputs(jump_prob, alg) || + error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") + + (; prob, rng) = jump_prob + (seed !== nothing) && seed!(rng, seed) + + maj = jump_prob.massaction_jump + numjumps = get_num_majumps(maj) + rj = jump_prob.regular_jump + # Extract rates + rate = rj !== nothing ? rj.rate : massaction_rate(maj, numjumps) + c = rj !== nothing ? rj.c : nothing + u0 = copy(prob.u0) + tspan = prob.tspan + p = prob.p + + if dtmin === nothing + dtmin = 1e-10 * one(typeof(tspan[2])) + end + + saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end) + + # Initialize current state and saved history + u_current = copy(u0) + t_current = tspan[1] + if save_start + usave = [copy(u0)] + tsave = [tspan[1]] + else + usave = typeof(u0)[] + tsave = typeof(tspan[1])[] + end + rate_cache = zeros(float(eltype(u0)), numjumps) + counts = zero(rate_cache) + du = similar(u0) + t_end = tspan[2] + epsilon = alg.epsilon + solver = alg.solver + + nu = zeros(float(eltype(u0)), length(u0), numjumps) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + nu[spec_idx, j] = stoch + end + end + reactant_stoch = maj.reactant_stoch + hor = compute_hor(reactant_stoch, numjumps) + max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) + + simple_implicit_tau_leaping_loop!( + prob, alg, u_current, t_current, t_end, p, rng, + rate, nu, hor, max_hor, max_stoich, numjumps, epsilon, + dtmin, saveat_times, usave, tsave, du, counts, rate_cache, + maj, solver, save_end) + + sol = DiffEqBase.build_solution(prob, alg, tsave, usave, + calculate_error=false, + interp=DiffEqBase.ConstantInterpolation(tsave, usave)) + return sol +end + struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm backend::Backend cpu_offload::Float64 diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index f4d009ea..54d4cd6a 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -61,7 +61,7 @@ end sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 0.1, saveat = t_compare) - # MassActionJump formulation for SimpleExplicitTauLeaping + # MassActionJump formulation for SimpleExplicitTauLeaping / SimpleImplicitTauLeaping reactant_stoich = [[1 => 1, 2 => 1], [2 => 1], Pair{Int, Int}[]] net_stoich = [[1 => -1, 2 => 1], [2 => -1, 3 => 1], [1 => 1]] param_idxs = [1, 2, 3] @@ -72,14 +72,26 @@ end sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims, saveat = t_compare) + # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) + sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()), EnsembleSerial(); + trajectories = Nsims, saveat = t_compare) + sol_implicit_trap = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()), EnsembleSerial(); + trajectories = Nsims, saveat = t_compare) + # Compute mean I trajectories via direct indexing (I is index 2 in SIR) mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 2) mean_I_simple = compute_mean_at_saves(sol_simple, Nsims, npts, 2) mean_I_explicit = compute_mean_at_saves(sol_adaptive, Nsims, npts, 2) + mean_I_implicit_newton = compute_mean_at_saves(sol_implicit_newton, Nsims, npts, 2) + mean_I_implicit_trap = compute_mean_at_saves(sol_implicit_trap, Nsims, npts, 2) # Compare full mean trajectories across all saved timepoints @test all(isapprox.(mean_I_direct, mean_I_simple, rtol = 0.1)) @test all(isapprox.(mean_I_direct, mean_I_explicit, rtol = 0.1)) + @test all(isapprox.(mean_I_direct, mean_I_implicit_newton, rtol = 0.1)) + @test all(isapprox.(mean_I_direct, mean_I_implicit_trap, rtol = 0.1)) end # SEIR model with exposed compartment @@ -127,7 +139,7 @@ end sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 0.1, saveat = t_compare) - # MassActionJump formulation for SimpleExplicitTauLeaping + # MassActionJump formulation for SimpleExplicitTauLeaping / SimpleImplicitTauLeaping reactant_stoich = [[1 => 1, 3 => 1], [2 => 1], [3 => 1]] net_stoich = [[1 => -1, 2 => 1], [2 => -1, 3 => 1], [3 => -1, 4 => 1]] param_idxs = [1, 2, 3] @@ -138,14 +150,26 @@ end sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims, saveat = t_compare) + # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) + sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()), EnsembleSerial(); + trajectories = Nsims, saveat = t_compare) + sol_implicit_trap = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()), EnsembleSerial(); + trajectories = Nsims, saveat = t_compare) + # Compute mean I trajectories via direct indexing (I is index 3 in SEIR) mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 3) mean_I_simple = compute_mean_at_saves(sol_simple, Nsims, npts, 3) mean_I_explicit = compute_mean_at_saves(sol_adaptive, Nsims, npts, 3) + mean_I_implicit_newton = compute_mean_at_saves(sol_implicit_newton, Nsims, npts, 3) + mean_I_implicit_trap = compute_mean_at_saves(sol_implicit_trap, Nsims, npts, 3) # Compare full mean trajectories across all saved timepoints @test all(isapprox.(mean_I_direct, mean_I_simple, rtol = 0.1)) @test all(isapprox.(mean_I_direct, mean_I_explicit, rtol = 0.1)) + @test all(isapprox.(mean_I_direct, mean_I_implicit_newton, rtol = 0.1)) + @test all(isapprox.(mean_I_direct, mean_I_implicit_trap, rtol = 0.1)) end # Test zero-rate case for SimpleExplicitTauLeaping @@ -175,62 +199,62 @@ end tspan = (0.0, 10.0) p = [0.1, 0.2] prob = DiscreteProblem(u0, tspan, p) - + # Create MassActionJump reactant_stoich = [[1 => 1], [1 => 2]] net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]] rates = [0.1, 0.05] maj = MassActionJump(rates, reactant_stoich, net_stoich) - + # Test PureLeaping JumpProblem creation jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj); rng) @test jp_pure.aggregator isa PureLeaping @test jp_pure.discrete_jump_aggregation === nothing @test jp_pure.massaction_jump !== nothing @test length(jp_pure.jump_callback.discrete_callbacks) == 0 - + # Test with ConstantRateJump rate(u, p, t) = p[1] * u[1] affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) crj = ConstantRateJump(rate, affect!) - + jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj); rng) @test jp_pure_crj.aggregator isa PureLeaping @test jp_pure_crj.discrete_jump_aggregation === nothing @test length(jp_pure_crj.constant_jumps) == 1 - + # Test with VariableRateJump vrate(u, p, t) = t * p[1] * u[1] vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) vrj = VariableRateJump(vrate, vaffect!) - + jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj); rng) @test jp_pure_vrj.aggregator isa PureLeaping @test jp_pure_vrj.discrete_jump_aggregation === nothing @test length(jp_pure_vrj.variable_jumps) == 1 - + # Test with RegularJump function rj_rate(out, u, p, t) out[1] = p[1] * u[1] end - + rj_dc = zeros(3, 1) rj_dc[1, 1] = -1 rj_dc[3, 1] = 1 - + function rj_c(du, u, p, t, counts, mark) mul!(du, rj_dc, counts) end - + regj = RegularJump(rj_rate, rj_c, 1) - + jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj); rng) @test jp_pure_regj.aggregator isa PureLeaping @test jp_pure_regj.discrete_jump_aggregation === nothing @test jp_pure_regj.regular_jump !== nothing - + # Test mixed jump types - mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), + mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), variable_jumps = (vrj,), regular_jumps = regj) jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps; rng) @test jp_pure_mixed.aggregator isa PureLeaping @@ -239,7 +263,7 @@ end @test length(jp_pure_mixed.constant_jumps) == 1 @test length(jp_pure_mixed.variable_jumps) == 1 @test jp_pure_mixed.regular_jump !== nothing - + # Test spatial system error spatial_sys = CartesianGrid((2, 2)) hopping_consts = [1.0] @@ -247,7 +271,7 @@ end spatial_system = spatial_sys) @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, hopping_constants = hopping_consts) - + # Test MassActionJump with parameter mapping maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng) @@ -255,6 +279,56 @@ end @test jp_params.massaction_jump.scaled_rates == scaled_rates end +# Test implicit solvers on stiff system (Cao et al. 2007) +@testset "Stiff System with Implicit Solvers" begin + # Reactions: S1 -> S2, S2 -> S1, S2 -> S3 + c = (1000.0, 1000.0, 1.0) + reactant_stoich1 = [Pair(1, 1)] + net_stoich1 = [Pair(1, -1), Pair(2, 1)] + reactant_stoich2 = [Pair(2, 1)] + net_stoich2 = [Pair(1, 1), Pair(2, -1)] + reactant_stoich3 = [Pair(2, 1)] + net_stoich3 = [Pair(2, -1), Pair(3, 1)] + + maj = MassActionJump([c[1], c[2], c[3]], + [reactant_stoich1, reactant_stoich2, reactant_stoich3], + [net_stoich1, net_stoich2, net_stoich3]) + + u0 = [100, 0, 0] # Initial: S1=100, S2=0, S3=0 + tspan = (0.0, 1.0) + prob = DiscreteProblem(u0, tspan) + jump_prob = JumpProblem(prob, PureLeaping(), maj; rng = rng) + + # Solve with SimpleExplicitTauLeaping + sol_explicit = solve(jump_prob, SimpleExplicitTauLeaping(); dtmin = 1e-6, saveat = 0.1) + + # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) + sol_implicit_newton = solve(jump_prob, SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()); + dtmin = 1e-6, saveat = 0.1) + sol_implicit_trap = solve(jump_prob, SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()); + dtmin = 1e-6, saveat = 0.1) + + # Check that all solvers completed successfully + @test sol_explicit.t[end] ≈ 1.0 atol = 1e-3 + @test sol_implicit_newton.t[end] ≈ 1.0 atol = 1e-3 + @test sol_implicit_trap.t[end] ≈ 1.0 atol = 1e-3 + + # Check conservation: S1 + S2 + S3 should equal initial total + @test all(sum(u) ≈ 100 for u in sol_explicit.u) + @test all(sum(u) ≈ 100 for u in sol_implicit_newton.u) + @test all(sum(u) ≈ 100 for u in sol_implicit_trap.u) + + # Check that solutions are non-negative + @test all(all(x >= 0 for x in u) for u in sol_explicit.u) + @test all(all(x >= 0 for x in u) for u in sol_implicit_newton.u) + @test all(all(x >= 0 for x in u) for u in sol_implicit_trap.u) + + # S3 should increase monotonically (it only has an influx reaction) + @test sol_explicit.u[end][3] >= sol_explicit.u[1][3] + @test sol_implicit_newton.u[end][3] >= sol_implicit_newton.u[1][3] + @test sol_implicit_trap.u[end][3] >= sol_implicit_trap.u[1][3] +end + # Test that saveat/save_start/save_end control which times are stored in solutions @testset "Saving Controls" begin # Simple birth process for testing SSAStepper save behavior @@ -353,4 +427,38 @@ end sol = solve(jp_explicit, SimpleExplicitTauLeaping(); saveat = [2.0, 8.0], save_start = true, save_end = true) @test sol.t == [0.0, 2.0, 8.0, 10.0] + + # --- SimpleImplicitTauLeaping save_start/save_end/saveat tests --- + u0_implicit = [100.0] + prob_implicit = DiscreteProblem(u0_implicit, tspan) + reactant_stoich_implicit = [[1 => 1]] + net_stoich_implicit = [[1 => -1]] + maj_implicit = MassActionJump([0.1], reactant_stoich_implicit, net_stoich_implicit) + jp_implicit = JumpProblem(prob_implicit, PureLeaping(), maj_implicit; rng) + + # saveat as Number: defaults save_start=true, save_end=true + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = 2.0) + @test sol.t == collect(0.0:2.0:10.0) + + # saveat as Number + save_start=false + save_end=false + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = 2.0, + save_start = false, save_end = false) + @test sol.t == collect(2.0:2.0:8.0) + + # saveat collection including endpoints + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = [0.0, 5.0, 10.0]) + @test sol.t == [0.0, 5.0, 10.0] + + # saveat collection without endpoints + explicit save_start=true, save_end=true + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = [2.0, 8.0], + save_start = true, save_end = true) + @test sol.t == [0.0, 2.0, 8.0, 10.0] + + # Test with save_start=false + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = 2.0, save_start = false) + @test sol.t == collect(2.0:2.0:10.0) + + # Test with save_end=false + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = 2.0, save_end = false) + @test sol.t == collect(0.0:2.0:8.0) end