Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2e43bcf
kernels setup done
sivasathyaseeelan Jul 23, 2025
f402962
ImplicitTauLeaping setup done for jump problem solver
sivasathyaseeelan Jul 24, 2025
f778529
jump_problem solver fixed
sivasathyaseeelan Jul 28, 2025
a6755f0
removed equilibrium_pair logic
sivasathyaseeelan Jul 31, 2025
3a1a3c6
tranculation error fixed
sivasathyaseeelan Jul 31, 2025
1c368de
nonlinearsolver is implemented
sivasathyaseeelan Aug 1, 2025
c92e30f
changed to SimpleImplicitTauLeaping
sivasathyaseeelan Aug 9, 2025
166754d
refactor
sivasathyaseeelan Aug 9, 2025
e50a43a
SimpleAdaptiveTauLeaping is done
sivasathyaseeelan Aug 9, 2025
273460a
simple version of SimpleImplicitTauLeaping
sivasathyaseeelan Aug 19, 2025
d7c6558
removed adaptive tau leap
sivasathyaseeelan Aug 19, 2025
a21b367
poiss change
sivasathyaseeelan Aug 19, 2025
df55f1e
changed to inline non linear solver
sivasathyaseeelan Aug 19, 2025
5c9d419
refactor
sivasathyaseeelan Aug 19, 2025
55d8b91
basic version of inplicit tau leap is done
sivasathyaseeelan Aug 19, 2025
6fd9110
added critical_threshold
sivasathyaseeelan Aug 20, 2025
2b0667a
residual update
sivasathyaseeelan Aug 20, 2025
a6af972
added comment line
sivasathyaseeelan Aug 20, 2025
becb2a5
SimpleImplicitTauLeaping
sivasathyaseeelan Sep 5, 2025
fe71c61
project.toml
sivasathyaseeelan Sep 5, 2025
1e2b6e6
project.toml
sivasathyaseeelan Sep 5, 2025
c8022cb
some
sivasathyaseeelan Sep 5, 2025
2845eb4
some
sivasathyaseeelan Sep 5, 2025
e6a736e
refactor
sivasathyaseeelan Feb 13, 2026
77991ee
some changes
sivasathyaseeelan Feb 13, 2026
0277f27
refactor
sivasathyaseeelan Feb 14, 2026
5770437
comcat entries
sivasathyaseeelan Feb 14, 2026
f33a041
test fix
sivasathyaseeelan Feb 15, 2026
08e4688
saveat implementation
sivasathyaseeelan Feb 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "9.25.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand All @@ -17,6 +18,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

Expand Down Expand Up @@ -54,6 +57,8 @@ RecursiveArrayTools = "3.35, 4"
Reexport = "1.2"
SafeTestsets = "0.1"
SciMLBase = "2.115, 3.1"
Setfield = "1"
SimpleNonlinearSolve = "1, 2"
StableRNGs = "1"
StaticArrays = "1.9.8"
Statistics = "1"
Expand All @@ -63,7 +68,6 @@ Test = "1"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
Expand All @@ -78,4 +82,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ADTypes", "Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"]
test = ["Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"]
7 changes: 5 additions & 2 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ using StaticArrays: StaticArrays, SVector, setindex
using Base.Threads: Threads
using Base.FastMath: add_fast

using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleNewtonRaphson
using ADTypes: ADTypes, AutoFiniteDiff

# Import functions we extend from Base
import Base: size, getindex, setindex!, length, similar, show, merge!, merge

Expand All @@ -40,7 +43,7 @@ using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction,
ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!,
deleteat!, isinplace, remake, savevalues!, step!,
u_modified!
using SciMLBase: SciMLBase, DEIntegrator
using SciMLBase: SciMLBase, DEIntegrator, NonlinearProblem

abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
Expand Down Expand Up @@ -131,7 +134,7 @@ export SSAStepper

# leaping:
include("simple_regular_solve.jl")
export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel
export SimpleTauLeaping, SimpleExplicitTauLeaping, SimpleImplicitTauLeaping, NewtonImplicitSolver, TrapezoidalImplicitSolver, EnsembleGPUKernel

# spatial:
include("spatial/spatial_massaction_jump.jl")
Expand Down
310 changes: 310 additions & 0 deletions src/simple_regular_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ end

SimpleExplicitTauLeaping(; epsilon = 0.05) = SimpleExplicitTauLeaping(epsilon)

# Define solver type hierarchy
abstract type AbstractImplicitSolver end
struct NewtonImplicitSolver <: AbstractImplicitSolver end
struct TrapezoidalImplicitSolver <: AbstractImplicitSolver end

struct SimpleImplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm
epsilon::T # Error control parameter for tau selection
solver::AbstractImplicitSolver # Solver type: Newton or Trapezoidal
end

SimpleImplicitTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver()) = SimpleImplicitTauLeaping(epsilon, solver)

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
Expand Down Expand Up @@ -69,6 +81,19 @@ function _process_saveat(saveat, tspan, save_start, save_end)
return saveat_vec, _save_start, _save_end
end

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \
Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release."
end
isempty(jump_prob.jump_callback.continuous_callbacks) &&
isempty(jump_prob.jump_callback.discrete_callbacks) &&
isempty(jump_prob.constant_jumps) &&
isempty(jump_prob.variable_jumps) &&
jump_prob.massaction_jump !== nothing
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
seed = nothing, dt = error("dt is required for SimpleTauLeaping."),
saveat = nothing, save_start = nothing, save_end = nothing)
Expand Down Expand Up @@ -405,6 +430,291 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping;
return sol
end

function compute_hor(reactant_stoch, numjumps)
# Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV.
# HOR is the sum of stoichiometric coefficients of reactants in reaction j.
# Extract the element type from reactant_stoch to avoid hardcoding type assumptions.
stoch_type = eltype(first(first(reactant_stoch)))
hor = zeros(stoch_type, numjumps)
max_order = 3 * one(stoch_type) # Maximum supported reaction order (type-aware)
for j in 1:numjumps
order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=zero(stoch_type))
if order > max_order
error("Reaction $j has order $order, which is not supported (maximum order is 3).")
end
hor[j] = order
end
return hor
end

function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps)
# Precompute reaction conditions for each species i, including:
# - max_hor: the highest order of reaction (HOR) where species i is a reactant.
# - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor.
# Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27).
hor_type = eltype(hor)
max_hor = zeros(hor_type, numspecies)
max_stoich = zeros(hor_type, numspecies)
stoch_type = eltype(first(first(reactant_stoch)))
zero_stoch = zero(stoch_type)
for j in 1:numjumps
for (spec_idx, stoch) in reactant_stoch[j]
if stoch > zero_stoch # Species is a reactant
if hor[j] > max_hor[spec_idx]
max_hor[spec_idx] = hor[j]
max_stoich[spec_idx] = stoch
elseif hor[j] == max_hor[spec_idx]
max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch)
end
end
end
end
return max_hor, max_stoich
end

function compute_gi(u, max_hor, max_stoich, i, t)
# Compute g_i for species i to bound the relative change in propensity functions,
# as per Cao et al. (2006), Section IV, equation (27).
# g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant:
# - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1
# - HOR = 2 (second-order):
# - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2
# - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1)
# - HOR = 3 (third-order):
# - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3
# - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1))
# - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2)
# Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep.
one_max_hor = one(1 / one(eltype(u)))
hor_type = eltype(max_hor)
zero_hor = zero(hor_type)
one_hor = one(hor_type)
two_hor = one_hor + one_hor
three_hor = two_hor + one_hor

if max_hor[i] == zero_hor # No reactions involve species i as a reactant
return one_max_hor
elseif max_hor[i] == one_hor
return one_max_hor
elseif max_hor[i] == two_hor
if max_stoich[i] == one_hor
return 2 * one_max_hor
else # max_stoich[i] == 2
return u[i] > one_max_hor ?
2 * one_max_hor + one_max_hor / (u[i] - one_max_hor) : 2 * one_max_hor # Fallback to 2 if x_i <= 1
end
elseif max_hor[i] == three_hor
if max_stoich[i] == one_hor
return 3 * one_max_hor
elseif max_stoich[i] == two_hor
return u[i] > one_max_hor ?
(3 * one_max_hor / 2) * (2 * one_max_hor + one_max_hor / (u[i] - one_max_hor)) : 3 * one_max_hor # Fallback to 3 if x_i <= 1
else # max_stoich[i] == 3
return u[i] > 2 * one_max_hor ?
3 * one_max_hor + one_max_hor / (u[i] - one_max_hor) + 2 * one_max_hor / (u[i] - 2 * one_max_hor) : 3 * one_max_hor # Fallback to 3 if x_i <= 2
end
end
return one_max_hor # Default case
end

function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
# Compute the tau-leaping step-size using equation (20) from Cao et al. (2006):
# tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) }
# where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b):
# mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x)
# I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified).
rate(rate_cache, u, p, t)
if all(<=(0), rate_cache) # Handle case where all rates are zero or negative
return dtmin
end
tau = typemax(typeof(t))
for i in 1:length(u)
mu = zero(eltype(u))
sigma2 = zero(eltype(u))
for j in 1:size(nu, 2)
mu += nu[i, j] * rate_cache[j] # Equation (9a)
sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b)
end
gi = compute_gi(u, max_hor, max_stoich, i, t)
bound = max(epsilon * u[i] / gi, one(eltype(u))) # max(epsilon * x_i / g_i, 1)
mu_term = abs(mu) > 0 ? bound / abs(mu) : typemax(typeof(t)) # First term in equation (8)
sigma_term = sigma2 > 0 ? bound^2 / sigma2 : typemax(typeof(t)) # Second term in equation (8)
tau = min(tau, mu_term, sigma_term) # Equation (8)
end
return max(tau, dtmin)
end

# Define residual for implicit equation
# Newton: u_new = u_current + sum_j nu_j * a_j(u_new) * tau (Cao et al., 2004)
# Trapezoidal: u_new = u_current + sum_j nu_j * (a_j(u_current) + a_j(u_new))/2 * tau
function implicit_equation!(resid, u_new, params)
u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params
rate(rate_cache, u_new, p, t + tau)
resid .= u_new .- u_current
if isa(solver, NewtonImplicitSolver)
for j in 1:numjumps
for spec_idx in 1:size(nu, 1)
resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004)
end
end
else # TrapezoidalImplicitSolver
rate_current = similar(rate_cache)
rate(rate_current, u_current, p, t)
half = one(eltype(rate_cache)) / 2
for j in 1:numjumps
for spec_idx in 1:size(nu, 1)
resid[spec_idx] -= nu[spec_idx, j] * half * (rate_cache[j] + rate_current[j]) * tau
end
end
end
resid .= max.(resid, -u_new) # Ensure non-negative solution
end

# Solve implicit equation using SimpleNonlinearSolve
function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)
u_new = convert(Vector{float(eltype(u_current))}, u_current)
prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver))
sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6)
return sol.u, sol.retcode == ReturnCode.Success
end

# Function to generate a mass action rate function
function massaction_rate(maj, numjumps)
return (out, u, p, t) -> begin
for j in 1:numjumps
out[j] = evalrxrate(u, j, maj)
end
end
end

function simple_implicit_tau_leaping_loop!(
prob, alg, u_current, t_current, t_end, p, rng,
rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
maj, solver, save_end)
save_idx = 1

while t_current < t_end
rate(rate_cache, u_current, p, t_current)
tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
tau = min(tau, t_end - t_current)
if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx]
tau = saveat_times[save_idx] - t_current
end

u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, t_current, tau, rate, numjumps, solver)
if !converged
tau /= 2
continue
end

rate(rate_cache, u_new_float, p, t_current + tau)
zero_rate = zero(eltype(rate_cache))
counts .= pois_rand.(rng, max.(rate_cache * tau, zero_rate))
du .= zero(eltype(du))
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
du[spec_idx] += stoch * counts[j]
end
end
u_new = u_current + du

zero_pop = zero(eltype(u_new))
if any(<(zero_pop), u_new)
# Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3
tau /= 2
continue
end
# Ensure non-negativity, as per Cao et al. (2006), Section 3.3
for i in eachindex(u_new)
u_new[i] = max(u_new[i], zero_pop)
end
t_new = t_current + tau

if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
push!(usave, copy(u_new))
push!(tsave, t_new)
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
save_idx += 1
end
end

u_current = u_new
t_current = t_new
end

# Save endpoint if requested and not already saved
if save_end && (isempty(tsave) || tsave[end] != t_end)
push!(usave, copy(u_current))
push!(tsave, t_end)
end
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping;
seed = nothing,
dtmin = nothing,
saveat = nothing, save_start = nothing, save_end = nothing)
validate_pure_leaping_inputs(jump_prob, alg) ||
error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")

(; prob, rng) = jump_prob
(seed !== nothing) && seed!(rng, seed)

maj = jump_prob.massaction_jump
numjumps = get_num_majumps(maj)
rj = jump_prob.regular_jump
# Extract rates
rate = rj !== nothing ? rj.rate : massaction_rate(maj, numjumps)
c = rj !== nothing ? rj.c : nothing
u0 = copy(prob.u0)
tspan = prob.tspan
p = prob.p

if dtmin === nothing
dtmin = 1e-10 * one(typeof(tspan[2]))
end

saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end)

# Initialize current state and saved history
u_current = copy(u0)
t_current = tspan[1]
if save_start
usave = [copy(u0)]
tsave = [tspan[1]]
else
usave = typeof(u0)[]
tsave = typeof(tspan[1])[]
end
rate_cache = zeros(float(eltype(u0)), numjumps)
counts = zero(rate_cache)
du = similar(u0)
t_end = tspan[2]
epsilon = alg.epsilon
solver = alg.solver

nu = zeros(float(eltype(u0)), length(u0), numjumps)
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
nu[spec_idx, j] = stoch
end
end
reactant_stoch = maj.reactant_stoch
hor = compute_hor(reactant_stoch, numjumps)
max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps)

simple_implicit_tau_leaping_loop!(
prob, alg, u_current, t_current, t_end, p, rng,
rate, nu, hor, max_hor, max_stoich, numjumps, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache,
maj, solver, save_end)

sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
calculate_error=false,
interp=DiffEqBase.ConstantInterpolation(tsave, usave))
return sol
end

struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
backend::Backend
cpu_offload::Float64
Expand Down
Loading
Loading