From 2e43bcf1f80c602d17e59362c22f2ee900886616 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Wed, 23 Jul 2025 23:14:28 +0530 Subject: [PATCH 01/29] kernels setup done --- ext/JumpProcessesKernelAbstractionsExt.jl | 2 +- ext/implicit_tau.jl | 406 ++++++++++++++++++++++ src/simple_regular_solve.jl | 12 + test/gpu/implicit_tau.jl | 50 +++ 4 files changed, 469 insertions(+), 1 deletion(-) create mode 100644 ext/implicit_tau.jl create mode 100644 test/gpu/implicit_tau.jl diff --git a/ext/JumpProcessesKernelAbstractionsExt.jl b/ext/JumpProcessesKernelAbstractionsExt.jl index 2b345ebc..ae38cda6 100644 --- a/ext/JumpProcessesKernelAbstractionsExt.jl +++ b/ext/JumpProcessesKernelAbstractionsExt.jl @@ -1,6 +1,6 @@ module JumpProcessesKernelAbstractionsExt -using JumpProcesses, SciMLBase +using JumpProcesses, SciMLBase, DiffEqBase using KernelAbstractions, Adapt using StaticArrays using PoissonRandom, Random diff --git a/ext/implicit_tau.jl b/ext/implicit_tau.jl new file mode 100644 index 00000000..e0d083fd --- /dev/null +++ b/ext/implicit_tau.jl @@ -0,0 +1,406 @@ +function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem, + alg::ImplicitTauLeaping, + ensemblealg::EnsembleGPUKernel; + trajectories, + seed = nothing, + dt = error("dt is required for ImplicitTauLeaping."), + kwargs...) + + if trajectories == 1 + return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); trajectories=1, + seed, dt, kwargs...) + end + + backend = ensemblealg.backend === nothing ? CPU() : ensemblealg.backend + + jump_prob = ensembleprob.prob + + @assert isempty(jump_prob.jump_callback.continuous_callbacks) + @assert isempty(jump_prob.jump_callback.discrete_callbacks) + prob = jump_prob.prob + + probs = [remake(jump_prob) for _ in 1:trajectories] + + ts, us = vectorized_solve(probs, jump_prob, alg; backend, trajectories, seed, dt, kwargs...) + + _ts = Array(ts) + _us = Array(us) + + time = @elapsed sol = [begin + ts = @view _ts[:, i] + us = @view _us[:, :, i] + sol_idx = findlast(x -> x != probs[i].prob.tspan[1], ts) + if sol_idx === nothing + @error "No solution found" tspan=probs[i].tspan[1] ts + error("Batch solve failed") + end + @views ensembleprob.output_func( + SciMLBase.build_solution(probs[i].prob, + alg, + ts[1:sol_idx], + [us[j, :] for j in 1:sol_idx], + k = nothing, + stats = nothing, + calculate_error = false, + retcode = sol_idx != length(ts) ? ReturnCode.Terminated : ReturnCode.Success), + i)[1] + end for i in eachindex(probs)] + return SciMLBase.EnsembleSolution(sol, time, true) +end + +struct TrajectoryDataImplicit{U <: StaticArray, P, T} + u0::U + p::P + tspan::Tuple{T, T} +end + +struct JumpDataImplicit{R, C, V} + rate::R + c::C + nu::V + numjumps::Int +end + +function compute_tau_explicit(u, rate, nu, num_jumps, epsilon, g, J_ncr, I_rs, p) + rate_cache = zeros(eltype(u), num_jumps) + rate(rate_cache, u, p, 0.0) + + mu = zeros(eltype(u), length(u)) + sigma2 = zeros(eltype(u), length(u)) + + for i in I_rs + mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_ncr; init=0.0) + sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_ncr; init=0.0) + end + + tau = Inf + for i in I_rs + denom_mu = max(epsilon * u[i] / g[i], 1.0) + denom_sigma = denom_mu^2 + if abs(mu[i]) > 0 + tau = min(tau, denom_mu / abs(mu[i])) + end + if sigma2[i] > 0 + tau = min(tau, denom_sigma / sigma2[i]) + end + end + return tau +end + +function compute_tau_implicit(u, rate, nu, num_jumps, epsilon, g, J_necr, I_rs, p) + rate_cache = zeros(eltype(u), num_jumps) + rate(rate_cache, u, p, 0.0) + + mu = zeros(eltype(u), length(u)) + sigma2 = zeros(eltype(u), length(u)) + + for i in I_rs + mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_necr; init=0.0) + sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_necr; init=0.0) + end + + tau = Inf + for i in I_rs + denom_mu = max(epsilon * u[i] / g[i], 1.0) + denom_sigma = denom_mu^2 + if abs(mu[i]) > 0 + tau = min(tau, denom_mu / abs(mu[i])) + end + if sigma2[i] > 0 + tau = min(tau, denom_sigma / sigma2[i]) + end + end + return isinf(tau) ? 1e6 : tau +end + +function identify_critical_reactions(u, nu, num_jumps, nc) + L = zeros(Int, num_jumps) + J_critical = Int[] + + for j in 1:num_jumps + min_val = Inf + for i in 1:length(u) + if nu[i,j] < 0 + val = floor(u[i] / abs(nu[i,j])) + min_val = min(min_val, val) + end + end + L[j] = min_val == Inf ? typemax(Int) : Int(min_val) + if L[j] < nc + push!(J_critical, j) + end + end + J_ncr = setdiff(1:num_jumps, J_critical) + return J_critical, J_ncr +end + +function check_partial_equilibrium(rate_cache, reversible_pairs, delta) + J_equilibrium = Int[] + for (j_plus, j_minus) in reversible_pairs + a_plus = rate_cache[j_plus] + a_minus = rate_cache[j_minus] + if abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) + push!(J_equilibrium, j_plus, j_minus) + end + end + return J_equilibrium +end + +function newton_solve!(x_new, x, rate, nu, rate_cache, counts, p, t, tau, max_iter=10, tol=1e-6) + state_dim = length(x) + num_jumps = length(counts) + + x_temp = copy(x_new) + for iter in 1:max_iter + rate(rate_cache, x_new, p, t) + rate_cache .*= tau + + residual = x_new .- x + for j in 1:num_jumps + residual .-= nu[:,j] * (counts[j] - rate_cache[j] + tau * rate_cache[j]) + end + + if norm(residual) < tol + break + end + + J = zeros(eltype(x), state_dim, state_dim) + for j in 1:num_jumps + for i in 1:state_dim + for k in 1:state_dim + J[i,k] += nu[i,j] * nu[k,j] * rate_cache[j] + end + end + end + J = I - tau * J + + delta_x = J \ residual + x_new .-= delta_x + + if norm(delta_x) < tol + break + end + end + return x_new +end + +@kernel function implicit_tau_leaping_kernel(@Const(probs_data), _us, _ts, dt, @Const(rj_data), + current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, + seed::UInt64, alg::ImplicitTauLeaping, reversible_pairs) + i = @index(Global, Linear) + + @inbounds begin + current_u = view(current_u_buf, :, i) + rate_cache = view(rate_cache_buf, :, i) + counts = view(counts_buf, :, i) + local_dc = view(local_dc_buf, :, i) + end + + @inbounds prob_data = probs_data[i] + u0 = prob_data.u0 + p = prob_data.p + tspan = prob_data.tspan + + rate = rj_data.rate + num_jumps = rj_data.numjumps + c = rj_data.c + nu = rj_data.nu + + @inbounds for k in 1:length(u0) + current_u[k] = u0[k] + end + + n = Int((tspan[2] - tspan[1]) / dt) + 1 + state_dim = length(u0) + + ts_view = @inbounds view(_ts, :, i) + us_view = @inbounds view(_us, :, :, i) + + @inbounds ts_view[1] = tspan[1] + @inbounds for k in 1:state_dim + us_view[1, k] = current_u[k] + end + + rng = Random.Xoshiro(seed + i) + + I_rs = 1:state_dim + g = ones(state_dim) + + for j in 2:n + tprev = tspan[1] + (j-2) * dt + + J_critical, J_ncr = identify_critical_reactions(current_u, nu, num_jumps, alg.nc) + + rate(rate_cache, current_u, p, tprev) + a0_critical = sum(rate_cache[j] for j in J_critical; init=0.0) + + J_equilibrium = check_partial_equilibrium(rate_cache, reversible_pairs, alg.delta) + J_necr = setdiff(J_ncr, J_equilibrium) + + tau_ex = compute_tau_explicit(current_u, rate, nu, num_jumps, alg.epsilon, g, J_ncr, I_rs, p) + tau_im = compute_tau_implicit(current_u, rate, nu, num_jumps, alg.epsilon, g, J_necr, I_rs, p) + + tau2 = a0_critical > 0 ? -log(rand(rng)) / a0_critical : Inf + + use_implicit = tau_im > alg.nstiff * tau_ex + tau1 = use_implicit ? tau_im : tau_ex + + if tau1 < 10 / sum(rate_cache; init=0.0) + a0 = sum(rate_cache; init=0.0) + if a0 > 0 + tau = -log(rand(rng)) / a0 + r = rand(rng) * a0 + cumsum_a = 0.0 + jc = 1 + for k in 1:num_jumps + cumsum_a += rate_cache[k] + if cumsum_a > r + jc = k + break + end + end + current_u .+= nu[:,jc] + else + tau = dt + end + else + tau = min(tau1, tau2, dt) + if tau == tau2 + if a0_critical > 0 + r = rand(rng) * a0_critical + cumsum_a = 0.0 + jc = J_critical[1] + for k in J_critical + cumsum_a += rate_cache[k] + if cumsum_a > r + jc = k + break + end + end + counts .= 0 + counts[jc] = 1 + if use_implicit && tau > tau_ex + for k in J_ncr + counts[k] = poisson_rand(rate_cache[k] * tau, rng) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) + else + for k in J_ncr + counts[k] = poisson_rand(rate_cache[k] * tau, rng) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .+= local_dc + end + else + tau = tau1 + if use_implicit + for k in 1:num_jumps + counts[k] = poisson_rand(rate_cache[k] * tau, rng) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) + else + for k in 1:num_jumps + counts[k] = poisson_rand(rate_cache[k] * tau, rng) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .+= local_dc + end + end + else + counts .= 0 + if use_implicit + for k in J_ncr + counts[k] = poisson_rand(rate_cache[k] * tau, rng) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) + else + for k in J_ncr + counts[k] = poisson_rand(rate_cache[k] * tau, rng) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .+= local_dc + end + end + end + + if any(current_u .< 0) + tau1 /= 2 + continue + end + + @inbounds for k in 1:state_dim + us_view[j, k] = current_u[k] + end + @inbounds ts_view[j] = tspan[1] + (j-1) * dt + end +end + +function vectorized_solve(probs, prob::JumpProblem, alg::ImplicitTauLeaping; backend, trajectories, seed, dt, kwargs...) + rj = prob.regular_jump + nu = zeros(Int, length(prob.prob.u0), rj.numjumps) + for j in 1:rj.numjumps + dc = zeros(length(prob.prob.u0)) + rj.c(dc, prob.prob.u0, prob.prob.p, 0.0, [i == j ? 1 : 0 for i in 1:rj.numjumps], nothing) + nu[:,j] = dc + end + rj_data = JumpDataImplicit(rj.rate, rj.c, nu, rj.numjumps) + + probs_data = [TrajectoryDataImplicit(SA{eltype(p.prob.u0)}[p.prob.u0...], p.prob.p, p.prob.tspan) for p in probs] + + probs_data_gpu = adapt(backend, probs_data) + rj_data_gpu = adapt(backend, rj_data) + + state_dim = length(first(probs_data).u0) + tspan = prob.prob.tspan + dt = Float64(dt) + n_steps = Int((tspan[2] - tspan[1]) / dt) + 1 + n_trajectories = length(probs) + num_jumps = rj_data.numjumps + + @assert state_dim > 0 "Dimension of state must be positive" + @assert num_jumps >= 0 "Number of jumps must be positive" + + ts = allocate(backend, eltype(prob.prob.tspan), (n_steps, n_trajectories)) + us = allocate(backend, eltype(prob.prob.u0), (n_steps, state_dim, n_trajectories)) + + current_u_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, n_trajectories)) + rate_cache_buf = allocate(backend, eltype(prob.prob.u0), (num_jumps, n_trajectories)) + counts_buf = allocate(backend, eltype(prob.prob.u0), (num_jumps, n_trajectories)) + local_dc_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, n_trajectories)) + + @kernel function init_buffers_kernel(@Const(probs_data), current_u_buf) + i = @index(Global, Linear) + @inbounds u0 = probs_data[i].u0 + @inbounds for k in 1:length(u0) + current_u_buf[k, i] = u0[k] + end + end + init_kernel = init_buffers_kernel(backend) + init_event = init_kernel(probs_data_gpu, current_u_buf; ndrange=n_trajectories) + KernelAbstractions.synchronize(backend) + + seed = seed === nothing ? UInt64(12345) : UInt64(seed) + reversible_pairs = get(kwargs, :reversible_pairs, Tuple{Int,Int}[]) + + kernel = implicit_tau_leaping_kernel(backend) + main_event = kernel(probs_data_gpu, us, ts, dt, rj_data_gpu, + current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, seed, alg, reversible_pairs; + ndrange=n_trajectories) + KernelAbstractions.synchronize(backend) + + return ts, us +end + +@inline function poisson_rand(lambda, rng) + L = exp(-lambda) + k = 0 + p = 1.0 + while p > L + k += 1 + p *= rand(rng) + end + return k - 1 +end diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 3675ddad..ca2a481d 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -417,3 +417,15 @@ end function EnsembleGPUKernel() EnsembleGPUKernel(nothing, 0.0) end + +# Define ImplicitTauLeaping algorithm +struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm + epsilon::Float64 # Error control parameter + nc::Int # Critical reaction threshold + nstiff::Int # Stiffness threshold multiplier + delta::Float64 # Partial equilibrium threshold +end + +ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100, delta=0.05) = ImplicitTauLeaping(epsilon, nc, nstiff, delta) + +export SimpleTauLeaping, EnsembleGPUKernel, ImplicitTauLeaping diff --git a/test/gpu/implicit_tau.jl b/test/gpu/implicit_tau.jl new file mode 100644 index 00000000..459578b2 --- /dev/null +++ b/test/gpu/implicit_tau.jl @@ -0,0 +1,50 @@ +using JumpProcesses, DiffEqBase +using Test, LinearAlgebra, Statistics +using KernelAbstractions, Adapt, CUDA +using StableRNGs, Plots + + +rng = StableRNG(12345) +Nsims = 10 + + +# Parameters +c1 = 1.0 # S1 -> 0 +c2 = 10.0 # S1 + S1 <- S2 +c3 = 1000.0 # S1 + S1 -> S2 +c4 = 0.1 # S2 -> S3 +p = (c1, c2, c3, c4) + +# Propensity functions +regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] # S1 -> 0 + out[2] = p[2] * u[2] # S1 + S1 <- S2 + out[3] = p[3] * u[1] * (u[1] - 1) / 2 # S1 + S1 -> S2 + out[4] = p[4] * u[2] # S2 -> S3 +end + +# State change function +regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0.0 + dc[1] = -counts[1] - 2 * counts[3] + 2 * counts[2] # S1: -decay - 2*forward + 2*backward + dc[2] = counts[3] - counts[2] - counts[4] # S2: +forward - backward - decay + dc[3] = counts[4] # S3: +decay +end + +# Initial condition +u0 = [10000.0, 0.0, 0.0] # S1, S2, S3 +tspan = (0.0, 4.0) + +# Define reversible reaction pairs (R2 and R3 are reversible: S1 + S1 <-> S2) +reversible_pairs = [(2, 3)] + +# Create JumpProblem with proper parameter passing +prob_disc = DiscreteProblem(u0, tspan, p) +rj = RegularJump(regular_rate, regular_c, 4) +jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) + +# Solve using ImplicitTauLeaping +alg = ImplicitTauLeaping(epsilon=0.05, nc=10, nstiff=100, delta=0.05) +sol = solve(EnsembleProblem(jump_prob), alg, EnsembleGPUKernel(); + trajectories=Nsims, dt=0.01, reversible_pairs=reversible_pairs) +plot(sol) From f402962aff79fea8b11fc53d48e7449d73b06d99 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Thu, 24 Jul 2025 13:32:47 +0530 Subject: [PATCH 02/29] ImplicitTauLeaping setup done for jump problem solver --- src/simple_regular_solve.jl | 305 ++++++++++++++++++++++++++++++++++-- 1 file changed, 295 insertions(+), 10 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index ca2a481d..0591ad09 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -405,6 +405,301 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end +# Define ImplicitTauLeaping algorithm +struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm + epsilon::Float64 # Error control parameter + nc::Int # Critical reaction threshold + nstiff::Int # Stiffness threshold multiplier + delta::Float64 # Partial equilibrium threshold +end + +ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100, delta=0.05) = ImplicitTauLeaping(epsilon, nc, nstiff, delta) + +function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; + seed = nothing, + dt = error("dt is required for ImplicitTauLeaping."), + kwargs...) + + # Boilerplate from SimpleTauLeaping + @assert isempty(jump_prob.jump_callback.continuous_callbacks) + @assert isempty(jump_prob.jump_callback.discrete_callbacks) + prob = jump_prob.prob + rng = DEFAULT_RNG + (seed !== nothing) && seed!(rng, seed) + + rj = jump_prob.regular_jump + rate = rj.rate # rate(out, u, p, t) + numjumps = rj.numjumps + c = rj.c # c(dc, u, p, t, counts, mark) + reversible_pairs = get(kwargs, :reversible_pairs, Tuple{Int,Int}[]) + + if !isnothing(rj.mark_dist) + error("Mark distributions are currently not supported in ImplicitTauLeaping") + end + + # Initialize state and buffers + u0 = copy(prob.u0) + p = prob.p + tspan = prob.tspan + state_dim = length(u0) + dt = Float64(dt) + + # Compute stoichiometry matrix + nu = zeros(Int, state_dim, numjumps) + for j in 1:numjumps + dc = zeros(state_dim) + c(dc, u0, p, 0.0, [i == j ? 1 : 0 for i in 1:numjumps], nothing) + nu[:, j] = dc + end + + # Initialize solution arrays + n = Int((tspan[2] - tspan[1]) / dt) + 1 + u = Vector{typeof(u0)}(undef, n) + u[1] = u0 + t = range(tspan[1], tspan[2], length=n) + + # Buffers for iteration + current_u = copy(u0) + rate_cache = zeros(Float64, numjumps) + counts = zeros(Float64, numjumps) + local_dc = zeros(Float64, state_dim) + I_rs = 1:state_dim + g = ones(state_dim) # Scaling factor for tau-leaping + + function compute_tau_explicit(u, rate, nu, num_jumps, epsilon, g, J_ncr, I_rs, p) + rate_cache = zeros(eltype(u), num_jumps) + rate(rate_cache, u, p, 0.0) + + mu = zeros(eltype(u), length(u)) + sigma2 = zeros(eltype(u), length(u)) + + for i in I_rs + mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_ncr; init=0.0) + sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_ncr; init=0.0) + end + + tau = Inf + for i in I_rs + denom_mu = max(epsilon * u[i] / g[i], 1.0) + denom_sigma = denom_mu^2 + if abs(mu[i]) > 0 + tau = min(tau, denom_mu / abs(mu[i])) + end + if sigma2[i] > 0 + tau = min(tau, denom_sigma / sigma2[i]) + end + end + return tau + end + + function compute_tau_implicit(u, rate, nu, num_jumps, epsilon, g, J_necr, I_rs, p) + rate_cache = zeros(eltype(u), num_jumps) + rate(rate_cache, u, p, 0.0) + + mu = zeros(eltype(u), length(u)) + sigma2 = zeros(eltype(u), length(u)) + + for i in I_rs + mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_necr; init=0.0) + sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_necr; init=0.0) + end + + tau = Inf + for i in I_rs + denom_mu = max(epsilon * u[i] / g[i], 1.0) + denom_sigma = denom_mu^2 + if abs(mu[i]) > 0 + tau = min(tau, denom_mu / abs(mu[i])) + end + if sigma2[i] > 0 + tau = min(tau, denom_sigma / sigma2[i]) + end + end + return isinf(tau) ? 1e6 : tau + end + + function identify_critical_reactions(u, nu, num_jumps, nc) + L = zeros(Int, num_jumps) + J_critical = Int[] + + for j in 1:num_jumps + min_val = Inf + for i in 1:length(u) + if nu[i,j] < 0 + val = floor(u[i] / abs(nu[i,j])) + min_val = min(min_val, val) + end + end + L[j] = min_val == Inf ? typemax(Int) : Int(min_val) + if L[j] < nc + push!(J_critical, j) + end + end + J_ncr = setdiff(1:num_jumps, J_critical) + return J_critical, J_ncr + end + + function check_partial_equilibrium(rate_cache, reversible_pairs, delta) + J_equilibrium = Int[] + for (j_plus, j_minus) in reversible_pairs + a_plus = rate_cache[j_plus] + a_minus = rate_cache[j_minus] + if abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) + push!(J_equilibrium, j_plus, j_minus) + end + end + return J_equilibrium + end + + function newton_solve!(x_new, x, rate, nu, rate_cache, counts, p, t, tau, max_iter=10, tol=1e-6) + state_dim = length(x) + num_jumps = length(counts) + + for iter in 1:max_iter + rate(rate_cache, x_new, p, t) + rate_cache .*= tau + + residual = x_new .- x + for j in 1:num_jumps + residual .-= nu[:,j] * (counts[j] - rate_cache[j] + tau * rate_cache[j]) + end + + if norm(residual) < tol + break + end + + J = zeros(eltype(x), state_dim, state_dim) + for j in 1:num_jumps + for i in 1:state_dim + for k in 1:state_dim + J[i,k] += nu[i,j] * nu[k,j] * rate_cache[j] + end + end + end + J = I - tau * J + + delta_x = J \ residual + x_new .-= delta_x + + if norm(delta_x) < tol + break + end + end + return x_new + end + + # Main solver loop + for i in 2:n + tprev = t[i - 1] + J_critical, J_ncr = identify_critical_reactions(current_u, nu, numjumps, alg.nc) + + rate(rate_cache, current_u, p, tprev) + a0_critical = sum(rate_cache[j] for j in J_critical; init=0.0) + + J_equilibrium = check_partial_equilibrium(rate_cache, reversible_pairs, alg.delta) + J_necr = setdiff(J_ncr, J_equilibrium) + + tau_ex = compute_tau_explicit(current_u, rate, nu, numjumps, alg.epsilon, g, J_ncr, I_rs, p) + tau_im = compute_tau_implicit(current_u, rate, nu, numjumps, alg.epsilon, g, J_necr, I_rs, p) + + tau2 = a0_critical > 0 ? -log(rand(rng)) / a0_critical : Inf + use_implicit = tau_im > alg.nstiff * tau_ex + tau1 = use_implicit ? tau_im : tau_ex + + if tau1 < 10 / sum(rate_cache; init=0.0) + a0 = sum(rate_cache; init=0.0) + if a0 > 0 + tau = -log(rand(rng)) / a0 + r = rand(rng) * a0 + cumsum_a = 0.0 + jc = 1 + for k in 1:numjumps + cumsum_a += rate_cache[k] + if cumsum_a > r + jc = k + break + end + end + current_u .+= nu[:,jc] + else + tau = dt + end + else + tau = min(tau1, tau2, dt) + if tau == tau2 + if a0_critical > 0 + r = rand(rng) * a0_critical + cumsum_a = 0.0 + jc = !isempty(J_critical) ? J_critical[1] : 1 + for k in J_critical + cumsum_a += rate_cache[k] + if cumsum_a > r + jc = k + break + end + end + counts .= 0 + counts[jc] = 1 + if use_implicit && tau > tau_ex + for k in J_ncr + counts[k] = pois_rand(rng, rate_cache[k] * tau) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) + else + for k in J_ncr + counts[k] = pois_rand(rng, rate_cache[k] * tau) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .+= local_dc + end + else + tau = tau1 + if use_implicit + for k in 1:numjumps + counts[k] = pois_rand(rng, rate_cache[k] * tau) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) + else + for k in 1:numjumps + counts[k] = pois_rand(rng, rate_cache[k] * tau) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .+= local_dc + end + end + else + counts .= 0 + if use_implicit + for k in J_ncr + counts[k] = pois_rand(rng, rate_cache[k] * tau) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) + else + for k in J_ncr + counts[k] = pois_rand(rng, rate_cache[k] * tau) + end + c(local_dc, current_u, p, tprev, counts, nothing) + current_u .+= local_dc + end + end + end + + if any(current_u .< 0) + tau1 /= 2 + continue + end + + u[i] = copy(current_u) + end + + sol = DiffEqBase.build_solution(prob, alg, t, u, + calculate_error = false, + interp = DiffEqBase.ConstantInterpolation(t, u)) +end + struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm backend::Backend cpu_offload::Float64 @@ -418,14 +713,4 @@ function EnsembleGPUKernel() EnsembleGPUKernel(nothing, 0.0) end -# Define ImplicitTauLeaping algorithm -struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm - epsilon::Float64 # Error control parameter - nc::Int # Critical reaction threshold - nstiff::Int # Stiffness threshold multiplier - delta::Float64 # Partial equilibrium threshold -end - -ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100, delta=0.05) = ImplicitTauLeaping(epsilon, nc, nstiff, delta) - export SimpleTauLeaping, EnsembleGPUKernel, ImplicitTauLeaping From f778529c4f63569103836186e3e61d7d3caed6f6 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Mon, 28 Jul 2025 18:21:51 +0530 Subject: [PATCH 03/29] jump_problem solver fixed --- src/simple_regular_solve.jl | 472 ++++++++++++++++++------------------ 1 file changed, 242 insertions(+), 230 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 0591ad09..a7ff2562 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -405,22 +405,21 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end -# Define ImplicitTauLeaping algorithm +# Define the ImplicitTauLeaping algorithm struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm epsilon::Float64 # Error control parameter nc::Int # Critical reaction threshold - nstiff::Int # Stiffness threshold multiplier + nstiff::Float64 # Stiffness threshold for switching delta::Float64 # Partial equilibrium threshold + equilibrium_pairs::Vector{Tuple{Int,Int}} # Reversible reaction pairs end -ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100, delta=0.05) = ImplicitTauLeaping(epsilon, nc, nstiff, delta) +# Default constructor +ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05, equilibrium_pairs=[(1, 2)]) = + ImplicitTauLeaping(epsilon, nc, nstiff, delta, equilibrium_pairs) -function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; - seed = nothing, - dt = error("dt is required for ImplicitTauLeaping."), - kwargs...) - - # Boilerplate from SimpleTauLeaping +function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed=nothing) + # Boilerplate setup @assert isempty(jump_prob.jump_callback.continuous_callbacks) @assert isempty(jump_prob.jump_callback.discrete_callbacks) prob = jump_prob.prob @@ -428,276 +427,289 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; (seed !== nothing) && seed!(rng, seed) rj = jump_prob.regular_jump - rate = rj.rate # rate(out, u, p, t) + rate = rj.rate numjumps = rj.numjumps - c = rj.c # c(dc, u, p, t, counts, mark) - reversible_pairs = get(kwargs, :reversible_pairs, Tuple{Int,Int}[]) - - if !isnothing(rj.mark_dist) - error("Mark distributions are currently not supported in ImplicitTauLeaping") - end - - # Initialize state and buffers + c = rj.c u0 = copy(prob.u0) - p = prob.p tspan = prob.tspan - state_dim = length(u0) - dt = Float64(dt) - - # Compute stoichiometry matrix - nu = zeros(Int, state_dim, numjumps) - for j in 1:numjumps - dc = zeros(state_dim) - c(dc, u0, p, 0.0, [i == j ? 1 : 0 for i in 1:numjumps], nothing) - nu[:, j] = dc - end - - # Initialize solution arrays - n = Int((tspan[2] - tspan[1]) / dt) + 1 - u = Vector{typeof(u0)}(undef, n) - u[1] = u0 - t = range(tspan[1], tspan[2], length=n) - - # Buffers for iteration - current_u = copy(u0) + p = prob.p + + # Initialize storage rate_cache = zeros(Float64, numjumps) - counts = zeros(Float64, numjumps) - local_dc = zeros(Float64, state_dim) - I_rs = 1:state_dim - g = ones(state_dim) # Scaling factor for tau-leaping - - function compute_tau_explicit(u, rate, nu, num_jumps, epsilon, g, J_ncr, I_rs, p) - rate_cache = zeros(eltype(u), num_jumps) - rate(rate_cache, u, p, 0.0) - - mu = zeros(eltype(u), length(u)) - sigma2 = zeros(eltype(u), length(u)) - - for i in I_rs - mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_ncr; init=0.0) - sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_ncr; init=0.0) + counts = zeros(Int, numjumps) + du = similar(u0) + u = [copy(u0)] + t = [tspan[1]] + + # Algorithm parameters + epsilon = alg.epsilon + nc = alg.nc + nstiff = alg.nstiff + delta = alg.delta + equilibrium_pairs = alg.equilibrium_pairs + t_end = tspan[2] + + # Compute stoichiometry matrix from c function + function compute_stoichiometry(c, u, numjumps) + nu = zeros(Int, length(u), numjumps) + for j in 1:numjumps + counts = zeros(numjumps) + counts[j] = 1 + du = similar(u) + c(du, u, p, t[1], counts, nothing) + nu[:, j] = round.(Int, du) end - - tau = Inf - for i in I_rs - denom_mu = max(epsilon * u[i] / g[i], 1.0) - denom_sigma = denom_mu^2 - if abs(mu[i]) > 0 - tau = min(tau, denom_mu / abs(mu[i])) - end - if sigma2[i] > 0 - tau = min(tau, denom_sigma / sigma2[i]) + return nu + end + nu = compute_stoichiometry(c, u0, numjumps) + + # Helper function to compute g_i (approximation from Cao et al., 2006) + function compute_gi(u, nu, i) + # Simplified g_i: highest order of reaction involving species i + max_order = 1.0 + for j in 1:numjumps + if abs(nu[i, j]) > 0 + # Approximate reaction order based on propensity (heuristic) + rate(rate_cache, u, p, t[end]) + if rate_cache[j] > 0 + order = 1.0 # Assume first-order for simplicity + if j == 1 # For SIR infection (S*I), assume second-order + order = 2.0 + end + max_order = max(max_order, order) + end end end - return tau + return max_order end - - function compute_tau_implicit(u, rate, nu, num_jumps, epsilon, g, J_necr, I_rs, p) - rate_cache = zeros(eltype(u), num_jumps) - rate(rate_cache, u, p, 0.0) - - mu = zeros(eltype(u), length(u)) - sigma2 = zeros(eltype(u), length(u)) - - for i in I_rs - mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_necr; init=0.0) - sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_necr; init=0.0) - end - + + # Tau-selection for explicit method (Equation 8) + function compute_tau_explicit(u, rate_cache, nu, p, t) + rate(rate_cache, u, p, t) + mu = zeros(length(u)) + sigma2 = zeros(length(u)) tau = Inf - for i in I_rs - denom_mu = max(epsilon * u[i] / g[i], 1.0) - denom_sigma = denom_mu^2 - if abs(mu[i]) > 0 - tau = min(tau, denom_mu / abs(mu[i])) - end - if sigma2[i] > 0 - tau = min(tau, denom_sigma / sigma2[i]) + for i in 1:length(u) + for j in 1:numjumps + mu[i] += nu[i, j] * rate_cache[j] + sigma2[i] += nu[i, j]^2 * rate_cache[j] end + gi = compute_gi(u, nu, i) + bound = max(epsilon * u[i] / gi, 1.0) + mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf + sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf + tau = min(tau, mu_term, sigma_term) end - return isinf(tau) ? 1e6 : tau + return tau end - - function identify_critical_reactions(u, nu, num_jumps, nc) - L = zeros(Int, num_jumps) - J_critical = Int[] - - for j in 1:num_jumps - min_val = Inf - for i in 1:length(u) - if nu[i,j] < 0 - val = floor(u[i] / abs(nu[i,j])) - min_val = min(min_val, val) - end + + # Partial equilibrium check (Equation 13) + function is_partial_equilibrium(rate_cache, j_plus, j_minus) + a_plus = rate_cache[j_plus] + a_minus = rate_cache[j_minus] + return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) + end + + # Tau-selection for implicit method (Equation 14) + function compute_tau_implicit(u, rate_cache, nu, p, t) + rate(rate_cache, u, p, t) + mu = zeros(length(u)) + sigma2 = zeros(length(u)) + non_equilibrium = trues(numjumps) + for (j_plus, j_minus) in equilibrium_pairs + if is_partial_equilibrium(rate_cache, j_plus, j_minus) + non_equilibrium[j_plus] = false + non_equilibrium[j_minus] = false end - L[j] = min_val == Inf ? typemax(Int) : Int(min_val) - if L[j] < nc - push!(J_critical, j) + end + tau = Inf + for i in 1:length(u) + for j in 1:numjumps + if non_equilibrium[j] + mu[i] += nu[i, j] * rate_cache[j] + sigma2[i] += nu[i, j]^2 * rate_cache[j] + end end + gi = compute_gi(u, nu, i) + bound = max(epsilon * u[i] / gi, 1.0) + mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf + sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf + tau = min(tau, mu_term, sigma_term) end - J_ncr = setdiff(1:num_jumps, J_critical) - return J_critical, J_ncr + return tau end - - function check_partial_equilibrium(rate_cache, reversible_pairs, delta) - J_equilibrium = Int[] - for (j_plus, j_minus) in reversible_pairs - a_plus = rate_cache[j_plus] - a_minus = rate_cache[j_minus] - if abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) - push!(J_equilibrium, j_plus, j_minus) + + # Identify critical reactions + function identify_critical_reactions(u, rate_cache, nu) + critical = falses(numjumps) + for j in 1:numjumps + if rate_cache[j] > 0 + Lj = Inf + for i in 1:length(u) + if nu[i, j] < 0 + Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j]))) + end + end + if Lj < nc + critical[j] = true + end end end - return J_equilibrium + return critical end - - function newton_solve!(x_new, x, rate, nu, rate_cache, counts, p, t, tau, max_iter=10, tol=1e-6) - state_dim = length(x) - num_jumps = length(counts) - + + # Implicit tau-leaping step with Newton's method + function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p) + u_new = copy(u_prev) + rate_new = zeros(numjumps) + tol = 1e-6 + max_iter = 100 for iter in 1:max_iter - rate(rate_cache, x_new, p, t) - rate_cache .*= tau - - residual = x_new .- x - for j in 1:num_jumps - residual .-= nu[:,j] * (counts[j] - rate_cache[j] + tau * rate_cache[j]) + rate(rate_new, u_new, p, t_prev + tau) + residual = u_new - u_prev + for j in 1:numjumps + residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j]) end - if norm(residual) < tol break end - - J = zeros(eltype(x), state_dim, state_dim) - for j in 1:num_jumps - for i in 1:state_dim - for k in 1:state_dim - J[i,k] += nu[i,j] * nu[k,j] * rate_cache[j] + # Approximate Jacobian + J = Diagonal(ones(length(u_new))) + for j in 1:numjumps + for i in 1:length(u_new) + if j == 1 && i in [1, 2] # Infection: β*S*I + J[i, i] += nu[i, j] * tau * p[1] * (i == 1 ? u_new[2] : u_new[1]) + elseif j == 2 && i == 2 # Recovery: ν*I + J[i, i] += nu[i, j] * tau * p[2] end end end - J = I - tau * J - - delta_x = J \ residual - x_new .-= delta_x - - if norm(delta_x) < tol - break - end + u_new -= J \ residual + u_new = max.(u_new, 0.0) end - return x_new + return round.(Int, u_new) end - - # Main solver loop - for i in 2:n - tprev = t[i - 1] - J_critical, J_ncr = identify_critical_reactions(current_u, nu, numjumps, alg.nc) - - rate(rate_cache, current_u, p, tprev) - a0_critical = sum(rate_cache[j] for j in J_critical; init=0.0) - - J_equilibrium = check_partial_equilibrium(rate_cache, reversible_pairs, alg.delta) - J_necr = setdiff(J_ncr, J_equilibrium) - - tau_ex = compute_tau_explicit(current_u, rate, nu, numjumps, alg.epsilon, g, J_ncr, I_rs, p) - tau_im = compute_tau_implicit(current_u, rate, nu, numjumps, alg.epsilon, g, J_necr, I_rs, p) - - tau2 = a0_critical > 0 ? -log(rand(rng)) / a0_critical : Inf - use_implicit = tau_im > alg.nstiff * tau_ex + + # Down-shifting condition (Equation 19) + function use_down_shifting(t, tau_im, tau_ex, a0, t_end) + return t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) + end + + # Main simulation loop + while t[end] < t_end + u_prev = u[end] + t_prev = t[end] + + # Compute propensities + rate(rate_cache, u_prev, p, t_prev) + + # Identify critical reactions + critical = identify_critical_reactions(u_prev, rate_cache, nu) + + # Compute tau values + tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev) + tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev) + + # Compute critical propensity sum + ac0 = sum(rate_cache[critical]) + tau2 = ac0 > 0 ? randexp(rng) / ac0 : Inf + + # Choose method and stepsize + a0 = sum(rate_cache) + use_implicit = tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end) tau1 = use_implicit ? tau_im : tau_ex - - if tau1 < 10 / sum(rate_cache; init=0.0) - a0 = sum(rate_cache; init=0.0) - if a0 > 0 - tau = -log(rand(rng)) / a0 + method = use_implicit ? :implicit : :explicit + + # Check if tau1 is too small + if tau1 < 10 / a0 + # Use SSA for a few steps + steps = method == :implicit ? 10 : 100 + for _ in 1:steps + if t_prev >= t_end + break + end + rate(rate_cache, u_prev, p, t_prev) + a0 = sum(rate_cache) + if a0 == 0 + break + end + tau = randexp(rng) / a0 r = rand(rng) * a0 - cumsum_a = 0.0 - jc = 1 - for k in 1:numjumps - cumsum_a += rate_cache[k] - if cumsum_a > r - jc = k + cumsum_rate = 0.0 + for j in 1:numjumps + cumsum_rate += rate_cache[j] + if cumsum_rate > r + u_prev += nu[:, j] break end end - current_u .+= nu[:,jc] + t_prev += tau + push!(u, copy(u_prev)) + push!(t, t_prev) + end + continue + end + + # Choose stepsize and compute firings + if tau2 > tau1 + tau = min(tau1, t_end - t_prev) + counts .= 0 + for j in 1:numjumps + if !critical[j] + counts[j] = pois_rand(rng, rate_cache[j] * tau) + end + end + if method == :implicit + u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p) else - tau = dt + c(du, u_prev, p, t_prev, counts, nothing) + u_new = u_prev + du end else - tau = min(tau1, tau2, dt) - if tau == tau2 - if a0_critical > 0 - r = rand(rng) * a0_critical - cumsum_a = 0.0 - jc = !isempty(J_critical) ? J_critical[1] : 1 - for k in J_critical - cumsum_a += rate_cache[k] - if cumsum_a > r - jc = k + tau = min(tau2, t_end - t_prev) + counts .= 0 + if ac0 > 0 + r = rand(rng) * ac0 + cumsum_rate = 0.0 + for j in 1:numjumps + if critical[j] + cumsum_rate += rate_cache[j] + if cumsum_rate > r + counts[j] = 1 break end end - counts .= 0 - counts[jc] = 1 - if use_implicit && tau > tau_ex - for k in J_ncr - counts[k] = pois_rand(rng, rate_cache[k] * tau) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) - else - for k in J_ncr - counts[k] = pois_rand(rng, rate_cache[k] * tau) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .+= local_dc - end - else - tau = tau1 - if use_implicit - for k in 1:numjumps - counts[k] = pois_rand(rng, rate_cache[k] * tau) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) - else - for k in 1:numjumps - counts[k] = pois_rand(rng, rate_cache[k] * tau) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .+= local_dc - end end - else - counts .= 0 - if use_implicit - for k in J_ncr - counts[k] = pois_rand(rng, rate_cache[k] * tau) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) - else - for k in J_ncr - counts[k] = pois_rand(rng, rate_cache[k] * tau) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .+= local_dc + end + for j in 1:numjumps + if !critical[j] + counts[j] = pois_rand(rng, rate_cache[j] * tau) end end + if method == :implicit && tau > tau_ex + u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p) + else + c(du, u_prev, p, t_prev, counts, nothing) + u_new = u_prev + du + end end - - if any(current_u .< 0) + + # Check for negative populations + if any(u_new .< 0) tau1 /= 2 continue end - - u[i] = copy(current_u) + + # Update state and time + push!(u, u_new) + push!(t, t_prev + tau) end - + + # Build solution sol = DiffEqBase.build_solution(prob, alg, t, u, - calculate_error = false, - interp = DiffEqBase.ConstantInterpolation(t, u)) + calculate_error=false, + interp=DiffEqBase.ConstantInterpolation(t, u)) + return sol end struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm From a6755f0b06286df902af4e6842c362b716e9a5f2 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Thu, 31 Jul 2025 17:49:13 +0530 Subject: [PATCH 04/29] removed equilibrium_pair logic --- src/simple_regular_solve.jl | 40 ++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index a7ff2562..0dc625a3 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -411,12 +411,11 @@ struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm nc::Int # Critical reaction threshold nstiff::Float64 # Stiffness threshold for switching delta::Float64 # Partial equilibrium threshold - equilibrium_pairs::Vector{Tuple{Int,Int}} # Reversible reaction pairs end # Default constructor -ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05, equilibrium_pairs=[(1, 2)]) = - ImplicitTauLeaping(epsilon, nc, nstiff, delta, equilibrium_pairs) +ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) = + ImplicitTauLeaping(epsilon, nc, nstiff, delta) function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed=nothing) # Boilerplate setup @@ -446,7 +445,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= nc = alg.nc nstiff = alg.nstiff delta = alg.delta - equilibrium_pairs = alg.equilibrium_pairs t_end = tspan[2] # Compute stoichiometry matrix from c function @@ -463,17 +461,30 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= end nu = compute_stoichiometry(c, u0, numjumps) + # Detect reversible reaction pairs + function find_reversible_pairs(nu) + pairs = Vector{Tuple{Int,Int}}() + for j in 1:numjumps + for k in (j+1):numjumps + if nu[:, j] == -nu[:, k] + push!(pairs, (j, k)) + end + end + end + return pairs + end + equilibrium_pairs = find_reversible_pairs(nu) + # Helper function to compute g_i (approximation from Cao et al., 2006) function compute_gi(u, nu, i) - # Simplified g_i: highest order of reaction involving species i max_order = 1.0 for j in 1:numjumps if abs(nu[i, j]) > 0 - # Approximate reaction order based on propensity (heuristic) rate(rate_cache, u, p, t[end]) if rate_cache[j] > 0 - order = 1.0 # Assume first-order for simplicity - if j == 1 # For SIR infection (S*I), assume second-order + order = 1.0 + # Heuristic: if reaction involves multiple species, assume higher order + if sum(abs.(nu[:, j])) > abs(nu[i, j]) order = 2.0 end max_order = max(max_order, order) @@ -573,14 +584,15 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= if norm(residual) < tol break end - # Approximate Jacobian + # Approximate Jacobian (diagonal approximation for simplicity) J = Diagonal(ones(length(u_new))) for j in 1:numjumps for i in 1:length(u_new) - if j == 1 && i in [1, 2] # Infection: β*S*I - J[i, i] += nu[i, j] * tau * p[1] * (i == 1 ? u_new[2] : u_new[1]) - elseif j == 2 && i == 2 # Recovery: ν*I - J[i, i] += nu[i, j] * tau * p[2] + # Heuristic derivative: assume linear or quadratic propensity + rate(rate_new, u_new, p, t_prev + tau) + if rate_new[j] > 0 + # Estimate derivative based on stoichiometry + J[i, i] += nu[i, j] * tau * rate_new[j] / max(u_new[i], 1.0) end end end @@ -621,7 +633,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= method = use_implicit ? :implicit : :explicit # Check if tau1 is too small - if tau1 < 10 / a0 + if a0 > 0 && tau1 < 10 / a0 # Use SSA for a few steps steps = method == :implicit ? 10 : 100 for _ in 1:steps From 3a1a3c616189eac46b0afde01c675d80e7fe0ecb Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Fri, 1 Aug 2025 04:46:55 +0530 Subject: [PATCH 05/29] tranculation error fixed --- Project.toml | 2 + ext/implicit_tau.jl | 645 +++++++++++++++++++++--------------- src/simple_regular_solve.jl | 43 ++- test/gpu/implicit_tau.jl | 30 +- 4 files changed, 413 insertions(+), 307 deletions(-) diff --git a/Project.toml b/Project.toml index dfc8e5f5..91bc133f 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,8 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" diff --git a/ext/implicit_tau.jl b/ext/implicit_tau.jl index e0d083fd..05a3cca9 100644 --- a/ext/implicit_tau.jl +++ b/ext/implicit_tau.jl @@ -1,329 +1,412 @@ +# Ensemble solver for ImplicitTauLeaping function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem, alg::ImplicitTauLeaping, ensemblealg::EnsembleGPUKernel; trajectories, seed = nothing, - dt = error("dt is required for ImplicitTauLeaping."), + max_steps = 10000, kwargs...) if trajectories == 1 - return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); trajectories=1, - seed, dt, kwargs...) + return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); trajectories = 1, + seed, kwargs...) end - backend = ensemblealg.backend === nothing ? CPU() : ensemblealg.backend + ensemblealg.backend === nothing ? backend = CPU() : + backend = ensemblealg.backend jump_prob = ensembleprob.prob - - @assert isempty(jump_prob.jump_callback.continuous_callbacks) - @assert isempty(jump_prob.jump_callback.discrete_callbacks) - prob = jump_prob.prob probs = [remake(jump_prob) for _ in 1:trajectories] + # Debug: Verify p in probs + for i in 1:trajectories + @assert typeof(probs[i].prob.p) == NTuple{4, Float64} "p in probs[$i] must be NTuple{4, Float64}, got $(typeof(probs[i].prob.p)), p = $(probs[i].prob.p)" + end - ts, us = vectorized_solve(probs, jump_prob, alg; backend, trajectories, seed, dt, kwargs...) + ts, us = vectorized_solve(probs, jump_prob, alg; backend, trajectories, seed, max_steps) _ts = Array(ts) _us = Array(us) time = @elapsed sol = [begin - ts = @view _ts[:, i] - us = @view _us[:, :, i] - sol_idx = findlast(x -> x != probs[i].prob.tspan[1], ts) - if sol_idx === nothing - @error "No solution found" tspan=probs[i].tspan[1] ts - error("Batch solve failed") - end - @views ensembleprob.output_func( - SciMLBase.build_solution(probs[i].prob, - alg, - ts[1:sol_idx], - [us[j, :] for j in 1:sol_idx], - k = nothing, - stats = nothing, - calculate_error = false, - retcode = sol_idx != length(ts) ? ReturnCode.Terminated : ReturnCode.Success), - i)[1] - end for i in eachindex(probs)] + ts = @view _ts[:, i] + us = @view _us[:, :, i] + sol_idx = findlast(x -> x != probs[i].prob.tspan[1], ts) + if sol_idx === nothing + @error "No solution found" tspan=probs[i].tspan[1] ts + error("Batch solve failed") + end + @views ensembleprob.output_func( + SciMLBase.build_solution(probs[i].prob, + alg, + ts[1:sol_idx], + [us[j, :] for j in 1:sol_idx], + k = nothing, + stats = nothing, + calculate_error = false, + retcode = sol_idx != + length(ts) ? + ReturnCode.Terminated : + ReturnCode.Success), + i)[1] + end + for i in eachindex(probs)] return SciMLBase.EnsembleSolution(sol, time, true) end -struct TrajectoryDataImplicit{U <: StaticArray, P, T} +# Structs for trajectory and jump data +struct ImplicitTauLeapingTrajectoryData{U <: StaticArray, P, T} u0::U p::P tspan::Tuple{T, T} end -struct JumpDataImplicit{R, C, V} +struct ImplicitTauLeapingJumpData{R, C, N} rate::R c::C - nu::V numjumps::Int + nu::N end -function compute_tau_explicit(u, rate, nu, num_jumps, epsilon, g, J_ncr, I_rs, p) - rate_cache = zeros(eltype(u), num_jumps) - rate(rate_cache, u, p, 0.0) - - mu = zeros(eltype(u), length(u)) - sigma2 = zeros(eltype(u), length(u)) - - for i in I_rs - mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_ncr; init=0.0) - sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_ncr; init=0.0) - end - - tau = Inf - for i in I_rs - denom_mu = max(epsilon * u[i] / g[i], 1.0) - denom_sigma = denom_mu^2 - if abs(mu[i]) > 0 - tau = min(tau, denom_mu / abs(mu[i])) - end - if sigma2[i] > 0 - tau = min(tau, denom_sigma / sigma2[i]) - end - end - return tau -end - -function compute_tau_implicit(u, rate, nu, num_jumps, epsilon, g, J_necr, I_rs, p) - rate_cache = zeros(eltype(u), num_jumps) - rate(rate_cache, u, p, 0.0) - - mu = zeros(eltype(u), length(u)) - sigma2 = zeros(eltype(u), length(u)) - - for i in I_rs - mu[i] = sum(nu[i,j] * rate_cache[j] for j in J_necr; init=0.0) - sigma2[i] = sum(nu[i,j]^2 * rate_cache[j] for j in J_necr; init=0.0) - end - - tau = Inf - for i in I_rs - denom_mu = max(epsilon * u[i] / g[i], 1.0) - denom_sigma = denom_mu^2 - if abs(mu[i]) > 0 - tau = min(tau, denom_mu / abs(mu[i])) - end - if sigma2[i] > 0 - tau = min(tau, denom_sigma / sigma2[i]) - end - end - return isinf(tau) ? 1e6 : tau -end - -function identify_critical_reactions(u, nu, num_jumps, nc) - L = zeros(Int, num_jumps) - J_critical = Int[] - - for j in 1:num_jumps - min_val = Inf - for i in 1:length(u) - if nu[i,j] < 0 - val = floor(u[i] / abs(nu[i,j])) - min_val = min(min_val, val) - end - end - L[j] = min_val == Inf ? typemax(Int) : Int(min_val) - if L[j] < nc - push!(J_critical, j) - end - end - J_ncr = setdiff(1:num_jumps, J_critical) - return J_critical, J_ncr +struct ImplicitTauLeapingData + epsilon::Float64 + nc::Int + nstiff::Float64 + delta::Float64 end -function check_partial_equilibrium(rate_cache, reversible_pairs, delta) - J_equilibrium = Int[] - for (j_plus, j_minus) in reversible_pairs - a_plus = rate_cache[j_plus] - a_minus = rate_cache[j_minus] - if abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) - push!(J_equilibrium, j_plus, j_minus) - end - end - return J_equilibrium -end - -function newton_solve!(x_new, x, rate, nu, rate_cache, counts, p, t, tau, max_iter=10, tol=1e-6) - state_dim = length(x) - num_jumps = length(counts) - - x_temp = copy(x_new) - for iter in 1:max_iter - rate(rate_cache, x_new, p, t) - rate_cache .*= tau - - residual = x_new .- x - for j in 1:num_jumps - residual .-= nu[:,j] * (counts[j] - rate_cache[j] + tau * rate_cache[j]) - end - - if norm(residual) < tol - break - end - - J = zeros(eltype(x), state_dim, state_dim) - for j in 1:num_jumps - for i in 1:state_dim - for k in 1:state_dim - J[i,k] += nu[i,j] * nu[k,j] * rate_cache[j] - end - end - end - J = I - tau * J - - delta_x = J \ residual - x_new .-= delta_x - - if norm(delta_x) < tol - break - end - end - return x_new -end - -@kernel function implicit_tau_leaping_kernel(@Const(probs_data), _us, _ts, dt, @Const(rj_data), +# ImplicitTauLeaping kernel +@kernel function implicit_tau_leaping_kernel(@Const(probs_data), _us, _ts, @Const(rj_data), @Const(alg_data), current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, - seed::UInt64, alg::ImplicitTauLeaping, reversible_pairs) + mu_buf, sigma2_buf, critical_buf, rate_new_buf, residual_buf, J_buf, + seed::UInt64, max_steps) i = @index(Global, Linear) + # Thread-local buffers @inbounds begin current_u = view(current_u_buf, :, i) rate_cache = view(rate_cache_buf, :, i) counts = view(counts_buf, :, i) local_dc = view(local_dc_buf, :, i) + mu = view(mu_buf, :, i) + sigma2 = view(sigma2_buf, :, i) + critical = view(critical_buf, :, i) + rate_new = view(rate_new_buf, :, i) + residual = view(residual_buf, :, i) + J = view(J_buf, :, :, i) end + # Problem data @inbounds prob_data = probs_data[i] u0 = prob_data.u0 p = prob_data.p tspan = prob_data.tspan + t_end = tspan[2] + # Jump data rate = rj_data.rate num_jumps = rj_data.numjumps c = rj_data.c nu = rj_data.nu + # Algorithm parameters + epsilon = alg_data.epsilon + nc = alg_data.nc + nstiff = alg_data.nstiff + delta = alg_data.delta + + # Initialize state @inbounds for k in 1:length(u0) current_u[k] = u0[k] end - - n = Int((tspan[2] - tspan[1]) / dt) + 1 state_dim = length(u0) + # Output arrays ts_view = @inbounds view(_ts, :, i) us_view = @inbounds view(_us, :, :, i) - @inbounds ts_view[1] = tspan[1] @inbounds for k in 1:state_dim us_view[1, k] = current_u[k] end - rng = Random.Xoshiro(seed + i) + # Debug: Check parameter type + if i == 1 + @assert typeof(p) == NTuple{4, Float64} "p must be a tuple of 4 Float64 values, got $(typeof(p)), p = $p" + @show p + @show typeof(rate) + end + + # Find reversible pairs + equilibrium_pairs = Tuple{Int,Int}[] + for j in 1:num_jumps + for k in (j+1):num_jumps + if all(nu[l, j] == -nu[l, k] for l in 1:state_dim) + push!(equilibrium_pairs, (j, k)) + end + end + end - I_rs = 1:state_dim - g = ones(state_dim) + # Helper functions + function compute_gi(u, t) + max_order = 1.0 + for j in 1:num_jumps + if any(abs.(nu[:, j]) .> 0) + # Debug: Check p and rate before calling + if i == 1 + @show p + @show typeof(rate) + end + rate(rate_cache, u, p, t) + if rate_cache[j] > 0 + order = sum(abs.(nu[:, j])) > abs(nu[findfirst(abs.(nu[:, j]) .> 0), j]) ? 2.0 : 1.0 + max_order = max(max_order, order) + end + end + end + max_order + end + + function compute_tau_explicit(u, t) + # Debug: Check p and rate before calling + if i == 1 + @show p + @show typeof(rate) + end + rate(rate_cache, u, p, t) + mu .= 0.0 + sigma2 .= 0.0 + tau = Inf + for l in 1:state_dim + for j in 1:num_jumps + mu[l] += nu[l, j] * rate_cache[j] + sigma2[l] += nu[l, j]^2 * rate_cache[j] + end + gi = compute_gi(u, t) + bound = max(epsilon * u[l] / gi, 1.0) + mu_term = abs(mu[l]) > 0 ? bound / abs(mu[l]) : Inf + sigma_term = sigma2[l] > 0 ? bound^2 / sigma2[l] : Inf + tau = min(tau, mu_term, sigma_term) + end + tau + end - for j in 2:n - tprev = tspan[1] + (j-2) * dt + function is_partial_equilibrium(rate_cache, j_plus, j_minus) + a_plus = rate_cache[j_plus] + a_minus = rate_cache[j_minus] + abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) + end - J_critical, J_ncr = identify_critical_reactions(current_u, nu, num_jumps, alg.nc) + function compute_tau_implicit(u, t) + # Debug: Check p and rate before calling + if i == 1 + @show p + @show typeof(rate) + end + rate(rate_cache, u, p, t) + mu .= 0.0 + sigma2 .= 0.0 + non_equilibrium = trues(num_jumps) + for (j_plus, j_minus) in equilibrium_pairs + if is_partial_equilibrium(rate_cache, j_plus, j_minus) + non_equilibrium[j_plus] = false + non_equilibrium[j_minus] = false + end + end + tau = Inf + for l in 1:state_dim + for j in 1:num_jumps + if non_equilibrium[j] + mu[l] += nu[l, j] * rate_cache[j] + sigma2[l] += nu[l, j]^2 * rate_cache[j] + end + end + gi = compute_gi(u, t) + bound = max(epsilon * u[l] / gi, 1.0) + mu_term = abs(mu[l]) > 0 ? bound / abs(mu[l]) : Inf + sigma_term = sigma2[l] > 0 ? bound^2 / sigma2[l] : Inf + tau = min(tau, mu_term, sigma_term) + end + tau + end - rate(rate_cache, current_u, p, tprev) - a0_critical = sum(rate_cache[j] for j in J_critical; init=0.0) + function identify_critical_reactions(u) + critical .= false + for j in 1:num_jumps + if rate_cache[j] > 0 + Lj = Inf + for l in 1:state_dim + if nu[l, j] < 0 + Lj = min(Lj, floor(Int, u[l] / abs(nu[l, j]))) + end + end + if Lj < nc + critical[j] = true + end + end + end + end - J_equilibrium = check_partial_equilibrium(rate_cache, reversible_pairs, alg.delta) - J_necr = setdiff(J_ncr, J_equilibrium) + function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, u_new) + u_new .= u_prev + tol = 1e-6 + max_iter = 100 + for iter in 1:max_iter + # Debug: Check p and rate before calling + if i == 1 + @show p + @show typeof(rate) + end + rate(rate_new, u_new, p, t_prev + tau) + residual .= u_new .- u_prev + for j in 1:num_jumps + for l in 1:state_dim + residual[l] -= nu[l, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j]) + end + end + if norm(residual) < tol + break + end + J .= 0.0 + for l in 1:state_dim + J[l, l] = 1.0 + for j in 1:num_jumps + if rate_new[j] > 0 + J[l, l] += nu[l, j] * tau * rate_new[j] / max(u_new[l], 1.0) + end + end + end + u_new .-= J \ residual + u_new .= max.(u_new, 0.0) + end + u_new .= round.(Int, u_new) + # Debug: Check p and c before calling + if i == 1 + @assert typeof(p) == NTuple{4, Float64} "p must be a tuple of 4 Float64 values before c, got $(typeof(p)), p = $p" + @show p + @show typeof(c) + end + c(local_dc, u_new, p, t_prev + tau, counts, nothing) + u_new .+= local_dc + end - tau_ex = compute_tau_explicit(current_u, rate, nu, num_jumps, alg.epsilon, g, J_ncr, I_rs, p) - tau_im = compute_tau_implicit(current_u, rate, nu, num_jumps, alg.epsilon, g, J_necr, I_rs, p) + function use_down_shifting(t, tau_im, tau_ex, a0) + t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) + end - tau2 = a0_critical > 0 ? -log(rand(rng)) / a0_critical : Inf + # Thread-local RNG + local rng_state = seed ⊻ UInt64(i) + function thread_rand() + rng_state = (1103515245 * rng_state + 12345) & 0x7fffffff + rng_state / 0x7fffffff + end + function thread_randexp() + -log(thread_rand()) + end + function thread_poisson(lambda) + L = exp(-lambda) + k = 0 + p = 1.0 + while p > L + k += 1 + p *= thread_rand() + end + k - 1 + end - use_implicit = tau_im > alg.nstiff * tau_ex + # Main simulation loop + step = 1 + t = tspan[1] + while t < t_end && step < max_steps + step += 1 + # Debug: Check p and rate before calling + if i == 1 + @show p + @show typeof(rate) + end + rate(rate_cache, current_u, p, t) + identify_critical_reactions(current_u) + tau_ex = compute_tau_explicit(current_u, t) + tau_im = compute_tau_implicit(current_u, t) + ac0 = sum(rate_cache[critical]) + tau2 = ac0 > 0 ? thread_randexp() / ac0 : Inf + a0 = sum(rate_cache) + use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t, tau_im, tau_ex, a0) tau1 = use_implicit ? tau_im : tau_ex - if tau1 < 10 / sum(rate_cache; init=0.0) - a0 = sum(rate_cache; init=0.0) - if a0 > 0 - tau = -log(rand(rng)) / a0 - r = rand(rng) * a0 - cumsum_a = 0.0 - jc = 1 - for k in 1:num_jumps - cumsum_a += rate_cache[k] - if cumsum_a > r - jc = k + if a0 > 0 && tau1 < 10 / a0 + steps = use_implicit ? 10 : 100 + for _ in 1:steps + if t >= t_end + break + end + rate(rate_cache, current_u, p, t) + a0 = sum(rate_cache) + if a0 == 0 + break + end + tau = thread_randexp() / a0 + r = thread_rand() * a0 + cumsum_rate = 0.0 + for j in 1:num_jumps + cumsum_rate += rate_cache[j] + if cumsum_rate > r + current_u .+= nu[:, j] break end end - current_u .+= nu[:,jc] + t += tau + if step <= max_steps + @inbounds ts_view[step] = t + @inbounds for k in 1:state_dim + us_view[step, k] = current_u[k] + end + step += 1 + end + end + continue + end + + if tau2 > tau1 + tau = min(1.0, t_end - t) + counts .= 0 + for j in 1:num_jumps + if !critical[j] + counts[j] = thread_poisson(rate_cache[j] * tau) + end + end + if use_implicit + implicit_tau_step(current_u, t, tau, rate_cache, counts, current_u) else - tau = dt + c(local_dc, current_u, p, t, counts, nothing) + current_u .+= local_dc end else - tau = min(tau1, tau2, dt) - if tau == tau2 - if a0_critical > 0 - r = rand(rng) * a0_critical - cumsum_a = 0.0 - jc = J_critical[1] - for k in J_critical - cumsum_a += rate_cache[k] - if cumsum_a > r - jc = k + tau = min(1.0, t_end - t) + counts .= 0 + if ac0 > 0 + r = thread_rand() * ac0 + cumsum_rate = 0.0 + for j in 1:num_jumps + if critical[j] + cumsum_rate += rate_cache[j] + if cumsum_rate > r + counts[j] = 1 break end end - counts .= 0 - counts[jc] = 1 - if use_implicit && tau > tau_ex - for k in J_ncr - counts[k] = poisson_rand(rate_cache[k] * tau, rng) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) - else - for k in J_ncr - counts[k] = poisson_rand(rate_cache[k] * tau, rng) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .+= local_dc - end - else - tau = tau1 - if use_implicit - for k in 1:num_jumps - counts[k] = poisson_rand(rate_cache[k] * tau, rng) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) - else - for k in 1:num_jumps - counts[k] = poisson_rand(rate_cache[k] * tau, rng) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .+= local_dc - end end - else - counts .= 0 - if use_implicit - for k in J_ncr - counts[k] = poisson_rand(rate_cache[k] * tau, rng) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .= newton_solve!(current_u .+ local_dc, current_u, rate, nu, rate_cache, counts, p, tprev, tau) - else - for k in J_ncr - counts[k] = poisson_rand(rate_cache[k] * tau, rng) - end - c(local_dc, current_u, p, tprev, counts, nothing) - current_u .+= local_dc + end + for j in 1:num_jumps + if !critical[j] + counts[j] = thread_poisson(rate_cache[j] * tau) end end + if use_implicit && tau > tau_ex + implicit_tau_step(current_u, t, tau, rate_cache, counts, current_u) + else + c(local_dc, current_u, p, t, counts, nothing) + current_u .+= local_dc + end end if any(current_u .< 0) @@ -331,45 +414,63 @@ end continue end - @inbounds for k in 1:state_dim - us_view[j, k] = current_u[k] + t += tau + if step <= max_steps + @inbounds ts_view[step] = t + @inbounds for k in 1:state_dim + us_view[step, k] = current_u[k] + end end - @inbounds ts_view[j] = tspan[1] + (j-1) * dt end end -function vectorized_solve(probs, prob::JumpProblem, alg::ImplicitTauLeaping; backend, trajectories, seed, dt, kwargs...) +# Vectorized solve for ImplicitTauLeaping +function vectorized_solve(probs, prob::JumpProblem, alg::ImplicitTauLeaping; backend, trajectories, seed, max_steps) rj = prob.regular_jump - nu = zeros(Int, length(prob.prob.u0), rj.numjumps) - for j in 1:rj.numjumps - dc = zeros(length(prob.prob.u0)) - rj.c(dc, prob.prob.u0, prob.prob.p, 0.0, [i == j ? 1 : 0 for i in 1:rj.numjumps], nothing) - nu[:,j] = dc + state_dim = length(prob.prob.u0) + p_correct = prob.prob.p # Store correct p + nu = let c = rj.c, u0 = prob.prob.u0, numjumps = rj.numjumps + nu = zeros(Int, state_dim, numjumps) + for j in 1:numjumps + counts = zeros(numjumps) + counts[j] = 1 + du = similar(u0) + c(du, u0, p_correct, prob.prob.tspan[1], counts, nothing) + nu[:, j] = round.(Int, du) + end + nu end - rj_data = JumpDataImplicit(rj.rate, rj.c, nu, rj.numjumps) + # Explicitly bind p_correct to both c and rate + c_fixed = (du, u, p, t, counts, mark) -> rj.c(du, u, p_correct, t, counts, mark) + rate_fixed = (out, u, p, t) -> rj.rate(out, u, p_correct, t) + rj_data = ImplicitTauLeapingJumpData(rate_fixed, c_fixed, rj.numjumps, nu) + alg_data = ImplicitTauLeapingData(alg.epsilon, alg.nc, alg.nstiff, alg.delta) - probs_data = [TrajectoryDataImplicit(SA{eltype(p.prob.u0)}[p.prob.u0...], p.prob.p, p.prob.tspan) for p in probs] + probs_data = [ImplicitTauLeapingTrajectoryData(SA{eltype(p.prob.u0)}[p.prob.u0...], p_correct, p.prob.tspan) for p in probs] probs_data_gpu = adapt(backend, probs_data) rj_data_gpu = adapt(backend, rj_data) + alg_data_gpu = adapt(backend, alg_data) - state_dim = length(first(probs_data).u0) tspan = prob.prob.tspan - dt = Float64(dt) - n_steps = Int((tspan[2] - tspan[1]) / dt) + 1 - n_trajectories = length(probs) num_jumps = rj_data.numjumps @assert state_dim > 0 "Dimension of state must be positive" @assert num_jumps >= 0 "Number of jumps must be positive" - ts = allocate(backend, eltype(prob.prob.tspan), (n_steps, n_trajectories)) - us = allocate(backend, eltype(prob.prob.u0), (n_steps, state_dim, n_trajectories)) + ts = allocate(backend, eltype(prob.prob.tspan), (max_steps, trajectories)) + us = allocate(backend, eltype(prob.prob.u0), (max_steps, state_dim, trajectories)) - current_u_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, n_trajectories)) - rate_cache_buf = allocate(backend, eltype(prob.prob.u0), (num_jumps, n_trajectories)) - counts_buf = allocate(backend, eltype(prob.prob.u0), (num_jumps, n_trajectories)) - local_dc_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, n_trajectories)) + current_u_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) + rate_cache_buf = allocate(backend, eltype(prob.prob.u0), (num_jumps, trajectories)) + counts_buf = allocate(backend, Int, (num_jumps, trajectories)) + local_dc_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) + mu_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) + sigma2_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) + critical_buf = allocate(backend, Bool, (num_jumps, trajectories)) + rate_new_buf = allocate(backend, eltype(prob.prob.u0), (num_jumps, trajectories)) + residual_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) + J_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, state_dim, trajectories)) @kernel function init_buffers_kernel(@Const(probs_data), current_u_buf) i = @index(Global, Linear) @@ -379,28 +480,24 @@ function vectorized_solve(probs, prob::JumpProblem, alg::ImplicitTauLeaping; bac end end init_kernel = init_buffers_kernel(backend) - init_event = init_kernel(probs_data_gpu, current_u_buf; ndrange=n_trajectories) + init_event = init_kernel(probs_data_gpu, current_u_buf; ndrange=trajectories) KernelAbstractions.synchronize(backend) seed = seed === nothing ? UInt64(12345) : UInt64(seed) - reversible_pairs = get(kwargs, :reversible_pairs, Tuple{Int,Int}[]) + + # Debug: Verify parameters before kernel launch + @assert all(typeof(p.prob.p) == NTuple{4, Float64} for p in probs) "All problems must have p as NTuple{4, Float64}" + @show typeof(probs[1].prob.p) + @show probs[1].prob.p + @show typeof(rj_data.rate) + @show typeof(rj_data.c) kernel = implicit_tau_leaping_kernel(backend) - main_event = kernel(probs_data_gpu, us, ts, dt, rj_data_gpu, - current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, seed, alg, reversible_pairs; - ndrange=n_trajectories) + main_event = kernel(probs_data_gpu, us, ts, rj_data_gpu, alg_data_gpu, + current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, + mu_buf, sigma2_buf, critical_buf, rate_new_buf, residual_buf, J_buf, + seed, max_steps; ndrange=trajectories) KernelAbstractions.synchronize(backend) return ts, us end - -@inline function poisson_rand(lambda, rng) - L = exp(-lambda) - k = 0 - p = 1.0 - while p > L - k += 1 - p *= rand(rng) - end - return k - 1 -end diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 0dc625a3..df016ded 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -483,7 +483,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= rate(rate_cache, u, p, t[end]) if rate_cache[j] > 0 order = 1.0 - # Heuristic: if reaction involves multiple species, assume higher order if sum(abs.(nu[:, j])) > abs(nu[i, j]) order = 2.0 end @@ -511,7 +510,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf tau = min(tau, mu_term, sigma_term) end - return tau + return max(tau, 1e-10) # Prevent zero or negative tau end # Partial equilibrium check (Equation 13) @@ -547,7 +546,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf tau = min(tau, mu_term, sigma_term) end - return tau + return max(tau, 1e-10) # Prevent zero or negative tau end # Identify critical reactions @@ -574,7 +573,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= u_new = copy(u_prev) rate_new = zeros(numjumps) tol = 1e-6 - max_iter = 100 + max_iter = 50 for iter in 1:max_iter rate(rate_new, u_new, p, t_prev + tau) residual = u_new - u_prev @@ -584,27 +583,36 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= if norm(residual) < tol break end - # Approximate Jacobian (diagonal approximation for simplicity) + # Improved Jacobian approximation J = Diagonal(ones(length(u_new))) for j in 1:numjumps for i in 1:length(u_new) - # Heuristic derivative: assume linear or quadratic propensity - rate(rate_new, u_new, p, t_prev + tau) - if rate_new[j] > 0 - # Estimate derivative based on stoichiometry - J[i, i] += nu[i, j] * tau * rate_new[j] / max(u_new[i], 1.0) + if rate_new[j] > 0 && u_new[i] > 0 + # Scale derivative to prevent overflow + J[i, i] += nu[i, j] * tau * min(rate_new[j] / u_new[i], 1e3) end end end - u_new -= J \ residual + # Check for singular or ill-conditioned Jacobian + if any(abs.(diag(J)) .< 1e-10) + return u_prev # Revert to previous state if Jacobian is singular + end + delta_u = J \ residual + # Limit step size to prevent overflow + delta_u = clamp.(delta_u, -1e3, 1e3) + u_new -= delta_u u_new = max.(u_new, 0.0) + # Check for numerical overflow + if any(isnan.(u_new)) || any(isinf.(u_new)) + return u_prev + end end - return round.(Int, u_new) + return round.(Int, max.(u_new, 0.0)) end # Down-shifting condition (Equation 19) function use_down_shifting(t, tau_im, tau_ex, a0, t_end) - return t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) + return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) end # Main simulation loop @@ -628,10 +636,13 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= # Choose method and stepsize a0 = sum(rate_cache) - use_implicit = tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end) + use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end) tau1 = use_implicit ? tau_im : tau_ex method = use_implicit ? :implicit : :explicit + # Cap tau to prevent large updates + tau1 = min(tau1, 1.0) + # Check if tau1 is too small if a0 > 0 && tau1 < 10 / a0 # Use SSA for a few steps @@ -668,7 +679,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= counts .= 0 for j in 1:numjumps if !critical[j] - counts[j] = pois_rand(rng, rate_cache[j] * tau) + counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0)) end end if method == :implicit @@ -695,7 +706,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= end for j in 1:numjumps if !critical[j] - counts[j] = pois_rand(rng, rate_cache[j] * tau) + counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0)) end end if method == :implicit && tau > tau_ex diff --git a/test/gpu/implicit_tau.jl b/test/gpu/implicit_tau.jl index 459578b2..d150c00f 100644 --- a/test/gpu/implicit_tau.jl +++ b/test/gpu/implicit_tau.jl @@ -1,14 +1,14 @@ using JumpProcesses, DiffEqBase using Test, LinearAlgebra, Statistics using KernelAbstractions, Adapt, CUDA -using StableRNGs, Plots +using StableRNGs rng = StableRNG(12345) -Nsims = 10 +Nsims = 1 -# Parameters +# Decaying Dimerization Model c1 = 1.0 # S1 -> 0 c2 = 10.0 # S1 + S1 <- S2 c3 = 1000.0 # S1 + S1 -> S2 @@ -17,6 +17,7 @@ p = (c1, c2, c3, c4) # Propensity functions regular_rate = (out, u, p, t) -> begin + @assert typeof(p) == NTuple{4, Float64} "p must be a tuple of 4 Float64 values" out[1] = p[1] * u[1] # S1 -> 0 out[2] = p[2] * u[2] # S1 + S1 <- S2 out[3] = p[3] * u[1] * (u[1] - 1) / 2 # S1 + S1 -> S2 @@ -24,27 +25,22 @@ regular_rate = (out, u, p, t) -> begin end # State change function -regular_c = (dc, u, p, t, counts, mark) -> begin - dc .= 0.0 - dc[1] = -counts[1] - 2 * counts[3] + 2 * counts[2] # S1: -decay - 2*forward + 2*backward - dc[2] = counts[3] - counts[2] - counts[4] # S2: +forward - backward - decay - dc[3] = counts[4] # S3: +decay +regular_c = (du, u, p, t, counts, mark) -> begin + @assert typeof(p) == NTuple{4, Float64} "p must be a tuple of 4 Float64 values" + du .= 0.0 + du[1] = -counts[1] - 2 * counts[3] + 2 * counts[2] # S1: -decay - 2*forward + 2*backward + du[2] = counts[3] - counts[2] - counts[4] # S2: +forward - backward - decay + du[3] = counts[4] # S3: +decay end # Initial condition u0 = [10000.0, 0.0, 0.0] # S1, S2, S3 tspan = (0.0, 4.0) -# Define reversible reaction pairs (R2 and R3 are reversible: S1 + S1 <-> S2) -reversible_pairs = [(2, 3)] - -# Create JumpProblem with proper parameter passing +# Create JumpProblem prob_disc = DiscreteProblem(u0, tspan, p) rj = RegularJump(regular_rate, regular_c, 4) -jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) +jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=rng) # Solve using ImplicitTauLeaping -alg = ImplicitTauLeaping(epsilon=0.05, nc=10, nstiff=100, delta=0.05) -sol = solve(EnsembleProblem(jump_prob), alg, EnsembleGPUKernel(); - trajectories=Nsims, dt=0.01, reversible_pairs=reversible_pairs) -plot(sol) +sol = solve(EnsembleProblem(jump_prob), ImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims) From 1c368dec54071242eaf7d8777b2499b34e045429 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Fri, 1 Aug 2025 18:24:38 +0530 Subject: [PATCH 06/29] nonlinearsolver is implemented --- src/simple_regular_solve.jl | 63 ++++++++++++++----------------------- 1 file changed, 24 insertions(+), 39 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index df016ded..7416713f 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -417,7 +417,7 @@ end ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) = ImplicitTauLeaping(epsilon, nc, nstiff, delta) -function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed=nothing) +function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed = nothing) # Boilerplate setup @assert isempty(jump_prob.jump_callback.continuous_callbacks) @assert isempty(jump_prob.jump_callback.discrete_callbacks) @@ -510,7 +510,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf tau = min(tau, mu_term, sigma_term) end - return max(tau, 1e-10) # Prevent zero or negative tau + return max(tau, 1e-10) end # Partial equilibrium check (Equation 13) @@ -546,7 +546,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf tau = min(tau, mu_term, sigma_term) end - return max(tau, 1e-10) # Prevent zero or negative tau + return max(tau, 1e-10) end # Identify critical reactions @@ -568,46 +568,32 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= return critical end - # Implicit tau-leaping step with Newton's method + # Implicit tau-leaping step using NonlinearSolve function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p) - u_new = copy(u_prev) - rate_new = zeros(numjumps) - tol = 1e-6 - max_iter = 50 - for iter in 1:max_iter + # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0 + function f(u_new, params) + rate_new = similar(rate_cache, eltype(u_new)) rate(rate_new, u_new, p, t_prev + tau) residual = u_new - u_prev for j in 1:numjumps residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j]) end - if norm(residual) < tol - break - end - # Improved Jacobian approximation - J = Diagonal(ones(length(u_new))) - for j in 1:numjumps - for i in 1:length(u_new) - if rate_new[j] > 0 && u_new[i] > 0 - # Scale derivative to prevent overflow - J[i, i] += nu[i, j] * tau * min(rate_new[j] / u_new[i], 1e3) - end - end - end - # Check for singular or ill-conditioned Jacobian - if any(abs.(diag(J)) .< 1e-10) - return u_prev # Revert to previous state if Jacobian is singular - end - delta_u = J \ residual - # Limit step size to prevent overflow - delta_u = clamp.(delta_u, -1e3, 1e3) - u_new -= delta_u - u_new = max.(u_new, 0.0) - # Check for numerical overflow - if any(isnan.(u_new)) || any(isinf.(u_new)) - return u_prev - end + return residual end - return round.(Int, max.(u_new, 0.0)) + + # Initial guess + u_new = copy(u_prev) + + # Solve the nonlinear system + prob = NonlinearProblem(f, u_new, nothing) + sol = solve(prob, NewtonRaphson()) + + # Check for convergence and numerical stability + if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) + return round.(Int, max.(u_prev, 0.0)) # Revert to previous state + end + + return round.(Int, max.(sol.u, 0.0)) end # Down-shifting condition (Equation 19) @@ -730,9 +716,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed= # Build solution sol = DiffEqBase.build_solution(prob, alg, t, u, - calculate_error=false, - interp=DiffEqBase.ConstantInterpolation(t, u)) - return sol + calculate_error = false, + interp = DiffEqBase.ConstantInterpolation(t, u)) end struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm From c92e30fab701aa81623419ca4c6a4cecf5f1fbdc Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sun, 10 Aug 2025 03:35:32 +0530 Subject: [PATCH 07/29] changed to SimpleImplicitTauLeaping --- ext/JumpProcessesKernelAbstractionsExt.jl | 2 +- ext/implicit_tau.jl | 503 ---------------------- src/simple_regular_solve.jl | 12 +- test/gpu/implicit_tau.jl | 46 -- 4 files changed, 7 insertions(+), 556 deletions(-) delete mode 100644 ext/implicit_tau.jl delete mode 100644 test/gpu/implicit_tau.jl diff --git a/ext/JumpProcessesKernelAbstractionsExt.jl b/ext/JumpProcessesKernelAbstractionsExt.jl index ae38cda6..2b345ebc 100644 --- a/ext/JumpProcessesKernelAbstractionsExt.jl +++ b/ext/JumpProcessesKernelAbstractionsExt.jl @@ -1,6 +1,6 @@ module JumpProcessesKernelAbstractionsExt -using JumpProcesses, SciMLBase, DiffEqBase +using JumpProcesses, SciMLBase using KernelAbstractions, Adapt using StaticArrays using PoissonRandom, Random diff --git a/ext/implicit_tau.jl b/ext/implicit_tau.jl deleted file mode 100644 index 05a3cca9..00000000 --- a/ext/implicit_tau.jl +++ /dev/null @@ -1,503 +0,0 @@ -# Ensemble solver for ImplicitTauLeaping -function SciMLBase.__solve(ensembleprob::SciMLBase.AbstractEnsembleProblem, - alg::ImplicitTauLeaping, - ensemblealg::EnsembleGPUKernel; - trajectories, - seed = nothing, - max_steps = 10000, - kwargs...) - - if trajectories == 1 - return SciMLBase.__solve(ensembleprob, alg, EnsembleSerial(); trajectories = 1, - seed, kwargs...) - end - - ensemblealg.backend === nothing ? backend = CPU() : - backend = ensemblealg.backend - - jump_prob = ensembleprob.prob - - probs = [remake(jump_prob) for _ in 1:trajectories] - # Debug: Verify p in probs - for i in 1:trajectories - @assert typeof(probs[i].prob.p) == NTuple{4, Float64} "p in probs[$i] must be NTuple{4, Float64}, got $(typeof(probs[i].prob.p)), p = $(probs[i].prob.p)" - end - - ts, us = vectorized_solve(probs, jump_prob, alg; backend, trajectories, seed, max_steps) - - _ts = Array(ts) - _us = Array(us) - - time = @elapsed sol = [begin - ts = @view _ts[:, i] - us = @view _us[:, :, i] - sol_idx = findlast(x -> x != probs[i].prob.tspan[1], ts) - if sol_idx === nothing - @error "No solution found" tspan=probs[i].tspan[1] ts - error("Batch solve failed") - end - @views ensembleprob.output_func( - SciMLBase.build_solution(probs[i].prob, - alg, - ts[1:sol_idx], - [us[j, :] for j in 1:sol_idx], - k = nothing, - stats = nothing, - calculate_error = false, - retcode = sol_idx != - length(ts) ? - ReturnCode.Terminated : - ReturnCode.Success), - i)[1] - end - for i in eachindex(probs)] - return SciMLBase.EnsembleSolution(sol, time, true) -end - -# Structs for trajectory and jump data -struct ImplicitTauLeapingTrajectoryData{U <: StaticArray, P, T} - u0::U - p::P - tspan::Tuple{T, T} -end - -struct ImplicitTauLeapingJumpData{R, C, N} - rate::R - c::C - numjumps::Int - nu::N -end - -struct ImplicitTauLeapingData - epsilon::Float64 - nc::Int - nstiff::Float64 - delta::Float64 -end - -# ImplicitTauLeaping kernel -@kernel function implicit_tau_leaping_kernel(@Const(probs_data), _us, _ts, @Const(rj_data), @Const(alg_data), - current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, - mu_buf, sigma2_buf, critical_buf, rate_new_buf, residual_buf, J_buf, - seed::UInt64, max_steps) - i = @index(Global, Linear) - - # Thread-local buffers - @inbounds begin - current_u = view(current_u_buf, :, i) - rate_cache = view(rate_cache_buf, :, i) - counts = view(counts_buf, :, i) - local_dc = view(local_dc_buf, :, i) - mu = view(mu_buf, :, i) - sigma2 = view(sigma2_buf, :, i) - critical = view(critical_buf, :, i) - rate_new = view(rate_new_buf, :, i) - residual = view(residual_buf, :, i) - J = view(J_buf, :, :, i) - end - - # Problem data - @inbounds prob_data = probs_data[i] - u0 = prob_data.u0 - p = prob_data.p - tspan = prob_data.tspan - t_end = tspan[2] - - # Jump data - rate = rj_data.rate - num_jumps = rj_data.numjumps - c = rj_data.c - nu = rj_data.nu - - # Algorithm parameters - epsilon = alg_data.epsilon - nc = alg_data.nc - nstiff = alg_data.nstiff - delta = alg_data.delta - - # Initialize state - @inbounds for k in 1:length(u0) - current_u[k] = u0[k] - end - state_dim = length(u0) - - # Output arrays - ts_view = @inbounds view(_ts, :, i) - us_view = @inbounds view(_us, :, :, i) - @inbounds ts_view[1] = tspan[1] - @inbounds for k in 1:state_dim - us_view[1, k] = current_u[k] - end - - # Debug: Check parameter type - if i == 1 - @assert typeof(p) == NTuple{4, Float64} "p must be a tuple of 4 Float64 values, got $(typeof(p)), p = $p" - @show p - @show typeof(rate) - end - - # Find reversible pairs - equilibrium_pairs = Tuple{Int,Int}[] - for j in 1:num_jumps - for k in (j+1):num_jumps - if all(nu[l, j] == -nu[l, k] for l in 1:state_dim) - push!(equilibrium_pairs, (j, k)) - end - end - end - - # Helper functions - function compute_gi(u, t) - max_order = 1.0 - for j in 1:num_jumps - if any(abs.(nu[:, j]) .> 0) - # Debug: Check p and rate before calling - if i == 1 - @show p - @show typeof(rate) - end - rate(rate_cache, u, p, t) - if rate_cache[j] > 0 - order = sum(abs.(nu[:, j])) > abs(nu[findfirst(abs.(nu[:, j]) .> 0), j]) ? 2.0 : 1.0 - max_order = max(max_order, order) - end - end - end - max_order - end - - function compute_tau_explicit(u, t) - # Debug: Check p and rate before calling - if i == 1 - @show p - @show typeof(rate) - end - rate(rate_cache, u, p, t) - mu .= 0.0 - sigma2 .= 0.0 - tau = Inf - for l in 1:state_dim - for j in 1:num_jumps - mu[l] += nu[l, j] * rate_cache[j] - sigma2[l] += nu[l, j]^2 * rate_cache[j] - end - gi = compute_gi(u, t) - bound = max(epsilon * u[l] / gi, 1.0) - mu_term = abs(mu[l]) > 0 ? bound / abs(mu[l]) : Inf - sigma_term = sigma2[l] > 0 ? bound^2 / sigma2[l] : Inf - tau = min(tau, mu_term, sigma_term) - end - tau - end - - function is_partial_equilibrium(rate_cache, j_plus, j_minus) - a_plus = rate_cache[j_plus] - a_minus = rate_cache[j_minus] - abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) - end - - function compute_tau_implicit(u, t) - # Debug: Check p and rate before calling - if i == 1 - @show p - @show typeof(rate) - end - rate(rate_cache, u, p, t) - mu .= 0.0 - sigma2 .= 0.0 - non_equilibrium = trues(num_jumps) - for (j_plus, j_minus) in equilibrium_pairs - if is_partial_equilibrium(rate_cache, j_plus, j_minus) - non_equilibrium[j_plus] = false - non_equilibrium[j_minus] = false - end - end - tau = Inf - for l in 1:state_dim - for j in 1:num_jumps - if non_equilibrium[j] - mu[l] += nu[l, j] * rate_cache[j] - sigma2[l] += nu[l, j]^2 * rate_cache[j] - end - end - gi = compute_gi(u, t) - bound = max(epsilon * u[l] / gi, 1.0) - mu_term = abs(mu[l]) > 0 ? bound / abs(mu[l]) : Inf - sigma_term = sigma2[l] > 0 ? bound^2 / sigma2[l] : Inf - tau = min(tau, mu_term, sigma_term) - end - tau - end - - function identify_critical_reactions(u) - critical .= false - for j in 1:num_jumps - if rate_cache[j] > 0 - Lj = Inf - for l in 1:state_dim - if nu[l, j] < 0 - Lj = min(Lj, floor(Int, u[l] / abs(nu[l, j]))) - end - end - if Lj < nc - critical[j] = true - end - end - end - end - - function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, u_new) - u_new .= u_prev - tol = 1e-6 - max_iter = 100 - for iter in 1:max_iter - # Debug: Check p and rate before calling - if i == 1 - @show p - @show typeof(rate) - end - rate(rate_new, u_new, p, t_prev + tau) - residual .= u_new .- u_prev - for j in 1:num_jumps - for l in 1:state_dim - residual[l] -= nu[l, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j]) - end - end - if norm(residual) < tol - break - end - J .= 0.0 - for l in 1:state_dim - J[l, l] = 1.0 - for j in 1:num_jumps - if rate_new[j] > 0 - J[l, l] += nu[l, j] * tau * rate_new[j] / max(u_new[l], 1.0) - end - end - end - u_new .-= J \ residual - u_new .= max.(u_new, 0.0) - end - u_new .= round.(Int, u_new) - # Debug: Check p and c before calling - if i == 1 - @assert typeof(p) == NTuple{4, Float64} "p must be a tuple of 4 Float64 values before c, got $(typeof(p)), p = $p" - @show p - @show typeof(c) - end - c(local_dc, u_new, p, t_prev + tau, counts, nothing) - u_new .+= local_dc - end - - function use_down_shifting(t, tau_im, tau_ex, a0) - t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) - end - - # Thread-local RNG - local rng_state = seed ⊻ UInt64(i) - function thread_rand() - rng_state = (1103515245 * rng_state + 12345) & 0x7fffffff - rng_state / 0x7fffffff - end - function thread_randexp() - -log(thread_rand()) - end - function thread_poisson(lambda) - L = exp(-lambda) - k = 0 - p = 1.0 - while p > L - k += 1 - p *= thread_rand() - end - k - 1 - end - - # Main simulation loop - step = 1 - t = tspan[1] - while t < t_end && step < max_steps - step += 1 - # Debug: Check p and rate before calling - if i == 1 - @show p - @show typeof(rate) - end - rate(rate_cache, current_u, p, t) - identify_critical_reactions(current_u) - tau_ex = compute_tau_explicit(current_u, t) - tau_im = compute_tau_implicit(current_u, t) - ac0 = sum(rate_cache[critical]) - tau2 = ac0 > 0 ? thread_randexp() / ac0 : Inf - a0 = sum(rate_cache) - use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t, tau_im, tau_ex, a0) - tau1 = use_implicit ? tau_im : tau_ex - - if a0 > 0 && tau1 < 10 / a0 - steps = use_implicit ? 10 : 100 - for _ in 1:steps - if t >= t_end - break - end - rate(rate_cache, current_u, p, t) - a0 = sum(rate_cache) - if a0 == 0 - break - end - tau = thread_randexp() / a0 - r = thread_rand() * a0 - cumsum_rate = 0.0 - for j in 1:num_jumps - cumsum_rate += rate_cache[j] - if cumsum_rate > r - current_u .+= nu[:, j] - break - end - end - t += tau - if step <= max_steps - @inbounds ts_view[step] = t - @inbounds for k in 1:state_dim - us_view[step, k] = current_u[k] - end - step += 1 - end - end - continue - end - - if tau2 > tau1 - tau = min(1.0, t_end - t) - counts .= 0 - for j in 1:num_jumps - if !critical[j] - counts[j] = thread_poisson(rate_cache[j] * tau) - end - end - if use_implicit - implicit_tau_step(current_u, t, tau, rate_cache, counts, current_u) - else - c(local_dc, current_u, p, t, counts, nothing) - current_u .+= local_dc - end - else - tau = min(1.0, t_end - t) - counts .= 0 - if ac0 > 0 - r = thread_rand() * ac0 - cumsum_rate = 0.0 - for j in 1:num_jumps - if critical[j] - cumsum_rate += rate_cache[j] - if cumsum_rate > r - counts[j] = 1 - break - end - end - end - end - for j in 1:num_jumps - if !critical[j] - counts[j] = thread_poisson(rate_cache[j] * tau) - end - end - if use_implicit && tau > tau_ex - implicit_tau_step(current_u, t, tau, rate_cache, counts, current_u) - else - c(local_dc, current_u, p, t, counts, nothing) - current_u .+= local_dc - end - end - - if any(current_u .< 0) - tau1 /= 2 - continue - end - - t += tau - if step <= max_steps - @inbounds ts_view[step] = t - @inbounds for k in 1:state_dim - us_view[step, k] = current_u[k] - end - end - end -end - -# Vectorized solve for ImplicitTauLeaping -function vectorized_solve(probs, prob::JumpProblem, alg::ImplicitTauLeaping; backend, trajectories, seed, max_steps) - rj = prob.regular_jump - state_dim = length(prob.prob.u0) - p_correct = prob.prob.p # Store correct p - nu = let c = rj.c, u0 = prob.prob.u0, numjumps = rj.numjumps - nu = zeros(Int, state_dim, numjumps) - for j in 1:numjumps - counts = zeros(numjumps) - counts[j] = 1 - du = similar(u0) - c(du, u0, p_correct, prob.prob.tspan[1], counts, nothing) - nu[:, j] = round.(Int, du) - end - nu - end - # Explicitly bind p_correct to both c and rate - c_fixed = (du, u, p, t, counts, mark) -> rj.c(du, u, p_correct, t, counts, mark) - rate_fixed = (out, u, p, t) -> rj.rate(out, u, p_correct, t) - rj_data = ImplicitTauLeapingJumpData(rate_fixed, c_fixed, rj.numjumps, nu) - alg_data = ImplicitTauLeapingData(alg.epsilon, alg.nc, alg.nstiff, alg.delta) - - probs_data = [ImplicitTauLeapingTrajectoryData(SA{eltype(p.prob.u0)}[p.prob.u0...], p_correct, p.prob.tspan) for p in probs] - - probs_data_gpu = adapt(backend, probs_data) - rj_data_gpu = adapt(backend, rj_data) - alg_data_gpu = adapt(backend, alg_data) - - tspan = prob.prob.tspan - num_jumps = rj_data.numjumps - - @assert state_dim > 0 "Dimension of state must be positive" - @assert num_jumps >= 0 "Number of jumps must be positive" - - ts = allocate(backend, eltype(prob.prob.tspan), (max_steps, trajectories)) - us = allocate(backend, eltype(prob.prob.u0), (max_steps, state_dim, trajectories)) - - current_u_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) - rate_cache_buf = allocate(backend, eltype(prob.prob.u0), (num_jumps, trajectories)) - counts_buf = allocate(backend, Int, (num_jumps, trajectories)) - local_dc_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) - mu_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) - sigma2_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) - critical_buf = allocate(backend, Bool, (num_jumps, trajectories)) - rate_new_buf = allocate(backend, eltype(prob.prob.u0), (num_jumps, trajectories)) - residual_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, trajectories)) - J_buf = allocate(backend, eltype(prob.prob.u0), (state_dim, state_dim, trajectories)) - - @kernel function init_buffers_kernel(@Const(probs_data), current_u_buf) - i = @index(Global, Linear) - @inbounds u0 = probs_data[i].u0 - @inbounds for k in 1:length(u0) - current_u_buf[k, i] = u0[k] - end - end - init_kernel = init_buffers_kernel(backend) - init_event = init_kernel(probs_data_gpu, current_u_buf; ndrange=trajectories) - KernelAbstractions.synchronize(backend) - - seed = seed === nothing ? UInt64(12345) : UInt64(seed) - - # Debug: Verify parameters before kernel launch - @assert all(typeof(p.prob.p) == NTuple{4, Float64} for p in probs) "All problems must have p as NTuple{4, Float64}" - @show typeof(probs[1].prob.p) - @show probs[1].prob.p - @show typeof(rj_data.rate) - @show typeof(rj_data.c) - - kernel = implicit_tau_leaping_kernel(backend) - main_event = kernel(probs_data_gpu, us, ts, rj_data_gpu, alg_data_gpu, - current_u_buf, rate_cache_buf, counts_buf, local_dc_buf, - mu_buf, sigma2_buf, critical_buf, rate_new_buf, residual_buf, J_buf, - seed, max_steps; ndrange=trajectories) - KernelAbstractions.synchronize(backend) - - return ts, us -end diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 7416713f..0e5f2c18 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -405,8 +405,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end -# Define the ImplicitTauLeaping algorithm -struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm +# Define the SimpleImplicitTauLeaping algorithm +struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm epsilon::Float64 # Error control parameter nc::Int # Critical reaction threshold nstiff::Float64 # Stiffness threshold for switching @@ -414,10 +414,10 @@ struct ImplicitTauLeaping <: DiffEqBase.DEAlgorithm end # Default constructor -ImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) = - ImplicitTauLeaping(epsilon, nc, nstiff, delta) +SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) = + SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta) -function DiffEqBase.solve(jump_prob::JumpProblem, alg::ImplicitTauLeaping; seed = nothing) +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed = nothing) # Boilerplate setup @assert isempty(jump_prob.jump_callback.continuous_callbacks) @assert isempty(jump_prob.jump_callback.discrete_callbacks) @@ -733,4 +733,4 @@ function EnsembleGPUKernel() EnsembleGPUKernel(nothing, 0.0) end -export SimpleTauLeaping, EnsembleGPUKernel, ImplicitTauLeaping +export SimpleTauLeaping, EnsembleGPUKernel, SimpleImplicitTauLeaping diff --git a/test/gpu/implicit_tau.jl b/test/gpu/implicit_tau.jl deleted file mode 100644 index d150c00f..00000000 --- a/test/gpu/implicit_tau.jl +++ /dev/null @@ -1,46 +0,0 @@ -using JumpProcesses, DiffEqBase -using Test, LinearAlgebra, Statistics -using KernelAbstractions, Adapt, CUDA -using StableRNGs - - -rng = StableRNG(12345) -Nsims = 1 - - -# Decaying Dimerization Model -c1 = 1.0 # S1 -> 0 -c2 = 10.0 # S1 + S1 <- S2 -c3 = 1000.0 # S1 + S1 -> S2 -c4 = 0.1 # S2 -> S3 -p = (c1, c2, c3, c4) - -# Propensity functions -regular_rate = (out, u, p, t) -> begin - @assert typeof(p) == NTuple{4, Float64} "p must be a tuple of 4 Float64 values" - out[1] = p[1] * u[1] # S1 -> 0 - out[2] = p[2] * u[2] # S1 + S1 <- S2 - out[3] = p[3] * u[1] * (u[1] - 1) / 2 # S1 + S1 -> S2 - out[4] = p[4] * u[2] # S2 -> S3 -end - -# State change function -regular_c = (du, u, p, t, counts, mark) -> begin - @assert typeof(p) == NTuple{4, Float64} "p must be a tuple of 4 Float64 values" - du .= 0.0 - du[1] = -counts[1] - 2 * counts[3] + 2 * counts[2] # S1: -decay - 2*forward + 2*backward - du[2] = counts[3] - counts[2] - counts[4] # S2: +forward - backward - decay - du[3] = counts[4] # S3: +decay -end - -# Initial condition -u0 = [10000.0, 0.0, 0.0] # S1, S2, S3 -tspan = (0.0, 4.0) - -# Create JumpProblem -prob_disc = DiscreteProblem(u0, tspan, p) -rj = RegularJump(regular_rate, regular_c, 4) -jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=rng) - -# Solve using ImplicitTauLeaping -sol = solve(EnsembleProblem(jump_prob), ImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims) From 166754de1ce944aad5067560ea8ead82efe3b6a2 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sun, 10 Aug 2025 04:07:34 +0530 Subject: [PATCH 08/29] refactor --- src/simple_regular_solve.jl | 319 ++++++++++++++++++------------------ 1 file changed, 160 insertions(+), 159 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 0e5f2c18..6ec9f79a 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -405,7 +405,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end -# Define the SimpleImplicitTauLeaping algorithm struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm epsilon::Float64 # Error control parameter nc::Int # Critical reaction threshold @@ -413,12 +412,162 @@ struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm delta::Float64 # Partial equilibrium threshold end -# Default constructor SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) = SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta) -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed = nothing) - # Boilerplate setup +# Compute stoichiometry matrix from c function +function compute_stoichiometry(c, u, numjumps, p, t) + nu = zeros(Int, length(u), numjumps) + for j in 1:numjumps + counts = zeros(numjumps) + counts[j] = 1 + du = similar(u) + c(du, u, p, t, counts, nothing) + nu[:, j] = round.(Int, du) + end + return nu +end + +# Detect reversible reaction pairs +function find_reversible_pairs(nu) + pairs = Vector{Tuple{Int,Int}}() + for j in 1:size(nu, 2) + for k in (j+1):size(nu, 2) + if nu[:, j] == -nu[:, k] + push!(pairs, (j, k)) + end + end + end + return pairs +end + +# Compute g_i (approximation from Cao et al., 2006) +function compute_gi(u, nu, i, rate, rate_cache, p, t) + max_order = 1.0 + for j in 1:size(nu, 2) + if abs(nu[i, j]) > 0 + rate(rate_cache, u, p, t) + if rate_cache[j] > 0 + order = 1.0 + if sum(abs.(nu[:, j])) > abs(nu[i, j]) + order = 2.0 + end + max_order = max(max_order, order) + end + end + end + return max_order +end + +# Tau-selection for explicit method (Equation 8) +function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate) + rate(rate_cache, u, p, t) + mu = zeros(length(u)) + sigma2 = zeros(length(u)) + tau = Inf + for i in 1:length(u) + for j in 1:size(nu, 2) + mu[i] += nu[i, j] * rate_cache[j] + sigma2[i] += nu[i, j]^2 * rate_cache[j] + end + gi = compute_gi(u, nu, i, rate, rate_cache, p, t) + bound = max(epsilon * u[i] / gi, 1.0) + mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf + sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf + tau = min(tau, mu_term, sigma_term) + end + return max(tau, 1e-10) +end + +# Partial equilibrium check (Equation 13) +function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta) + a_plus = rate_cache[j_plus] + a_minus = rate_cache[j_minus] + return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) +end + +# Tau-selection for implicit method (Equation 14) +function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta) + rate(rate_cache, u, p, t) + mu = zeros(length(u)) + sigma2 = zeros(length(u)) + non_equilibrium = trues(size(nu, 2)) + for (j_plus, j_minus) in equilibrium_pairs + if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta) + non_equilibrium[j_plus] = false + non_equilibrium[j_minus] = false + end + end + tau = Inf + for i in 1:length(u) + for j in 1:size(nu, 2) + if non_equilibrium[j] + mu[i] += nu[i, j] * rate_cache[j] + sigma2[i] += nu[i, j]^2 * rate_cache[j] + end + end + gi = compute_gi(u, nu, i, rate, rate_cache, p, t) + bound = max(epsilon * u[i] / gi, 1.0) + mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf + sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf + tau = min(tau, mu_term, sigma_term) + end + return max(tau, 1e-10) +end + +# Identify critical reactions +function identify_critical_reactions(u, rate_cache, nu, nc) + critical = falses(size(nu, 2)) + for j in 1:size(nu, 2) + if rate_cache[j] > 0 + Lj = Inf + for i in 1:length(u) + if nu[i, j] < 0 + Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j]))) + end + end + if Lj < nc + critical[j] = true + end + end + end + return critical +end + +# Implicit tau-leaping step using NonlinearSolve +function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) + # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0 + function f(u_new, params) + rate_new = zeros(eltype(u_new), numjumps) + rate(rate_new, u_new, p, t_prev + tau) + residual = u_new - u_prev + for j in 1:numjumps + residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j]) + end + return residual + end + + # Initial guess + u_new = copy(u_prev) + + # Solve the nonlinear system + prob = NonlinearProblem(f, u_new, nothing) + sol = solve(prob, NewtonRaphson()) + + # Check for convergence and numerical stability + if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) + return round.(Int, max.(u_prev, 0.0)) # Revert to previous state + end + + return round.(Int, max.(sol.u, 0.0)) +end + +# Down-shifting condition (Equation 19) +function use_down_shifting(t, tau_im, tau_ex, a0, t_end) + return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) +end + +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing) @assert isempty(jump_prob.jump_callback.continuous_callbacks) @assert isempty(jump_prob.jump_callback.discrete_callbacks) prob = jump_prob.prob @@ -447,160 +596,12 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; delta = alg.delta t_end = tspan[2] - # Compute stoichiometry matrix from c function - function compute_stoichiometry(c, u, numjumps) - nu = zeros(Int, length(u), numjumps) - for j in 1:numjumps - counts = zeros(numjumps) - counts[j] = 1 - du = similar(u) - c(du, u, p, t[1], counts, nothing) - nu[:, j] = round.(Int, du) - end - return nu - end - nu = compute_stoichiometry(c, u0, numjumps) + # Compute stoichiometry matrix + nu = compute_stoichiometry(c, u0, numjumps, p, t[1]) # Detect reversible reaction pairs - function find_reversible_pairs(nu) - pairs = Vector{Tuple{Int,Int}}() - for j in 1:numjumps - for k in (j+1):numjumps - if nu[:, j] == -nu[:, k] - push!(pairs, (j, k)) - end - end - end - return pairs - end equilibrium_pairs = find_reversible_pairs(nu) - # Helper function to compute g_i (approximation from Cao et al., 2006) - function compute_gi(u, nu, i) - max_order = 1.0 - for j in 1:numjumps - if abs(nu[i, j]) > 0 - rate(rate_cache, u, p, t[end]) - if rate_cache[j] > 0 - order = 1.0 - if sum(abs.(nu[:, j])) > abs(nu[i, j]) - order = 2.0 - end - max_order = max(max_order, order) - end - end - end - return max_order - end - - # Tau-selection for explicit method (Equation 8) - function compute_tau_explicit(u, rate_cache, nu, p, t) - rate(rate_cache, u, p, t) - mu = zeros(length(u)) - sigma2 = zeros(length(u)) - tau = Inf - for i in 1:length(u) - for j in 1:numjumps - mu[i] += nu[i, j] * rate_cache[j] - sigma2[i] += nu[i, j]^2 * rate_cache[j] - end - gi = compute_gi(u, nu, i) - bound = max(epsilon * u[i] / gi, 1.0) - mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf - sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf - tau = min(tau, mu_term, sigma_term) - end - return max(tau, 1e-10) - end - - # Partial equilibrium check (Equation 13) - function is_partial_equilibrium(rate_cache, j_plus, j_minus) - a_plus = rate_cache[j_plus] - a_minus = rate_cache[j_minus] - return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) - end - - # Tau-selection for implicit method (Equation 14) - function compute_tau_implicit(u, rate_cache, nu, p, t) - rate(rate_cache, u, p, t) - mu = zeros(length(u)) - sigma2 = zeros(length(u)) - non_equilibrium = trues(numjumps) - for (j_plus, j_minus) in equilibrium_pairs - if is_partial_equilibrium(rate_cache, j_plus, j_minus) - non_equilibrium[j_plus] = false - non_equilibrium[j_minus] = false - end - end - tau = Inf - for i in 1:length(u) - for j in 1:numjumps - if non_equilibrium[j] - mu[i] += nu[i, j] * rate_cache[j] - sigma2[i] += nu[i, j]^2 * rate_cache[j] - end - end - gi = compute_gi(u, nu, i) - bound = max(epsilon * u[i] / gi, 1.0) - mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf - sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf - tau = min(tau, mu_term, sigma_term) - end - return max(tau, 1e-10) - end - - # Identify critical reactions - function identify_critical_reactions(u, rate_cache, nu) - critical = falses(numjumps) - for j in 1:numjumps - if rate_cache[j] > 0 - Lj = Inf - for i in 1:length(u) - if nu[i, j] < 0 - Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j]))) - end - end - if Lj < nc - critical[j] = true - end - end - end - return critical - end - - # Implicit tau-leaping step using NonlinearSolve - function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p) - # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0 - function f(u_new, params) - rate_new = similar(rate_cache, eltype(u_new)) - rate(rate_new, u_new, p, t_prev + tau) - residual = u_new - u_prev - for j in 1:numjumps - residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j]) - end - return residual - end - - # Initial guess - u_new = copy(u_prev) - - # Solve the nonlinear system - prob = NonlinearProblem(f, u_new, nothing) - sol = solve(prob, NewtonRaphson()) - - # Check for convergence and numerical stability - if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) - return round.(Int, max.(u_prev, 0.0)) # Revert to previous state - end - - return round.(Int, max.(sol.u, 0.0)) - end - - # Down-shifting condition (Equation 19) - function use_down_shifting(t, tau_im, tau_ex, a0, t_end) - return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) - end - # Main simulation loop while t[end] < t_end u_prev = u[end] @@ -610,11 +611,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; rate(rate_cache, u_prev, p, t_prev) # Identify critical reactions - critical = identify_critical_reactions(u_prev, rate_cache, nu) + critical = identify_critical_reactions(u_prev, rate_cache, nu, nc) # Compute tau values - tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev) - tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev) + tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate) + tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta) # Compute critical propensity sum ac0 = sum(rate_cache[critical]) @@ -669,7 +670,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; end end if method == :implicit - u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p) + u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) else c(du, u_prev, p, t_prev, counts, nothing) u_new = u_prev + du @@ -696,7 +697,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; end end if method == :implicit && tau > tau_ex - u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p) + u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) else c(du, u_prev, p, t_prev, counts, nothing) u_new = u_prev + du From e50a43ae40e52e7c40f0f6e13e2eaa3f4aad5150 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sun, 10 Aug 2025 05:06:32 +0530 Subject: [PATCH 09/29] SimpleAdaptiveTauLeaping is done --- src/simple_regular_solve.jl | 56 ++++++++++++++++++- test/regular_jumps.jl | 105 ++++++++++++++++++++++++++++++------ 2 files changed, 145 insertions(+), 16 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 6ec9f79a..ca1459a1 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -405,6 +405,60 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end +struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm + epsilon::Float64 # Error control parameter +end + +SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) + +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing) + @assert isempty(jump_prob.jump_callback.continuous_callbacks) + @assert isempty(jump_prob.jump_callback.discrete_callbacks) + prob = jump_prob.prob + rng = DEFAULT_RNG + (seed !== nothing) && seed!(rng, seed) + + rj = jump_prob.regular_jump + rate = rj.rate + numjumps = rj.numjumps + c = rj.c + u0 = copy(prob.u0) + tspan = prob.tspan + p = prob.p + + u = [copy(u0)] + t = [tspan[1]] + rate_cache = zeros(Float64, numjumps) + counts = zeros(Int, numjumps) + du = similar(u0) + t_end = tspan[2] + epsilon = alg.epsilon + + nu = compute_stoichiometry(c, u0, numjumps, p, t[1]) + + while t[end] < t_end + u_prev = u[end] + t_prev = t[end] + rate(rate_cache, u_prev, p, t_prev) + tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate) + tau = min(tau, t_end - t_prev) + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) + c(du, u_prev, p, t_prev, counts, nothing) + u_new = u_prev + du + if any(u_new .< 0) + tau /= 2 + continue + end + push!(u, u_new) + push!(t, t_prev + tau) + end + + sol = DiffEqBase.build_solution(prob, alg, t, u, + calculate_error=false, + interp=DiffEqBase.ConstantInterpolation(t, u)) + return sol +end + struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm epsilon::Float64 # Error control parameter nc::Int # Critical reaction threshold @@ -734,4 +788,4 @@ function EnsembleGPUKernel() EnsembleGPUKernel(nothing, 0.0) end -export SimpleTauLeaping, EnsembleGPUKernel, SimpleImplicitTauLeaping +export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index f4d009ea..c068fab4 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -175,62 +175,62 @@ end tspan = (0.0, 10.0) p = [0.1, 0.2] prob = DiscreteProblem(u0, tspan, p) - + # Create MassActionJump reactant_stoich = [[1 => 1], [1 => 2]] net_stoich = [[1 => -1, 2 => 1], [1 => -2, 3 => 1]] rates = [0.1, 0.05] maj = MassActionJump(rates, reactant_stoich, net_stoich) - + # Test PureLeaping JumpProblem creation jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj); rng) @test jp_pure.aggregator isa PureLeaping @test jp_pure.discrete_jump_aggregation === nothing @test jp_pure.massaction_jump !== nothing @test length(jp_pure.jump_callback.discrete_callbacks) == 0 - + # Test with ConstantRateJump rate(u, p, t) = p[1] * u[1] affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) crj = ConstantRateJump(rate, affect!) - + jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj); rng) @test jp_pure_crj.aggregator isa PureLeaping @test jp_pure_crj.discrete_jump_aggregation === nothing @test length(jp_pure_crj.constant_jumps) == 1 - + # Test with VariableRateJump vrate(u, p, t) = t * p[1] * u[1] vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) vrj = VariableRateJump(vrate, vaffect!) - + jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj); rng) @test jp_pure_vrj.aggregator isa PureLeaping @test jp_pure_vrj.discrete_jump_aggregation === nothing @test length(jp_pure_vrj.variable_jumps) == 1 - + # Test with RegularJump function rj_rate(out, u, p, t) out[1] = p[1] * u[1] end - + rj_dc = zeros(3, 1) rj_dc[1, 1] = -1 rj_dc[3, 1] = 1 - + function rj_c(du, u, p, t, counts, mark) mul!(du, rj_dc, counts) end - + regj = RegularJump(rj_rate, rj_c, 1) - + jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj); rng) @test jp_pure_regj.aggregator isa PureLeaping @test jp_pure_regj.discrete_jump_aggregation === nothing @test jp_pure_regj.regular_jump !== nothing - + # Test mixed jump types - mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), + mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), variable_jumps = (vrj,), regular_jumps = regj) jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps; rng) @test jp_pure_mixed.aggregator isa PureLeaping @@ -239,7 +239,7 @@ end @test length(jp_pure_mixed.constant_jumps) == 1 @test length(jp_pure_mixed.variable_jumps) == 1 @test jp_pure_mixed.regular_jump !== nothing - + # Test spatial system error spatial_sys = CartesianGrid((2, 2)) hopping_consts = [1.0] @@ -247,7 +247,7 @@ end spatial_system = spatial_sys) @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, hopping_constants = hopping_consts) - + # Test MassActionJump with parameter mapping maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng) @@ -255,6 +255,81 @@ end @test jp_params.massaction_jump.scaled_rates == scaled_rates end +# SIR model with influx - SimpleImplicitTauLeaping vs SimpleTauLeaping +let + β = 0.1 / 1000.0 + ν = 0.01 + influx_rate = 1.0 + p = (β, ν, influx_rate) + + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[2] # β*S*I (infection) + out[2] = p[2] * u[2] # ν*I (recovery) + out[3] = p[3] # influx_rate + end + + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0.0 + dc[1] = -counts[1] + counts[3] # S: -infection + influx + dc[2] = counts[1] - counts[2] # I: +infection - recovery + dc[3] = counts[2] # R: +recovery + end + + u0 = [999.0, 10.0, 0.0] # S, I, R + tspan = (0.0, 250.0) + + prob_disc = DiscreteProblem(u0, tspan, p) + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob = JumpProblem(prob_disc, Direct(), rj) + + sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) + mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims) + + sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims) + mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims) + + @test isapprox(mean_simple, mean_implicit, rtol=0.05) +end + +# SEIR model with exposed compartment - SimpleImplicitTauLeaping vs SimpleTauLeaping +let + β = 0.3 / 1000.0 + σ = 0.2 + ν = 0.01 + p = (β, σ, ν) + + regular_rate = (out, u, p, t) -> begin + out[1] = p[1] * u[1] * u[3] # β*S*I (infection) + out[2] = p[2] * u[2] # σ*E (progression) + out[3] = p[3] * u[3] # ν*I (recovery) + end + + regular_c = (dc, u, p, t, counts, mark) -> begin + dc .= 0.0 + dc[1] = -counts[1] # S: -infection + dc[2] = counts[1] - counts[2] # E: +infection - progression + dc[3] = counts[2] - counts[3] # I: +progression - recovery + dc[4] = counts[3] # R: +recovery + end + + # Initial state + u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R + tspan = (0.0, 250.0) + + # Create JumpProblem + prob_disc = DiscreteProblem(u0, tspan, p) + rj = RegularJump(regular_rate, regular_c, 3) + jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) + + sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) + mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims) + + sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims) + mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims) + + @test isapprox(mean_simple, mean_implicit, rtol=0.05) +end + # Test that saveat/save_start/save_end control which times are stored in solutions @testset "Saving Controls" begin # Simple birth process for testing SSAStepper save behavior From 273460a2247989a257e72a61ee45da20a3a7ca82 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Tue, 19 Aug 2025 06:29:58 +0530 Subject: [PATCH 10/29] simple version of SimpleImplicitTauLeaping --- Project.toml | 3 +- src/simple_regular_solve.jl | 319 +++++++++++------------------------- 2 files changed, 98 insertions(+), 224 deletions(-) diff --git a/Project.toml b/Project.toml index 91bc133f..3d442047 100644 --- a/Project.toml +++ b/Project.toml @@ -13,12 +13,13 @@ FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" -NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index ca1459a1..c86b7617 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -459,62 +459,32 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; return sol end +# SimpleImplicitTauLeaping implementation struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm epsilon::Float64 # Error control parameter - nc::Int # Critical reaction threshold - nstiff::Float64 # Stiffness threshold for switching - delta::Float64 # Partial equilibrium threshold end -SimpleImplicitTauLeaping(; epsilon=0.05, nc=10, nstiff=100.0, delta=0.05) = - SimpleImplicitTauLeaping(epsilon, nc, nstiff, delta) +SimpleImplicitTauLeaping(; epsilon=0.05) = SimpleImplicitTauLeaping(epsilon) -# Compute stoichiometry matrix from c function -function compute_stoichiometry(c, u, numjumps, p, t) - nu = zeros(Int, length(u), numjumps) - for j in 1:numjumps - counts = zeros(numjumps) - counts[j] = 1 - du = similar(u) - c(du, u, p, t, counts, nothing) - nu[:, j] = round.(Int, du) - end - return nu -end - -# Detect reversible reaction pairs -function find_reversible_pairs(nu) - pairs = Vector{Tuple{Int,Int}}() +function compute_hor(nu) + hor = zeros(Int, size(nu, 2)) for j in 1:size(nu, 2) - for k in (j+1):size(nu, 2) - if nu[:, j] == -nu[:, k] - push!(pairs, (j, k)) - end - end + hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1 end - return pairs + return hor end -# Compute g_i (approximation from Cao et al., 2006) -function compute_gi(u, nu, i, rate, rate_cache, p, t) +function compute_gi(u, nu, hor, i) max_order = 1.0 for j in 1:size(nu, 2) if abs(nu[i, j]) > 0 - rate(rate_cache, u, p, t) - if rate_cache[j] > 0 - order = 1.0 - if sum(abs.(nu[:, j])) > abs(nu[i, j]) - order = 2.0 - end - max_order = max(max_order, order) - end + max_order = max(max_order, Float64(hor[j])) end end return max_order end -# Tau-selection for explicit method (Equation 8) -function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate) +function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate) rate(rate_cache, u, p, t) mu = zeros(length(u)) sigma2 = zeros(length(u)) @@ -524,104 +494,58 @@ function compute_tau_explicit(u, rate_cache, nu, p, t, epsilon, rate) mu[i] += nu[i, j] * rate_cache[j] sigma2[i] += nu[i, j]^2 * rate_cache[j] end - gi = compute_gi(u, nu, i, rate, rate_cache, p, t) + gi = compute_gi(u, nu, hor, i) bound = max(epsilon * u[i] / gi, 1.0) mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf tau = min(tau, mu_term, sigma_term) end - return max(tau, 1e-10) -end - -# Partial equilibrium check (Equation 13) -function is_partial_equilibrium(rate_cache, j_plus, j_minus, delta) - a_plus = rate_cache[j_plus] - a_minus = rate_cache[j_minus] - return abs(a_plus - a_minus) <= delta * min(a_plus, a_minus) + return tau end -# Tau-selection for implicit method (Equation 14) -function compute_tau_implicit(u, rate_cache, nu, p, t, epsilon, rate, equilibrium_pairs, delta) +function compute_tau_implicit(u, rate_cache, nu, p, t, rate) rate(rate_cache, u, p, t) - mu = zeros(length(u)) - sigma2 = zeros(length(u)) - non_equilibrium = trues(size(nu, 2)) - for (j_plus, j_minus) in equilibrium_pairs - if is_partial_equilibrium(rate_cache, j_plus, j_minus, delta) - non_equilibrium[j_plus] = false - non_equilibrium[j_minus] = false - end - end tau = Inf for i in 1:length(u) + sum_nu_a = 0.0 for j in 1:size(nu, 2) - if non_equilibrium[j] - mu[i] += nu[i, j] * rate_cache[j] - sigma2[i] += nu[i, j]^2 * rate_cache[j] - end + sum_nu_a += abs(nu[i, j]) * rate_cache[j] end - gi = compute_gi(u, nu, i, rate, rate_cache, p, t) - bound = max(epsilon * u[i] / gi, 1.0) - mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf - sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf - tau = min(tau, mu_term, sigma_term) - end - return max(tau, 1e-10) -end - -# Identify critical reactions -function identify_critical_reactions(u, rate_cache, nu, nc) - critical = falses(size(nu, 2)) - for j in 1:size(nu, 2) - if rate_cache[j] > 0 - Lj = Inf - for i in 1:length(u) - if nu[i, j] < 0 - Lj = min(Lj, floor(Int, u[i] / abs(nu[i, j]))) - end - end - if Lj < nc - critical[j] = true - end + if sum_nu_a > 0 + tau = min(tau, 1.0 / sum_nu_a) end end - return critical + return tau end -# Implicit tau-leaping step using NonlinearSolve function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0 + # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0 function f(u_new, params) rate_new = zeros(eltype(u_new), numjumps) rate(rate_new, u_new, p, t_prev + tau) residual = u_new - u_prev for j in 1:numjumps - residual -= nu[:, j] * (counts[j] - tau * rate_cache[j] + tau * rate_new[j]) + residual -= nu[:, j] * (counts[j] - tau * (rate_cache[j] - rate_new[j])) end return residual end - + # Initial guess u_new = copy(u_prev) - + # Solve the nonlinear system prob = NonlinearProblem(f, u_new, nothing) - sol = solve(prob, NewtonRaphson()) - + sol = solve(prob, SimpleNewtonRaphson(), tol=1e-6) + # Check for convergence and numerical stability if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) - return round.(Int, max.(u_prev, 0.0)) # Revert to previous state + return nothing # Signal failure to trigger tau halving end - - return round.(Int, max.(sol.u, 0.0)) -end -# Down-shifting condition (Equation 19) -function use_down_shifting(t, tau_im, tau_ex, a0, t_end) - return a0 > 0 && t + tau_im >= t_end - 100 * (tau_ex + 1 / a0) + return round.(Int, max.(sol.u, 0.0)) end -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing) +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing, dtmin=1e-10, saveat=nothing) @assert isempty(jump_prob.jump_callback.continuous_callbacks) @assert isempty(jump_prob.jump_callback.discrete_callbacks) prob = jump_prob.prob @@ -635,144 +559,93 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; u0 = copy(prob.u0) tspan = prob.tspan p = prob.p - - # Initialize storage - rate_cache = zeros(Float64, numjumps) - counts = zeros(Int, numjumps) - du = similar(u0) + u = [copy(u0)] t = [tspan[1]] - - # Algorithm parameters - epsilon = alg.epsilon - nc = alg.nc - nstiff = alg.nstiff - delta = alg.delta + rate_cache = zeros(Float64, numjumps) + counts = zeros(Int, numjumps) + du = similar(u0, Int) t_end = tspan[2] - - # Compute stoichiometry matrix - nu = compute_stoichiometry(c, u0, numjumps, p, t[1]) - - # Detect reversible reaction pairs - equilibrium_pairs = find_reversible_pairs(nu) - - # Main simulation loop + epsilon = alg.epsilon + + # Compute initial stoichiometry and HOR + nu = zeros(Int, length(u0), numjumps) + counts_temp = zeros(Int, numjumps) + for j in 1:numjumps + fill!(counts_temp, 0) + counts_temp[j] = 1 + c(du, u0, p, t[1], counts_temp, nothing) + nu[:, j] = du + end + hor = compute_hor(nu) + + saveat_times = isnothing(saveat) ? Float64[] : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat) + save_idx = 1 + while t[end] < t_end u_prev = u[end] t_prev = t[end] - - # Compute propensities + # Recompute stoichiometry + for j in 1:numjumps + fill!(counts_temp, 0) + counts_temp[j] = 1 + c(du, u_prev, p, t_prev, counts_temp, nothing) + nu[:, j] = du + end rate(rate_cache, u_prev, p, t_prev) - - # Identify critical reactions - critical = identify_critical_reactions(u_prev, rate_cache, nu, nc) - - # Compute tau values - tau_ex = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate) - tau_im = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate, equilibrium_pairs, delta) - - # Compute critical propensity sum - ac0 = sum(rate_cache[critical]) - tau2 = ac0 > 0 ? randexp(rng) / ac0 : Inf - - # Choose method and stepsize - a0 = sum(rate_cache) - use_implicit = a0 > 0 && tau_im > nstiff * tau_ex && !use_down_shifting(t_prev, tau_im, tau_ex, a0, t_end) - tau1 = use_implicit ? tau_im : tau_ex - method = use_implicit ? :implicit : :explicit - - # Cap tau to prevent large updates - tau1 = min(tau1, 1.0) - - # Check if tau1 is too small - if a0 > 0 && tau1 < 10 / a0 - # Use SSA for a few steps - steps = method == :implicit ? 10 : 100 - for _ in 1:steps - if t_prev >= t_end - break - end - rate(rate_cache, u_prev, p, t_prev) - a0 = sum(rate_cache) - if a0 == 0 - break - end - tau = randexp(rng) / a0 - r = rand(rng) * a0 - cumsum_rate = 0.0 - for j in 1:numjumps - cumsum_rate += rate_cache[j] - if cumsum_rate > r - u_prev += nu[:, j] - break - end - end - t_prev += tau - push!(u, copy(u_prev)) - push!(t, t_prev) + tau_prime = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate) + tau_double_prime = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, rate) + tau = min(tau_prime, tau_double_prime / 10.0) + tau = max(tau, dtmin) + tau = min(tau, t_end - t_prev) + if !isempty(saveat_times) + if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx] + tau = saveat_times[save_idx] - t_prev end - continue end - - # Choose stepsize and compute firings - if tau2 > tau1 - tau = min(tau1, t_end - t_prev) - counts .= 0 - for j in 1:numjumps - if !critical[j] - counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0)) - end - end - if method == :implicit - u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - else - c(du, u_prev, p, t_prev, counts, nothing) - u_new = u_prev + du + counts .= rand(rng, Poisson.(max.(rate_cache * tau, 0.0))) + c(du, u_prev, p, t_prev, counts, nothing) + u_new = u_prev + du + if tau_prime <= tau_double_prime / 10.0 + # Explicit update + if any(u_new .< 0) + # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) + tau /= 2 + continue end else - tau = min(tau2, t_end - t_prev) - counts .= 0 - if ac0 > 0 - r = rand(rng) * ac0 - cumsum_rate = 0.0 - for j in 1:numjumps - if critical[j] - cumsum_rate += rate_cache[j] - if cumsum_rate > r - counts[j] = 1 - break - end - end - end - end - for j in 1:numjumps - if !critical[j] - counts[j] = pois_rand(rng, max(rate_cache[j] * tau, 0.0)) - end + # Implicit update using NonlinearSolve + u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) + if u_new === nothing || any(u_new .< 0) + # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) + tau /= 2 + continue end - if method == :implicit && tau > tau_ex - u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - else - c(du, u_prev, p, t_prev, counts, nothing) - u_new = u_prev + du - end - end - - # Check for negative populations - if any(u_new .< 0) - tau1 /= 2 - continue end - - # Update state and time + u_new = max.(u_new, 0) push!(u, u_new) push!(t, t_prev + tau) + if !isempty(saveat_times) && save_idx <= length(saveat_times) && t[end] >= saveat_times[save_idx] + save_idx += 1 + end end - - # Build solution + + # Interpolate to saveat times if specified + if !isempty(saveat_times) + t_out = saveat_times + u_out = [u[end]] + for t_save in saveat_times + idx = findlast(ti -> ti <= t_save, t) + push!(u_out, u[idx]) + end + t = t_out + u = u_out[2:end] + end + sol = DiffEqBase.build_solution(prob, alg, t, u, - calculate_error = false, - interp = DiffEqBase.ConstantInterpolation(t, u)) + calculate_error=false, + interp=DiffEqBase.ConstantInterpolation(t, u)) + return sol end struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm From d7c65589676534a24ec51e7dd6bd05604460df54 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Tue, 19 Aug 2025 06:33:38 +0530 Subject: [PATCH 11/29] removed adaptive tau leap --- src/simple_regular_solve.jl | 56 +------------------ test/regular_jumps.jl | 108 +++++++++++++++++++++--------------- 2 files changed, 64 insertions(+), 100 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index c86b7617..0113b717 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -405,60 +405,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end -struct SimpleAdaptiveTauLeaping <: DiffEqBase.DEAlgorithm - epsilon::Float64 # Error control parameter -end - -SimpleAdaptiveTauLeaping(; epsilon=0.05) = SimpleAdaptiveTauLeaping(epsilon) - -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping; seed=nothing) - @assert isempty(jump_prob.jump_callback.continuous_callbacks) - @assert isempty(jump_prob.jump_callback.discrete_callbacks) - prob = jump_prob.prob - rng = DEFAULT_RNG - (seed !== nothing) && seed!(rng, seed) - - rj = jump_prob.regular_jump - rate = rj.rate - numjumps = rj.numjumps - c = rj.c - u0 = copy(prob.u0) - tspan = prob.tspan - p = prob.p - - u = [copy(u0)] - t = [tspan[1]] - rate_cache = zeros(Float64, numjumps) - counts = zeros(Int, numjumps) - du = similar(u0) - t_end = tspan[2] - epsilon = alg.epsilon - - nu = compute_stoichiometry(c, u0, numjumps, p, t[1]) - - while t[end] < t_end - u_prev = u[end] - t_prev = t[end] - rate(rate_cache, u_prev, p, t_prev) - tau = compute_tau_explicit(u_prev, rate_cache, nu, p, t_prev, epsilon, rate) - tau = min(tau, t_end - t_prev) - counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) - c(du, u_prev, p, t_prev, counts, nothing) - u_new = u_prev + du - if any(u_new .< 0) - tau /= 2 - continue - end - push!(u, u_new) - push!(t, t_prev + tau) - end - - sol = DiffEqBase.build_solution(prob, alg, t, u, - calculate_error=false, - interp=DiffEqBase.ConstantInterpolation(t, u)) - return sol -end - # SimpleImplicitTauLeaping implementation struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm epsilon::Float64 # Error control parameter @@ -661,4 +607,4 @@ function EnsembleGPUKernel() EnsembleGPUKernel(nothing, 0.0) end -export SimpleTauLeaping, EnsembleGPUKernel, SimpleAdaptiveTauLeaping, SimpleImplicitTauLeaping +export SimpleTauLeaping, EnsembleGPUKernel, SimpleImplicitTauLeaping diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index c068fab4..738c9089 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -255,79 +255,97 @@ end @test jp_params.massaction_jump.scaled_rates == scaled_rates end -# SIR model with influx - SimpleImplicitTauLeaping vs SimpleTauLeaping -let +# SimpleImplicitTauLeaping correctness - SIR model +@testset "SimpleImplicitTauLeaping SIR Correctness" begin β = 0.1 / 1000.0 ν = 0.01 influx_rate = 1.0 p = (β, ν, influx_rate) + rate1(u, p, t) = p[1] * u[1] * u[2] + rate2(u, p, t) = p[2] * u[2] + rate3(u, p, t) = p[3] + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[1] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) + + u0 = [999, 10, 0] # Integer initial conditions + tspan = (0.0, 250.0) + prob_disc = DiscreteProblem(u0, tspan, p) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345)) + + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims) + regular_rate = (out, u, p, t) -> begin - out[1] = p[1] * u[1] * u[2] # β*S*I (infection) - out[2] = p[2] * u[2] # ν*I (recovery) - out[3] = p[3] # influx_rate + out[1] = p[1] * u[1] * u[2] + out[2] = p[2] * u[2] + out[3] = p[3] end - regular_c = (dc, u, p, t, counts, mark) -> begin - dc .= 0.0 - dc[1] = -counts[1] + counts[3] # S: -infection + influx - dc[2] = counts[1] - counts[2] # I: +infection - recovery - dc[3] = counts[2] # R: +recovery + dc .= 0 + dc[1] = -counts[1] + counts[3] + dc[2] = counts[1] - counts[2] + dc[3] = counts[2] end - - u0 = [999.0, 10.0, 0.0] # S, I, R - tspan = (0.0, 250.0) - - prob_disc = DiscreteProblem(u0, tspan, p) rj = RegularJump(regular_rate, regular_c, 3) - jump_prob = JumpProblem(prob_disc, Direct(), rj) + jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) - sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) - mean_simple = mean(sol.u[i][1,end] for i in 1:Nsims) + sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) - sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims) - mean_implicit = mean(sol.u[i][1,end] for i in 1:Nsims) + t_points = 0:1.0:250.0 + mean_direct_S = [mean(sol_direct[i](t)[1] for i in 1:Nsims) for t in t_points] + mean_implicit_S = [mean(sol_implicit[i](t)[1] for i in 1:Nsims) for t in t_points] - @test isapprox(mean_simple, mean_implicit, rtol=0.05) + max_error_implicit = maximum(abs.(mean_direct_S .- mean_implicit_S)) + @test max_error_implicit < 0.01 * mean(mean_direct_S) end -# SEIR model with exposed compartment - SimpleImplicitTauLeaping vs SimpleTauLeaping -let +# SimpleImplicitTauLeaping correctness - SEIR model +@testset "SimpleImplicitTauLeaping SEIR Correctness" begin β = 0.3 / 1000.0 σ = 0.2 ν = 0.01 p = (β, σ, ν) + rate1(u, p, t) = p[1] * u[1] * u[3] + rate2(u, p, t) = p[2] * u[2] + rate3(u, p, t) = p[3] * u[3] + affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) + affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) + affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing) + jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) + + u0 = [999, 0, 10, 0] # Integer initial conditions + tspan = (0.0, 250.0) + prob_disc = DiscreteProblem(u0, tspan, p) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345)) + + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims) + regular_rate = (out, u, p, t) -> begin - out[1] = p[1] * u[1] * u[3] # β*S*I (infection) - out[2] = p[2] * u[2] # σ*E (progression) - out[3] = p[3] * u[3] # ν*I (recovery) + out[1] = p[1] * u[1] * u[3] + out[2] = p[2] * u[2] + out[3] = p[3] * u[3] end - regular_c = (dc, u, p, t, counts, mark) -> begin - dc .= 0.0 - dc[1] = -counts[1] # S: -infection - dc[2] = counts[1] - counts[2] # E: +infection - progression - dc[3] = counts[2] - counts[3] # I: +progression - recovery - dc[4] = counts[3] # R: +recovery + dc .= 0 + dc[1] = -counts[1] + dc[2] = counts[1] - counts[2] + dc[3] = counts[2] - counts[3] + dc[4] = counts[3] end - - # Initial state - u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R - tspan = (0.0, 250.0) - - # Create JumpProblem - prob_disc = DiscreteProblem(u0, tspan, p) rj = RegularJump(regular_rate, regular_c, 3) - jump_prob = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) + jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) - sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 1.0) - mean_simple = mean(sol.u[i][end,end] for i in 1:Nsims) + sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) - sol = solve(EnsembleProblem(jump_prob), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims) - mean_implicit = mean(sol.u[i][end,end] for i in 1:Nsims) + t_points = 0:1.0:250.0 + mean_direct_R = [mean(sol_direct[i](t)[4] for i in 1:Nsims) for t in t_points] + mean_implicit_R = [mean(sol_implicit[i](t)[4] for i in 1:Nsims) for t in t_points] - @test isapprox(mean_simple, mean_implicit, rtol=0.05) + max_error_implicit = maximum(abs.(mean_direct_R .- mean_implicit_R)) + @test max_error_implicit < 0.01 * mean(mean_direct_R) end # Test that saveat/save_start/save_end control which times are stored in solutions From a21b3677a07f75961737e27f2df74e09a269ace1 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Tue, 19 Aug 2025 06:41:04 +0530 Subject: [PATCH 12/29] poiss change --- src/simple_regular_solve.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 0113b717..04bb53f3 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -481,7 +481,7 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, # Solve the nonlinear system prob = NonlinearProblem(f, u_new, nothing) - sol = solve(prob, SimpleNewtonRaphson(), tol=1e-6) + sol = solve(prob, SimpleNewtonRaphson()) # Check for convergence and numerical stability if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) @@ -549,7 +549,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; tau = saveat_times[save_idx] - t_prev end end - counts .= rand(rng, Poisson.(max.(rate_cache * tau, 0.0))) + counts .= counts .= pois_rand.((rng,), max.(rate_cache * tau, 0.0)) c(du, u_prev, p, t_prev, counts, nothing) u_new = u_prev + du if tau_prime <= tau_double_prime / 10.0 From df55f1e3019bfc669a761ec2886d795b2a36b17e Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Tue, 19 Aug 2025 07:11:42 +0530 Subject: [PATCH 13/29] changed to inline non linear solver --- src/simple_regular_solve.jl | 49 ++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 04bb53f3..359cdf7b 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -466,8 +466,8 @@ end function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0 - function f(u_new, params) - rate_new = zeros(eltype(u_new), numjumps) + function f(u_new) + rate_new = zeros(Float64, numjumps) rate(rate_new, u_new, p, t_prev + tau) residual = u_new - u_prev for j in 1:numjumps @@ -476,19 +476,41 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, return residual end - # Initial guess - u_new = copy(u_prev) - - # Solve the nonlinear system - prob = NonlinearProblem(f, u_new, nothing) - sol = solve(prob, SimpleNewtonRaphson()) - - # Check for convergence and numerical stability - if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) - return nothing # Signal failure to trigger tau halving + # Compute Jacobian using finite differences + function compute_jacobian(u_new) + n = length(u_new) + J = zeros(Float64, n, n) + h = 1e-6 + f_u = f(u_new) + for j in 1:n + u_pert = copy(u_new) + u_pert[j] += h + f_pert = f(u_pert) + J[:, j] = (f_pert - f_u) / h + end + return J end - return round.(Int, max.(sol.u, 0.0)) + # Inline Newton-Raphson + u_new = float.(u_prev + sum(nu[:, j] * counts[j] for j in 1:numjumps)) # Initial guess: explicit step + tol = 1e-6 + maxiters = 100 + for iter in 1:maxiters + F = f(u_new) + if norm(F) < tol + return round.(Int, max.(u_new, 0.0)) # Converged + end + J = compute_jacobian(u_new) + if abs(det(J)) < 1e-10 # Check for singular Jacobian + return nothing # Signal failure + end + delta = J \ F + u_new -= delta + if any(isnan.(u_new)) || any(isinf.(u_new)) + return nothing # Signal failure + end + end + return nothing # Failed to converge end function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing, dtmin=1e-10, saveat=nothing) @@ -560,7 +582,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; continue end else - # Implicit update using NonlinearSolve u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) if u_new === nothing || any(u_new .< 0) # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) From 5c9d419b4c668c6a6d212026f32368cfb2817830 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Tue, 19 Aug 2025 08:11:00 +0530 Subject: [PATCH 14/29] refactor --- src/simple_regular_solve.jl | 41 ++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 359cdf7b..bb5025e5 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -455,17 +455,19 @@ function compute_tau_implicit(u, rate_cache, nu, p, t, rate) for i in 1:length(u) sum_nu_a = 0.0 for j in 1:size(nu, 2) - sum_nu_a += abs(nu[i, j]) * rate_cache[j] + if nu[i, j] < 0 # Only sum negative stoichiometry + sum_nu_a += abs(nu[i, j]) * rate_cache[j] + end end - if sum_nu_a > 0 - tau = min(tau, 1.0 / sum_nu_a) + if sum_nu_a > 0 && u[i] > 0 # Avoid division by zero + tau = min(tau, u[i] / sum_nu_a) end end return tau end function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0 + # Nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (k_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0 function f(u_new) rate_new = zeros(Float64, numjumps) rate(rate_new, u_new, p, t_prev + tau) @@ -476,7 +478,7 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, return residual end - # Compute Jacobian using finite differences + # Numerical Jacobian function compute_jacobian(u_new) n = length(u_new) J = zeros(Float64, n, n) @@ -502,12 +504,12 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, end J = compute_jacobian(u_new) if abs(det(J)) < 1e-10 # Check for singular Jacobian - return nothing # Signal failure + return nothing end delta = J \ F u_new -= delta if any(isnan.(u_new)) || any(isinf.(u_new)) - return nothing # Signal failure + return nothing end end return nothing # Failed to converge @@ -563,7 +565,13 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; rate(rate_cache, u_prev, p, t_prev) tau_prime = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate) tau_double_prime = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, rate) - tau = min(tau_prime, tau_double_prime / 10.0) + # Cao et al. (2007): Use tau_prime for explicit, tau_double_prime for implicit + use_implicit = false + tau = tau_prime # Default to explicit + if tau_double_prime < tau_prime && any(u_prev .< 10) # Implicit if populations are low + tau = tau_double_prime + use_implicit = true + end tau = max(tau, dtmin) tau = min(tau, t_end - t_prev) if !isempty(saveat_times) @@ -571,23 +579,18 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; tau = saveat_times[save_idx] - t_prev end end - counts .= counts .= pois_rand.((rng,), max.(rate_cache * tau, 0.0)) + counts .= pois_rand.((rng,), max.(rate_cache * tau, 0.0)) c(du, u_prev, p, t_prev, counts, nothing) u_new = u_prev + du - if tau_prime <= tau_double_prime / 10.0 - # Explicit update - if any(u_new .< 0) - # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) - tau /= 2 - continue - end - else + if use_implicit u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) if u_new === nothing || any(u_new .< 0) - # Halve tau to avoid negative populations, as per Cao et al. (2006, J. Chem. Phys., DOI: 10.1063/1.2159468) - tau /= 2 + tau /= 2 # Halve tau if implicit fails or produces negative populations continue end + elseif any(u_new .< 0) + tau /= 2 # Halve tau if explicit produces negative populations + continue end u_new = max.(u_new, 0) push!(u, u_new) From 55d8b9131695751f0e0fcbb05a9a7c6889711d36 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Tue, 19 Aug 2025 11:21:18 +0530 Subject: [PATCH 15/29] basic version of inplicit tau leap is done --- src/JumpProcesses.jl | 2 ++ src/simple_regular_solve.jl | 66 ++++++++++--------------------------- 2 files changed, 20 insertions(+), 48 deletions(-) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 87c8c2c2..9631fbdb 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -18,6 +18,8 @@ using StaticArrays: StaticArrays, SVector, setindex using Base.Threads: Threads using Base.FastMath: add_fast +using SimpleNonlinearSolve + # Import functions we extend from Base import Base: size, getindex, setindex!, length, similar, show, merge!, merge diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index bb5025e5..b2ec2b5f 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -413,7 +413,7 @@ end SimpleImplicitTauLeaping(; epsilon=0.05) = SimpleImplicitTauLeaping(epsilon) function compute_hor(nu) - hor = zeros(Int, size(nu, 2)) + hor = zeros(Int64, size(nu, 2)) for j in 1:size(nu, 2) hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1 end @@ -432,8 +432,8 @@ end function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate) rate(rate_cache, u, p, t) - mu = zeros(length(u)) - sigma2 = zeros(length(u)) + mu = zeros(Float64, length(u)) + sigma2 = zeros(Float64, length(u)) tau = Inf for i in 1:length(u) for j in 1:size(nu, 2) @@ -455,11 +455,11 @@ function compute_tau_implicit(u, rate_cache, nu, p, t, rate) for i in 1:length(u) sum_nu_a = 0.0 for j in 1:size(nu, 2) - if nu[i, j] < 0 # Only sum negative stoichiometry + if nu[i, j] < 0 sum_nu_a += abs(nu[i, j]) * rate_cache[j] end end - if sum_nu_a > 0 && u[i] > 0 # Avoid division by zero + if sum_nu_a > 0 && u[i] > 0 tau = min(tau, u[i] / sum_nu_a) end end @@ -467,9 +467,8 @@ function compute_tau_implicit(u, rate_cache, nu, p, t, rate) end function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - # Nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (k_j - tau * (a_j(u_prev) - a_j(u_new)))) = 0 - function f(u_new) - rate_new = zeros(Float64, numjumps) + function f(u_new, p) + rate_new = zeros(eltype(u_new), numjumps) rate(rate_new, u_new, p, t_prev + tau) residual = u_new - u_prev for j in 1:numjumps @@ -478,41 +477,14 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, return residual end - # Numerical Jacobian - function compute_jacobian(u_new) - n = length(u_new) - J = zeros(Float64, n, n) - h = 1e-6 - f_u = f(u_new) - for j in 1:n - u_pert = copy(u_new) - u_pert[j] += h - f_pert = f(u_pert) - J[:, j] = (f_pert - f_u) / h - end - return J - end + u_new = float.(u_prev + sum(nu[:, j] * counts[j] for j in 1:numjumps)) + prob = NonlinearProblem{false}(f, u_new, p) + sol = solve(prob, SimpleNewtonRaphson(), abstol=1e-6, maxiters=100) - # Inline Newton-Raphson - u_new = float.(u_prev + sum(nu[:, j] * counts[j] for j in 1:numjumps)) # Initial guess: explicit step - tol = 1e-6 - maxiters = 100 - for iter in 1:maxiters - F = f(u_new) - if norm(F) < tol - return round.(Int, max.(u_new, 0.0)) # Converged - end - J = compute_jacobian(u_new) - if abs(det(J)) < 1e-10 # Check for singular Jacobian - return nothing - end - delta = J \ F - u_new -= delta - if any(isnan.(u_new)) || any(isinf.(u_new)) - return nothing - end + if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) + return nothing end - return nothing # Failed to converge + return round.(Int64, max.(sol.u, 0.0)) end function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing, dtmin=1e-10, saveat=nothing) @@ -555,7 +527,6 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; while t[end] < t_end u_prev = u[end] t_prev = t[end] - # Recompute stoichiometry for j in 1:numjumps fill!(counts_temp, 0) counts_temp[j] = 1 @@ -565,11 +536,10 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; rate(rate_cache, u_prev, p, t_prev) tau_prime = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate) tau_double_prime = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, rate) - # Cao et al. (2007): Use tau_prime for explicit, tau_double_prime for implicit use_implicit = false - tau = tau_prime # Default to explicit - if tau_double_prime < tau_prime && any(u_prev .< 10) # Implicit if populations are low - tau = tau_double_prime + tau = tau_prime + if any(u_prev .< 10) + tau = min(tau_double_prime, tau_prime) # Tighter cap for accuracy use_implicit = true end tau = max(tau, dtmin) @@ -585,11 +555,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; if use_implicit u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) if u_new === nothing || any(u_new .< 0) - tau /= 2 # Halve tau if implicit fails or produces negative populations + tau /= 2 continue end elseif any(u_new .< 0) - tau /= 2 # Halve tau if explicit produces negative populations + tau /= 2 continue end u_new = max.(u_new, 0) From 6fd91106c849fc426ed671290ab6d7f7084228b4 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Wed, 20 Aug 2025 11:08:58 +0530 Subject: [PATCH 16/29] added critical_threshold --- src/simple_regular_solve.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index b2ec2b5f..319f11aa 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -407,10 +407,11 @@ end # SimpleImplicitTauLeaping implementation struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm - epsilon::Float64 # Error control parameter + epsilon::Float64 + critical_threshold::Float64 end -SimpleImplicitTauLeaping(; epsilon=0.05) = SimpleImplicitTauLeaping(epsilon) +SimpleImplicitTauLeaping(; epsilon=0.05, L=10.0) = SimpleImplicitTauLeaping(epsilon, L) function compute_hor(nu) hor = zeros(Int64, size(nu, 2)) @@ -509,6 +510,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; du = similar(u0, Int) t_end = tspan[2] epsilon = alg.epsilon + critical_threshold = alg.critical_threshold # Compute initial stoichiometry and HOR nu = zeros(Int, length(u0), numjumps) @@ -538,8 +540,8 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; tau_double_prime = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, rate) use_implicit = false tau = tau_prime - if any(u_prev .< 10) - tau = min(tau_double_prime, tau_prime) # Tighter cap for accuracy + if any(u_prev .< critical_threshold) + tau = min(tau_double_prime, tau_prime) use_implicit = true end tau = max(tau, dtmin) From 2b0667a7c7c5be26f2568ffa697c83f978e23bd5 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Wed, 20 Aug 2025 12:03:38 +0530 Subject: [PATCH 17/29] residual update --- src/simple_regular_solve.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 319f11aa..575f8fd4 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -411,7 +411,7 @@ struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm critical_threshold::Float64 end -SimpleImplicitTauLeaping(; epsilon=0.05, L=10.0) = SimpleImplicitTauLeaping(epsilon, L) +SimpleImplicitTauLeaping(; epsilon=0.05, critical_threshold=10.0) = SimpleImplicitTauLeaping(epsilon, critical_threshold) function compute_hor(nu) hor = zeros(Int64, size(nu, 2)) @@ -471,9 +471,10 @@ function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, function f(u_new, p) rate_new = zeros(eltype(u_new), numjumps) rate(rate_new, u_new, p, t_prev + tau) - residual = u_new - u_prev + residual = zeros(eltype(u_new), length(u_new)) + residual .= u_new - u_prev for j in 1:numjumps - residual -= nu[:, j] * (counts[j] - tau * (rate_cache[j] - rate_new[j])) + residual .-= nu[:, j] * (counts[j] - tau * (rate_cache[j] - rate_new[j])) end return residual end From a6af972fafb6fcb563474ca1a8350593959029ae Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Wed, 20 Aug 2025 12:05:35 +0530 Subject: [PATCH 18/29] added comment line --- src/simple_regular_solve.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 575f8fd4..6ff6a4a7 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -468,6 +468,7 @@ function compute_tau_implicit(u, rate_cache, nu, p, t, rate) end function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) + # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0 function f(u_new, p) rate_new = zeros(eltype(u_new), numjumps) rate(rate_new, u_new, p, t_prev + tau) From becb2a55f5ed60ea7d9c19aaea663dc390c6b172 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:28:10 +0530 Subject: [PATCH 19/29] SimpleImplicitTauLeaping --- src/JumpProcesses.jl | 2 +- src/simple_regular_solve.jl | 330 +++++++++++++++++++++--------------- test/regular_jumps.jl | 145 +++++++++++----- 3 files changed, 297 insertions(+), 180 deletions(-) diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 9631fbdb..70e5306f 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -133,7 +133,7 @@ export SSAStepper # leaping: include("simple_regular_solve.jl") -export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel +export SimpleTauLeaping, SimpleExplicitTauLeaping, SimpleImplicitTauLeaping, NewtonImplicitSolver, TrapezoidalImplicitSolver, EnsembleGPUKernel # spatial: include("spatial/spatial_massaction_jump.jl") diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 6ff6a4a7..9b606c85 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -6,6 +6,18 @@ end SimpleExplicitTauLeaping(; epsilon = 0.05) = SimpleExplicitTauLeaping(epsilon) +# Define solver type hierarchy +abstract type AbstractImplicitSolver end +struct NewtonImplicitSolver <: AbstractImplicitSolver end +struct TrapezoidalImplicitSolver <: AbstractImplicitSolver end + +struct SimpleImplicitTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm + epsilon::T # Error control parameter for tau selection + solver::AbstractImplicitSolver # Solver type: Newton or Trapezoidal +end + +SimpleImplicitTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver()) = SimpleImplicitTauLeaping(epsilon, solver) + function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ @@ -69,6 +81,19 @@ function _process_saveat(saveat, tspan, save_start, save_end) return saveat_vec, _save_start, _save_end end +function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping) + if !(jump_prob.aggregator isa PureLeaping) + @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ + JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \ + Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release." + end + isempty(jump_prob.jump_callback.continuous_callbacks) && + isempty(jump_prob.jump_callback.discrete_callbacks) && + isempty(jump_prob.constant_jumps) && + isempty(jump_prob.variable_jumps) && + jump_prob.massaction_jump !== nothing +end + function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; seed = nothing, dt = error("dt is required for SimpleTauLeaping."), saveat = nothing, save_start = nothing, save_end = nothing) @@ -405,190 +430,233 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; return sol end -# SimpleImplicitTauLeaping implementation -struct SimpleImplicitTauLeaping <: DiffEqBase.DEAlgorithm - epsilon::Float64 - critical_threshold::Float64 -end - -SimpleImplicitTauLeaping(; epsilon=0.05, critical_threshold=10.0) = SimpleImplicitTauLeaping(epsilon, critical_threshold) - -function compute_hor(nu) - hor = zeros(Int64, size(nu, 2)) - for j in 1:size(nu, 2) - hor[j] = sum(abs.(nu[:, j])) > maximum(abs.(nu[:, j])) ? 2 : 1 +function compute_hor(reactant_stoch, numjumps) + # Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. + # HOR is the sum of stoichiometric coefficients of reactants in reaction j. + hor = zeros(Int, numjumps) + for j in 1:numjumps + order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) + if order > 3 + error("Reaction $j has order $order, which is not supported (maximum order is 3).") + end + hor[j] = order end return hor end -function compute_gi(u, nu, hor, i) - max_order = 1.0 - for j in 1:size(nu, 2) - if abs(nu[i, j]) > 0 - max_order = max(max_order, Float64(hor[j])) +function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjumps) + # Precompute reaction conditions for each species i, including: + # - max_hor: the highest order of reaction (HOR) where species i is a reactant. + # - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor. + # Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27). + max_hor = zeros(Int, numspecies) + max_stoich = zeros(Int, numspecies) + for j in 1:numjumps + for (spec_idx, stoch) in reactant_stoch[j] + if stoch > 0 # Species is a reactant + if hor[j] > max_hor[spec_idx] + max_hor[spec_idx] = hor[j] + max_stoich[spec_idx] = stoch + elseif hor[j] == max_hor[spec_idx] + max_stoich[spec_idx] = max(max_stoich[spec_idx], stoch) + end + end end end - return max_order + return max_hor, max_stoich end -function compute_tau_explicit(u, rate_cache, nu, hor, p, t, epsilon, rate) - rate(rate_cache, u, p, t) - mu = zeros(Float64, length(u)) - sigma2 = zeros(Float64, length(u)) - tau = Inf - for i in 1:length(u) - for j in 1:size(nu, 2) - mu[i] += nu[i, j] * rate_cache[j] - sigma2[i] += nu[i, j]^2 * rate_cache[j] +function compute_gi(u, max_hor, max_stoich, i, t) + # Compute g_i for species i to bound the relative change in propensity functions, + # as per Cao et al. (2006), Section IV, equation (27). + # g_i is determined by the highest order of reaction (HOR) and maximum stoichiometry (nu_ij) where species i is a reactant: + # - HOR = 1 (first-order, e.g., S_i -> products): g_i = 1 + # - HOR = 2 (second-order): + # - nu_ij = 1 (e.g., S_i + S_k -> products): g_i = 2 + # - nu_ij = 2 (e.g., 2S_i -> products): g_i = 2 + 1/(x_i - 1) + # - HOR = 3 (third-order): + # - nu_ij = 1 (e.g., S_i + S_k + S_m -> products): g_i = 3 + # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1)) + # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) + # Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep. + if max_hor[i] == 0 # No reactions involve species i as a reactant + return 1.0 + elseif max_hor[i] == 1 + return 1.0 + elseif max_hor[i] == 2 + if max_stoich[i] == 1 + return 2.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 # Fallback to 2.0 if x_i <= 1 + end + elseif max_hor[i] == 3 + if max_stoich[i] == 1 + return 3.0 + elseif max_stoich[i] == 2 + return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 # Fallback to 3.0 if x_i <= 1 + elseif max_stoich[i] == 3 + return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 # Fallback to 3.0 if x_i <= 2 end - gi = compute_gi(u, nu, hor, i) - bound = max(epsilon * u[i] / gi, 1.0) - mu_term = abs(mu[i]) > 0 ? bound / abs(mu[i]) : Inf - sigma_term = sigma2[i] > 0 ? bound^2 / sigma2[i] : Inf - tau = min(tau, mu_term, sigma_term) end - return tau + return 1.0 # Default case end -function compute_tau_implicit(u, rate_cache, nu, p, t, rate) +function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + # Compute the tau-leaping step-size using equation (20) from Cao et al. (2006): + # tau = min_{i in I_rs} { max(epsilon * x_i / g_i, 1) / |mu_i(x)|, max(epsilon * x_i / g_i, 1)^2 / sigma_i^2(x) } + # where mu_i(x) and sigma_i^2(x) are defined in equations (9a) and (9b): + # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) + # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). rate(rate_cache, u, p, t) + if all(==(0.0), rate_cache) # Handle case where all rates are zero + return dtmin + end tau = Inf for i in 1:length(u) - sum_nu_a = 0.0 + mu = zero(eltype(u)) + sigma2 = zero(eltype(u)) for j in 1:size(nu, 2) - if nu[i, j] < 0 - sum_nu_a += abs(nu[i, j]) * rate_cache[j] - end - end - if sum_nu_a > 0 && u[i] > 0 - tau = min(tau, u[i] / sum_nu_a) + mu += nu[i, j] * rate_cache[j] # Equation (9a) + sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b) end + gi = compute_gi(u, max_hor, max_stoich, i, t) + bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1) + mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) + sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) + tau = min(tau, mu_term, sigma_term) # Equation (8) end - return tau + return max(tau, dtmin) end -function implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - # Define the nonlinear system: F(u_new) = u_new - u_prev - sum(nu_j * (counts_j - tau * a_j(u_prev) + tau * a_j(u_new))) = 0 - function f(u_new, p) - rate_new = zeros(eltype(u_new), numjumps) - rate(rate_new, u_new, p, t_prev + tau) - residual = zeros(eltype(u_new), length(u_new)) - residual .= u_new - u_prev - for j in 1:numjumps - residual .-= nu[:, j] * (counts[j] - tau * (rate_cache[j] - rate_new[j])) +# Define residual for implicit equation +# Newton: u_new = u_current + sum_j nu_j * a_j(u_new) * tau (Cao et al., 2004) +# Trapezoidal: u_new = u_current + sum_j nu_j * (a_j(u_current) + a_j(u_new))/2 * tau +function implicit_equation!(resid, u_new, params) + u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params + rate(rate_cache, u_new, p, t + tau) + resid .= u_new .- u_current + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + if isa(solver, NewtonImplicitSolver) + resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004) + else # TrapezoidalImplicitSolver + rate_current = similar(rate_cache) + rate(rate_current, u_current, p, t) + resid[spec_idx] -= nu[spec_idx, j] * 0.5 * (rate_cache[j] + rate_current[j]) * tau + end end - return residual end + resid .= max.(resid, -u_new) # Ensure non-negative solution +end - u_new = float.(u_prev + sum(nu[:, j] * counts[j] for j in 1:numjumps)) - prob = NonlinearProblem{false}(f, u_new, p) - sol = solve(prob, SimpleNewtonRaphson(), abstol=1e-6, maxiters=100) - - if sol.retcode != ReturnCode.Success || any(isnan.(sol.u)) || any(isinf.(sol.u)) - return nothing - end - return round.(Int64, max.(sol.u, 0.0)) +# Solve implicit equation using SimpleNonlinearSolve +function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver) + u_new = convert(Vector{Float64}, u_current) + prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)) + sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6) + return sol.u, sol.retcode == ReturnCode.Success end -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed=nothing, dtmin=1e-10, saveat=nothing) - @assert isempty(jump_prob.jump_callback.continuous_callbacks) - @assert isempty(jump_prob.jump_callback.discrete_callbacks) - prob = jump_prob.prob - rng = DEFAULT_RNG +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; + seed = nothing, + dtmin = 1e-10, + saveat = nothing) + validate_pure_leaping_inputs(jump_prob, alg) || + error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") + + @unpack prob, rng = jump_prob (seed !== nothing) && seed!(rng, seed) + maj = jump_prob.massaction_jump + numjumps = get_num_majumps(maj) rj = jump_prob.regular_jump - rate = rj.rate - numjumps = rj.numjumps - c = rj.c + # Extract rates + rate = rj !== nothing ? rj.rate : + (out, u, p, t) -> begin + for j in 1:numjumps + out[j] = evalrxrate(u, j, maj) + end + end + c = rj !== nothing ? rj.c : nothing u0 = copy(prob.u0) tspan = prob.tspan p = prob.p - u = [copy(u0)] - t = [tspan[1]] + u_current = copy(u0) + t_current = tspan[1] + usave = [copy(u0)] + tsave = [tspan[1]] rate_cache = zeros(Float64, numjumps) - counts = zeros(Int, numjumps) - du = similar(u0, Int) + counts = zeros(Int64, numjumps) + du = similar(u0) t_end = tspan[2] epsilon = alg.epsilon - critical_threshold = alg.critical_threshold + solver = alg.solver - # Compute initial stoichiometry and HOR - nu = zeros(Int, length(u0), numjumps) - counts_temp = zeros(Int, numjumps) + nu = zeros(Int64, length(u0), numjumps) for j in 1:numjumps - fill!(counts_temp, 0) - counts_temp[j] = 1 - c(du, u0, p, t[1], counts_temp, nothing) - nu[:, j] = du + for (spec_idx, stoch) in maj.net_stoch[j] + nu[spec_idx, j] = stoch + end end - hor = compute_hor(nu) + reactant_stoch = maj.reactant_stoch + hor = compute_hor(reactant_stoch, numjumps) + max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) - saveat_times = isnothing(saveat) ? Float64[] : saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat) + saveat_times = isnothing(saveat) ? Vector{Float64}() : + (saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat)) save_idx = 1 - while t[end] < t_end - u_prev = u[end] - t_prev = t[end] - for j in 1:numjumps - fill!(counts_temp, 0) - counts_temp[j] = 1 - c(du, u_prev, p, t_prev, counts_temp, nothing) - nu[:, j] = du + while t_current < t_end + rate(rate_cache, u_current, p, t_current) + tau = compute_tau(u_current, rate_cache, nu, hor, p, t_current, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) + tau = min(tau, t_end - t_current) + if !isempty(saveat_times) && save_idx <= length(saveat_times) && t_current + tau > saveat_times[save_idx] + tau = saveat_times[save_idx] - t_current end - rate(rate_cache, u_prev, p, t_prev) - tau_prime = compute_tau_explicit(u_prev, rate_cache, nu, hor, p, t_prev, epsilon, rate) - tau_double_prime = compute_tau_implicit(u_prev, rate_cache, nu, p, t_prev, rate) - use_implicit = false - tau = tau_prime - if any(u_prev .< critical_threshold) - tau = min(tau_double_prime, tau_prime) - use_implicit = true + + u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p, t_current, tau, rate, numjumps, solver) + if !converged + tau /= 2 + continue end - tau = max(tau, dtmin) - tau = min(tau, t_end - t_prev) - if !isempty(saveat_times) - if save_idx <= length(saveat_times) && t_prev + tau > saveat_times[save_idx] - tau = saveat_times[save_idx] - t_prev + + rate(rate_cache, u_new_float, p, t_current + tau) + counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) + du .= zero(eltype(u_current)) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + du[spec_idx] += stoch * counts[j] end end - counts .= pois_rand.((rng,), max.(rate_cache * tau, 0.0)) - c(du, u_prev, p, t_prev, counts, nothing) - u_new = u_prev + du - if use_implicit - u_new = implicit_tau_step(u_prev, t_prev, tau, rate_cache, counts, nu, p, rate, numjumps) - if u_new === nothing || any(u_new .< 0) - tau /= 2 - continue - end - elseif any(u_new .< 0) + u_new = u_current + du + + if any(<(0), u_new) + # Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3 tau /= 2 continue end - u_new = max.(u_new, 0) - push!(u, u_new) - push!(t, t_prev + tau) - if !isempty(saveat_times) && save_idx <= length(saveat_times) && t[end] >= saveat_times[save_idx] - save_idx += 1 + # Ensure non-negativity, as per Cao et al. (2006), Section 3.3 + for i in eachindex(u_new) + u_new[i] = max(u_new[i], 0) end - end + t_new = t_current + tau - # Interpolate to saveat times if specified - if !isempty(saveat_times) - t_out = saveat_times - u_out = [u[end]] - for t_save in saveat_times - idx = findlast(ti -> ti <= t_save, t) - push!(u_out, u[idx]) + if isempty(saveat_times) || (save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx]) + push!(usave, copy(u_new)) + push!(tsave, t_new) + if !isempty(saveat_times) && t_new >= saveat_times[save_idx] + save_idx += 1 + end end - t = t_out - u = u_out[2:end] + + u_current = u_new + t_current = t_new end - sol = DiffEqBase.build_solution(prob, alg, t, u, + sol = DiffEqBase.build_solution(prob, alg, tsave, usave, calculate_error=false, - interp=DiffEqBase.ConstantInterpolation(t, u)) + interp=DiffEqBase.ConstantInterpolation(tsave, usave)) return sol end @@ -604,5 +672,3 @@ end function EnsembleGPUKernel() EnsembleGPUKernel(nothing, 0.0) end - -export SimpleTauLeaping, EnsembleGPUKernel, SimpleImplicitTauLeaping diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 738c9089..75169161 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -262,6 +262,7 @@ end influx_rate = 1.0 p = (β, ν, influx_rate) + # ConstantRateJump formulation for SSAStepper rate1(u, p, t) = p[1] * u[1] * u[2] rate2(u, p, t) = p[2] * u[2] rate3(u, p, t) = p[3] @@ -270,35 +271,38 @@ end affect3!(integrator) = (integrator.u[1] += 1; nothing) jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) - u0 = [999, 10, 0] # Integer initial conditions + u0 = [999.0, 10.0, 0.0] # S, I, R tspan = (0.0, 250.0) prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345)) - - sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) - regular_rate = (out, u, p, t) -> begin - out[1] = p[1] * u[1] * u[2] - out[2] = p[2] * u[2] - out[3] = p[3] - end - regular_c = (dc, u, p, t, counts, mark) -> begin - dc .= 0 - dc[1] = -counts[1] + counts[3] - dc[2] = counts[1] - counts[2] - dc[3] = counts[2] - end - rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) - - sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + # Solve with SSAStepper + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); + trajectories = Nsims, saveat = 1.0) + # MassActionJump formulation for SimpleImplicitTauLeaping + reactant_stoich = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [1=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs = param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng = rng) + + # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) + sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()), + EnsembleSerial(); trajectories = Nsims, saveat = 1.0) + sol_implicit_trapezoidal = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()), + EnsembleSerial(); trajectories = Nsims, saveat = 1.0) + + # Compare peak I trajectories against SSA t_points = 0:1.0:250.0 - mean_direct_S = [mean(sol_direct[i](t)[1] for i in 1:Nsims) for t in t_points] - mean_implicit_S = [mean(sol_implicit[i](t)[1] for i in 1:Nsims) for t in t_points] + max_direct_I = maximum([mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_implicit_newton = maximum([mean(sol_implicit_newton[i](t)[2] for i in 1:Nsims) for t in t_points]) + max_implicit_trapezoidal = maximum([mean(sol_implicit_trapezoidal[i](t)[2] for i in 1:Nsims) for t in t_points]) - max_error_implicit = maximum(abs.(mean_direct_S .- mean_implicit_S)) - @test max_error_implicit < 0.01 * mean(mean_direct_S) + @test isapprox(max_direct_I, max_implicit_newton, rtol = 0.05) + @test isapprox(max_direct_I, max_implicit_trapezoidal, rtol = 0.05) end # SimpleImplicitTauLeaping correctness - SEIR model @@ -308,6 +312,7 @@ end ν = 0.01 p = (β, σ, ν) + # ConstantRateJump formulation for SSAStepper rate1(u, p, t) = p[1] * u[1] * u[3] rate2(u, p, t) = p[2] * u[2] rate3(u, p, t) = p[3] * u[3] @@ -316,36 +321,82 @@ end affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing) jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) - u0 = [999, 0, 10, 0] # Integer initial conditions + u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R tspan = (0.0, 250.0) prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng=StableRNG(12345)) - - sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); trajectories=Nsims) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) - regular_rate = (out, u, p, t) -> begin - out[1] = p[1] * u[1] * u[3] - out[2] = p[2] * u[2] - out[3] = p[3] * u[3] - end - regular_c = (dc, u, p, t, counts, mark) -> begin - dc .= 0 - dc[1] = -counts[1] - dc[2] = counts[1] - counts[2] - dc[3] = counts[2] - counts[3] - dc[4] = counts[3] - end - rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, Direct(), rj; rng=StableRNG(12345)) - - sol_implicit = solve(EnsembleProblem(jump_prob_tau), SimpleImplicitTauLeaping(), EnsembleSerial(); trajectories=Nsims, saveat=1.0) + # Solve with SSAStepper + sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); + trajectories = Nsims, saveat = 1.0) + # MassActionJump formulation for SimpleImplicitTauLeaping + reactant_stoich = [[1=>1, 3=>1], [2=>1], [3=>1]] + net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [3=>-1, 4=>1]] + param_idxs = [1, 2, 3] + maj = MassActionJump(reactant_stoich, net_stoich; param_idxs = param_idxs) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng = rng) + + # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) + sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()), + EnsembleSerial(); trajectories = Nsims, saveat = 1.0) + sol_implicit_trapezoidal = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()), + EnsembleSerial(); trajectories = Nsims, saveat = 1.0) + + # Compare peak I trajectories against SSA (I is index 3 in SEIR) t_points = 0:1.0:250.0 - mean_direct_R = [mean(sol_direct[i](t)[4] for i in 1:Nsims) for t in t_points] - mean_implicit_R = [mean(sol_implicit[i](t)[4] for i in 1:Nsims) for t in t_points] + max_direct_I = maximum([mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_implicit_newton = maximum([mean(sol_implicit_newton[i](t)[3] for i in 1:Nsims) for t in t_points]) + max_implicit_trapezoidal = maximum([mean(sol_implicit_trapezoidal[i](t)[3] for i in 1:Nsims) for t in t_points]) + + @test isapprox(max_direct_I, max_implicit_newton, rtol = 0.05) + @test isapprox(max_direct_I, max_implicit_trapezoidal, rtol = 0.05) +end - max_error_implicit = maximum(abs.(mean_direct_R .- mean_implicit_R)) - @test max_error_implicit < 0.01 * mean(mean_direct_R) +# SimpleImplicitTauLeaping integration tests (stiff and reversible systems) +@testset "SimpleImplicitTauLeaping Integration Tests" begin + # Stiff system from Cao et al. (2007): S1 -> S2, S2 -> S1, S2 -> S3 + c = (1000.0, 1000.0, 1.0) + reactant_stoich1 = [Pair(1, 1)] + net_stoich1 = [Pair(1, -1), Pair(2, 1)] + reactant_stoich2 = [Pair(2, 1)] + net_stoich2 = [Pair(1, 1), Pair(2, -1)] + reactant_stoich3 = [Pair(2, 1)] + net_stoich3 = [Pair(2, -1), Pair(3, 1)] + stiff_jumps = MassActionJump([c[1], c[2], c[3]], + [reactant_stoich1, reactant_stoich2, reactant_stoich3], + [net_stoich1, net_stoich2, net_stoich3]) + + u0 = [100, 0, 0] + tspan = (0.0, 10.0) + prob = DiscreteProblem(u0, tspan) + jump_prob = JumpProblem(prob, PureLeaping(), stiff_jumps) + sol = solve(jump_prob, SimpleImplicitTauLeaping(); saveat = 0.1) + @test sol.t[end] ≈ 10.0 atol = 1e-6 + + # Reversible isomerization from Cao et al. (2004) + c2 = (1000.0, 1000.0) + reactant_stoich_r1 = [Pair(1, 1)] + net_stoich_r1 = [Pair(1, -1), Pair(2, 1)] + reactant_stoich_r2 = [Pair(2, 1)] + net_stoich_r2 = [Pair(1, 1), Pair(2, -1)] + rev_jumps = MassActionJump([c2[1], c2[2]], + [reactant_stoich_r1, reactant_stoich_r2], + [net_stoich_r1, net_stoich_r2]) + + u0_rev = [1000, 0] + tspan_rev = (0.0, 0.1) + prob_rev = DiscreteProblem(u0_rev, tspan_rev) + jump_prob_rev = JumpProblem(prob_rev, PureLeaping(), rev_jumps) + + sol_newton = solve(jump_prob_rev, SimpleImplicitTauLeaping(epsilon = 0.05, + solver = NewtonImplicitSolver())) + sol_trapezoidal = solve(jump_prob_rev, SimpleImplicitTauLeaping(epsilon = 0.05, + solver = TrapezoidalImplicitSolver())) + @test sol_newton.t[end] ≈ 0.1 atol = 1e-6 + @test sol_trapezoidal.t[end] ≈ 0.1 atol = 1e-6 end # Test that saveat/save_start/save_end control which times are stored in solutions From fe71c61823884d66f8576e18e1f1aacb25dc6f8f Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:29:50 +0530 Subject: [PATCH 20/29] project.toml --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 3d442047..e96c9b4a 100644 --- a/Project.toml +++ b/Project.toml @@ -12,14 +12,12 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" PoissonRandom = "e409e4f3-bfea-5376-8464-e040bb5c01ab" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" -SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" From 1e2b6e6377bbbb1a6c828c5c726ebc1766af7a29 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sat, 6 Sep 2025 00:05:51 +0530 Subject: [PATCH 21/29] project.toml --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index e96c9b4a..1d77e4b3 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" +SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" From c8022cb4ee423c5be6b0a09926a6d7dd3a39a8f3 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sat, 6 Sep 2025 00:27:58 +0530 Subject: [PATCH 22/29] some --- test/regular_jumps.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 75169161..f201ab46 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -82,7 +82,7 @@ end @test all(isapprox.(mean_I_direct, mean_I_explicit, rtol = 0.1)) end -# SEIR model with exposed compartment +# SEIR model with exposed compartmen @testset "SEIR Model Correctness" begin β = 0.3 / 1000.0 σ = 0.2 From 2845eb4fad0f1a0cabdd89040702f3579115a567 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sat, 6 Sep 2025 00:28:03 +0530 Subject: [PATCH 23/29] some --- test/regular_jumps.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index f201ab46..75169161 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -82,7 +82,7 @@ end @test all(isapprox.(mean_I_direct, mean_I_explicit, rtol = 0.1)) end -# SEIR model with exposed compartmen +# SEIR model with exposed compartment @testset "SEIR Model Correctness" begin β = 0.3 / 1000.0 σ = 0.2 From e6a736e57d1718b1a3727f94601d362dfb23f7b4 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sat, 14 Feb 2026 00:19:52 +0530 Subject: [PATCH 24/29] refactor --- src/simple_regular_solve.jl | 164 ++++++++++++++++++++++-------------- 1 file changed, 99 insertions(+), 65 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 9b606c85..6a323838 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -433,9 +433,11 @@ end function compute_hor(reactant_stoch, numjumps) # Compute the highest order of reaction (HOR) for each reaction j, as per Cao et al. (2006), Section IV. # HOR is the sum of stoichiometric coefficients of reactants in reaction j. - hor = zeros(Int, numjumps) + # Extract the element type from reactant_stoch to avoid hardcoding type assumptions. + stoch_type = eltype(first(first(reactant_stoch))) + hor = zeros(stoch_type, numjumps) for j in 1:numjumps - order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=0) + order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=zero(stoch_type)) if order > 3 error("Reaction $j has order $order, which is not supported (maximum order is 3).") end @@ -449,8 +451,9 @@ function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjump # - max_hor: the highest order of reaction (HOR) where species i is a reactant. # - max_stoich: the maximum stoichiometry (nu_ij) in reactions with max_hor. # Used to optimize compute_gi, as per Cao et al. (2006), Section IV, equation (27). - max_hor = zeros(Int, numspecies) - max_stoich = zeros(Int, numspecies) + hor_type = eltype(hor) + max_hor = zeros(hor_type, numspecies) + max_stoich = zeros(hor_type, numspecies) for j in 1:numjumps for (spec_idx, stoch) in reactant_stoch[j] if stoch > 0 # Species is a reactant @@ -479,26 +482,31 @@ function compute_gi(u, max_hor, max_stoich, i, t) # - nu_ij = 2 (e.g., 2S_i + S_k -> products): g_i = (3/2) * (2 + 1/(x_i - 1)) # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) # Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep. + one_max_hor = one(1 / one(eltype(u))) + if max_hor[i] == 0 # No reactions involve species i as a reactant - return 1.0 + return one_max_hor elseif max_hor[i] == 1 - return 1.0 + return one_max_hor elseif max_hor[i] == 2 if max_stoich[i] == 1 - return 2.0 - elseif max_stoich[i] == 2 - return u[i] > 1 ? 2.0 + 1.0 / (u[i] - 1) : 2.0 # Fallback to 2.0 if x_i <= 1 + return 2 * one_max_hor + else # max_stoich[i] == 2 + return u[i] > one_max_hor ? + 2 * one_max_hor + one_max_hor / (u[i] - one_max_hor) : 2 * one_max_hor # Fallback to 2 if x_i <= 1 end elseif max_hor[i] == 3 if max_stoich[i] == 1 - return 3.0 + return 3 * one_max_hor elseif max_stoich[i] == 2 - return u[i] > 1 ? 1.5 * (2.0 + 1.0 / (u[i] - 1)) : 3.0 # Fallback to 3.0 if x_i <= 1 - elseif max_stoich[i] == 3 - return u[i] > 2 ? 3.0 + 1.0 / (u[i] - 1) + 2.0 / (u[i] - 2) : 3.0 # Fallback to 3.0 if x_i <= 2 + return u[i] > one_max_hor ? + (3 * one_max_hor / 2) * (2 * one_max_hor + one_max_hor / (u[i] - one_max_hor)) : 3 * one_max_hor # Fallback to 3 if x_i <= 1 + else # max_stoich[i] == 3 + return u[i] > 2 * one_max_hor ? + 3 * one_max_hor + one_max_hor / (u[i] - one_max_hor) + 2 * one_max_hor / (u[i] - 2 * one_max_hor) : 3 * one_max_hor # Fallback to 3 if x_i <= 2 end end - return 1.0 # Default case + return one_max_hor # Default case end function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps) @@ -508,10 +516,10 @@ function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor # mu_i(x) = sum_j nu_ij * a_j(x), sigma_i^2(x) = sum_j nu_ij^2 * a_j(x) # I_rs is the set of reactant species (assumed to be all species here, as critical reactions are not specified). rate(rate_cache, u, p, t) - if all(==(0.0), rate_cache) # Handle case where all rates are zero + if all(<=(0), rate_cache) # Handle case where all rates are zero or negative return dtmin end - tau = Inf + tau = typemax(typeof(t)) for i in 1:length(u) mu = zero(eltype(u)) sigma2 = zero(eltype(u)) @@ -520,9 +528,9 @@ function compute_tau(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor sigma2 += nu[i, j]^2 * rate_cache[j] # Equation (9b) end gi = compute_gi(u, max_hor, max_stoich, i, t) - bound = max(epsilon * u[i] / gi, 1.0) # max(epsilon * x_i / g_i, 1) - mu_term = abs(mu) > 0 ? bound / abs(mu) : Inf # First term in equation (8) - sigma_term = sigma2 > 0 ? bound^2 / sigma2 : Inf # Second term in equation (8) + bound = max(epsilon * u[i] / gi, one(eltype(u))) # max(epsilon * x_i / g_i, 1) + mu_term = abs(mu) > 0 ? bound / abs(mu) : typemax(typeof(t)) # First term in equation (8) + sigma_term = sigma2 > 0 ? bound^2 / sigma2 : typemax(typeof(t)) # Second term in equation (8) tau = min(tau, mu_term, sigma_term) # Equation (8) end return max(tau, dtmin) @@ -557,54 +565,20 @@ function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, so return sol.u, sol.retcode == ReturnCode.Success end -function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; - seed = nothing, - dtmin = 1e-10, - saveat = nothing) - validate_pure_leaping_inputs(jump_prob, alg) || - error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") - - @unpack prob, rng = jump_prob - (seed !== nothing) && seed!(rng, seed) - - maj = jump_prob.massaction_jump - numjumps = get_num_majumps(maj) - rj = jump_prob.regular_jump - # Extract rates - rate = rj !== nothing ? rj.rate : - (out, u, p, t) -> begin - for j in 1:numjumps - out[j] = evalrxrate(u, j, maj) - end - end - c = rj !== nothing ? rj.c : nothing - u0 = copy(prob.u0) - tspan = prob.tspan - p = prob.p - - u_current = copy(u0) - t_current = tspan[1] - usave = [copy(u0)] - tsave = [tspan[1]] - rate_cache = zeros(Float64, numjumps) - counts = zeros(Int64, numjumps) - du = similar(u0) - t_end = tspan[2] - epsilon = alg.epsilon - solver = alg.solver - - nu = zeros(Int64, length(u0), numjumps) - for j in 1:numjumps - for (spec_idx, stoch) in maj.net_stoch[j] - nu[spec_idx, j] = stoch +# Function to generate a mass action rate function +function massaction_rate(maj, numjumps) + return (out, u, p, t) -> begin + for j in 1:numjumps + out[j] = evalrxrate(u, j, maj) end end - reactant_stoch = maj.reactant_stoch - hor = compute_hor(reactant_stoch, numjumps) - max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) +end - saveat_times = isnothing(saveat) ? Vector{Float64}() : - (saveat isa Number ? collect(range(tspan[1], tspan[2], step=saveat)) : collect(saveat)) +function simple_implicit_tau_leaping_loop!( + prob, alg, u_current, t_current, t_end, p, rng, + rate, nu, hor, max_hor, max_stoich, numjumps, epsilon, + dtmin, saveat_times, usave, tsave, du, counts, rate_cache, + maj, solver) save_idx = 1 while t_current < t_end @@ -623,7 +597,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; rate(rate_cache, u_new_float, p, t_current + tau) counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) - du .= zero(eltype(u_current)) + du .= 0 for j in 1:numjumps for (spec_idx, stoch) in maj.net_stoch[j] du[spec_idx] += stoch * counts[j] @@ -653,6 +627,66 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; u_current = u_new t_current = t_new end +end + +function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; + seed = nothing, + dtmin = nothing, + saveat = nothing) + validate_pure_leaping_inputs(jump_prob, alg) || + error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") + + (; prob, rng) = jump_prob + (seed !== nothing) && Random.seed!(rng, seed) + + maj = jump_prob.massaction_jump + numjumps = get_num_majumps(maj) + rj = jump_prob.regular_jump + # Extract rates + rate = rj !== nothing ? rj.rate : massaction_rate(maj, numjumps) + c = rj !== nothing ? rj.c : nothing + u0 = copy(prob.u0) + tspan = prob.tspan + p = prob.p + + if dtmin === nothing + dtmin = 1e-10 * one(typeof(tspan[2])) + end + + u_current = copy(u0) + t_current = tspan[1] + usave = [copy(u0)] + tsave = [tspan[1]] + rate_cache = zeros(float(eltype(u0)), numjumps) + counts = zero(rate_cache) + du = similar(u0) + t_end = tspan[2] + epsilon = alg.epsilon + solver = alg.solver + + nu = zeros(float(eltype(u0)), length(u0), numjumps) + for j in 1:numjumps + for (spec_idx, stoch) in maj.net_stoch[j] + nu[spec_idx, j] = stoch + end + end + reactant_stoch = maj.reactant_stoch + hor = compute_hor(reactant_stoch, numjumps) + max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) + + # Set up saveat_times + if isnothing(saveat) + saveat_times = Vector{typeof(tspan[1])}() + elseif saveat isa Number + saveat_times = collect(range(tspan[1], tspan[2], step=saveat)) + else + saveat_times = collect(saveat) + end + simple_implicit_tau_leaping_loop!( + prob, alg, u_current, t_current, t_end, p, rng, + rate, nu, hor, max_hor, max_stoich, numjumps, epsilon, + dtmin, saveat_times, usave, tsave, du, counts, rate_cache, + maj, solver) sol = DiffEqBase.build_solution(prob, alg, tsave, usave, calculate_error=false, From 77991ee7856141fc68aaacb0100a1b3337b827de Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sat, 14 Feb 2026 00:39:47 +0530 Subject: [PATCH 25/29] some changes --- src/simple_regular_solve.jl | 57 +++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 6a323838..000304bf 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -436,9 +436,10 @@ function compute_hor(reactant_stoch, numjumps) # Extract the element type from reactant_stoch to avoid hardcoding type assumptions. stoch_type = eltype(first(first(reactant_stoch))) hor = zeros(stoch_type, numjumps) + max_order = 3 * one(stoch_type) # Maximum supported reaction order (type-aware) for j in 1:numjumps order = sum(stoch for (spec_idx, stoch) in reactant_stoch[j]; init=zero(stoch_type)) - if order > 3 + if order > max_order error("Reaction $j has order $order, which is not supported (maximum order is 3).") end hor[j] = order @@ -454,9 +455,11 @@ function precompute_reaction_conditions(reactant_stoch, hor, numspecies, numjump hor_type = eltype(hor) max_hor = zeros(hor_type, numspecies) max_stoich = zeros(hor_type, numspecies) + stoch_type = eltype(first(first(reactant_stoch))) + zero_stoch = zero(stoch_type) for j in 1:numjumps for (spec_idx, stoch) in reactant_stoch[j] - if stoch > 0 # Species is a reactant + if stoch > zero_stoch # Species is a reactant if hor[j] > max_hor[spec_idx] max_hor[spec_idx] = hor[j] max_stoich[spec_idx] = stoch @@ -483,22 +486,27 @@ function compute_gi(u, max_hor, max_stoich, i, t) # - nu_ij = 3 (e.g., 3S_i -> products): g_i = 3 + 1/(x_i - 1) + 2/(x_i - 2) # Uses precomputed max_hor and max_stoich to reduce work to O(num_species) per timestep. one_max_hor = one(1 / one(eltype(u))) + hor_type = eltype(max_hor) + zero_hor = zero(hor_type) + one_hor = one(hor_type) + two_hor = one_hor + one_hor + three_hor = two_hor + one_hor - if max_hor[i] == 0 # No reactions involve species i as a reactant + if max_hor[i] == zero_hor # No reactions involve species i as a reactant return one_max_hor - elseif max_hor[i] == 1 + elseif max_hor[i] == one_hor return one_max_hor - elseif max_hor[i] == 2 - if max_stoich[i] == 1 + elseif max_hor[i] == two_hor + if max_stoich[i] == one_hor return 2 * one_max_hor else # max_stoich[i] == 2 return u[i] > one_max_hor ? 2 * one_max_hor + one_max_hor / (u[i] - one_max_hor) : 2 * one_max_hor # Fallback to 2 if x_i <= 1 end - elseif max_hor[i] == 3 - if max_stoich[i] == 1 + elseif max_hor[i] == three_hor + if max_stoich[i] == one_hor return 3 * one_max_hor - elseif max_stoich[i] == 2 + elseif max_stoich[i] == two_hor return u[i] > one_max_hor ? (3 * one_max_hor / 2) * (2 * one_max_hor + one_max_hor / (u[i] - one_max_hor)) : 3 * one_max_hor # Fallback to 3 if x_i <= 1 else # max_stoich[i] == 3 @@ -543,14 +551,19 @@ function implicit_equation!(resid, u_new, params) u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params rate(rate_cache, u_new, p, t + tau) resid .= u_new .- u_current - for j in 1:numjumps - for spec_idx in 1:size(nu, 1) - if isa(solver, NewtonImplicitSolver) + if isa(solver, NewtonImplicitSolver) + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004) - else # TrapezoidalImplicitSolver - rate_current = similar(rate_cache) - rate(rate_current, u_current, p, t) - resid[spec_idx] -= nu[spec_idx, j] * 0.5 * (rate_cache[j] + rate_current[j]) * tau + end + end + else # TrapezoidalImplicitSolver + rate_current = similar(rate_cache) + rate(rate_current, u_current, p, t) + half = one(eltype(rate_cache)) / 2 + for j in 1:numjumps + for spec_idx in 1:size(nu, 1) + resid[spec_idx] -= nu[spec_idx, j] * half * (rate_cache[j] + rate_current[j]) * tau end end end @@ -559,7 +572,7 @@ end # Solve implicit equation using SimpleNonlinearSolve function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver) - u_new = convert(Vector{Float64}, u_current) + u_new = convert(Vector{float(eltype(u_current))}, u_current) prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)) sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6) return sol.u, sol.retcode == ReturnCode.Success @@ -596,8 +609,9 @@ function simple_implicit_tau_leaping_loop!( end rate(rate_cache, u_new_float, p, t_current + tau) - counts .= pois_rand.(rng, max.(rate_cache * tau, 0.0)) - du .= 0 + zero_rate = zero(eltype(rate_cache)) + counts .= pois_rand.(rng, max.(rate_cache * tau, zero_rate)) + du .= zero(eltype(du)) for j in 1:numjumps for (spec_idx, stoch) in maj.net_stoch[j] du[spec_idx] += stoch * counts[j] @@ -605,14 +619,15 @@ function simple_implicit_tau_leaping_loop!( end u_new = u_current + du - if any(<(0), u_new) + zero_pop = zero(eltype(u_new)) + if any(<(zero_pop), u_new) # Halve tau to avoid negative populations, as per Cao et al. (2006), Section 3.3 tau /= 2 continue end # Ensure non-negativity, as per Cao et al. (2006), Section 3.3 for i in eachindex(u_new) - u_new[i] = max(u_new[i], 0) + u_new[i] = max(u_new[i], zero_pop) end t_new = t_current + tau From 0277f27d500c52753229f1e9969b8c09829b1167 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan <95441117+sivasathyaseeelan@users.noreply.github.com> Date: Sun, 15 Feb 2026 00:14:10 +0530 Subject: [PATCH 26/29] refactor --- Project.toml | 4 +- src/JumpProcesses.jl | 5 +- test/regular_jumps.jl | 196 ++++++++++++++---------------------------- 3 files changed, 68 insertions(+), 137 deletions(-) diff --git a/Project.toml b/Project.toml index 1d77e4b3..77c51fbf 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas "] version = "9.25.1" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" @@ -65,7 +66,6 @@ Test = "1" julia = "1.10" [extras] -ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" @@ -80,4 +80,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ADTypes", "Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"] +test = ["Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"] diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 70e5306f..e2f49fcb 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -18,7 +18,8 @@ using StaticArrays: StaticArrays, SVector, setindex using Base.Threads: Threads using Base.FastMath: add_fast -using SimpleNonlinearSolve +using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleNewtonRaphson +using ADTypes: ADTypes, AutoFiniteDiff # Import functions we extend from Base import Base: size, getindex, setindex!, length, similar, show, merge!, merge @@ -42,7 +43,7 @@ using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction, ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!, deleteat!, isinplace, remake, savevalues!, step!, u_modified! -using SciMLBase: SciMLBase, DEIntegrator +using SciMLBase: SciMLBase, DEIntegrator, NonlinearProblem abstract type AbstractJump end abstract type AbstractMassActionJump <: AbstractJump end diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 75169161..2e2c1fed 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -61,7 +61,7 @@ end sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 0.1, saveat = t_compare) - # MassActionJump formulation for SimpleExplicitTauLeaping + # MassActionJump formulation for SimpleExplicitTauLeaping / SimpleImplicitTauLeaping reactant_stoich = [[1 => 1, 2 => 1], [2 => 1], Pair{Int, Int}[]] net_stoich = [[1 => -1, 2 => 1], [2 => -1, 3 => 1], [1 => 1]] param_idxs = [1, 2, 3] @@ -72,14 +72,26 @@ end sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims, saveat = t_compare) + # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) + sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()), EnsembleSerial(); + trajectories = Nsims, saveat = t_compare) + sol_implicit_trap = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()), EnsembleSerial(); + trajectories = Nsims, saveat = t_compare) + # Compute mean I trajectories via direct indexing (I is index 2 in SIR) mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 2) mean_I_simple = compute_mean_at_saves(sol_simple, Nsims, npts, 2) mean_I_explicit = compute_mean_at_saves(sol_adaptive, Nsims, npts, 2) + mean_I_implicit_newton = compute_mean_at_saves(sol_implicit_newton, Nsims, npts, 2) + mean_I_implicit_trap = compute_mean_at_saves(sol_implicit_trap, Nsims, npts, 2) # Compare full mean trajectories across all saved timepoints @test all(isapprox.(mean_I_direct, mean_I_simple, rtol = 0.1)) @test all(isapprox.(mean_I_direct, mean_I_explicit, rtol = 0.1)) + @test all(isapprox.(mean_I_direct, mean_I_implicit_newton, rtol = 0.1)) + @test all(isapprox.(mean_I_direct, mean_I_implicit_trap, rtol = 0.1)) end # SEIR model with exposed compartment @@ -127,7 +139,7 @@ end sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); trajectories = Nsims, dt = 0.1, saveat = t_compare) - # MassActionJump formulation for SimpleExplicitTauLeaping + # MassActionJump formulation for SimpleExplicitTauLeaping / SimpleImplicitTauLeaping reactant_stoich = [[1 => 1, 3 => 1], [2 => 1], [3 => 1]] net_stoich = [[1 => -1, 2 => 1], [2 => -1, 3 => 1], [3 => -1, 4 => 1]] param_idxs = [1, 2, 3] @@ -138,14 +150,26 @@ end sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); trajectories = Nsims, saveat = t_compare) + # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) + sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()), EnsembleSerial(); + trajectories = Nsims, saveat = t_compare) + sol_implicit_trap = solve(EnsembleProblem(jump_prob_maj), + SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()), EnsembleSerial(); + trajectories = Nsims, saveat = t_compare) + # Compute mean I trajectories via direct indexing (I is index 3 in SEIR) mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 3) mean_I_simple = compute_mean_at_saves(sol_simple, Nsims, npts, 3) mean_I_explicit = compute_mean_at_saves(sol_adaptive, Nsims, npts, 3) + mean_I_implicit_newton = compute_mean_at_saves(sol_implicit_newton, Nsims, npts, 3) + mean_I_implicit_trap = compute_mean_at_saves(sol_implicit_trap, Nsims, npts, 3) # Compare full mean trajectories across all saved timepoints @test all(isapprox.(mean_I_direct, mean_I_simple, rtol = 0.1)) @test all(isapprox.(mean_I_direct, mean_I_explicit, rtol = 0.1)) + @test all(isapprox.(mean_I_direct, mean_I_implicit_newton, rtol = 0.1)) + @test all(isapprox.(mean_I_direct, mean_I_implicit_trap, rtol = 0.1)) end # Test zero-rate case for SimpleExplicitTauLeaping @@ -255,109 +279,9 @@ end @test jp_params.massaction_jump.scaled_rates == scaled_rates end -# SimpleImplicitTauLeaping correctness - SIR model -@testset "SimpleImplicitTauLeaping SIR Correctness" begin - β = 0.1 / 1000.0 - ν = 0.01 - influx_rate = 1.0 - p = (β, ν, influx_rate) - - # ConstantRateJump formulation for SSAStepper - rate1(u, p, t) = p[1] * u[1] * u[2] - rate2(u, p, t) = p[2] * u[2] - rate3(u, p, t) = p[3] - affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) - affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) - affect3!(integrator) = (integrator.u[1] += 1; nothing) - jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) - - u0 = [999.0, 10.0, 0.0] # S, I, R - tspan = (0.0, 250.0) - prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) - - # Solve with SSAStepper - sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); - trajectories = Nsims, saveat = 1.0) - - # MassActionJump formulation for SimpleImplicitTauLeaping - reactant_stoich = [[1=>1, 2=>1], [2=>1], Pair{Int,Int}[]] - net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [1=>1]] - param_idxs = [1, 2, 3] - maj = MassActionJump(reactant_stoich, net_stoich; param_idxs = param_idxs) - jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng = rng) - - # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) - sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), - SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()), - EnsembleSerial(); trajectories = Nsims, saveat = 1.0) - sol_implicit_trapezoidal = solve(EnsembleProblem(jump_prob_maj), - SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()), - EnsembleSerial(); trajectories = Nsims, saveat = 1.0) - - # Compare peak I trajectories against SSA - t_points = 0:1.0:250.0 - max_direct_I = maximum([mean(sol_direct[i](t)[2] for i in 1:Nsims) for t in t_points]) - max_implicit_newton = maximum([mean(sol_implicit_newton[i](t)[2] for i in 1:Nsims) for t in t_points]) - max_implicit_trapezoidal = maximum([mean(sol_implicit_trapezoidal[i](t)[2] for i in 1:Nsims) for t in t_points]) - - @test isapprox(max_direct_I, max_implicit_newton, rtol = 0.05) - @test isapprox(max_direct_I, max_implicit_trapezoidal, rtol = 0.05) -end - -# SimpleImplicitTauLeaping correctness - SEIR model -@testset "SimpleImplicitTauLeaping SEIR Correctness" begin - β = 0.3 / 1000.0 - σ = 0.2 - ν = 0.01 - p = (β, σ, ν) - - # ConstantRateJump formulation for SSAStepper - rate1(u, p, t) = p[1] * u[1] * u[3] - rate2(u, p, t) = p[2] * u[2] - rate3(u, p, t) = p[3] * u[3] - affect1!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1; nothing) - affect2!(integrator) = (integrator.u[2] -= 1; integrator.u[3] += 1; nothing) - affect3!(integrator) = (integrator.u[3] -= 1; integrator.u[4] += 1; nothing) - jumps = (ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!), ConstantRateJump(rate3, affect3!)) - - u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R - tspan = (0.0, 250.0) - prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng = rng) - - # Solve with SSAStepper - sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); - trajectories = Nsims, saveat = 1.0) - - # MassActionJump formulation for SimpleImplicitTauLeaping - reactant_stoich = [[1=>1, 3=>1], [2=>1], [3=>1]] - net_stoich = [[1=>-1, 2=>1], [2=>-1, 3=>1], [3=>-1, 4=>1]] - param_idxs = [1, 2, 3] - maj = MassActionJump(reactant_stoich, net_stoich; param_idxs = param_idxs) - jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng = rng) - - # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) - sol_implicit_newton = solve(EnsembleProblem(jump_prob_maj), - SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()), - EnsembleSerial(); trajectories = Nsims, saveat = 1.0) - sol_implicit_trapezoidal = solve(EnsembleProblem(jump_prob_maj), - SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()), - EnsembleSerial(); trajectories = Nsims, saveat = 1.0) - - # Compare peak I trajectories against SSA (I is index 3 in SEIR) - t_points = 0:1.0:250.0 - max_direct_I = maximum([mean(sol_direct[i](t)[3] for i in 1:Nsims) for t in t_points]) - max_implicit_newton = maximum([mean(sol_implicit_newton[i](t)[3] for i in 1:Nsims) for t in t_points]) - max_implicit_trapezoidal = maximum([mean(sol_implicit_trapezoidal[i](t)[3] for i in 1:Nsims) for t in t_points]) - - @test isapprox(max_direct_I, max_implicit_newton, rtol = 0.05) - @test isapprox(max_direct_I, max_implicit_trapezoidal, rtol = 0.05) -end - -# SimpleImplicitTauLeaping integration tests (stiff and reversible systems) -@testset "SimpleImplicitTauLeaping Integration Tests" begin - # Stiff system from Cao et al. (2007): S1 -> S2, S2 -> S1, S2 -> S3 +# Test implicit solvers on stiff system (Cao et al. 2007) +@testset "Stiff System with Implicit Solvers" begin + # Reactions: S1 -> S2, S2 -> S1, S2 -> S3 c = (1000.0, 1000.0, 1.0) reactant_stoich1 = [Pair(1, 1)] net_stoich1 = [Pair(1, -1), Pair(2, 1)] @@ -365,38 +289,44 @@ end net_stoich2 = [Pair(1, 1), Pair(2, -1)] reactant_stoich3 = [Pair(2, 1)] net_stoich3 = [Pair(2, -1), Pair(3, 1)] - stiff_jumps = MassActionJump([c[1], c[2], c[3]], + + maj = MassActionJump([c[1], c[2], c[3]], [reactant_stoich1, reactant_stoich2, reactant_stoich3], [net_stoich1, net_stoich2, net_stoich3]) - u0 = [100, 0, 0] - tspan = (0.0, 10.0) + u0 = [100, 0, 0] # Initial: S1=100, S2=0, S3=0 + tspan = (0.0, 1.0) prob = DiscreteProblem(u0, tspan) - jump_prob = JumpProblem(prob, PureLeaping(), stiff_jumps) - sol = solve(jump_prob, SimpleImplicitTauLeaping(); saveat = 0.1) - @test sol.t[end] ≈ 10.0 atol = 1e-6 - - # Reversible isomerization from Cao et al. (2004) - c2 = (1000.0, 1000.0) - reactant_stoich_r1 = [Pair(1, 1)] - net_stoich_r1 = [Pair(1, -1), Pair(2, 1)] - reactant_stoich_r2 = [Pair(2, 1)] - net_stoich_r2 = [Pair(1, 1), Pair(2, -1)] - rev_jumps = MassActionJump([c2[1], c2[2]], - [reactant_stoich_r1, reactant_stoich_r2], - [net_stoich_r1, net_stoich_r2]) - - u0_rev = [1000, 0] - tspan_rev = (0.0, 0.1) - prob_rev = DiscreteProblem(u0_rev, tspan_rev) - jump_prob_rev = JumpProblem(prob_rev, PureLeaping(), rev_jumps) - - sol_newton = solve(jump_prob_rev, SimpleImplicitTauLeaping(epsilon = 0.05, - solver = NewtonImplicitSolver())) - sol_trapezoidal = solve(jump_prob_rev, SimpleImplicitTauLeaping(epsilon = 0.05, - solver = TrapezoidalImplicitSolver())) - @test sol_newton.t[end] ≈ 0.1 atol = 1e-6 - @test sol_trapezoidal.t[end] ≈ 0.1 atol = 1e-6 + jump_prob = JumpProblem(prob, PureLeaping(), maj; rng = rng) + + # Solve with SimpleExplicitTauLeaping + sol_explicit = solve(jump_prob, SimpleExplicitTauLeaping(); dtmin = 1e-6, saveat = 0.1) + + # Solve with SimpleImplicitTauLeaping (Newton and Trapezoidal) + sol_implicit_newton = solve(jump_prob, SimpleImplicitTauLeaping(solver = NewtonImplicitSolver()); + dtmin = 1e-6, saveat = 0.1) + sol_implicit_trap = solve(jump_prob, SimpleImplicitTauLeaping(solver = TrapezoidalImplicitSolver()); + dtmin = 1e-6, saveat = 0.1) + + # Check that all solvers completed successfully + @test sol_explicit.t[end] ≈ 1.0 atol = 1e-3 + @test sol_implicit_newton.t[end] ≈ 1.0 atol = 1e-3 + @test sol_implicit_trap.t[end] ≈ 1.0 atol = 1e-3 + + # Check conservation: S1 + S2 + S3 should equal initial total + @test all(sum(u) ≈ 100 for u in sol_explicit.u) + @test all(sum(u) ≈ 100 for u in sol_implicit_newton.u) + @test all(sum(u) ≈ 100 for u in sol_implicit_trap.u) + + # Check that solutions are non-negative + @test all(all(x >= 0 for x in u) for u in sol_explicit.u) + @test all(all(x >= 0 for x in u) for u in sol_implicit_newton.u) + @test all(all(x >= 0 for x in u) for u in sol_implicit_trap.u) + + # S3 should increase monotonically (it only has an influx reaction) + @test sol_explicit.u[end][3] >= sol_explicit.u[1][3] + @test sol_implicit_newton.u[end][3] >= sol_implicit_newton.u[1][3] + @test sol_implicit_trap.u[end][3] >= sol_implicit_trap.u[1][3] end # Test that saveat/save_start/save_end control which times are stored in solutions From 5770437ce655eeeec6b7822c54b7aa72871ee5e5 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N Date: Sun, 15 Feb 2026 01:06:02 +0530 Subject: [PATCH 27/29] comcat entries --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 77c51fbf..899eb285 100644 --- a/Project.toml +++ b/Project.toml @@ -57,6 +57,8 @@ RecursiveArrayTools = "3.35, 4" Reexport = "1.2" SafeTestsets = "0.1" SciMLBase = "2.115, 3.1" +Setfield = "1" +SimpleNonlinearSolve = "1, 2" StableRNGs = "1" StaticArrays = "1.9.8" Statistics = "1" From f33a0411abfc2a413e990524e8b0165f612f26fb Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N Date: Sun, 15 Feb 2026 13:56:28 +0530 Subject: [PATCH 28/29] test fix --- test/regular_jumps.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 2e2c1fed..54d4cd6a 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -427,4 +427,38 @@ end sol = solve(jp_explicit, SimpleExplicitTauLeaping(); saveat = [2.0, 8.0], save_start = true, save_end = true) @test sol.t == [0.0, 2.0, 8.0, 10.0] + + # --- SimpleImplicitTauLeaping save_start/save_end/saveat tests --- + u0_implicit = [100.0] + prob_implicit = DiscreteProblem(u0_implicit, tspan) + reactant_stoich_implicit = [[1 => 1]] + net_stoich_implicit = [[1 => -1]] + maj_implicit = MassActionJump([0.1], reactant_stoich_implicit, net_stoich_implicit) + jp_implicit = JumpProblem(prob_implicit, PureLeaping(), maj_implicit; rng) + + # saveat as Number: defaults save_start=true, save_end=true + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = 2.0) + @test sol.t == collect(0.0:2.0:10.0) + + # saveat as Number + save_start=false + save_end=false + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = 2.0, + save_start = false, save_end = false) + @test sol.t == collect(2.0:2.0:8.0) + + # saveat collection including endpoints + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = [0.0, 5.0, 10.0]) + @test sol.t == [0.0, 5.0, 10.0] + + # saveat collection without endpoints + explicit save_start=true, save_end=true + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = [2.0, 8.0], + save_start = true, save_end = true) + @test sol.t == [0.0, 2.0, 8.0, 10.0] + + # Test with save_start=false + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = 2.0, save_start = false) + @test sol.t == collect(2.0:2.0:10.0) + + # Test with save_end=false + sol = solve(jp_implicit, SimpleImplicitTauLeaping(); saveat = 2.0, save_end = false) + @test sol.t == collect(0.0:2.0:8.0) end From 08e46881d49a01c707c9575cfa2e250fbcc2c1a5 Mon Sep 17 00:00:00 2001 From: Siva Sathyaseelan D N Date: Sun, 15 Feb 2026 14:03:50 +0530 Subject: [PATCH 29/29] saveat implementation --- src/simple_regular_solve.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 000304bf..fb865da7 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -591,7 +591,7 @@ function simple_implicit_tau_leaping_loop!( prob, alg, u_current, t_current, t_end, p, rng, rate, nu, hor, max_hor, max_stoich, numjumps, epsilon, dtmin, saveat_times, usave, tsave, du, counts, rate_cache, - maj, solver) + maj, solver, save_end) save_idx = 1 while t_current < t_end @@ -642,17 +642,23 @@ function simple_implicit_tau_leaping_loop!( u_current = u_new t_current = t_new end + + # Save endpoint if requested and not already saved + if save_end && (isempty(tsave) || tsave[end] != t_end) + push!(usave, copy(u_current)) + push!(tsave, t_end) + end end function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; seed = nothing, dtmin = nothing, - saveat = nothing) + saveat = nothing, save_start = nothing, save_end = nothing) validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleImplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") (; prob, rng) = jump_prob - (seed !== nothing) && Random.seed!(rng, seed) + (seed !== nothing) && seed!(rng, seed) maj = jump_prob.massaction_jump numjumps = get_num_majumps(maj) @@ -668,10 +674,18 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; dtmin = 1e-10 * one(typeof(tspan[2])) end + saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end) + + # Initialize current state and saved history u_current = copy(u0) t_current = tspan[1] - usave = [copy(u0)] - tsave = [tspan[1]] + if save_start + usave = [copy(u0)] + tsave = [tspan[1]] + else + usave = typeof(u0)[] + tsave = typeof(tspan[1])[] + end rate_cache = zeros(float(eltype(u0)), numjumps) counts = zero(rate_cache) du = similar(u0) @@ -689,19 +703,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleImplicitTauLeaping; hor = compute_hor(reactant_stoch, numjumps) max_hor, max_stoich = precompute_reaction_conditions(reactant_stoch, hor, length(u0), numjumps) - # Set up saveat_times - if isnothing(saveat) - saveat_times = Vector{typeof(tspan[1])}() - elseif saveat isa Number - saveat_times = collect(range(tspan[1], tspan[2], step=saveat)) - else - saveat_times = collect(saveat) - end simple_implicit_tau_leaping_loop!( prob, alg, u_current, t_current, t_end, p, rng, rate, nu, hor, max_hor, max_stoich, numjumps, epsilon, dtmin, saveat_times, usave, tsave, du, counts, rate_cache, - maj, solver) + maj, solver, save_end) sol = DiffEqBase.build_solution(prob, alg, tsave, usave, calculate_error=false,