From e0498e315f25727b4d674a18106ac7cf7b065b16 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Mon, 23 Feb 2026 22:05:31 -0500 Subject: [PATCH 01/20] Move RNG from aggregators/JumpProblem to integrator Make the integrator the single source of truth for RNG (Phase 1 Step 4). Remove rng field from all aggregator structs, JumpProblem, and variable-rate callback structs. All code now reads get_rng(integrator) at runtime. Key changes: - Remove RNG type param from AbstractSSAJumpAggregator and all concrete aggregators - Remove rng from build_jump_aggregation and all aggregate() signatures - Add rng field to SSAIntegrator with has_rng/get_rng/set_rng! interface - Add resolve_rng utility for rng/seed/fallback priority resolution - Remove rng from VR_FRMEventCallback and VR_DirectEventCache - JumpProblem rng kwarg now forwards to solver via solkwargs - Simplify solve.jl: remove _derive_jump_seed, simplify __jump_init - Update SimpleTauLeaping/SimpleExplicitTauLeaping to use rng kwarg - Add rng_kwarg_tests.jl covering all three solve pathways - Update docs with RNG control section and FAQ entry - Bump version to 9.23.0, SciMLBase compat to 2.144 Co-Authored-By: Claude Opus 4.5 --- Project.toml | 7 +- docs/src/api.md | 102 ++++++++- docs/src/faq.md | 32 ++- docs/src/tutorials/simple_poisson_process.md | 9 +- src/JumpProcesses.jl | 2 +- src/SSA_stepper.jl | 31 ++- src/aggregators/ccnrm.jl | 43 ++-- src/aggregators/coevolve.jl | 40 ++-- src/aggregators/direct.jl | 34 +-- src/aggregators/directcr.jl | 24 +- src/aggregators/frm.jl | 40 ++-- src/aggregators/nrm.jl | 37 ++-- src/aggregators/rdirect.jl | 22 +- src/aggregators/rssa.jl | 22 +- src/aggregators/rssacr.jl | 20 +- src/aggregators/sortingdirect.jl | 24 +- src/aggregators/ssajump.jl | 18 +- src/problem.jl | 55 ++--- src/simple_regular_solve.jl | 17 +- src/solve.jl | 71 +++--- src/spatial/directcrdirect.jl | 24 +- src/spatial/nsm.jl | 36 +-- src/spatial/utils.jl | 10 +- src/variable_rate.jl | 113 +++++----- test/degenerate_rx_cases.jl | 20 +- test/ensemble_problems.jl | 140 ++++++------ test/extended_jump_array_remake.jl | 43 ++-- test/remake_test.jl | 51 +++-- test/rng_kwarg_tests.jl | 219 +++++++++++++++++++ test/runtests.jl | 3 +- test/ssa_callback_test.jl | 76 +++---- test/variable_rate.jl | 3 +- 32 files changed, 873 insertions(+), 515 deletions(-) create mode 100644 test/rng_kwarg_tests.jl diff --git a/Project.toml b/Project.toml index 0a2cb01a..404b9b5e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "JumpProcesses" uuid = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" +version = "9.23.0" authors = ["Chris Rackauckas "] -version = "9.22.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -18,6 +18,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [weakdeps] @@ -45,14 +46,14 @@ KernelAbstractions = "0.9" LinearAlgebra = "1" LinearSolve = "3" OrdinaryDiffEq = "6" -OrdinaryDiffEqCore = "1.32.0" +OrdinaryDiffEqCore = "3.9" Pkg = "1" PoissonRandom = "0.4" Random = "1" RecursiveArrayTools = "3.35" Reexport = "1.2" SafeTestsets = "0.1" -SciMLBase = "2.115" +SciMLBase = "2.144" StableRNGs = "1" StaticArrays = "1.9.8" Statistics = "1" diff --git a/docs/src/api.md b/docs/src/api.md index d8fd944b..885be372 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -40,9 +40,109 @@ RSSACR SortingDirect ``` -# Private API Functions +## Random Number Generator Control + +JumpProcesses supports controlling the random number generator (RNG) used for +jump sampling via the `rng` and `seed` keyword arguments to `solve` or `init`. + +### `rng` keyword argument + +Pass any `AbstractRNG` to `solve` or `init`: + +```julia +using Random, StableRNGs + +# Using a StableRNG for cross-version reproducibility +sol = solve(jprob, SSAStepper(); rng = StableRNG(1234)) + +# Using Julia's built-in Xoshiro +sol = solve(jprob, Tsit5(); rng = Xoshiro(42)) +``` + +### `seed` keyword argument + +As a shorthand, pass an integer `seed` to create a `Xoshiro` generator: + +```julia +sol = solve(jprob, SSAStepper(); seed = 1234) +# equivalent to: solve(jprob, SSAStepper(); rng = Xoshiro(1234)) +``` + +### Passing `rng` via `JumpProblem` + +The `rng` keyword can also be passed to [`JumpProblem`](@ref), where it is +stored and automatically forwarded to the solver: + +```julia +jprob = JumpProblem(dprob, Direct(), jump; rng = StableRNG(1234)) +sol = solve(jprob, SSAStepper()) # uses StableRNG(1234) +``` + +An `rng` passed directly to `solve`/`init` takes priority over one stored in +the `JumpProblem`. + +### Resolution priority + +When multiple RNG sources are provided, the following priority determines which +is used: + +| User provides | Result | +|---|---| +| `rng` via `solve`/`init` | Uses that `rng` | +| `seed` via `solve`/`init` | Creates `Xoshiro(seed)` | +| `rng` via `JumpProblem` | Uses that `rng` | +| Nothing | Uses `Random.default_rng()` (SSAStepper, ODE, tau-leaping) or a randomly-seeded `Xoshiro` (SDE) | + +When both `rng` and `seed` are passed to the same call, `rng` takes priority. +When `solve`/`init` kwargs and `JumpProblem` kwargs overlap on the same key, +`solve`/`init` kwargs take priority. + +### Behavior by solver pathway + +| Solver | Default RNG (nothing passed) | `rng` / `seed` support | +|---|---|---| +| `SSAStepper` | `Random.default_rng()` | Full support via `solve`/`init` kwargs | +| ODE solvers (e.g., `Tsit5`) | `Random.default_rng()` | Full support via `solve`/`init` kwargs | +| SDE solvers (e.g., `SRIW1`) | Randomly-seeded `Xoshiro` | Full support; `TaskLocalRNG` is auto-converted to `Xoshiro` | +| `SimpleTauLeaping` | `Random.default_rng()` | Full support via `solve` kwargs | + +!!! note + For reproducible simulations, always pass an explicit `rng` or `seed`. + The default RNG is shared global state and may produce different results + depending on prior usage. + +# Private / Developer API ```@docs ExtendedJumpArray SSAIntegrator ``` + +## Internal Dispatch Pathways + +The following table documents which code handles `solve`/`init` for each solver +type. This is relevant for developers working on JumpProcesses or its solver +backends. + +| Solver type | `__solve` handled by | `__init` handled by | Uses `__jump_init`? | +|---|---|---|---| +| `SSAStepper` | JumpProcesses (`solve.jl`) | JumpProcesses (`SSA_stepper.jl`) | No | +| ODE (e.g., `Tsit5`) | JumpProcesses (`solve.jl`) | JumpProcesses (`solve.jl`) → OrdinaryDiffEq | Yes | +| SDE (e.g., `SRIW1`) | StochasticDiffEq | StochasticDiffEq | No | +| `SimpleTauLeaping` | JumpProcesses (`simple_regular_solve.jl`, custom `DiffEqBase.solve`) | N/A | No | + +For **SSAStepper**, `rng` is resolved via `resolve_rng` in `SSA_stepper.jl`'s +`__init` and stored on the [`SSAIntegrator`](@ref). + +For **ODE solvers**, `rng` is resolved via `resolve_rng` in `__jump_init` +(`solve.jl`) and forwarded to OrdinaryDiffEq's `init`, which stores it on the +`ODEIntegrator`. + +For **SDE solvers**, StochasticDiffEq handles the full solve/init pathway +directly (JumpProcesses' ambiguity-fix `__solve` method is never dispatched to). +StochasticDiffEq has its own `_resolve_rng` that additionally handles +`TaskLocalRNG` conversion and the problem's stored seed. + +For **tau-leaping**, JumpProcesses defines a custom `DiffEqBase.solve` that +bypasses the standard `__solve`/`__init` pathway. It calls `resolve_rng` with +the `JumpProblem`'s stored `rng` kwarg as a fallback. diff --git a/docs/src/faq.md b/docs/src/faq.md index d2d0be3b..0e3dad78 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -55,21 +55,35 @@ jset = JumpSet(; constant_jumps = cjvec, variable_jumps = vjtuple, ## How can I set the random number generator used in the jump process sampling algorithms (SSAs)? -Random number generators can be passed to `JumpProblem` via the `rng` keyword +Random number generators can be passed to `solve` or `init` via the `rng` keyword argument. Continuing the previous example: ```julia -#] add RandomNumbers -using RandomNumbers -jprob = JumpProblem(dprob, Direct(), maj, - rng = Xorshifts.Xoroshiro128Star(rand(UInt64))) +using Random +jprob = JumpProblem(dprob, Direct(), maj) +sol = solve(jprob, SSAStepper(); rng = Xoshiro(1234)) +``` + +Any `AbstractRNG` can be used. For example, to use a generator from +[StableRNGs.jl](https://github.com/JuliaRandom/StableRNGs.jl): + +```julia +using StableRNGs +sol = solve(jprob, SSAStepper(); rng = StableRNG(1234)) ``` -uses the `Xoroshiro128Star` generator from -[RandomNumbers.jl](https://github.com/JuliaRandom/RandomNumbers.jl). +A `seed` keyword argument is also supported as a shorthand for creating a `Xoshiro` +generator: `solve(jprob, SSAStepper(); seed = 1234)`. + +Alternatively, `rng` can be passed to `JumpProblem` where it will be forwarded to the +solver automatically: + +```julia +jprob = JumpProblem(dprob, Direct(), maj; rng = Xoshiro(1234)) +sol = solve(jprob, SSAStepper()) # uses Xoshiro(1234) +``` -On version 1.7 and up, JumpProcesses uses Julia's built-in random number generator by -default. On versions below 1.7 it uses `Xoroshiro128Star`. +By default, JumpProcesses uses Julia's built-in `Random.default_rng()`. ## What are these aggregators and aggregations in JumpProcesses? diff --git a/docs/src/tutorials/simple_poisson_process.md b/docs/src/tutorials/simple_poisson_process.md index ae8e0955..11f5c5f0 100644 --- a/docs/src/tutorials/simple_poisson_process.md +++ b/docs/src/tutorials/simple_poisson_process.md @@ -371,11 +371,10 @@ with ``N(t)`` a Poisson counting process with constant transition rate ``\lambda``, and the ``C_i`` independent and identical samples from a uniform distribution over ``\{-1,1\}``. We can simulate such a process as follows. -We first ensure that we use the same random number generator as JumpProcesses. We -can either pass one as an input to [`JumpProblem`](@ref) via the `rng` keyword -argument, and make sure it is the same one we use in our `affect!` function, or -we can just use the default generator chosen by JumpProcesses if one is not -specified, `JumpProcesses.DEFAULT_RNG`. Let's do the latter +We first ensure that we use the same random number generator as JumpProcesses. +Custom RNGs can be passed to `solve` or `init` via the `rng` keyword argument. +If no RNG is specified, JumpProcesses uses `Random.default_rng()`, which is also +available as `JumpProcesses.DEFAULT_RNG`. Let's use the default ```@example tut1 rng = JumpProcesses.DEFAULT_RNG diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index 87c8c2c2..e32c4fb8 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -40,7 +40,7 @@ using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction, ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!, deleteat!, isinplace, remake, savevalues!, step!, u_modified! -using SciMLBase: SciMLBase, DEIntegrator +using SciMLBase: SciMLBase, DEIntegrator, has_rng, get_rng, set_rng! abstract type AbstractJump end abstract type AbstractMassActionJump <: AbstractJump end diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index 3403b060..c4f72906 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -13,6 +13,9 @@ Highly efficient integrator for pure jump problems that involve only `ConstantRa `SSAStepper`. - Only supports a limited subset of the output controls from the common solver interface, specifically `save_start`, `save_end`, and `saveat`. + - Supports `rng` and `seed` keyword arguments in `solve`/`init` to control the random + number generator used for jump sampling. `rng` accepts any `AbstractRNG`, while `seed` + creates a `Xoshiro` generator. `rng` takes priority over `seed`. - As when using jumps with ODEs and SDEs, saving controls for whether to save each time a jump occurs are via the `save_positions` keyword argument to `JumpProblem`. Note that when choosing `SSAStepper` as the timestepper, `save_positions = (true,true)`, `(true,false)`, @@ -62,13 +65,13 @@ SciMLBase.allows_late_binding_tstops(::SSAStepper) = true """ $(TYPEDEF) -Solution objects for pure jump problems solved via `SSAStepper`. +Integrator for pure jump problems solved via `SSAStepper`. ## Fields $(FIELDS) """ -mutable struct SSAIntegrator{F, uType, tType, tdirType, P, S, CB, SA, OPT, TS} <: +mutable struct SSAIntegrator{F, uType, tType, tdirType, P, S, CB, SA, OPT, TS, R} <: AbstractSSAIntegrator{SSAStepper, Nothing, uType, tType} """The underlying `prob.f` function. Not currently used.""" f::F @@ -108,6 +111,15 @@ mutable struct SSAIntegrator{F, uType, tType, tdirType, P, S, CB, SA, OPT, TS} < alias_tstops::Bool """If true indicates we have already allocated the tstops array""" copied_tstops::Bool + """The random number generator.""" + rng::R +end + +SciMLBase.has_rng(::SSAIntegrator) = true +SciMLBase.get_rng(integrator::SSAIntegrator) = integrator.rng +function SciMLBase.set_rng!(integrator::SSAIntegrator, rng) + integrator.rng = rng + nothing end (integrator::SSAIntegrator)(t) = copy(integrator.u) @@ -198,6 +210,7 @@ function DiffEqBase.__init(jump_prob::JumpProblem, save_start = true, save_end = true, seed = nothing, + rng = nothing, alias_jump = Threads.threadid() == 1, saveat = nothing, callback = nothing, @@ -219,19 +232,13 @@ function DiffEqBase.__init(jump_prob::JumpProblem, # Check for continuous callbacks passed via kwargs (from JumpProblem constructor or solve) check_continuous_callback_error(callback) + + _rng = resolve_rng(rng, seed) + if alias_jump cb = jump_prob.jump_callback.discrete_callbacks[end] - if seed !== nothing - Random.seed!(cb.condition.rng, seed) - end else cb = deepcopy(jump_prob.jump_callback.discrete_callbacks[end]) - # Only reseed if an explicit seed is provided. This respects the user's RNG choice - # and enables reproducibility. For EnsembleProblems, use prob_func to set unique seeds - # for each trajectory if different results are needed. - if seed !== nothing - Random.seed!(cb.condition.rng, seed) - end end opts = (callback = CallbackSet(callback),) @@ -286,7 +293,7 @@ function DiffEqBase.__init(jump_prob::JumpProblem, integrator = SSAIntegrator(prob.f, copy(prob.u0), prob.tspan[1], prob.tspan[1], tdir, prob.p, sol, 1, prob.tspan[1], cb, _saveat, save_everystep, - save_end, cur_saveat, opts, _tstops, 1, false, true, alias_tstops, false) + save_end, cur_saveat, opts, _tstops, 1, false, true, alias_tstops, false, _rng) cb.initialize(cb, integrator.u, prob.tspan[1], integrator) DiffEqBase.initialize!(opts.callback, integrator.u, prob.tspan[1], integrator) if save_start diff --git a/src/aggregators/ccnrm.jl b/src/aggregators/ccnrm.jl index 379ad642..720868c5 100644 --- a/src/aggregators/ccnrm.jl +++ b/src/aggregators/ccnrm.jl @@ -3,8 +3,8 @@ # algorithm with optimal binning, Journal of Chemical Physics 143, 074108 # (2015). doi: 10.1063/1.4928635. -mutable struct CCNRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PT} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct CCNRMJumpAggregation{T, S, F1, F2, DEPGR, PT} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -15,15 +15,14 @@ mutable struct CCNRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PT} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR ptt::PT end function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; num_specs, dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + num_specs, dep_graph = nothing, + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -45,30 +44,31 @@ function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, binwidthconst = binwidthconst, numbinsconst = numbinsconst) # We will re-initialize this in initialize!() affecttype = F2 <: Tuple ? F2 : Any - CCNRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(ptt)}( + CCNRMJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(ptt)}( nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, - rng, dg, ptt) + dg, ptt) end -+############################# Required Functions ############################## +############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::CCNRM, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(CCNRMJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end # set up a new simulation and calculate the first jump / jump time function initialize!(p::CCNRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] - initialize_rates_and_times!(p, u, params, t) + rng = get_rng(integrator) + initialize_rates_and_times!(p, u, params, t, rng) generate_jumps!(p, integrator, u, params, t) nothing end @@ -79,7 +79,8 @@ function execute_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t, affec u = update_state!(p, integrator, u, affects!) # update current jump rates and times - update_dependent_rates!(p, u, params, t) + rng = get_rng(integrator) + update_dependent_rates!(p, u, params, t, rng) nothing end @@ -88,7 +89,7 @@ end function generate_jumps!(p::CCNRMJumpAggregation, integrator, u, params, t) p.next_jump, p.next_jump_time = getfirst(p.ptt) - # Rebuild the table if no next jump is found. + # Rebuild the table if no next jump is found. if p.next_jump == 0 timestep = 1 / sum(p.cur_rates) min_time = minimum(p.ptt.times) @@ -102,7 +103,7 @@ end ######################## SSA specific helper routines ######################## # Recalculate jump rates for jumps that depend on the just executed jump (p.next_jump) -function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t) +function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t, rng) @inbounds dep_rxs = p.dep_gr[p.next_jump] (; ptt, cur_rates, rates, ma_jumps, end_time) = p num_majumps = get_num_majumps(ma_jumps) @@ -125,7 +126,7 @@ function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t) end else if cur_rates[rx] > zero(eltype(cur_rates)) - update!(ptt, rx, oldtime, t + randexp(p.rng) / cur_rates[rx]) + update!(ptt, rx, oldtime, t + randexp(rng) / cur_rates[rx]) else update!(ptt, rx, oldtime, floatmax(typeof(t))) end @@ -134,15 +135,15 @@ function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t) nothing end -# Evaluate all the rates and initialize the times in the priority table. -function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t) +# Evaluate all the rates and initialize the times in the priority table. +function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t, rng) # Initialize next-reaction times for the mass action jumps majumps = p.ma_jumps cur_rates = p.cur_rates pttdata = Vector{typeof(t)}(undef, length(cur_rates)) @inbounds for i in 1:get_num_majumps(majumps) cur_rates[i] = evalrxrate(u, i, majumps) - pttdata[i] = t + randexp(p.rng) / cur_rates[i] + pttdata[i] = t + randexp(rng) / cur_rates[i] end # Initialize next-reaction times for the constant rates @@ -150,11 +151,11 @@ function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t) idx = get_num_majumps(majumps) + 1 @inbounds for rate in rates cur_rates[idx] = rate(u, params, t) - pttdata[idx] = t + randexp(p.rng) / cur_rates[idx] + pttdata[idx] = t + randexp(rng) / cur_rates[idx] idx += 1 end - # Build the priority time table with the times and bin width. + # Build the priority time table with the times and bin width. timestep = 1 / sum(cur_rates) p.ptt.times = pttdata rebuild!(p.ptt, t, timestep) diff --git a/src/aggregators/coevolve.jl b/src/aggregators/coevolve.jl index dc48a5e2..ab5dc5f4 100644 --- a/src/aggregators/coevolve.jl +++ b/src/aggregators/coevolve.jl @@ -1,8 +1,8 @@ """ Queue method. This method handles variable intensity rates. """ -mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct CoevolveJumpAggregation{T, S, F1, F2, GR, PQ} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int # the next jump to execute prev_jump::Int # the previous jump that was executed next_jump_time::T # the time of the next jump @@ -13,7 +13,6 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: rates::F1 # vector of rate functions affects!::F2 # vector of affect functions for VariableRateJumps save_positions::Tuple{Bool, Bool} # tuple for whether to save the jumps before and/or after event - rng::RNG # random number generator dep_gr::GR # map from jumps to jumps depending on it pq::PQ # priority queue of next time lrates::F1 # vector of rate lower bound functions @@ -24,10 +23,10 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: end function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Nothing, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; u::U, dep_graph = nothing, lrates, urates, + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + u::U, dep_graph = nothing, lrates, urates, rateintervals, haslratevec, - cur_lrates::Vector{T}) where {T, S, F1, F2, RNG, U} + cur_lrates::Vector{T}) where {T, S, F1, F2, U} if dep_graph === nothing if (get_num_majumps(maj) == 0) || !isempty(urates) error("To use Coevolve a dependency graph between jumps must be supplied.") @@ -49,9 +48,9 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not pq = MutableBinaryMinHeap{T}() affecttype = F2 <: Tuple ? F2 : Any - CoevolveJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), + CoevolveJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(pq)}(nj, nj, njt, et, crs, sr, maj, - rs, affs!, sps, rng, dg, pq, + rs, affs!, sps, dg, pq, lrates, urates, rateintervals, haslratevec, cur_lrates) end @@ -98,7 +97,7 @@ end # creating the JumpAggregation structure (tuple-based variable jumps) function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; dep_graph = nothing, + ma_jumps, save_positions; dep_graph = nothing, variable_jumps = nothing, kwargs...) RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), Tuple{typeof(u), typeof(p), typeof(t)}} @@ -141,7 +140,7 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, next_jump = 0 next_jump_time = typemax(t) CoevolveJumpAggregation(next_jump, next_jump_time, end_time, cur_rates, sum_rate, - ma_jumps, rates, affects!, save_positions, rng; + ma_jumps, rates, affects!, save_positions; u, dep_graph, lrates, urates, rateintervals, haslratevec, cur_lrates) end @@ -149,7 +148,8 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::CoevolveJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] - fill_rates_and_get_times!(p, u, params, t) + rng = get_rng(integrator) + fill_rates_and_get_times!(p, u, params, t, rng) generate_jumps!(p, integrator, u, params, t) nothing end @@ -160,7 +160,8 @@ function execute_jumps!(p::CoevolveJumpAggregation, integrator, u, params, t, # execute jump update_state!(p, integrator, u, affects!) # update current jump rates and times - update_dependent_rates!(p, integrator.u, integrator.p, t) + rng = get_rng(integrator) + update_dependent_rates!(p, integrator.u, integrator.p, t, rng) nothing end @@ -178,7 +179,8 @@ function accept_next_jump!(p::CoevolveJumpAggregation, integrator, u, params, t) (next_jump <= num_majumps) && return true - (; cur_rates, rates, rng, urates, cur_lrates) = p + (; cur_rates, rates, urates, cur_lrates) = p + rng = get_rng(integrator) num_cjumps = length(urates) - length(rates) uidx = next_jump - num_majumps lidx = uidx - num_cjumps @@ -225,11 +227,11 @@ function accept_next_jump!(p::CoevolveJumpAggregation, integrator, u, params, t) return false end -function update_dependent_rates!(p::CoevolveJumpAggregation, u, params, t) +function update_dependent_rates!(p::CoevolveJumpAggregation, u, params, t, rng) @inbounds deps = p.dep_gr[p.next_jump] (; cur_rates, pq) = p for (ix, i) in enumerate(deps) - ti, urate_i = next_time(p, u, params, t, i) + ti, urate_i = next_time(p, u, params, t, i, rng) update!(pq, i, ti) @inbounds cur_rates[i] = urate_i end @@ -256,8 +258,8 @@ end @inbounds return p.rates[lidx](u, params, t) end -function next_time(p::CoevolveJumpAggregation, u, params, t, i) - (; next_jump, cur_rates, ma_jumps, rates, rng, pq, urates) = p +function next_time(p::CoevolveJumpAggregation, u, params, t, i, rng) + (; next_jump, cur_rates, ma_jumps, rates, pq, urates) = p num_majumps = get_num_majumps(ma_jumps) num_cjumps = length(urates) - length(rates) uidx = i - num_majumps @@ -300,12 +302,12 @@ function next_candidate_time!(p::CoevolveJumpAggregation, u, params, t, s, lidx) end # re-evaluates all rates, recalculate all jump times, and reinit the priority queue -function fill_rates_and_get_times!(p::CoevolveJumpAggregation, u, params, t) +function fill_rates_and_get_times!(p::CoevolveJumpAggregation, u, params, t, rng) num_jumps = get_num_majumps(p.ma_jumps) + length(p.urates) p.cur_rates = zeros(typeof(t), num_jumps) jump_times = Vector{typeof(t)}(undef, num_jumps) @inbounds for i in 1:num_jumps - jump_times[i], p.cur_rates[i] = next_time(p, u, params, t, i) + jump_times[i], p.cur_rates[i] = next_time(p, u, params, t, i, rng) end p.pq = MutableBinaryMinHeap(jump_times) nothing diff --git a/src/aggregators/direct.jl b/src/aggregators/direct.jl index ab33dd84..b51e18f7 100644 --- a/src/aggregators/direct.jl +++ b/src/aggregators/direct.jl @@ -1,5 +1,5 @@ -mutable struct DirectJumpAggregation{T, S, F1, F2, RNG} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct DirectJumpAggregation{T, S, F1, F2} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -10,38 +10,37 @@ mutable struct DirectJumpAggregation{T, S, F1, F2, RNG} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG end function DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, - rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; - kwargs...) where {T, S, F1, F2, RNG} + rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + kwargs...) where {T, S, F1, F2} affecttype = F2 <: Tuple ? F2 : Any - DirectJumpAggregation{T, S, F1, affecttype, RNG}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng) + DirectJumpAggregation{T, S, F1, affecttype}(nj, nj, njt, et, crs, sr, maj, rs, + affs!, sps) end ############################# Required Functions ############################# # creating the JumpAggregation structure (tuple-based constant jumps) function aggregate(aggregator::Direct, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using tuples rates, affects! = get_jump_info_tuples(constant_jumps) build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; kwargs...) + rates, affects!, save_positions; kwargs...) end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::DirectFW, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; kwargs...) + rates, affects!, save_positions; kwargs...) end # set up a new simulation and calculate the first jump / jump time @@ -60,9 +59,10 @@ end # calculate the next jump / jump time function generate_jumps!(p::DirectJumpAggregation, integrator, u, params, t) - p.sum_rate, ttnj = time_to_next_jump(p, u, params, t) + rng = get_rng(integrator) + p.sum_rate, ttnj = time_to_next_jump(p, u, params, t, rng) p.next_jump_time = add_fast(t, ttnj) - @inbounds p.next_jump = searchsortedfirst(p.cur_rates, rand(p.rng) * p.sum_rate) + @inbounds p.next_jump = searchsortedfirst(p.cur_rates, rand(rng) * p.sum_rate) nothing end @@ -70,7 +70,7 @@ end # tuple-based constant jumps function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, - t) where {T, S, F1 <: Tuple} + t, rng) where {T, S, F1 <: Tuple} prev_rate = zero(t) new_rate = zero(t) cur_rates = p.cur_rates @@ -96,7 +96,7 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, end @inbounds sum_rate = cur_rates[end] - sum_rate, randexp(p.rng) / sum_rate + sum_rate, randexp(rng) / sum_rate end @inline function fill_cur_rates(u, p, t, cur_rates, idx, rate, rates...) @@ -112,7 +112,7 @@ end # function wrapper-based constant jumps function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, - t) where {T, S, F1 <: AbstractArray} + t, rng) where {T, S, F1 <: AbstractArray} prev_rate = zero(t) new_rate = zero(t) cur_rates = p.cur_rates @@ -137,5 +137,5 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, end @inbounds sum_rate = cur_rates[end] - sum_rate, randexp(p.rng) / sum_rate + sum_rate, randexp(rng) / sum_rate end diff --git a/src/aggregators/directcr.jl b/src/aggregators/directcr.jl index 41d079fe..636300bb 100644 --- a/src/aggregators/directcr.jl +++ b/src/aggregators/directcr.jl @@ -10,9 +10,9 @@ by S. Mauch and M. Stalzer, ACM Trans. Comp. Biol. and Bioinf., 8, No. 1, 27-35 const MINJUMPRATE = 2.0^exponent(1e-12) -mutable struct DirectCRJumpAggregation{T, S, F1, F2, RNG, DEPGR, U <: PriorityTable, +mutable struct DirectCRJumpAggregation{T, S, F1, F2, DEPGR, U <: PriorityTable, W <: Function} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -23,7 +23,6 @@ mutable struct DirectCRJumpAggregation{T, S, F1, F2, RNG, DEPGR, U <: PriorityTa rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR minrate::T maxrate::T # initial maxrate only, table can increase beyond it! @@ -32,11 +31,11 @@ mutable struct DirectCRJumpAggregation{T, S, F1, F2, RNG, DEPGR, U <: PriorityTa end function DirectCRJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; num_specs, dep_graph = nothing, + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + num_specs, dep_graph = nothing, minrate = convert(T, MINJUMPRATE), maxrate = convert(T, Inf), - kwargs...) where {T, S, F1, F2, RNG} + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -63,9 +62,9 @@ function DirectCRJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) affecttype = F2 <: Tuple ? F2 : Any - DirectCRJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), + DirectCRJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, crs, sr, maj, - rs, affs!, sps, rng, dg, + rs, affs!, sps, dg, minrate, maxrate, rt, ratetogroup) end @@ -74,13 +73,13 @@ end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::DirectCR, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(DirectCRJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end @@ -113,10 +112,11 @@ end # calculate the next jump / jump time function generate_jumps!(p::DirectCRJumpAggregation, integrator, u, params, t) - p.next_jump_time = t + randexp(p.rng) / p.sum_rate + rng = get_rng(integrator) + p.next_jump_time = t + randexp(rng) / p.sum_rate if p.next_jump_time < p.end_time - p.next_jump = sample(p.rt, p.cur_rates, p.rng) + p.next_jump = sample(p.rt, p.cur_rates, rng) end nothing end diff --git a/src/aggregators/frm.jl b/src/aggregators/frm.jl index 94baed37..0db30113 100644 --- a/src/aggregators/frm.jl +++ b/src/aggregators/frm.jl @@ -1,5 +1,5 @@ -mutable struct FRMJumpAggregation{T, S, F1, F2, RNG} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct FRMJumpAggregation{T, S, F1, F2} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -10,40 +10,39 @@ mutable struct FRMJumpAggregation{T, S, F1, F2, RNG} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG end function FRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, - affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; - kwargs...) where {T, S, F1, F2, RNG} + affs!::F2, sps::Tuple{Bool, Bool}; + kwargs...) where {T, S, F1, F2} affecttype = F2 <: Tuple ? F2 : Any - FRMJumpAggregation{T, S, F1, affecttype, RNG}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng) + FRMJumpAggregation{T, S, F1, affecttype}(nj, nj, njt, et, crs, sr, maj, rs, + affs!, sps) end ############################# Required Functions ############################# # creating the JumpAggregation structure (tuple-based constant jumps) function aggregate(aggregator::FRM, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using tuples rates, affects! = get_jump_info_tuples(constant_jumps) build_jump_aggregation( FRMJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, - save_positions, rng; kwargs...) + save_positions; kwargs...) end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::FRMFW, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation( FRMJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, - save_positions, rng; kwargs...) + save_positions; kwargs...) end # set up a new simulation and calculate the first jump / jump time @@ -62,8 +61,9 @@ end # calculate the next jump / jump time function generate_jumps!(p::FRMJumpAggregation, integrator, u, params, t) - nextmaj, ttnmaj = next_ma_jump(p, u, params, t) - nextcrj, ttncrj = next_constant_rate_jump(p, u, params, t) + rng = get_rng(integrator) + nextmaj, ttnmaj = next_ma_jump(p, u, params, t, rng) + nextcrj, ttncrj = next_constant_rate_jump(p, u, params, t, rng) # execute reaction with minimal time if ttnmaj < ttncrj @@ -79,13 +79,13 @@ end ######################## SSA specific helper routines ######################## # mass action jumps -function next_ma_jump(p::FRMJumpAggregation, u, params, t) +function next_ma_jump(p::FRMJumpAggregation, u, params, t, rng) ttnj = typemax(typeof(t)) nextrx = zero(Int) majumps = p.ma_jumps @inbounds for i in 1:get_num_majumps(majumps) p.cur_rates[i] = evalrxrate(u, i, majumps) - dt = randexp(p.rng) / p.cur_rates[i] + dt = randexp(rng) / p.cur_rates[i] if dt < ttnj ttnj = dt nextrx = i @@ -95,15 +95,15 @@ function next_ma_jump(p::FRMJumpAggregation, u, params, t) end # tuple-based constant jumps -function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u, params, - t) where {T, S, F1 <: Tuple, F2 <: Tuple, RNG} +function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2}, u, params, + t, rng) where {T, S, F1 <: Tuple, F2 <: Tuple} ttnj = typemax(typeof(t)) nextrx = zero(Int) if !isempty(p.rates) idx = get_num_majumps(p.ma_jumps) + 1 fill_cur_rates(u, params, t, p.cur_rates, idx, p.rates...) @inbounds for i in idx:length(p.cur_rates) - dt = randexp(p.rng) / p.cur_rates[i] + dt = randexp(rng) / p.cur_rates[i] if dt < ttnj ttnj = dt nextrx = i @@ -115,14 +115,14 @@ end # function wrapper-based constant jumps function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1}, u, params, - t) where {T, S, F1 <: AbstractArray} + t, rng) where {T, S, F1 <: AbstractArray} ttnj = typemax(typeof(t)) nextrx = zero(Int) if !isempty(p.rates) idx = get_num_majumps(p.ma_jumps) + 1 @inbounds for i in 1:length(p.rates) p.cur_rates[idx] = p.rates[i](u, params, t) - dt = randexp(p.rng) / p.cur_rates[idx] + dt = randexp(rng) / p.cur_rates[idx] if dt < ttnj ttnj = dt nextrx = idx diff --git a/src/aggregators/nrm.jl b/src/aggregators/nrm.jl index 7fbcd596..93454044 100644 --- a/src/aggregators/nrm.jl +++ b/src/aggregators/nrm.jl @@ -1,8 +1,8 @@ # Implementation the original Next Reaction Method # Gibson and Bruck, J. Phys. Chem. A, 104 (9), (2000) -mutable struct NRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PQ} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct NRMJumpAggregation{T, S, F1, F2, DEPGR, PQ} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -13,15 +13,14 @@ mutable struct NRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PQ} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR pq::PQ end function NRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; num_specs, dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + num_specs, dep_graph = nothing, + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -40,29 +39,30 @@ function NRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, pq = MutableBinaryMinHeap{T}() affecttype = F2 <: Tuple ? F2 : Any - NRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(pq)}(nj, nj, njt, et, + NRMJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(pq)}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, - rng, dg, pq) + dg, pq) end -+############################# Required Functions ############################## +############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::NRM, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(NRMJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end # set up a new simulation and calculate the first jump / jump time function initialize!(p::NRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] - fill_rates_and_get_times!(p, u, params, t) + rng = get_rng(integrator) + fill_rates_and_get_times!(p, u, params, t, rng) generate_jumps!(p, integrator, u, params, t) nothing end @@ -73,7 +73,8 @@ function execute_jumps!(p::NRMJumpAggregation, integrator, u, params, t, affects u = update_state!(p, integrator, u, affects!) # update current jump rates and times - update_dependent_rates!(p, u, params, t) + rng = get_rng(integrator) + update_dependent_rates!(p, u, params, t, rng) nothing end @@ -87,7 +88,7 @@ end ######################## SSA specific helper routines ######################## # recalculate jump rates for jumps that depend on the just executed jump (p.next_jump) -function update_dependent_rates!(p::NRMJumpAggregation, u, params, t) +function update_dependent_rates!(p::NRMJumpAggregation, u, params, t, rng) @inbounds dep_rxs = p.dep_gr[p.next_jump] (; cur_rates, rates, ma_jumps) = p num_majumps = get_num_majumps(ma_jumps) @@ -108,7 +109,7 @@ function update_dependent_rates!(p::NRMJumpAggregation, u, params, t) end else if cur_rates[rx] > zero(eltype(cur_rates)) - update!(p.pq, rx, t + randexp(p.rng) / cur_rates[rx]) + update!(p.pq, rx, t + randexp(rng) / cur_rates[rx]) else update!(p.pq, rx, typemax(t)) end @@ -118,7 +119,7 @@ function update_dependent_rates!(p::NRMJumpAggregation, u, params, t) end # reevaluate all rates, recalculate all jump times, and reinit the priority queue -function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t) +function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t, rng) # mass action jumps majumps = p.ma_jumps @@ -126,7 +127,7 @@ function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t) pqdata = Vector{typeof(t)}(undef, length(cur_rates)) @inbounds for i in 1:get_num_majumps(majumps) cur_rates[i] = evalrxrate(u, i, majumps) - pqdata[i] = t + randexp(p.rng) / cur_rates[i] + pqdata[i] = t + randexp(rng) / cur_rates[i] end # constant rates @@ -134,7 +135,7 @@ function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t) idx = get_num_majumps(majumps) + 1 @inbounds for rate in rates cur_rates[idx] = rate(u, params, t) - pqdata[idx] = t + randexp(p.rng) / cur_rates[idx] + pqdata[idx] = t + randexp(rng) / cur_rates[idx] idx += 1 end diff --git a/src/aggregators/rdirect.jl b/src/aggregators/rdirect.jl index 8376b71b..eea28316 100644 --- a/src/aggregators/rdirect.jl +++ b/src/aggregators/rdirect.jl @@ -2,8 +2,8 @@ Direct with rejection sampling """ -mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct RDirectJumpAggregation{T, S, F1, F2, DEPGR} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -14,7 +14,6 @@ mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR max_rate::T counter::Int @@ -22,10 +21,10 @@ mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: end function RDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, - rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; + rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; num_specs, counter_threshold = length(crs), dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing if (get_num_majumps(maj) == 0) || !isempty(rs) @@ -42,9 +41,9 @@ function RDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, m max_rate = maximum(crs) affecttype = F2 <: Tuple ? F2 : Any - return RDirectJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg)}(nj, nj, njt, et, + return RDirectJumpAggregation{T, S, F1, affecttype, typeof(dg)}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng, + affs!, sps, dg, max_rate, 0, counter_threshold) end @@ -53,13 +52,13 @@ end # creating the JumpAggregation structure (tuple-based constant jumps) function aggregate(aggregator::RDirect, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(RDirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end @@ -93,7 +92,8 @@ function generate_jumps!(p::RDirectJumpAggregation, integrator, u, params, t) if nomorejumps!(p, sum_rate) return nothing end - (; rng, cur_rates, max_rate) = p + rng = get_rng(integrator) + (; cur_rates, max_rate) = p num_rxs = length(cur_rates) counter = 0 @@ -105,7 +105,7 @@ function generate_jumps!(p::RDirectJumpAggregation, integrator, u, params, t) p.counter = counter p.next_jump = rx - p.next_jump_time = t + randexp(p.rng) / sum_rate + p.next_jump_time = t + randexp(rng) / sum_rate nothing end diff --git a/src/aggregators/rssa.jl b/src/aggregators/rssa.jl index 2943bf88..f86e0b44 100644 --- a/src/aggregators/rssa.jl +++ b/src/aggregators/rssa.jl @@ -4,8 +4,8 @@ # functions of the current population sizes (i.e. u) # requires vartojumps_map and fluct_rates as JumpProblem keywords -mutable struct RSSAJumpAggregation{T, S, F1, F2, RNG, VJMAP, JVMAP, BD, U} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct RSSAJumpAggregation{T, S, F1, F2, VJMAP, JVMAP, BD, U} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -17,7 +17,6 @@ mutable struct RSSAJumpAggregation{T, S, F1, F2, RNG, VJMAP, JVMAP, BD, U} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG vartojumps_map::VJMAP jumptovars_map::JVMAP bracket_data::BD @@ -26,10 +25,10 @@ mutable struct RSSAJumpAggregation{T, S, F1, F2, RNG, VJMAP, JVMAP, BD, U} <: end function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; u::U, vartojumps_map = nothing, + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + u::U, vartojumps_map = nothing, jumptovars_map = nothing, - bracket_data = nothing, kwargs...) where {T, S, F1, F2, RNG, U} + bracket_data = nothing, kwargs...) where {T, S, F1, F2, U} # a dependency graph is needed and must be provided if there are constant rate jumps if vartojumps_map === nothing if (get_num_majumps(maj) == 0) || !isempty(rs) @@ -63,10 +62,10 @@ function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, uhigh = similar(u) affecttype = F2 <: Tuple ? F2 : Any - RSSAJumpAggregation{T, S, F1, affecttype, RNG, typeof(vtoj_map), + RSSAJumpAggregation{T, S, F1, affecttype, typeof(vtoj_map), typeof(jtov_map), typeof(bd), U}(nj, nj, njt, et, crl_bnds, crh_bnds, sr, maj, rs, affs!, sps, - rng, vtoj_map, jtov_map, bd, ulow, + vtoj_map, jtov_map, bd, ulow, uhigh) end @@ -74,13 +73,13 @@ end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::RSSA, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(RSSAJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; u = u, + rates, affects!, save_positions; u, kwargs...) end @@ -108,7 +107,8 @@ function generate_jumps!(p::RSSAJumpAggregation, integrator, u, params, t) return nothing end # next jump type - (; ma_jumps, rates, cur_rate_high, cur_rate_low, rng) = p + (; ma_jumps, rates, cur_rate_high, cur_rate_low) = p + rng = get_rng(integrator) num_majumps = get_num_majumps(ma_jumps) rerl = zero(sum_rate) diff --git a/src/aggregators/rssacr.jl b/src/aggregators/rssacr.jl index 1caf3be5..9cc47c3b 100644 --- a/src/aggregators/rssacr.jl +++ b/src/aggregators/rssacr.jl @@ -4,9 +4,9 @@ Composition-Rejection with Rejection sampling method (RSSA-CR) const MINJUMPRATE = 2.0^exponent(1e-12) -mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, +mutable struct RSSACRJumpAggregation{F, S, F1, F2, U, VJMAP, JVMAP, BD, P <: PriorityTable, W <: Function} <: - AbstractSSAJumpAggregator{F, S, F1, F2, RNG} + AbstractSSAJumpAggregator{F, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::F @@ -18,7 +18,6 @@ mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG vartojumps_map::VJMAP jumptovars_map::JVMAP bracket_data::BD @@ -31,11 +30,11 @@ mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, end function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate::F, maj::S, - rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; u::U, + rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; u::U, vartojumps_map = nothing, jumptovars_map = nothing, bracket_data = nothing, minrate = convert(F, MINJUMPRATE), maxrate = convert(F, Inf), - kwargs...) where {F, S, F1, F2, RNG, U} + kwargs...) where {F, S, F1, F2, U} # a dependency graph is needed and must be provided if there are constant rate jumps if vartojumps_map === nothing if (get_num_majumps(maj) == 0) || !isempty(rs) @@ -80,10 +79,10 @@ function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate: rt = PriorityTable(ratetogroup, zeros(F, 1), minrate, 2 * minrate) affecttype = F2 <: Tuple ? F2 : Any - RSSACRJumpAggregation{typeof(njt), S, F1, affecttype, RNG, U, typeof(vtoj_map), + RSSACRJumpAggregation{typeof(njt), S, F1, affecttype, U, typeof(vtoj_map), typeof(jtov_map), typeof(bd), typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, crl_bnds, crh_bnds, - sum_rate, maj, rs, affs!, sps, rng, vtoj_map, + sum_rate, maj, rs, affs!, sps, vtoj_map, jtov_map, bd, ulow, uhigh, minrate, maxrate, rt, ratetogroup) end @@ -92,13 +91,13 @@ end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::RSSACR, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(RSSACRJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; u = u, kwargs...) + rates, affects!, save_positions; u, kwargs...) end # set up a new simulation and calculate the first jump / jump time @@ -134,7 +133,8 @@ function generate_jumps!(p::RSSACRJumpAggregation, integrator, u, params, t) return nothing end - (; rt, ma_jumps, rates, cur_rate_high, cur_rate_low, rng) = p + (; rt, ma_jumps, rates, cur_rate_high, cur_rate_low) = p + rng = get_rng(integrator) num_majumps = get_num_majumps(ma_jumps) rerl = zero(sum_rate) diff --git a/src/aggregators/sortingdirect.jl b/src/aggregators/sortingdirect.jl index f9048f03..c20ea1ae 100644 --- a/src/aggregators/sortingdirect.jl +++ b/src/aggregators/sortingdirect.jl @@ -2,8 +2,8 @@ # "The sorting direct method for stochastic simulation of biochemical systems with varying reaction execution behavior" # Comp. Bio. and Chem., 30, pg. 39-49 (2006). -mutable struct SortingDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} +mutable struct SortingDirectJumpAggregation{T, S, F1, F2, DEPGR} <: + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::Int prev_jump::Int next_jump_time::T @@ -14,16 +14,15 @@ mutable struct SortingDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR jump_search_order::Vector{Int} jump_search_idx::Int end function SortingDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, - maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, - rng::RNG; num_specs, dep_graph = nothing, - kwargs...) where {T, S, F1, F2, RNG} + maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + num_specs, dep_graph = nothing, + kwargs...) where {T, S, F1, F2} # a dependency graph is needed and must be provided if there are constant rate jumps if dep_graph === nothing @@ -42,9 +41,9 @@ function SortingDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr # map jump idx to idx in cur_rates jtoidx = collect(1:length(crs)) affecttype = F2 <: Tuple ? F2 : Any - SortingDirectJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg)}(nj, nj, njt, et, + SortingDirectJumpAggregation{T, S, F1, affecttype, typeof(dg)}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps, rng, + affs!, sps, dg, jtoidx, zero(Int)) end @@ -53,13 +52,13 @@ end # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::SortingDirect, u, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; kwargs...) + ma_jumps, save_positions; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) build_jump_aggregation(SortingDirectJumpAggregation, u, p, t, end_time, ma_jumps, - rates, affects!, save_positions, rng; num_specs = length(u), + rates, affects!, save_positions; num_specs = length(u), kwargs...) end @@ -92,14 +91,15 @@ end # calculate the next jump / jump time function generate_jumps!(p::SortingDirectJumpAggregation, integrator, u, params, t) - p.next_jump_time = t + randexp(p.rng) / p.sum_rate + rng = get_rng(integrator) + p.next_jump_time = t + randexp(rng) / p.sum_rate # search for next jump if p.next_jump_time < p.end_time cur_rates = p.cur_rates numjumps = length(cur_rates) jso = p.jump_search_order - rn = p.sum_rate * rand(p.rng) + rn = p.sum_rate * rand(rng) @inbounds for idx in 1:numjumps rn -= cur_rates[jso[idx]] if rn < zero(rn) diff --git a/src/aggregators/ssajump.jl b/src/aggregators/ssajump.jl index 90c260c9..0e72b234 100644 --- a/src/aggregators/ssajump.jl +++ b/src/aggregators/ssajump.jl @@ -13,13 +13,11 @@ An aggregator interface for SSA-like algorithms. - `rates` # vector of rate functions for ConstantRateJumps - `affects!` # vector of affect functions for ConstantRateJumps - `save_positions` # tuple for whether to save the jumps before and/or after event - - `rng` # random number generator - ### Optional fields: - `dep_gr` # dependency graph, dep_gr[i] = indices of reactions that should be updated when rx i occurs. """ -abstract type AbstractSSAJumpAggregator{T, S, F1, F2, RNG} <: AbstractJumpAggregator end +abstract type AbstractSSAJumpAggregator{T, S, F1, F2} <: AbstractJumpAggregator end function DiscreteCallback(c::AbstractSSAJumpAggregator) DiscreteCallback(c, c, initialize = c, save_positions = c.save_positions) @@ -55,8 +53,8 @@ end nothing end -@inline function concretize_affects!(p::AbstractSSAJumpAggregator{T, S, F1, F2}, - ::I) where {T, S, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} +@inline function concretize_affects!(p::AbstractSSAJumpAggregator{<:Any, <:Any, <:Any, F2}, + ::I) where {F2 <: Tuple, I <: SciMLBase.DEIntegrator} nothing end @@ -88,8 +86,8 @@ function (p::AbstractSSAJumpAggregator)(integrator::I) where {I <: SciMLBase.DEI end function (p::AbstractSSAJumpAggregator{ - T, S, F1, F2})(integrator::SciMLBase.DEIntegrator) where - {T, S, F1, F2 <: Union{Tuple, Nothing}} + <:Any, <:Any, <:Any, F2})(integrator::SciMLBase.DEIntegrator) where + {F2 <: Union{Tuple, Nothing}} execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t, p.affects!) generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) register_next_jump_time!(integrator, p, integrator.t) @@ -112,12 +110,12 @@ end """ build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rates, - affects!, save_positions, rng; kwargs...) + affects!, save_positions; kwargs...) Helper routine for setting up standard fields of SSA jump aggregations. """ function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rates, - affects!, save_positions, rng; kwargs...) + affects!, save_positions; kwargs...) # mass action jumps majumps = ma_jumps @@ -134,7 +132,7 @@ function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rate next_jump = 0 next_jump_time = typemax(typeof(t)) jump_agg_type(next_jump, next_jump_time, end_time, cur_rates, sum_rate, - majumps, rates, affects!, save_positions, rng; kwargs...) + majumps, rates, affects!, save_positions; kwargs...) end """ diff --git a/src/problem.jl b/src/problem.jl index 157c97a3..e09d92dd 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -51,7 +51,6 @@ $(FIELDS) ## Keyword Arguments - - `rng`, the random number generator to use. Defaults to Julia's built-in generator. - `save_positions=(true,true)` when including variable rates and `(false,true)` for constant rates, specifies whether to save the system's state (before, after) the jump occurs. - `spatial_system`, for spatial problems the underlying spatial structure. @@ -61,14 +60,20 @@ $(FIELDS) integration interface, and treated like general `VariableRateJump`s. - `vr_aggregator`, indicates the aggregator to use for sampling variable rate jumps. Current default is `VR_FRM`. + - `rng`, a random number generator to use for jump sampling. If provided, it is stored in + the problem's solver keyword arguments and forwarded to the solver when `solve` or `init` + is called. The recommended approach is to pass `rng` directly to `solve` or `init`: + `solve(jprob, SSAStepper(); rng = my_rng)`. + - `tstops`, time stops to pass through to the solver. Can be an `AbstractVector` of times + or a callable `(p, tspan) -> times`. Please see the [tutorial page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in the DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage examples and commonly asked questions. """ -mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J1, - J2, J3, J4, R, K} <: DiffEqBase.AbstractJumpProblem{P, J} +mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J1, + J2, J3, J4, K} <: DiffEqBase.AbstractJumpProblem{P, J} """The type of problem to couple the jumps to. For a pure jump process use `DiscreteProblem`, to couple to ODEs, `ODEProblem`, etc.""" prob::P """The aggregator algorithm that determines the next jump times and types for `ConstantRateJump`s and `MassActionJump`s. Examples include `Direct`.""" @@ -85,26 +90,24 @@ mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggrega regular_jump::J3 """The `MassActionJump`s.""" massaction_jump::J4 - """The random number generator to use.""" - rng::R """kwargs to pass on to solve call.""" kwargs::K end function JumpProblem(p::P, a::A, dj::J, jc::C, cj::J1, vj::J2, rj::J3, mj::J4, - rng::R, kwargs::K) where {P, A, J, C, J1, J2, J3, J4, R, K} + kwargs::K) where {P, A, J, C, J1, J2, J3, J4, K} iip = isinplace_jump(p, rj) - JumpProblem{iip, P, A, C, J, J1, J2, J3, J4, R, K}(p, a, dj, jc, cj, vj, rj, mj, - rng, kwargs) + JumpProblem{iip, P, A, C, J, J1, J2, J3, J4, K}(p, a, dj, jc, cj, vj, rj, mj, + kwargs) end ######## remaking ###### # for a problem where prob.u0 is an ExtendedJumpArray, create an ExtendedJumpArray that # aliases and resets prob.u0.jump_u while having newu0 as the new u component. -function remake_extended_u0(prob, newu0, rng) +function remake_extended_u0(prob, newu0) jump_u = prob.u0.jump_u ttype = eltype(prob.tspan) - @. jump_u = -randexp(rng, ttype) + @. jump_u = zero(ttype) ExtendedJumpArray(newu0, jump_u) end @@ -142,7 +145,7 @@ function DiffEqBase.remake(jprob::JumpProblem; u0 = missing, p = missing, error("Passed in u0 is incompatible with current u0 which has type: $(typeof(prob.u0.u)).") end - final_u0 = remake_extended_u0(prob, state_vals, jprob.rng) + final_u0 = remake_extended_u0(prob, state_vals) end newprob = DiffEqBase.remake(prob; u0 = final_u0, p, interpret_symbolicmap, use_defaults, kwargs...) else @@ -169,8 +172,8 @@ function DiffEqBase.remake(jprob::JumpProblem; u0 = missing, p = missing, end T(newprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback, - jprob.constant_jumps, jprob.variable_jumps, jprob.regular_jump, - jprob.massaction_jump, jprob.rng, jprob.kwargs) + jprob.constant_jumps, jprob.variable_jumps, jprob.regular_jump, + jprob.massaction_jump, jprob.kwargs) end # for updating parameters in JumpProblems to update MassActionJumps @@ -239,7 +242,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS vr_aggregator::VariableRateAggregator = VR_FRM(), save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true), - rng = DEFAULT_RNG, scale_rates = true, useiszero = true, + rng = nothing, scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, callback = nothing, tstops = nothing, use_vrj_bounds = true, kwargs...) @@ -289,15 +292,15 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS constant_jump_callback = CallbackSet() else disc_agg = aggregate(aggregator, u, prob.p, t, end_time, crjs, maj, - save_positions, rng; kwargs...) + save_positions; kwargs...) constant_jump_callback = DiscreteCallback(disc_agg) end # handle any remaining vrjs if length(cvrjs) > 0 # Handle variable rate jumps based on vr_aggregator - new_prob, variable_jump_callback = configure_jump_problem(prob, vr_aggregator, - jumps, cvrjs; rng) + new_prob, variable_jump_callback = configure_jump_problem(prob, vr_aggregator, + jumps, cvrjs) else new_prob = prob variable_jump_callback = CallbackSet() @@ -306,19 +309,19 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS jump_cbs = CallbackSet(constant_jump_callback, variable_jump_callback) iip = isinplace_jump(prob, jumps.regular_jump) - solkwargs = tstops === nothing ? make_kwarg(; callback) : make_kwarg(; callback, tstops) + solkwargs = make_kwarg(; filter(!isnothing, (; callback, tstops, rng))...) JumpProblem{iip, typeof(new_prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(cvrjs), typeof(jumps.regular_jump), - typeof(maj), typeof(rng), typeof(solkwargs)}(new_prob, aggregator, disc_agg, - jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, rng, solkwargs) + typeof(maj), typeof(solkwargs)}(new_prob, aggregator, disc_agg, + jump_cbs, crjs, cvrjs, jumps.regular_jump, maj, solkwargs) end # Special dispatch for PureLeaping aggregator - bypasses all aggregation function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true), - rng = DEFAULT_RNG, scale_rates = true, useiszero = true, + rng = nothing, scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, callback = nothing, tstops = nothing, kwargs...) @@ -342,18 +345,18 @@ function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; # No discrete jump aggregation or variable rate callbacks are created disc_agg = nothing jump_cbs = CallbackSet() - + # Store all jump types for access by tau-leaping solver crjs = jumps.constant_jumps vrjs = jumps.variable_jumps - + iip = isinplace_jump(prob, jumps.regular_jump) - solkwargs = tstops === nothing ? make_kwarg(; callback) : make_kwarg(; callback, tstops) + solkwargs = make_kwarg(; filter(!isnothing, (; callback, tstops, rng))...) JumpProblem{iip, typeof(prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(vrjs), typeof(jumps.regular_jump), - typeof(maj), typeof(rng), typeof(solkwargs)}(prob, aggregator, disc_agg, - jump_cbs, crjs, vrjs, jumps.regular_jump, maj, rng, solkwargs) + typeof(maj), typeof(solkwargs)}(prob, aggregator, disc_agg, + jump_cbs, crjs, vrjs, jumps.regular_jump, maj, solkwargs) end aggregator(jp::JumpProblem{iip, P, A}) where {iip, P, A} = A diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index da98ba1c..4d817b03 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -70,13 +70,14 @@ function _process_saveat(saveat, tspan, save_start, save_end) end function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; - seed = nothing, dt = error("dt is required for SimpleTauLeaping."), + seed = nothing, rng = nothing, + dt = error("dt is required for SimpleTauLeaping."), saveat = nothing, save_start = nothing, save_end = nothing) validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only RegularJumps.") - (; prob, rng) = jump_prob - (seed !== nothing) && seed!(rng, seed) + prob = jump_prob.prob + _rng = resolve_rng(rng, seed, get(jump_prob.kwargs, :rng, nothing)) rj = jump_prob.regular_jump rate = rj.rate # rate function rate(out,u,p,t) @@ -117,7 +118,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; t_new = tprev + dt rate(rate_cache, uprev, p, tprev) rate_cache .*= dt - counts .= pois_rand.((rng,), rate_cache) + counts .= pois_rand.((_rng,), rate_cache) c(du, uprev, p, tprev, counts, mark) u_new .= du .+ uprev @@ -335,22 +336,20 @@ function simple_explicit_tau_leaping_loop!( end function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; - seed = nothing, + seed = nothing, rng = nothing, dtmin = nothing, saveat = nothing, save_start = nothing, save_end = nothing) validate_pure_leaping_inputs(jump_prob, alg) || error("SimpleExplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") prob = jump_prob.prob - rng = jump_prob.rng + _rng = resolve_rng(rng, seed, get(jump_prob.kwargs, :rng, nothing)) tspan = prob.tspan if dtmin === nothing dtmin = 1e-10 * one(typeof(tspan[2])) end - (seed !== nothing) && seed!(rng, seed) - maj = jump_prob.massaction_jump numjumps = get_num_majumps(maj) rj = jump_prob.regular_jump @@ -394,7 +393,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; reactant_stoch, hor, length(u0), numjumps) simple_explicit_tau_leaping_loop!( - prob, alg, u_current, u_new, t_current, t_end, p, rng, + prob, alg, u_current, u_new, t_current, t_end, p, _rng, rate, c, nu, hor, max_hor, max_stoich, numjumps, epsilon, dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj, save_end) diff --git a/src/solve.jl b/src/solve.jl index 11bc16bb..1ae5b384 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,3 +1,22 @@ +""" + resolve_rng(rng, seed[, fallback_rng]) + +Resolve which RNG to use for a jump simulation. + +Priority: `rng` > `seed` (creates `Xoshiro`) > `fallback_rng` > `Random.default_rng()`. +""" +function resolve_rng(rng, seed, fallback_rng = nothing) + if rng !== nothing + rng + elseif seed !== nothing + Random.Xoshiro(seed) + elseif fallback_rng !== nothing + fallback_rng + else + Random.default_rng() + end +end + function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg::DiffEqBase.DEAlgorithm; merge_callbacks = true, kwargs...) where {P} @@ -38,53 +57,33 @@ function DiffEqBase.__init(_jump_prob::DiffEqBase.AbstractJumpProblem{P}, kwargs = DiffEqBase.merge_problem_kwargs(_jump_prob; merge_callbacks, kwargs...) __jump_init(_jump_prob, alg; kwargs...) -end +end function __jump_init(_jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg; - callback = nothing, seed = nothing, + callback = nothing, seed = nothing, rng = nothing, alias_jump = Threads.threadid() == 1, kwargs...) where {P} + + _rng = resolve_rng(rng, seed) + if alias_jump jump_prob = _jump_prob - reset_jump_problem!(jump_prob, seed) else - jump_prob = resetted_jump_problem(_jump_prob, seed) + jump_prob = resetted_jump_problem(_jump_prob) end - # DDEProblems do not have a recompile_flag argument - if jump_prob.prob isa DiffEqBase.AbstractDDEProblem - # callback comes after jump consistent with SSAStepper - integrator = init(jump_prob.prob, alg; - callback = CallbackSet(jump_prob.jump_callback, callback), - kwargs...) - else - # callback comes after jump consistent with SSAStepper - integrator = init(jump_prob.prob, alg; - callback = CallbackSet(jump_prob.jump_callback, callback), - kwargs...) - end + init(jump_prob.prob, alg; + callback = CallbackSet(jump_prob.jump_callback, callback), + rng = _rng, kwargs...) end -# Derive an independent seed from the caller's seed. When a caller (e.g. StochasticDiffEq) -# passes the same seed used for its noise process, we must produce a distinct seed for the -# jump aggregator's RNG. We cannot assume the JumpProblem's stored RNG is any particular -# type, so we pass the seed through `hash` (to decorrelate from the input) and then through -# a Xoshiro draw (to ensure strong mixing regardless of the target RNG's seeding quality). -const _JUMP_SEED_SALT = 0x4a756d7050726f63 # "JumPProc" in ASCII -_derive_jump_seed(seed) = rand(Random.Xoshiro(hash(seed, _JUMP_SEED_SALT)), UInt64) - -function resetted_jump_problem(_jump_prob, seed) - jump_prob = deepcopy(_jump_prob) - if seed !== nothing && !isempty(jump_prob.jump_callback.discrete_callbacks) - rng = jump_prob.jump_callback.discrete_callbacks[1].condition.rng - Random.seed!(rng, _derive_jump_seed(seed)) - end - jump_prob +# Keep function signatures for StochasticDiffEq backward compatibility. +# The seed argument is accepted but no longer used to reseed aggregator RNGs +# (RNG state is now managed by the integrator). +function resetted_jump_problem(_jump_prob, seed = nothing) + deepcopy(_jump_prob) end -function reset_jump_problem!(jump_prob, seed) - if seed !== nothing && !isempty(jump_prob.jump_callback.discrete_callbacks) - Random.seed!(jump_prob.jump_callback.discrete_callbacks[1].condition.rng, - _derive_jump_seed(seed)) - end +function reset_jump_problem!(jump_prob, seed = nothing) + nothing end diff --git a/src/spatial/directcrdirect.jl b/src/spatial/directcrdirect.jl index f282bd1a..0961ec25 100644 --- a/src/spatial/directcrdirect.jl +++ b/src/spatial/directcrdirect.jl @@ -5,10 +5,10 @@ const MINJUMPRATE = 2.0^exponent(1e-12) #NOTE state vector u is a matrix. u[i,j] is species i, site j #NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j -mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, +mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, J, RX, HOP, DEPGR, VJMAP, JVMAP, SS, U <: PriorityTable, W <: Function} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::SpatialJump{J} #some structure to identify the next event: reaction or hop prev_jump::SpatialJump{J} #some structure to identify the previous event: reaction or hop next_jump_time::T @@ -19,7 +19,6 @@ mutable struct DirectCRDirectJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPG rates::F1 # legacy, not used affects!::F2 # legacy, not used save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR #dep graph is same for each site vartojumps_map::VJMAP #vartojumps_map is same for each site jumptovars_map::JVMAP #jumptovars_map is same for each site @@ -31,11 +30,11 @@ end function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, hop_rates::HOP, site_rates::Vector{T}, - sps::Tuple{Bool, Bool}, rng::RNG, spatial_system::SS; + sps::Tuple{Bool, Bool}, spatial_system::SS; num_specs, minrate = convert(T, MINJUMPRATE), vartojumps_map = nothing, jumptovars_map = nothing, dep_graph = nothing, - kwargs...) where {J, T, RX, HOP, RNG, SS} + kwargs...) where {J, T, RX, HOP, SS} # a dependency graph is needed if dep_graph === nothing @@ -69,12 +68,12 @@ function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rat # construct an empty initial priority table -- we'll reset this in init rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate) - DirectCRDirectJumpAggregation{T, Nothing, Nothing, Nothing, RNG, J, RX, HOP, + DirectCRDirectJumpAggregation{T, Nothing, Nothing, Nothing, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), SS, typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, rx_rates, hop_rates, site_rates, nothing, nothing, sps, - rng, dg, vtoj_map, + dg, vtoj_map, jtov_map, spatial_system, num_specs, rt, ratetogroup) end @@ -82,7 +81,7 @@ end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::DirectCRDirect, starting_state, p, t, end_time, - constant_jumps, ma_jumps, save_positions, rng; hopping_constants, + constant_jumps, ma_jumps, save_positions; hopping_constants, spatial_system, kwargs...) num_species = size(starting_state, 1) majumps = ma_jumps @@ -99,7 +98,7 @@ function aggregate(aggregator::DirectCRDirect, starting_state, p, t, end_time, site_rates = zeros(typeof(end_time), num_sites(spatial_system)) DirectCRDirectJumpAggregation(next_jump, next_jump_time, end_time, rx_rates, hop_rates, - site_rates, save_positions, rng, spatial_system; + site_rates, save_positions, spatial_system; num_specs = num_species, kwargs...) end @@ -113,10 +112,11 @@ end # calculate the next jump / jump time function generate_jumps!(p::DirectCRDirectJumpAggregation, integrator, params, u, t) - p.next_jump_time = t + randexp(p.rng) / p.rt.gsum + rng = get_rng(integrator) + p.next_jump_time = t + randexp(rng) / p.rt.gsum p.next_jump_time >= p.end_time && return nothing - site = sample(p.rt, p.site_rates, p.rng) - p.next_jump = sample_jump_direct(p, site) + site = sample(p.rt, p.site_rates, rng) + p.next_jump = sample_jump_direct(p, site, rng) nothing end diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index 07c48da6..acf6b56e 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -3,9 +3,9 @@ ############################ NSM ################################### #NOTE state vector u is a matrix. u[i,j] is species i, site j #NOTE hopping_constants is a matrix. hopping_constants[i,j] is species i, site j -mutable struct NSMJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, JVMAP, +mutable struct NSMJumpAggregation{T, S, F1, F2, J, RX, HOP, DEPGR, VJMAP, JVMAP, PQ, SS} <: - AbstractSSAJumpAggregator{T, S, F1, F2, RNG} + AbstractSSAJumpAggregator{T, S, F1, F2} next_jump::SpatialJump{J} #some structure to identify the next event: reaction or hop prev_jump::SpatialJump{J} #some structure to identify the previous event: reaction or hop next_jump_time::T @@ -15,7 +15,6 @@ mutable struct NSMJumpAggregation{T, S, F1, F2, RNG, J, RX, HOP, DEPGR, VJMAP, J rates::F1 # legacy, not used affects!::F2 # legacy, not used save_positions::Tuple{Bool, Bool} - rng::RNG dep_gr::DEPGR #dep graph is same for each site vartojumps_map::VJMAP #vartojumps_map is same for each site jumptovars_map::JVMAP #jumptovars_map is same for each site @@ -27,9 +26,9 @@ end function NSMJumpAggregation( nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, hop_rates::HOP, sps::Tuple{Bool, Bool}, - rng::RNG, spatial_system::SS; num_specs, + spatial_system::SS; num_specs, vartojumps_map = nothing, jumptovars_map = nothing, - dep_graph = nothing, kwargs...) where {J, T, RX, HOP, RNG, SS} + dep_graph = nothing, kwargs...) where {J, T, RX, HOP, SS} # a dependency graph is needed if dep_graph === nothing @@ -55,13 +54,13 @@ function NSMJumpAggregation( pq = MutableBinaryMinHeap{T}() - NSMJumpAggregation{T, Nothing, Nothing, Nothing, RNG, J, RX, HOP, typeof(dg), + NSMJumpAggregation{T, Nothing, Nothing, Nothing, J, RX, HOP, typeof(dg), typeof(vtoj_map), typeof(jtov_map), typeof(pq), SS}(nj, nj, njt, et, rx_rates, hop_rates, nothing, nothing, - sps, rng, dg, + sps, dg, vtoj_map, jtov_map, pq, @@ -72,7 +71,7 @@ end ############################# Required Functions ############################## # creating the JumpAggregation structure (function wrapper-based constant jumps) function aggregate(aggregator::NSM, starting_state, p, t, end_time, constant_jumps, - ma_jumps, save_positions, rng; hopping_constants, spatial_system, + ma_jumps, save_positions; hopping_constants, spatial_system, kwargs...) num_species = size(starting_state, 1) majumps = ma_jumps @@ -88,7 +87,7 @@ function aggregate(aggregator::NSM, starting_state, p, t, end_time, constant_jum hop_rates = HopRates(hopping_constants, spatial_system) NSMJumpAggregation(next_jump, next_jump_time, end_time, rx_rates, hop_rates, - save_positions, rng, spatial_system; num_specs = num_species, + save_positions, spatial_system; num_specs = num_species, kwargs...) end @@ -104,7 +103,8 @@ end function generate_jumps!(p::NSMJumpAggregation, integrator, params, u, t) p.next_jump_time, site = top_with_handle(p.pq) p.next_jump_time >= p.end_time && return nothing - p.next_jump = sample_jump_direct(p, site) + rng = get_rng(integrator) + p.next_jump = sample_jump_direct(p, site, rng) nothing end @@ -125,7 +125,8 @@ end reset all structs, reevaluate all rates, recalculate tentative site firing times, and reinit the priority queue """ function fill_rates_and_get_times!(aggregation::NSMJumpAggregation, integrator, t) - (; spatial_system, rx_rates, hop_rates, rng) = aggregation + (; spatial_system, rx_rates, hop_rates) = aggregation + rng = get_rng(integrator) u = integrator.u reset!(rx_rates) @@ -153,30 +154,31 @@ recalculate jump rates for jumps that depend on the just executed jump (p.prev_j """ function update_dependent_rates_and_firing_times!(p::NSMJumpAggregation, integrator, t) u = integrator.u + rng = get_rng(integrator) jump = p.prev_jump if is_hop(p, jump) source_site = jump.src target_site = jump.dst update_rates_after_hop!(p, integrator, source_site, target_site, jump.jidx) - update_site_time!(p, source_site, t) - update_site_time!(p, target_site, t) + update_site_time!(p, source_site, t, rng) + update_site_time!(p, target_site, t, rng) else site = jump.src update_rates_after_reaction!(p, integrator, site, reaction_id_from_jump(p, jump)) - update_site_time!(p, site, t) + update_site_time!(p, site, t, rng) end nothing end """ - update_site_time!(p::NSMJumpAggregation, site, t) + update_site_time!(p::NSMJumpAggregation, site, t, rng) update the time of site in the priority queue """ -function update_site_time!(p::NSMJumpAggregation, site, t) +function update_site_time!(p::NSMJumpAggregation, site, t, rng) site_rate = (total_site_rate(p.rx_rates, p.hop_rates, site)) if site_rate > zero(typeof(site_rate)) - update!(p.pq, site, t + randexp(p.rng) / site_rate) + update!(p.pq, site, t + randexp(rng) / site_rate) else update!(p.pq, site, typemax(t)) end diff --git a/src/spatial/utils.jl b/src/spatial/utils.jl index 370e8601..3be2b737 100644 --- a/src/spatial/utils.jl +++ b/src/spatial/utils.jl @@ -23,18 +23,18 @@ end ######################## helper routines for all spatial SSAs ######################## """ - sample_jump_direct(p, site) + sample_jump_direct(p, site, rng) sample jump at site with direct method """ -function sample_jump_direct(p, site) - if rand(p.rng) * (total_site_rate(p.rx_rates, p.hop_rates, site)) < +function sample_jump_direct(p, site, rng) + if rand(rng) * (total_site_rate(p.rx_rates, p.hop_rates, site)) < total_site_rx_rate(p.rx_rates, site) - rx = sample_rx_at_site(p.rx_rates, site, p.rng) + rx = sample_rx_at_site(p.rx_rates, site, rng) return SpatialJump(site, rx + p.numspecies, site) else species_to_diffuse, - target_site = sample_hop_at_site(p.hop_rates, site, p.rng, + target_site = sample_hop_at_site(p.hop_rates, site, rng, p.spatial_system) return SpatialJump(site, species_to_diffuse, target_site) end diff --git a/src/variable_rate.jl b/src/variable_rate.jl index 48c40534..beed2a0a 100644 --- a/src/variable_rate.jl +++ b/src/variable_rate.jl @@ -28,21 +28,21 @@ Simulating a birth-death process with `VR_FRM`: ```julia using JumpProcesses, OrdinaryDiffEq -u0 = [1.0] # Initial population -p = [10.0, 0.5] # [birth rate, death rate] +u0 = [1.0] # Initial population +p = [10.0, 0.5] # [birth rate, death rate] tspan = (0.0, 10.0) -# Birth jump: ∅ → X +# Birth jump: ∅ → X birth_rate(u, p, t) = p[1] birth_affect!(integrator) = (integrator.u[1] += 1; nothing) birth_jump = VariableRateJump(birth_rate, birth_affect!) -# Death jump: X → ∅ +# Death jump: X → ∅ death_rate(u, p, t) = p[2] * u[1] death_affect!(integrator) = (integrator.u[1] -= 1; nothing) death_jump = VariableRateJump(death_rate, death_affect!) -# Problem setup +# Problem setup oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p) jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_FRM()) sol = solve(jprob, Tsit5()) @@ -58,26 +58,25 @@ sol = solve(jprob, Tsit5()) """ struct VR_FRM <: VariableRateAggregator end -function configure_jump_problem(prob, vr_aggregator::VR_FRM, jumps, cvrjs; - rng = DEFAULT_RNG) - new_prob = extend_problem(prob, cvrjs; rng) - variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...; rng) +function configure_jump_problem(prob, vr_aggregator::VR_FRM, jumps, cvrjs) + new_prob = extend_problem(prob, cvrjs) + variable_jump_callback = build_variable_callback(CallbackSet(), 0, cvrjs...) return new_prob, variable_jump_callback end # extends prob.u0 to an ExtendedJumpArray with Njumps integrated intensity values, # of type prob.tspan -function extend_u0(prob, Njumps, rng) +function extend_u0(prob, Njumps) ttype = eltype(prob.tspan) - u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:Njumps]) + u0 = ExtendedJumpArray(prob.u0, zeros(ttype, Njumps)) return u0 end -function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractDiscreteProblem, jumps) error("General `VariableRateJump`s require a continuous problem, like an ODE/SDE/DDE/DAE problem. To use a `DiscreteProblem` bounded `VariableRateJump`s must be used. See the JumpProcesses docs.") end -function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps) _f = SciMLBase.unwrapped_f(prob.f) if isinplace(prob) @@ -97,13 +96,13 @@ function extend_problem(prob::DiffEqBase.AbstractODEProblem, jumps; rng = DEFAUL end end - u0 = extend_u0(prob, length(jumps), rng) + u0 = extend_u0(prob, length(jumps)) f = ODEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, observed = prob.f.observed) remake(prob; f, u0) end -function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps) _f = SciMLBase.unwrapped_f(prob.f) if isinplace(prob) @@ -133,13 +132,13 @@ function extend_problem(prob::DiffEqBase.AbstractSDEProblem, jumps; rng = DEFAUL end end - u0 = extend_u0(prob, length(jumps), rng) + u0 = extend_u0(prob, length(jumps)) f = SDEFunction{isinplace(prob)}(jump_f, jump_g; sys = prob.f.sys, observed = prob.f.observed) remake(prob; f, g = jump_g, u0) end -function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps) _f = SciMLBase.unwrapped_f(prob.f) if isinplace(prob) @@ -159,14 +158,14 @@ function extend_problem(prob::DiffEqBase.AbstractDDEProblem, jumps; rng = DEFAUL end end - u0 = extend_u0(prob, length(jumps), rng) + u0 = extend_u0(prob, length(jumps)) f = DDEFunction{isinplace(prob)}(jump_f; sys = prob.f.sys, observed = prob.f.observed) remake(prob; f, u0) end # Not sure if the DAE one is correct: Should be a residual of sorts -function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAULT_RNG) +function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps) _f = SciMLBase.unwrapped_f(prob.f) if isinplace(prob) @@ -186,16 +185,15 @@ function extend_problem(prob::DiffEqBase.AbstractDAEProblem, jumps; rng = DEFAUL end end - u0 = extend_u0(prob, length(jumps), rng) + u0 = extend_u0(prob, length(jumps)) f = DAEFunction{isinplace(prob)}(jump_f, sys = prob.f.sys, observed = prob.f.observed) remake(prob; f, u0) end -struct VR_FRMEventCallback{F, RNG} +struct VR_FRMEventCallback{F} idx::Int affect!::F - rng::RNG end # condition: (u, t, integrator) @@ -204,20 +202,22 @@ end # affect: (integrator) function (c::VR_FRMEventCallback)(integrator) c.affect!(integrator) - integrator.u.jump_u[c.idx] = -randexp(c.rng, typeof(integrator.t)) + rng = get_rng(integrator) + integrator.u.jump_u[c.idx] = -randexp(rng, typeof(integrator.t)) nothing end # initialize: (cb, u, t, integrator) function (c::VR_FRMEventCallback)(cb, u, t, integrator) - integrator.u.jump_u[c.idx] = -randexp(c.rng, typeof(integrator.t)) + rng = get_rng(integrator) + integrator.u.jump_u[c.idx] = -randexp(rng, typeof(integrator.t)) integrator.uprev.jump_u[c.idx] = integrator.u.jump_u[c.idx] u_modified!(integrator, false) nothing end -function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG) - cb_functor = VR_FRMEventCallback(idx, jump.affect!, rng) +function wrap_jump_in_callback(idx, jump) + cb_functor = VR_FRMEventCallback(idx, jump.affect!) ContinuousCallback(cb_functor, cb_functor; initialize = cb_functor, idxs = jump.idxs, @@ -228,15 +228,15 @@ function wrap_jump_in_callback(idx, jump; rng = DEFAULT_RNG) reltol = jump.reltol) end -function build_variable_callback(cb, idx, jump, jumps...; rng = DEFAULT_RNG) +function build_variable_callback(cb, idx, jump, jumps...) idx += 1 - new_cb = wrap_jump_in_callback(idx, jump; rng) - build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...; rng = DEFAULT_RNG) + new_cb = wrap_jump_in_callback(idx, jump) + build_variable_callback(CallbackSet(cb, new_cb), idx, jumps...) end -function build_variable_callback(cb, idx, jump; rng = DEFAULT_RNG) +function build_variable_callback(cb, idx, jump) idx += 1 - CallbackSet(cb, wrap_jump_in_callback(idx, jump; rng)) + CallbackSet(cb, wrap_jump_in_callback(idx, jump)) end @inline function update_jumps!(du, u, p, t, idx, jump) @@ -268,21 +268,21 @@ Simulating a birth-death process with `VR_Direct` (default) and VR_DirectFW: ```julia using JumpProcesses, OrdinaryDiffEq -u0 = [1.0] # Initial population -p = [10.0, 0.5] # [birth rate, death rate coefficient] +u0 = [1.0] # Initial population +p = [10.0, 0.5] # [birth rate, death rate coefficient] tspan = (0.0, 10.0) -# Birth jump: ∅ → X +# Birth jump: ∅ → X birth_rate(u, p, t) = p[1] birth_affect!(integrator) = (integrator.u[1] += 1; nothing) birth_jump = VariableRateJump(birth_rate, birth_affect!) -# Death jump: X → ∅ +# Death jump: X → ∅ death_rate(u, p, t) = p[2] * u[1] death_affect!(integrator) = (integrator.u[1] -= 1; nothing) death_jump = VariableRateJump(death_rate, death_affect!) -# Problem setup +# Problem setup oprob = ODEProblem((du, u, p, t) -> du .= 0, u0, tspan, p) jprob = JumpProblem(oprob, birth_jump, death_jump; vr_aggregator = VR_Direct()) sol = solve(jprob, Tsit5()) @@ -298,21 +298,19 @@ sol = solve(jprob, Tsit5()) struct VR_Direct <: VariableRateAggregator end struct VR_DirectFW <: VariableRateAggregator end -mutable struct VR_DirectEventCache{T, RNG, F1, F2} +mutable struct VR_DirectEventCache{T, F1, F2} prev_time::T prev_threshold::T current_time::T current_threshold::T total_rate::T - rng::RNG rate_funcs::F1 affect_funcs::F2 cum_rate_sum::Vector{T} end function VR_DirectEventCache( - jumps::JumpSet, ::VR_Direct, prob, ::Type{T}; rng = DEFAULT_RNG) where {T} - initial_threshold = randexp(rng, T) + jumps::JumpSet, ::VR_Direct, prob, ::Type{T}) where {T} vjumps = jumps.variable_jumps # handle vjumps using tuples @@ -320,14 +318,13 @@ function VR_DirectEventCache( cum_rate_sum = Vector{T}(undef, length(vjumps)) - VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), typeof(affect_funcs)}(zero(T), - initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, + VR_DirectEventCache{T, typeof(rate_funcs), typeof(affect_funcs)}(zero(T), + zero(T), zero(T), zero(T), zero(T), rate_funcs, affect_funcs, cum_rate_sum) end function VR_DirectEventCache( - jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}; rng = DEFAULT_RNG) where {T} - initial_threshold = randexp(rng, T) + jumps::JumpSet, ::VR_DirectFW, prob, ::Type{T}) where {T} vjumps = jumps.variable_jumps t, u = prob.tspan[1], prob.u0 @@ -337,16 +334,17 @@ function VR_DirectEventCache( cum_rate_sum = Vector{T}(undef, length(vjumps)) - VR_DirectEventCache{T, typeof(rng), typeof(rate_funcs), Any}(zero(T), - initial_threshold, zero(T), initial_threshold, zero(T), rng, rate_funcs, + VR_DirectEventCache{T, typeof(rate_funcs), Any}(zero(T), + zero(T), zero(T), zero(T), zero(T), rate_funcs, affect_funcs, cum_rate_sum) end # Initialization function for VR_DirectEventCache function initialize_vr_direct_cache!(cache::VR_DirectEventCache, u, t, integrator) + rng = get_rng(integrator) cache.prev_time = zero(integrator.t) cache.current_time = zero(integrator.t) - cache.prev_threshold = randexp(cache.rng, eltype(integrator.t)) + cache.prev_threshold = randexp(rng, eltype(integrator.t)) cache.current_threshold = cache.prev_threshold cache.total_rate = zero(integrator.t) cache.cum_rate_sum .= 0 @@ -364,8 +362,8 @@ end nothing end -@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, RNG, F1, F2}, - ::I) where {T, RNG, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} +@inline function concretize_vr_direct_affects!(cache::VR_DirectEventCache{T, F1, F2}, + ::I) where {T, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} nothing end @@ -393,16 +391,16 @@ function build_variable_integcallback(cache::VR_DirectEventCache, jumps) save_positions, abstol, reltol) end -function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs; rng = DEFAULT_RNG) +function configure_jump_problem(prob, ::VR_Direct, jumps, cvrjs) new_prob = prob - cache = VR_DirectEventCache(jumps, VR_Direct(), prob, eltype(prob.tspan); rng) + cache = VR_DirectEventCache(jumps, VR_Direct(), prob, eltype(prob.tspan)) variable_jump_callback = build_variable_integcallback(cache, cvrjs) return new_prob, variable_jump_callback end -function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs; rng = DEFAULT_RNG) +function configure_jump_problem(prob, ::VR_DirectFW, jumps, cvrjs) new_prob = prob - cache = VR_DirectEventCache(jumps, VR_DirectFW(), prob, eltype(prob.tspan); rng) + cache = VR_DirectEventCache(jumps, VR_DirectFW(), prob, eltype(prob.tspan)) variable_jump_callback = build_variable_integcallback(cache, cvrjs) return new_prob, variable_jump_callback end @@ -435,8 +433,7 @@ end end function total_variable_rate( - cache::VR_DirectEventCache{ - T, RNG, F1, F2}, u, p, t) where {T, RNG, F1, F2} + cache::VR_DirectEventCache{T, F1, F2}, u, p, t) where {T, F1, F2} (; cum_rate_sum, rate_funcs) = cache sum_rate = cumsum_rates!(cum_rate_sum, u, p, t, rate_funcs) return sum_rate @@ -481,8 +478,8 @@ function (cache::VR_DirectEventCache)(u, t, integrator) return cache.current_threshold end -@generated function execute_affect!(cache::VR_DirectEventCache{T, RNG, F1, F2}, - integrator::I, idx) where {T, RNG, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} +@generated function execute_affect!(cache::VR_DirectEventCache{T, F1, F2}, + integrator::I, idx) where {T, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} quote (; affect_funcs) = cache Base.Cartesian.@nif $(fieldcount(F2)) i -> (i == idx) i -> (@inbounds affect_funcs[i](integrator)) i -> (@inbounds affect_funcs[fieldcount(F2)](integrator)) @@ -509,7 +506,7 @@ function (cache::VR_DirectEventCache)(integrator) end cache.total_rate = total_variable_rate_sum - rng = cache.rng + rng = get_rng(integrator) r = rand(rng) * total_variable_rate_sum @inbounds jump_idx = searchsortedfirst(cache.cum_rate_sum, r) execute_affect!(cache, integrator, jump_idx) diff --git a/test/degenerate_rx_cases.jl b/test/degenerate_rx_cases.jl index 3b211dec..55b2985e 100644 --- a/test/degenerate_rx_cases.jl +++ b/test/degenerate_rx_cases.jl @@ -20,8 +20,8 @@ rs = [[0 => 1]] ns = [[1 => 1]] jump = MassActionJump(rate, rs, ns) prob = DiscreteProblem([100], (0.0, 100.0)) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) -sol = solve(jump_prob, SSAStepper()) +jump_prob = JumpProblem(prob, Direct(), jump) +sol = solve(jump_prob, SSAStepper(); rng) if doprint println("mass act jump using vectors of data: last val = ", sol[end, end]) end @@ -35,8 +35,8 @@ rate = 2.0 rs = [0 => 3] # stoich power should be ignored ns = [1 => 1] jump = MassActionJump(rate, rs, ns) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) -sol = solve(jump_prob, SSAStepper()) +jump_prob = JumpProblem(prob, Direct(), jump) +sol = solve(jump_prob, SSAStepper(); rng) if doprint println("mass act jump using scalar data: last val = ", sol[end, end]) end @@ -51,8 +51,8 @@ rs = [Vector{Pair{Int, Int}}()] ns = [[1 => 1]] jump = MassActionJump(rate, rs, ns) prob = DiscreteProblem([100], (0.0, 100.0)) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) -sol = solve(jump_prob, SSAStepper()) +jump_prob = JumpProblem(prob, Direct(), jump) +sol = solve(jump_prob, SSAStepper(); rng) if doprint println("mass act jump using vector of Pair{Int,Int}: last val = ", sol[end, end]) end @@ -66,8 +66,8 @@ rate = 2.0 rs = Vector{Pair{Int, Int}}() ns = [1 => 1] jump = MassActionJump(rate, rs, ns) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) -sol = solve(jump_prob, SSAStepper()) +jump_prob = JumpProblem(prob, Direct(), jump) +sol = solve(jump_prob, SSAStepper(); rng) if doprint println("mass act jump using scalar Pair{Int,Int}: last val = ", sol[end, end]) end @@ -100,8 +100,8 @@ jump_to_dep_specs = [[1], [1]] namedpars = (dep_graph = dep_graph, vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs) for method in methods - local jump_prob = JumpProblem(prob, method, jump, jump2; rng = rng, namedpars...) - local sol = solve(jump_prob, SSAStepper()) + local jump_prob = JumpProblem(prob, method, jump, jump2; namedpars...) + local sol = solve(jump_prob, SSAStepper(); rng) if doplot plot!(plothand2, sol, label = ("A <-> 0, " * string(method))) diff --git a/test/ensemble_problems.jl b/test/ensemble_problems.jl index f042d515..541e7cf1 100644 --- a/test/ensemble_problems.jl +++ b/test/ensemble_problems.jl @@ -6,30 +6,30 @@ using StableRNGs, Random # ========================================================================== # Constant-rate birth-death for SSAStepper / ODE-coupled tests -function make_ssa_jump_prob(; rng = StableRNG(12345)) +function make_ssa_jump_prob() j1 = ConstantRateJump((u, p, t) -> 10.0, integrator -> (integrator.u[1] += 1)) j2 = ConstantRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1)) dprob = DiscreteProblem([10], (0.0, 20.0)) - JumpProblem(dprob, Direct(), j1, j2; rng) + JumpProblem(dprob, Direct(), j1, j2) end # ODE + variable-rate jump -function make_vr_jump_prob(agg; rng = StableRNG(12345)) +function make_vr_jump_prob(agg) f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) oprob = ODEProblem(f!, [100.0], (0.0, 10.0)) vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1.0)) - JumpProblem(oprob, Direct(), vrj; vr_aggregator = agg, rng) + JumpProblem(oprob, Direct(), vrj; vr_aggregator = agg) end # SDE + variable-rate jump -function make_sde_vr_jump_prob(agg; rng = StableRNG(12345)) +function make_sde_vr_jump_prob(agg) f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) g!(du, u, p, t) = (du[1] = 0.1 * u[1]; nothing) sprob = SDEProblem(f!, g!, [100.0], (0.0, 10.0)) vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1.0)) - JumpProblem(sprob, Direct(), vrj; vr_aggregator = agg, rng) + JumpProblem(sprob, Direct(), vrj; vr_aggregator = agg) end # Helpers @@ -43,7 +43,7 @@ first_jump_time(traj) = traj.t[2] @testset "SSAStepper" begin jprob = make_ssa_jump_prob() sol = solve(EnsembleProblem(jprob), SSAStepper(), EnsembleSerial(); - trajectories = 3) + trajectories = 3, rng = StableRNG(12345)) times = [first_jump_time(sol.u[i]) for i in 1:3] @test allunique(times) end @@ -51,7 +51,7 @@ first_jump_time(traj) = traj.t[2] @testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jprob = make_vr_jump_prob(agg) sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleSerial(); - trajectories = 3) + trajectories = 3, rng = StableRNG(12345)) times = [first_jump_time(sol.u[i]) for i in 1:3] @test allunique(times) finals = [sol.u[i].u[end][1] for i in 1:3] @@ -76,13 +76,15 @@ end @testset "Sequential solves: different RNG streams" begin @testset "SSAStepper" begin jprob = make_ssa_jump_prob() - times = [first_jump_time(solve(jprob, SSAStepper())) for _ in 1:3] + rng = StableRNG(12345) + times = [first_jump_time(solve(jprob, SSAStepper(); rng)) for _ in 1:3] @test allunique(times) end @testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jprob = make_vr_jump_prob(agg) - sols = [solve(jprob, Tsit5()) for _ in 1:3] + rng = StableRNG(12345) + sols = [solve(jprob, Tsit5(); rng) for _ in 1:3] times = [first_jump_time(s) for s in sols] @test allunique(times) finals = [s.u[end][1] for s in sols] @@ -93,13 +95,9 @@ end # ========================================================================== # 3. Threaded ensemble: no data race on the shared JumpProblem # -# The ODE/SSA path through __jump_init receives seed=nothing from -# SciMLBase, so deepcopy'd problems on non-main threads start with -# identical RNG states. We only assert completion here — uniqueness -# requires explicit seeding (tested in section 4 below). -# -# The SDE path goes through StochasticDiffEq's __init which generates -# per-trajectory seeds, so we can additionally verify uniqueness there. +# With integrator-owned RNGs, each thread's integrator gets its own +# default_rng(). We only assert completion here — uniqueness is tested +# via explicit rng kwarg in section 4. # ========================================================================== @testset "EnsembleThreads: no data race" begin @@ -112,8 +110,6 @@ end @testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jprob = make_vr_jump_prob(agg) - # This path previously had a data race: resetted_jump_problem called - # randexp!(_jump_prob.rng, ...) on the shared original problem. sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleThreads(); trajectories = 4, save_everystep = false) @test length(sol) == 4 @@ -121,8 +117,8 @@ end @testset "SDE + VR (VR_FRM): unique trajectories" begin jprob = make_sde_vr_jump_prob(VR_FRM()) - # StochasticDiffEq generates per-trajectory seeds and passes them to - # resetted_jump_problem, so trajectories should be distinct. + # StochasticDiffEq generates per-trajectory seeds, so trajectories + # should be distinct. sol = solve(EnsembleProblem(jprob), EM(), EnsembleThreads(); trajectories = 4, dt = 0.01, save_everystep = false) @test length(sol) == 4 @@ -132,57 +128,63 @@ end end # ========================================================================== -# 4. Seed-based stream independence: resetted_jump_problem and -# reset_jump_problem! produce distinct RNG streams for different seeds -# -# This tests the mechanism that EnsembleThreads relies on (when seeds are -# provided by the caller, e.g. StochasticDiffEq) to get independent streams -# on different threads. +# 4. rng kwarg reproducibility: same rng seed → identical trajectory, +# different rng seeds → different trajectories # ========================================================================== -@testset "resetted_jump_problem: different seeds → different streams" begin - jprob = make_ssa_jump_prob() - seeds = UInt64[100, 200, 300] +@testset "rng kwarg reproducibility" begin + @testset "SSAStepper: same seed → same trajectory" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); rng = StableRNG(42)) + sol2 = solve(jprob, SSAStepper(); rng = StableRNG(42)) + @test sol1.t == sol2.t + @test sol1.u == sol2.u + end - # Each seed should produce a distinct aggregator RNG state - rngs = map(seeds) do s - jp = JumpProcesses.resetted_jump_problem(jprob, s) - jp.jump_callback.discrete_callbacks[1].condition.rng + @testset "SSAStepper: different seeds → different trajectories" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); rng = StableRNG(100)) + sol2 = solve(jprob, SSAStepper(); rng = StableRNG(200)) + sol3 = solve(jprob, SSAStepper(); rng = StableRNG(300)) + times = [first_jump_time(sol1), first_jump_time(sol2), first_jump_time(sol3)] + @test allunique(times) end - draws = [rand(rng) for rng in rngs] - @test allunique(draws) - - # Same seed should be deterministic - jp1 = JumpProcesses.resetted_jump_problem(jprob, UInt64(42)) - jp2 = JumpProcesses.resetted_jump_problem(jprob, UInt64(42)) - rng1 = jp1.jump_callback.discrete_callbacks[1].condition.rng - rng2 = jp2.jump_callback.discrete_callbacks[1].condition.rng - @test rand(rng1) == rand(rng2) -end -@testset "reset_jump_problem!: different seeds → different streams" begin - seeds = UInt64[100, 200, 300] - draws = map(seeds) do s - jp = make_ssa_jump_prob() - JumpProcesses.reset_jump_problem!(jp, s) - rand(jp.jump_callback.discrete_callbacks[1].condition.rng) + @testset "ODE + VR ($agg): same seed → same trajectory" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) + jprob = make_vr_jump_prob(agg) + sol1 = solve(jprob, Tsit5(); rng = StableRNG(42)) + sol2 = solve(jprob, Tsit5(); rng = StableRNG(42)) + @test sol1.t ≈ sol2.t + @test sol1.u[end] ≈ sol2.u[end] + end + + @testset "ODE + VR ($agg): different seeds → different trajectories" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) + jprob = make_vr_jump_prob(agg) + sols = [solve(jprob, Tsit5(); rng = StableRNG(s)) for s in (100, 200, 300)] + finals = [s.u[end][1] for s in sols] + @test allunique(finals) end - @test allunique(draws) end -@testset "_derive_jump_seed: decorrelates from input seed" begin - seed = UInt64(12345) - derived = JumpProcesses._derive_jump_seed(seed) - # Derived seed should differ from input - @test derived != seed - # Should be deterministic - @test derived == JumpProcesses._derive_jump_seed(seed) - # Different inputs → different outputs - @test JumpProcesses._derive_jump_seed(UInt64(1)) != JumpProcesses._derive_jump_seed(UInt64(2)) +# ========================================================================== +# 5. has_rng / get_rng / set_rng! interface on SSAIntegrator +# ========================================================================== + +@testset "SSAIntegrator RNG interface" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); rng = StableRNG(42)) + + @test SciMLBase.has_rng(integrator) + rng = SciMLBase.get_rng(integrator) + @test rng isa StableRNG + + new_rng = StableRNG(99) + SciMLBase.set_rng!(integrator, new_rng) + @test SciMLBase.get_rng(integrator) === new_rng end # ========================================================================== -# 5. Variable-rate: jump_u thresholds are unique per trajectory +# 6. Variable-rate: jump_u thresholds are unique per trajectory # # For VR_FRM, each trajectory's first jump time is determined by the initial # jump_u threshold (set to -randexp() by the VR_FRMEventCallback initialize). @@ -193,9 +195,23 @@ end @testset "VR_FRM: jump_u thresholds unique per trajectory (EnsembleSerial)" begin jprob = make_vr_jump_prob(VR_FRM()) sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleSerial(); - trajectories = 3) + trajectories = 3, rng = StableRNG(12345)) # The second time point is when the first variable-rate jump fires, # directly reflecting the initial -randexp() threshold. event_times = [sol.u[i].t[2] for i in 1:3] @test allunique(event_times) end + +# ========================================================================== +# 7. JumpProblem rng kwarg forwarded to solver +# ========================================================================== + +@testset "JumpProblem rng kwarg forwarded to solver" begin + j1 = ConstantRateJump((u, p, t) -> 1.0, integrator -> (integrator.u[1] += 1)) + dprob = DiscreteProblem([10], (0.0, 10.0)) + jprob = JumpProblem(dprob, Direct(), j1; rng = StableRNG(1)) + @test haskey(jprob.kwargs, :rng) + @test jprob.kwargs[:rng] isa StableRNG + sol = solve(jprob, SSAStepper()) + @test sol.retcode == ReturnCode.Success +end diff --git a/test/extended_jump_array_remake.jl b/test/extended_jump_array_remake.jl index 7f2168d5..224064ae 100644 --- a/test/extended_jump_array_remake.jl +++ b/test/extended_jump_array_remake.jl @@ -17,20 +17,21 @@ using StableRNGs vr_affect!(integrator) = (integrator.u[1] -= 1; integrator.u[2] += 1) vrj = VariableRateJump(vr_rate, vr_affect!) - jprob = JumpProblem(oprob, vrj; rng) + jprob = JumpProblem(oprob, vrj) # Verify we have ExtendedJumpArray @test jprob.prob.u0 isa ExtendedJumpArray @test jprob.prob.u0.u == [10.0, 5.0] @testset "remake with numeric Vector{Float64}" begin - original_jump_u = copy(jprob.prob.u0.jump_u) - prob2 = remake(jprob; u0 = [20.0, 10.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u == [20.0, 10.0] - # jump_u should be resampled (different from original) - @test prob2.prob.u0.jump_u != original_jump_u + # jump_u is zeroed at construction; callback initializes it at solve time + @test all(iszero, prob2.prob.u0.jump_u) + # After init the callback should set jump_u to non-zero thresholds + integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + @test any(!iszero, integrator.u.jump_u) end @testset "remake with ExtendedJumpArray (no resample)" begin @@ -47,34 +48,34 @@ using StableRNGs end @testset "remake with Symbol pairs" begin - original_jump_u = copy(jprob.prob.u0.jump_u) - # This was the FAILING case - should work after fix prob2 = remake(jprob; u0 = [:X => 25.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u[1] == 25.0 - # jump_u should be resampled - @test prob2.prob.u0.jump_u != original_jump_u + # jump_u zeroed at construction, initialized by callback at solve time + @test all(iszero, prob2.prob.u0.jump_u) + integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + @test any(!iszero, integrator.u.jump_u) end @testset "remake with multiple Symbol pairs" begin - original_jump_u = copy(jprob.prob.u0.jump_u) - prob2 = remake(jprob; u0 = [:X => 35.0, :Y => 15.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u == [35.0, 15.0] - # jump_u should be resampled - @test prob2.prob.u0.jump_u != original_jump_u + # jump_u zeroed at construction, initialized by callback at solve time + @test all(iszero, prob2.prob.u0.jump_u) + integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + @test any(!iszero, integrator.u.jump_u) end @testset "remake with Dict" begin - original_jump_u = copy(jprob.prob.u0.jump_u) - prob2 = remake(jprob; u0 = Dict(:X => 40.0)) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u[1] == 40.0 - # jump_u should be resampled - @test prob2.prob.u0.jump_u != original_jump_u + # jump_u zeroed at construction, initialized by callback at solve time + @test all(iszero, prob2.prob.u0.jump_u) + integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + @test any(!iszero, integrator.u.jump_u) end @testset "remake with parameters only (u0 unchanged)" begin @@ -88,14 +89,14 @@ using StableRNGs end @testset "remake with both u0 and p" begin - original_jump_u = copy(jprob.prob.u0.jump_u) - prob2 = remake(jprob; u0 = [:X => 50.0], p = [:k1 => 3.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u[1] == 50.0 @test prob2.prob.p[1] == 3.0 - # jump_u should be resampled - @test prob2.prob.u0.jump_u != original_jump_u + # jump_u zeroed at construction, initialized by callback at solve time + @test all(iszero, prob2.prob.u0.jump_u) + integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + @test any(!iszero, integrator.u.jump_u) end @testset "remake preserves problem solvability" begin diff --git a/test/remake_test.jl b/test/remake_test.jl index 91ab4dab..90c863d3 100644 --- a/test/remake_test.jl +++ b/test/remake_test.jl @@ -1,4 +1,4 @@ -using JumpProcesses, DiffEqBase, OrdinaryDiffEq +using JumpProcesses, DiffEqBase, OrdinaryDiffEq, Test using StableRNGs rng = StableRNG(12345) @@ -21,21 +21,20 @@ p = (0.1 / 1000, 0.01) tspan = (0.0, 2500.0) dprob = DiscreteProblem(u0, tspan, p) -jprob = JumpProblem(dprob, Direct(), jump, jump2, save_positions = (false, false), - rng = rng) -sol = solve(jprob, SSAStepper()) +jprob = JumpProblem(dprob, Direct(), jump, jump2, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng) @test sol[3, end] == 1000 u02 = [1000, 1, 0] p2 = (0.1 / 1000, 0.0) dprob2 = remake(dprob, u0 = u02, p = p2) jprob2 = remake(jprob, prob = dprob2) -sol2 = solve(jprob2, SSAStepper()) +sol2 = solve(jprob2, SSAStepper(); rng) @test sol2[2, end] == 1001 tspan2 = (0.0, 25000.0) jprob3 = remake(jprob, p = p2, tspan = tspan2) -sol3 = solve(jprob3, SSAStepper()) +sol3 = solve(jprob3, SSAStepper(); rng) @test sol3[2, end] == 1000 @test sol3.t[end] == 25000.0 @@ -46,19 +45,19 @@ ns = [[2 => -1, 3 => 1], [1 => -1, 2 => 1]] pidxs = [2, 1] maj = MassActionJump(rs, ns; param_idxs = pidxs) dprob = DiscreteProblem(u0, tspan, p) -jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper()) +jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng) @test sol[3, end] == 1000 # update the MassActionJump dprob2 = remake(dprob, u0 = u02, p = p2) jprob2 = remake(jprob, prob = dprob2) -sol2 = solve(jprob2, SSAStepper()) +sol2 = solve(jprob2, SSAStepper(); rng) @test sol2[2, end] == 1001 tspan2 = (0.0, 25000.0) jprob3 = remake(jprob, p = p2, tspan = tspan2) -sol3 = solve(jprob3, SSAStepper()) +sol3 = solve(jprob3, SSAStepper(); rng) @test sol3[2, end] == 1000 @test sol3.t[end] == 25000.0 @@ -75,27 +74,27 @@ let rrate(u, p, t) = u[1] aaffect!(integrator) = (integrator.u[1] += 1; nothing) vrj = VariableRateJump(rrate, aaffect!) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) u0 = [4.0] jprob2 = remake(jprob; u0) @test jprob2.prob.u0 isa ExtendedJumpArray @test jprob2.prob.u0.u === u0 - sol = solve(jprob2, Tsit5()) + sol = solve(jprob2, Tsit5(); rng) u = sol[1, :] @test length(u) > 2 @test all(>(u0[1]), u[3:end]) u0 = deepcopy(jprob2.prob.u0) u0.u .= 0 jprob3 = remake(jprob2; u0) - sol = solve(jprob3, Tsit5()) + sol = solve(jprob3, Tsit5(); rng) @test all(==(0.0), sol[1, :]) @test_throws ErrorException jprob4=remake(jprob, u0 = 1) end @@ -107,24 +106,24 @@ let rrate(u, p, t) = u[1] aaffect!(integrator) = (integrator.u[1] += 1; nothing) vrj = VariableRateJump(rrate, aaffect!) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_Direct()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_DirectFW()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) - jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM(), rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, vrj; vr_aggregator = VR_FRM()) + sol = solve(jprob, Tsit5(); rng) @test all(==(0.0), sol[1, :]) u0 = [4.0] prob2 = remake(jprob.prob; u0) @test_throws ErrorException jprob2=remake(jprob; prob = prob2) - u0eja = JumpProcesses.remake_extended_u0(jprob.prob, u0, rng) + u0eja = JumpProcesses.remake_extended_u0(jprob.prob, u0) prob3 = remake(jprob.prob; u0 = u0eja) jprob3 = remake(jprob; prob = prob3) @test jprob3.prob.u0 isa ExtendedJumpArray @test jprob3.prob.u0 === u0eja - sol = solve(jprob3, Tsit5()) + sol = solve(jprob3, Tsit5(); rng) u = sol[1, :] @test length(u) > 2 @test all(>(u0[1]), u[3:end]) diff --git a/test/rng_kwarg_tests.jl b/test/rng_kwarg_tests.jl new file mode 100644 index 00000000..a619edb7 --- /dev/null +++ b/test/rng_kwarg_tests.jl @@ -0,0 +1,219 @@ +using JumpProcesses, OrdinaryDiffEq, StochasticDiffEq, Test +using StableRNGs, Random + +# ========================================================================== +# Test that rng can be passed via JumpProblem kwargs OR solve kwargs, +# and that solve-level rng takes precedence over JumpProblem-level rng. +# +# Strategy: use different RNG *types* to verify which one the integrator +# receives. StableRNG is passed at one level and Xoshiro at another, +# then we check the type on the integrator. +# ========================================================================== + +# -------------------------------------------------------------------------- +# Problem constructors +# -------------------------------------------------------------------------- +function make_ssa_jump_prob(; kwargs...) + j1 = ConstantRateJump((u, p, t) -> 10.0, integrator -> (integrator.u[1] += 1)) + j2 = ConstantRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1)) + dprob = DiscreteProblem([10], (0.0, 20.0)) + JumpProblem(dprob, Direct(), j1, j2; kwargs...) +end + +function make_ode_vr_jump_prob(; kwargs...) + f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) + oprob = ODEProblem(f!, [100.0], (0.0, 10.0)) + vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], + integrator -> (integrator.u[1] -= 1.0)) + JumpProblem(oprob, Direct(), vrj; kwargs...) +end + +function make_sde_vr_jump_prob(; kwargs...) + f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) + g!(du, u, p, t) = (du[1] = 0.1 * u[1]; nothing) + sprob = SDEProblem(f!, g!, [100.0], (0.0, 10.0)) + vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], + integrator -> (integrator.u[1] -= 1.0)) + JumpProblem(sprob, Direct(), vrj; kwargs...) +end + +# ========================================================================== +# 1. SSAStepper: rng via JumpProblem +# ========================================================================== +@testset "SSAStepper: rng via JumpProblem kwargs" begin + xrng = Xoshiro(42) + jprob = make_ssa_jump_prob(; rng = xrng) + integrator = init(jprob, SSAStepper()) + @test SciMLBase.get_rng(integrator) isa Xoshiro + sol = solve(jprob, SSAStepper()) + @test sol.retcode == ReturnCode.Success +end + +# ========================================================================== +# 2. SSAStepper: rng via solve overrides JumpProblem +# ========================================================================== +@testset "SSAStepper: solve rng overrides JumpProblem rng" begin + jprob = make_ssa_jump_prob(; rng = Xoshiro(42)) + integrator = init(jprob, SSAStepper(); rng = StableRNG(99)) + @test SciMLBase.get_rng(integrator) isa StableRNG +end + +# ========================================================================== +# 3. SSAStepper: reproducibility via JumpProblem rng +# ========================================================================== +@testset "SSAStepper: JumpProblem rng reproducibility" begin + jprob1 = make_ssa_jump_prob(; rng = StableRNG(123)) + jprob2 = make_ssa_jump_prob(; rng = StableRNG(123)) + sol1 = solve(jprob1, SSAStepper()) + sol2 = solve(jprob2, SSAStepper()) + @test sol1.t == sol2.t + @test sol1.u == sol2.u +end + +# ========================================================================== +# 4. SSAStepper: solve rng overrides for reproducibility +# ========================================================================== +@testset "SSAStepper: solve rng override reproducibility" begin + jprob = make_ssa_jump_prob(; rng = Xoshiro(1)) + sol1 = solve(jprob, SSAStepper(); rng = StableRNG(42)) + sol2 = solve(jprob, SSAStepper(); rng = StableRNG(42)) + @test sol1.t == sol2.t + @test sol1.u == sol2.u +end + +# ========================================================================== +# 5. ODE + VR: rng via JumpProblem +# ========================================================================== +@testset "ODE + VR: rng via JumpProblem kwargs" begin + jprob = make_ode_vr_jump_prob(; rng = Xoshiro(42)) + integrator = init(jprob, Tsit5()) + @test SciMLBase.get_rng(integrator) isa Xoshiro +end + +# ========================================================================== +# 6. ODE + VR: solve rng overrides JumpProblem rng +# ========================================================================== +@testset "ODE + VR: solve rng overrides JumpProblem rng" begin + jprob = make_ode_vr_jump_prob(; rng = Xoshiro(42)) + integrator = init(jprob, Tsit5(); rng = StableRNG(99)) + @test SciMLBase.get_rng(integrator) isa StableRNG +end + +# ========================================================================== +# 7. ODE + VR: reproducibility via JumpProblem rng +# ========================================================================== +@testset "ODE + VR: JumpProblem rng reproducibility" begin + jprob1 = make_ode_vr_jump_prob(; rng = StableRNG(123)) + jprob2 = make_ode_vr_jump_prob(; rng = StableRNG(123)) + sol1 = solve(jprob1, Tsit5()) + sol2 = solve(jprob2, Tsit5()) + @test sol1.t ≈ sol2.t + @test sol1.u[end] ≈ sol2.u[end] +end + +# ========================================================================== +# 8. ODE + VR: solve rng overrides for reproducibility +# ========================================================================== +@testset "ODE + VR: solve rng override reproducibility" begin + jprob = make_ode_vr_jump_prob(; rng = Xoshiro(1)) + sol1 = solve(jprob, Tsit5(); rng = StableRNG(42)) + sol2 = solve(jprob, Tsit5(); rng = StableRNG(42)) + @test sol1.t ≈ sol2.t + @test sol1.u[end] ≈ sol2.u[end] +end + +# ========================================================================== +# 9. SDE + VR: rng via JumpProblem +# ========================================================================== +@testset "SDE + VR: rng via JumpProblem kwargs" begin + jprob = make_sde_vr_jump_prob(; rng = Xoshiro(42)) + integrator = init(jprob, EM(); dt = 0.01) + @test SciMLBase.get_rng(integrator) isa Xoshiro +end + +# ========================================================================== +# 10. SDE + VR: solve rng overrides JumpProblem rng +# ========================================================================== +@testset "SDE + VR: solve rng overrides JumpProblem rng" begin + jprob = make_sde_vr_jump_prob(; rng = Xoshiro(42)) + integrator = init(jprob, EM(); dt = 0.01, rng = StableRNG(99)) + @test SciMLBase.get_rng(integrator) isa StableRNG +end + +# ========================================================================== +# 11. SDE + VR: reproducibility via JumpProblem rng +# ========================================================================== +@testset "SDE + VR: JumpProblem rng reproducibility" begin + jprob1 = make_sde_vr_jump_prob(; rng = StableRNG(123)) + jprob2 = make_sde_vr_jump_prob(; rng = StableRNG(123)) + sol1 = solve(jprob1, EM(); dt = 0.01, save_everystep = false) + sol2 = solve(jprob2, EM(); dt = 0.01, save_everystep = false) + @test sol1.u[end] ≈ sol2.u[end] +end + +# ========================================================================== +# 12. Tau-leaping: rng via JumpProblem +# ========================================================================== +@testset "SimpleTauLeaping: rng via JumpProblem kwargs" begin + rate(out, u, p, t) = (out .= max.(u, 0); nothing) + c(du, u, p, t, counts, mark) = (du .= counts; nothing) + rj = RegularJump(rate, c, 2) + dprob = DiscreteProblem([100, 100], (0.0, 1.0)) + jprob = JumpProblem(dprob, PureLeaping(), rj; rng = StableRNG(42)) + sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01) + jprob2 = JumpProblem(dprob, PureLeaping(), rj; rng = StableRNG(42)) + sol2 = solve(jprob2, SimpleTauLeaping(); dt = 0.01) + @test sol1.u == sol2.u +end + +# ========================================================================== +# 13. Tau-leaping: solve rng overrides JumpProblem rng +# ========================================================================== +@testset "SimpleTauLeaping: solve rng overrides JumpProblem rng" begin + rate(out, u, p, t) = (out .= max.(u, 0); nothing) + c(du, u, p, t, counts, mark) = (du .= counts; nothing) + rj = RegularJump(rate, c, 2) + dprob = DiscreteProblem([100, 100], (0.0, 1.0)) + # JumpProblem has Xoshiro, solve has StableRNG + jprob = JumpProblem(dprob, PureLeaping(), rj; rng = Xoshiro(1)) + sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(42)) + sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(42)) + @test sol1.u == sol2.u + # Different from using the JumpProblem rng + jprob2 = JumpProblem(dprob, PureLeaping(), rj; rng = Xoshiro(1)) + sol3 = solve(jprob2, SimpleTauLeaping(); dt = 0.01) + @test sol1.u != sol3.u +end + +# ========================================================================== +# 14. has_rng / get_rng / set_rng! interface on SSAIntegrator +# ========================================================================== +@testset "SSAIntegrator RNG interface" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); rng = StableRNG(42)) + + @test SciMLBase.has_rng(integrator) + rng = SciMLBase.get_rng(integrator) + @test rng isa StableRNG + + new_rng = StableRNG(99) + SciMLBase.set_rng!(integrator, new_rng) + @test SciMLBase.get_rng(integrator) === new_rng +end + +# ========================================================================== +# 15. No rng kwarg: uses default_rng (non-reproducible but functional) +# ========================================================================== +@testset "No rng kwarg: functional solve" begin + @testset "SSAStepper" begin + jprob = make_ssa_jump_prob() + sol = solve(jprob, SSAStepper()) + @test sol.retcode == ReturnCode.Success + end + + @testset "ODE + VR" begin + jprob = make_ode_vr_jump_prob() + sol = solve(jprob, Tsit5()) + @test sol.retcode == ReturnCode.Success + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 389ed1da..be0236c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,6 +39,7 @@ end if GROUP == "All" || GROUP == "InterfaceII" @time @safetestset "Saveat Regression test" begin include("saveat_regression.jl") end @time @safetestset "Save_positions test" begin include("save_positions.jl") end + @time @safetestset "RNG kwarg tests" begin include("rng_kwarg_tests.jl") end @time @safetestset "Ensemble Uniqueness test" begin include("ensemble_uniqueness.jl") end @time @safetestset "Thread Safety test" begin include("thread_safety.jl") end @time @safetestset "Ensemble Problem Tests" begin include("ensemble_problems.jl") end @@ -47,7 +48,7 @@ end @time @safetestset "ExtendedJumpArray remake tests" begin include("extended_jump_array_remake.jl") end @time @safetestset "Symbol based problem indexing" begin include("jprob_symbol_indexing.jl") end @time @safetestset "Long time accuracy test" begin include("longtimes_test.jl") end - @time @safetestset "Hawkes process" begin include("hawkes_test.jl") end + @time @safetestset "Hawkes process" begin include("hawkes_test.jl") end @time @safetestset "Reaction rates" begin include("spatial/reaction_rates.jl") end @time @safetestset "Hop rates" begin include("spatial/hop_rates.jl") end @time @safetestset "Topology" begin include("spatial/topology.jl") end diff --git a/test/ssa_callback_test.jl b/test/ssa_callback_test.jl index dbe9c969..1cbf678b 100644 --- a/test/ssa_callback_test.jl +++ b/test/ssa_callback_test.jl @@ -11,9 +11,9 @@ end jump = ConstantRateJump(rate, affect!) prob = DiscreteProblem([0.0, 0.0], (0.0, 10.0)) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump) -sol = solve(jump_prob, SSAStepper()) +sol = solve(jump_prob, SSAStepper(); rng) @test sol.t == [0.0, 10.0] @test sol.u == [[0.0, 0.0], [0.0, 0.0]] @@ -25,13 +25,13 @@ function fuel_affect!(integrator) end cb = DiscreteCallback(condition, fuel_affect!, save_positions = (false, true)) -sol = solve(jump_prob, SSAStepper(); callback = cb, tstops = [5]) +sol = solve(jump_prob, SSAStepper(); rng, callback = cb, tstops = [5]) @test sol.t[1:2] == [0.0, 5.0] # no jumps between t=0 and t=5 @test sol(5 + 1e-10) == [100, 0] # state just after fueling before any decays can happen # test can pass callbacks via JumpProblem -jump_prob2 = JumpProblem(prob, Direct(), jump; rng = rng, callback = cb) -sol2 = solve(jump_prob2, SSAStepper(); tstops = [5]) +jump_prob2 = JumpProblem(prob, Direct(), jump; callback = cb) +sol2 = solve(jump_prob2, SSAStepper(); rng, tstops = [5]) @test sol2.t[1:2] == [0.0, 5.0] # no jumps between t=0 and t=5 @test sol2(5 + 1e-10) == [100, 0] # state just after fueling before any decays can happen @@ -48,7 +48,7 @@ finalizer_called = 0 fuel_finalize(cb, u, t, integrator) = global finalizer_called += 1 cb2 = DiscreteCallback(condition, fuel_affect!, initialize = fuel_init!, finalize = fuel_finalize) -sol = solve(jump_prob, SSAStepper(), callback = cb2) +sol = solve(jump_prob, SSAStepper(); rng, callback = cb2) for tstop in random_tstops @test tstop ∈ sol.t end @@ -62,37 +62,37 @@ maj = MassActionJump(rs, ns; param_idxs = [1, 2]) u₀ = [100, 0] tspan = (0.0, 2000.0) dprob = DiscreteProblem(u₀, tspan, p) -jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng) +jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false)) pcondit(u, t, integrator) = t == 1000.0 function paffect!(integrator) integrator.p[1] = 0.0 integrator.p[2] = 1.0 reset_aggregated_jumps!(integrator) end -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p .== [0.0, 1.0]) @test sol[1, end] == 100 p .= [1.0, 0.0] maj1 = MassActionJump([1 => 1], [1 => -1, 2 => 1]; param_idxs = 1) maj2 = MassActionJump([2 => 1], [1 => 1, 2 => -1]; param_idxs = 2) -jprob = JumpProblem(dprob, Direct(), maj1, maj2, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) +jprob = JumpProblem(dprob, Direct(), maj1, maj2, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p .== [0.0, 1.0]) @test sol[1, end] == 100 p2 = [1.0, 0.0, 0.0] maj3 = MassActionJump([1 => 1], [1 => -1, 2 => 1]; param_idxs = 3) dprob = DiscreteProblem(u₀, tspan, p2) -jprob = JumpProblem(dprob, Direct(), maj1, maj2, maj3, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) +jprob = JumpProblem(dprob, Direct(), maj1, maj2, maj3, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p2 .== [0.0, 1.0, 0.0]) @test sol[1, end] == 100 p2 .= [1.0, 0.0, 0.0] jprob = JumpProblem(dprob, Direct(), JumpSet(; massaction_jumps = [maj1, maj2, maj3]), - save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) + save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p2 .== [0.0, 1.0, 0.0]) @test sol[1, end] == 100 @@ -100,8 +100,8 @@ p .= [1.0, 0.0] dprob = DiscreteProblem(u₀, tspan, p) maj4 = MassActionJump([[1 => 1], [2 => 1]], [[1 => -1, 2 => 1], [1 => 1, 2 => -1]]; param_idxs = [1, 2]) -jprob = JumpProblem(dprob, Direct(), maj4, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) +jprob = JumpProblem(dprob, Direct(), maj4, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCallback(pcondit, paffect!)) @test all(p .== [0.0, 1.0]) @test sol[1, end] == 100 @@ -109,9 +109,9 @@ sol = solve(jprob, SSAStepper(), tstops = [1000.0], callback = DiscreteCallback( p .= [1.0] dprob = DiscreteProblem(u₀, tspan, p) maj5 = MassActionJump([[1 => 2]], [[1 => -1, 2 => 1]]; param_idxs = [1]) -jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false), rng = rng) +jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false)) @test all(jprob.massaction_jump.scaled_rates .== [0.5]) -jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false), rng = rng, scale_rates = false) +jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false), scale_rates = false) @test all(jprob.massaction_jump.scaled_rates .== [1.0]) # test for https://github.com/SciML/JumpProcesses.jl/issues/239 @@ -119,11 +119,11 @@ maj6 = MassActionJump([[1 => 1], [2 => 1]], [[1 => -1, 2 => 1], [1 => 1, 2 => -1 param_idxs = [1, 2]) p = (0.1, 0.1) dprob = DiscreteProblem([10, 0], (0.0, 100.0), p) -jprob = JumpProblem(dprob, Direct(), maj6; save_positions = (false, false), rng = rng) +jprob = JumpProblem(dprob, Direct(), maj6; save_positions = (false, false)) cbtimes = [20.0, 30.0] affectpresets!(integrator) = integrator.u[1] += 10 cb = PresetTimeCallback(cbtimes, affectpresets!) -jsol = solve(jprob, SSAStepper(), saveat = 0.1, callback = cb) +jsol = solve(jprob, SSAStepper(); rng, saveat = 0.1, callback = cb) @test (jsol(20.00000000001) - jsol(19.9999999999))[1] == 10 # test periodic callbacks working, i.e. #417 @@ -134,18 +134,18 @@ let dprob = DiscreteProblem([0], (0.0, 10.0)) cbfun(integ) = (integ.u[1] += 1; nothing) cb = PeriodicCallback(cbfun, 1.0) - jprob = JumpProblem(dprob, crj; rng) - sol = solve(jprob; callback = cb) + jprob = JumpProblem(dprob, crj) + sol = solve(jprob; rng, callback = cb) @test sol[1, end] == 9 cb = PeriodicCallback(cbfun, 1.0; initial_affect = true) - jprob = JumpProblem(dprob, crj; rng) - sol = solve(jprob; callback = cb) + jprob = JumpProblem(dprob, crj) + sol = solve(jprob; rng, callback = cb) @test sol[1, end] == 10 cb = PeriodicCallback(cbfun, 1.0; initial_affect = true, final_affect = true) - jprob = JumpProblem(dprob, crj; rng) - sol = solve(jprob; callback = cb) + jprob = JumpProblem(dprob, crj) + sol = solve(jprob; rng, callback = cb) @test sol[1, end] == 11 end @@ -157,22 +157,22 @@ let dprob = DiscreteProblem([0], (0.0, 10.0)) cbfun(integ) = (integ.u[1] += 1; nothing) cb = PeriodicCallback(cbfun, 1.0) - jprob = JumpProblem(dprob, crj; rng) + jprob = JumpProblem(dprob, crj) tstops = Float64[] # tests for when aliasing system is in place - #sol = solve(jprob; callback = cb, tstops, alias_tstops = true) + #sol = solve(jprob; callback = cb, tstops, alias_tstops = true) # @test sol[1, end] == 9 - #@test tstops == 1.0:9.0 + #@test tstops == 1.0:9.0 # empty!(tstops) # sol = solve(jprob; callback = cb, tstops, alias_tstops = false) # @test sol[1, end] == 9 # @test isempty(tstops) - sol = solve(jprob; callback = cb, tstops) + sol = solve(jprob; rng, callback = cb, tstops) @test sol[1, end] == 9 @test isempty(tstops) empty!(tstops) - integ = init(jprob, SSAStepper(); callback = cb, tstops) + integ = init(jprob, SSAStepper(); rng, callback = cb, tstops) solve!(integ) @test integ.tstops !== tstops @test isempty(tstops) @@ -184,18 +184,18 @@ let affect!(integrator) = (integrator.u[1] += 1) crj = ConstantRateJump(rate, affect!) prob = DiscreteProblem([0], (0.0, 10.0), [10.0]) - jprob = JumpProblem(prob, Direct(), crj; rng) + jprob = JumpProblem(prob, Direct(), crj) # basic callable tstops my_tstops = (p, tspan) -> [3.0, 6.0] - sol = solve(jprob, SSAStepper(); tstops = my_tstops) + sol = solve(jprob, SSAStepper(); rng, tstops = my_tstops) @test sol.t[end] == 10.0 @test 3.0 ∈ sol.t @test 6.0 ∈ sol.t # parameter-dependent callable tstops param_tstops = (p, tspan) -> [p[1] / 5.0, p[1] / 2.0] - sol2 = solve(jprob, SSAStepper(); tstops = param_tstops) + sol2 = solve(jprob, SSAStepper(); rng, tstops = param_tstops) @test sol2.t[end] == 10.0 @test 2.0 ∈ sol2.t # 10.0 / 5.0 @test 5.0 ∈ sol2.t # 10.0 / 2.0 @@ -204,7 +204,7 @@ let condition(u, t, integrator) = t == 3.0 cb_affect!(integrator) = (integrator.u[1] += 1000) cb = DiscreteCallback(condition, cb_affect!) - sol3 = solve(jprob, SSAStepper(); tstops = my_tstops, callback = cb) + sol3 = solve(jprob, SSAStepper(); rng, tstops = my_tstops, callback = cb) @test sol3.t[end] == 10.0 @test 3.0 ∈ sol3.t # verify the callback fired: use findlast to get post-callback state at t=3.0 @@ -213,15 +213,15 @@ let # callable returning a tuple tuple_tstops = (p, tspan) -> (2.0, 7.0) - sol4 = solve(jprob, SSAStepper(); tstops = tuple_tstops) + sol4 = solve(jprob, SSAStepper(); rng, tstops = tuple_tstops) @test sol4.t[end] == 10.0 @test 2.0 ∈ sol4.t @test 7.0 ∈ sol4.t # callable tstops stored in JumpProblem via constructor kwarg - jprob2 = JumpProblem(prob, Direct(), crj; rng, tstops = my_tstops) + jprob2 = JumpProblem(prob, Direct(), crj; tstops = my_tstops) @test haskey(jprob2.kwargs, :tstops) - sol5 = solve(jprob2, SSAStepper()) + sol5 = solve(jprob2, SSAStepper(); rng) @test sol5.t[end] == 10.0 @test 3.0 ∈ sol5.t @test 6.0 ∈ sol5.t diff --git a/test/variable_rate.jl b/test/variable_rate.jl index ce08781d..2669b70e 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -293,7 +293,6 @@ end let seed = 12345 - rng = StableRNG(seed) b = 2.0 d = 1.0 n0 = 1.0 @@ -325,7 +324,7 @@ let ode_prob = ODEProblem(ode_fxn, u0, tspan, p) tsave = range(tspan[1], tspan[2]; step = 0.1) for vr_aggregator in (VR_Direct(), VR_DirectFW(), VR_FRM()) - sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator, rng) + sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator) for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization())) umean = getmean(Nsims, sjm_prob, alg, tsave, seed) From 8861d6e0df564ca72c6bcabfd495df75f801d995 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Mon, 23 Feb 2026 22:06:28 -0500 Subject: [PATCH 02/20] Revert version bump to 9.22.1 Co-Authored-By: Claude Opus 4.5 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 404b9b5e..ef960cdf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "JumpProcesses" uuid = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" -version = "9.23.0" +version = "9.22.1" authors = ["Chris Rackauckas "] [deps] From e37938b3ac741bbda0eace5f13cc0ead2d8d49d5 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 05:56:03 -0500 Subject: [PATCH 03/20] Remove rng kwarg from JumpProblem; pass rng to solve/init only Breaking: JumpProblem no longer accepts an `rng` keyword argument. Passing `rng` now throws an ArgumentError directing users to pass it to `solve` or `init` instead. This fixes inconsistent RNG priority ordering across solver pathways (SSAStepper, ODE/SDE, tau-leaping) and makes the integrator the single source of truth for RNG state. - Simplify resolve_rng to 2-arg form (remove fallback_rng) - Remove jump_prob.kwargs rng fallback from tau-leaping solvers - Update all tests to pass rng to solve/init instead of JumpProblem - Update docs and HISTORY.md with breaking change notes Co-Authored-By: Claude Opus 4.5 --- HISTORY.md | 22 +++- docs/src/api.md | 26 +---- docs/src/faq.md | 8 -- src/problem.jl | 20 ++-- src/simple_regular_solve.jl | 4 +- src/solve.jl | 8 +- test/allocations.jl | 32 +++--- test/bimolerx_test.jl | 22 ++-- test/bracketing.jl | 2 +- test/callbacks.jl | 94 ++++++++--------- test/constant_rate.jl | 20 ++-- test/ensemble_problems.jl | 11 +- test/ensemble_uniqueness.jl | 32 ++---- test/extinction_test.jl | 18 ++-- test/fp_unknowns.jl | 4 +- test/functionwrappers.jl | 10 +- test/geneexpr_test.jl | 53 ++++++---- test/gpu/regular_jumps.jl | 2 +- test/hawkes_test.jl | 14 +-- test/linearreaction_test.jl | 20 ++-- test/longtimes_test.jl | 4 +- test/monte_carlo_test.jl | 24 ++--- test/regular_jumps.jl | 54 +++++----- test/reversible_binding.jl | 9 +- test/rng_kwarg_tests.jl | 188 +++++++++++++++------------------ test/save_positions.jl | 18 ++-- test/saveat_regression.jl | 10 +- test/sir_model.jl | 8 +- test/spatial/ABC.jl | 14 +-- test/spatial/diffusion.jl | 32 +++--- test/spatial/spatial_majump.jl | 26 ++--- test/splitcoupled.jl | 52 ++++----- test/ssa_tests.jl | 17 ++- test/thread_safety.jl | 6 +- test/variable_rate.jl | 100 +++++++++--------- 35 files changed, 477 insertions(+), 507 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 1851dbd6..d010778a 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,26 @@ # Breaking updates and feature summaries across releases -## JumpProcesses unreleased (master branch) +## 10.0 (Breaking) + + - **Breaking**: The `rng` keyword argument has been removed from + `JumpProblem`. Pass `rng` to `solve` or `init` instead: + ```julia + # Before (no longer works): + jprob = JumpProblem(dprob, Direct(), jump; rng = Xoshiro(1234)) + sol = solve(jprob, SSAStepper()) + + # After: + jprob = JumpProblem(dprob, Direct(), jump) + sol = solve(jprob, SSAStepper(); rng = Xoshiro(1234)) + ``` + - RNG state is now owned by the integrator, not the aggregator. This + eliminates data races when sharing a `JumpProblem` across threads and + ensures a single, consistent RNG priority across all solver pathways: + `rng` > `seed` > `Random.default_rng()`. + - `rng` and `seed` kwargs are fully supported on `solve`/`init` for all + solver pathways (SSAStepper, ODE, SDE, tau-leaping). + - `SSAIntegrator` now supports the `SciMLBase` RNG interface (`has_rng`, + `get_rng`, `set_rng!`). ## 9.14 diff --git a/docs/src/api.md b/docs/src/api.md index 885be372..4a0db690 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -68,35 +68,17 @@ sol = solve(jprob, SSAStepper(); seed = 1234) # equivalent to: solve(jprob, SSAStepper(); rng = Xoshiro(1234)) ``` -### Passing `rng` via `JumpProblem` - -The `rng` keyword can also be passed to [`JumpProblem`](@ref), where it is -stored and automatically forwarded to the solver: - -```julia -jprob = JumpProblem(dprob, Direct(), jump; rng = StableRNG(1234)) -sol = solve(jprob, SSAStepper()) # uses StableRNG(1234) -``` - -An `rng` passed directly to `solve`/`init` takes priority over one stored in -the `JumpProblem`. - ### Resolution priority -When multiple RNG sources are provided, the following priority determines which -is used: +When both `rng` and `seed` are passed to the same `solve`/`init` call, `rng` +takes priority: | User provides | Result | |---|---| | `rng` via `solve`/`init` | Uses that `rng` | | `seed` via `solve`/`init` | Creates `Xoshiro(seed)` | -| `rng` via `JumpProblem` | Uses that `rng` | | Nothing | Uses `Random.default_rng()` (SSAStepper, ODE, tau-leaping) or a randomly-seeded `Xoshiro` (SDE) | -When both `rng` and `seed` are passed to the same call, `rng` takes priority. -When `solve`/`init` kwargs and `JumpProblem` kwargs overlap on the same key, -`solve`/`init` kwargs take priority. - ### Behavior by solver pathway | Solver | Default RNG (nothing passed) | `rng` / `seed` support | @@ -144,5 +126,5 @@ StochasticDiffEq has its own `_resolve_rng` that additionally handles `TaskLocalRNG` conversion and the problem's stored seed. For **tau-leaping**, JumpProcesses defines a custom `DiffEqBase.solve` that -bypasses the standard `__solve`/`__init` pathway. It calls `resolve_rng` with -the `JumpProblem`'s stored `rng` kwarg as a fallback. +bypasses the standard `__solve`/`__init` pathway. It calls `resolve_rng` +directly with the `rng` and `seed` kwargs from the `solve` call. diff --git a/docs/src/faq.md b/docs/src/faq.md index 0e3dad78..47d77c3c 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -75,14 +75,6 @@ sol = solve(jprob, SSAStepper(); rng = StableRNG(1234)) A `seed` keyword argument is also supported as a shorthand for creating a `Xoshiro` generator: `solve(jprob, SSAStepper(); seed = 1234)`. -Alternatively, `rng` can be passed to `JumpProblem` where it will be forwarded to the -solver automatically: - -```julia -jprob = JumpProblem(dprob, Direct(), maj; rng = Xoshiro(1234)) -sol = solve(jprob, SSAStepper()) # uses Xoshiro(1234) -``` - By default, JumpProcesses uses Julia's built-in `Random.default_rng()`. ## What are these aggregators and aggregations in JumpProcesses? diff --git a/src/problem.jl b/src/problem.jl index e09d92dd..718179f7 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -60,10 +60,6 @@ $(FIELDS) integration interface, and treated like general `VariableRateJump`s. - `vr_aggregator`, indicates the aggregator to use for sampling variable rate jumps. Current default is `VR_FRM`. - - `rng`, a random number generator to use for jump sampling. If provided, it is stored in - the problem's solver keyword arguments and forwarded to the solver when `solve` or `init` - is called. The recommended approach is to pass `rng` directly to `solve` or `init`: - `solve(jprob, SSAStepper(); rng = my_rng)`. - `tstops`, time stops to pass through to the solver. Can be an `AbstractVector` of times or a callable `(p, tspan) -> times`. @@ -242,10 +238,14 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS vr_aggregator::VariableRateAggregator = VR_FRM(), save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true), - rng = nothing, scale_rates = true, useiszero = true, + scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, callback = nothing, tstops = nothing, use_vrj_bounds = true, kwargs...) + if haskey(kwargs, :rng) + throw(ArgumentError("`rng` is no longer a keyword argument for `JumpProblem`. Pass `rng` to `solve` or `init` instead, e.g. `solve(jprob, SSAStepper(); rng = my_rng)`.")) + end + # initialize the MassActionJump rate constants with the user parameters if using_params(jumps.massaction_jump) rates = jumps.massaction_jump.param_mapper(prob.p) @@ -309,7 +309,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS jump_cbs = CallbackSet(constant_jump_callback, variable_jump_callback) iip = isinplace_jump(prob, jumps.regular_jump) - solkwargs = make_kwarg(; filter(!isnothing, (; callback, tstops, rng))...) + solkwargs = make_kwarg(; filter(!isnothing, (; callback, tstops))...) JumpProblem{iip, typeof(new_prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(cvrjs), typeof(jumps.regular_jump), @@ -321,10 +321,14 @@ end function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true), - rng = nothing, scale_rates = true, useiszero = true, + scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, callback = nothing, tstops = nothing, kwargs...) + if haskey(kwargs, :rng) + throw(ArgumentError("`rng` is no longer a keyword argument for `JumpProblem`. Pass `rng` to `solve` or `init` instead, e.g. `solve(jprob, SSAStepper(); rng = my_rng)`.")) + end + # Validate no spatial systems (not currently supported) (spatial_system !== nothing || hopping_constants !== nothing) && error("PureLeaping does not currently support spatial problems.") @@ -351,7 +355,7 @@ function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; vrjs = jumps.variable_jumps iip = isinplace_jump(prob, jumps.regular_jump) - solkwargs = make_kwarg(; filter(!isnothing, (; callback, tstops, rng))...) + solkwargs = make_kwarg(; filter(!isnothing, (; callback, tstops))...) JumpProblem{iip, typeof(prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(vrjs), typeof(jumps.regular_jump), diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 4d817b03..7c8df9fb 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -77,7 +77,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping; error("SimpleTauLeaping can only be used with PureLeaping JumpProblems with only RegularJumps.") prob = jump_prob.prob - _rng = resolve_rng(rng, seed, get(jump_prob.kwargs, :rng, nothing)) + _rng = resolve_rng(rng, seed) rj = jump_prob.regular_jump rate = rj.rate # rate function rate(out,u,p,t) @@ -343,7 +343,7 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; error("SimpleExplicitTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.") prob = jump_prob.prob - _rng = resolve_rng(rng, seed, get(jump_prob.kwargs, :rng, nothing)) + _rng = resolve_rng(rng, seed) tspan = prob.tspan if dtmin === nothing diff --git a/src/solve.jl b/src/solve.jl index 1ae5b384..db58e3e4 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,17 +1,15 @@ """ - resolve_rng(rng, seed[, fallback_rng]) + resolve_rng(rng, seed) Resolve which RNG to use for a jump simulation. -Priority: `rng` > `seed` (creates `Xoshiro`) > `fallback_rng` > `Random.default_rng()`. +Priority: `rng` > `seed` (creates `Xoshiro`) > `Random.default_rng()`. """ -function resolve_rng(rng, seed, fallback_rng = nothing) +function resolve_rng(rng, seed) if rng !== nothing rng elseif seed !== nothing Random.Xoshiro(seed) - elseif fallback_rng !== nothing - fallback_rng else Random.default_rng() end diff --git a/test/allocations.jl b/test/allocations.jl index 64efb758..25bd02f1 100644 --- a/test/allocations.jl +++ b/test/allocations.jl @@ -35,17 +35,17 @@ let u₀ = [999, 10, 0] tspan = (0.0, 250.0) dprob = DiscreteProblem(u₀, tspan, p) - jprob = JumpProblem(dprob, Direct(), maj, jump, jump2; save_positions, rng) - sol = solve(jprob, SSAStepper()) + jprob = JumpProblem(dprob, Direct(), maj, jump, jump2; save_positions) + sol = solve(jprob, SSAStepper(); rng) - al1 = @allocations solve(jprob, SSAStepper()) + al1 = @allocations solve(jprob, SSAStepper(); rng) tspan2 = (0.0, 2500.0) dprob2 = DiscreteProblem(u₀, tspan2, p) - jprob2 = JumpProblem(dprob2, Direct(), maj, jump, jump2; save_positions, rng) - sol2 = solve(jprob2, SSAStepper()) + jprob2 = JumpProblem(dprob2, Direct(), maj, jump, jump2; save_positions) + sol2 = solve(jprob2, SSAStepper(); rng) - al2 = @allocations solve(jprob2, SSAStepper()) + al2 = @allocations solve(jprob2, SSAStepper(); rng) @test al1 == al2 end @@ -56,7 +56,7 @@ let end function makeprob(; T = 100.0, alg = Direct(), save_positions = (false, false), - graphkwargs = (;), rng) + graphkwargs = (;)) r1(u, p, t) = rate(p[1], u[1], u[2], p[2]) * u[1] r2(u, p, t) = rate(p[1], u[2], u[1], p[2]) * u[2] r3(u, p, t) = p[3] * u[1] @@ -86,7 +86,7 @@ let ConstantRateJump(r3, aff3!), ConstantRateJump(r4, aff4!), ConstantRateJump(r5, aff5!), ConstantRateJump(r6, aff6!); - save_positions, rng, graphkwargs...) + save_positions, graphkwargs...) return jprob end @@ -99,15 +99,15 @@ let graphkwargs = (; dep_graph, vartojumps_map, jumptovars_map) @testset "Allocations for $agg" for agg in JumpProcesses.JUMP_AGGREGATORS - jprob1 = makeprob(; alg = agg, T = 10.0, graphkwargs, rng = StableRNG(1234)) + jprob1 = makeprob(; alg = agg, T = 10.0, graphkwargs) stepper = SSAStepper() - sol1 = solve(jprob1, stepper) - sol1 = solve(jprob1, stepper) - al1 = @allocated solve(jprob1, stepper) - jprob2 = makeprob(; alg = agg, T = 100.0, graphkwargs, rng = StableRNG(1234)) - sol2 = solve(jprob2, stepper) - sol2 = solve(jprob2, stepper) - al2 = @allocated solve(jprob2, stepper) + sol1 = solve(jprob1, stepper; rng = StableRNG(1234)) + sol1 = solve(jprob1, stepper; rng = StableRNG(1234)) + al1 = @allocated solve(jprob1, stepper; rng = StableRNG(1234)) + jprob2 = makeprob(; alg = agg, T = 100.0, graphkwargs) + sol2 = solve(jprob2, stepper; rng = StableRNG(1234)) + sol2 = solve(jprob2, stepper; rng = StableRNG(1234)) + al2 = @allocated solve(jprob2, stepper; rng = StableRNG(1234)) @test al1 == al2 end end diff --git a/test/bimolerx_test.jl b/test/bimolerx_test.jl index befeefab..832968b7 100644 --- a/test/bimolerx_test.jl +++ b/test/bimolerx_test.jl @@ -55,10 +55,14 @@ jump_to_dep_specs = [[1, 2], [1, 2], [1, 2, 3], [1, 2, 3], [1, 3]] majumps = MassActionJump(rates, reactstoch, netstoch) # average number of proteins in a simulation -function runSSAs(jump_prob; use_stepper = true) +function runSSAs(jump_prob; use_stepper = true, rng = nothing) Psamp = zeros(Int, Nsims) for i in 1:Nsims - sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob) + sol = if use_stepper + isnothing(rng) ? solve(jump_prob, SSAStepper()) : solve(jump_prob, SSAStepper(); rng) + else + isnothing(rng) ? solve(jump_prob) : solve(jump_prob; rng) + end Psamp[i] = sol[1, end] end mean(Psamp) @@ -72,8 +76,8 @@ if doplot for alg in SSAalgs local jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - local sol = solve(jump_prob, SSAStepper()) + jumptovars_map = jump_to_dep_specs) + local sol = solve(jump_prob, SSAStepper(); rng) local plothand = plot(sol, seriestype = :steppost, reuse = false) display(plothand) end @@ -84,15 +88,15 @@ if dotestmean for (i, alg) in enumerate(SSAalgs) local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - means = runSSAs(jump_prob) + jumptovars_map = jump_to_dep_specs) + means = runSSAs(jump_prob; rng) relerr = abs(means - expected_avg) / expected_avg doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means, ", rel err = ", relerr) @test abs(means - expected_avg) < reltol * expected_avg # test not specifying SSAStepper - means = runSSAs(jump_prob; use_stepper = false) + means = runSSAs(jump_prob; use_stepper = false, rng) relerr = abs(means - expected_avg) / expected_avg @test abs(means - expected_avg) < reltol * expected_avg end @@ -107,8 +111,8 @@ if dotestmean jset = JumpSet((), (), nothing, majump_vec) jump_prob = JumpProblem(prob, Direct(), jset, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - meanval = runSSAs(jump_prob) + jumptovars_map = jump_to_dep_specs) + meanval = runSSAs(jump_prob; rng) relerr = abs(meanval - expected_avg) / expected_avg if doprintmeans println("Using individual MassActionJumps; Mean from method: ", typeof(Direct()), diff --git a/test/bracketing.jl b/test/bracketing.jl index 7a4776da..18268943 100644 --- a/test/bracketing.jl +++ b/test/bracketing.jl @@ -49,7 +49,7 @@ t = 0.0 ### Aggregator ### mutable struct DummyAggregator{T, M, R, BD} <: - JP.AbstractSSAJumpAggregator{T, M, R, Nothing, Nothing} + JP.AbstractSSAJumpAggregator{T, M, R, Nothing} ulow::Vector{Int} uhigh::Vector{Int} cur_rate_low::Vector{T} diff --git a/test/callbacks.jl b/test/callbacks.jl index 6ac5f954..ce45d006 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -25,8 +25,8 @@ rng = StableRNG(12345) affect_cb!(integrator) = (cb_called[] = true) cb = ContinuousCallback(condition, affect_cb!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb) + sol = solve(jprob, Tsit5(); rng) @test cb_called[] @test sol.t[end] ≈ 10.0 @@ -37,8 +37,8 @@ rng = StableRNG(12345) affect_dcb!(integrator) = (dcb_called[] += 1) dcb = DiscreteCallback(condition_d, affect_dcb!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = dcb) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = dcb) + sol = solve(jprob, Tsit5(); rng) @test dcb_called[] > 0 # Should have fired multiple times @@ -47,8 +47,8 @@ rng = StableRNG(12345) affect_term!(integrator) = terminate!(integrator) cb_term = ContinuousCallback(condition_term, affect_term!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb_term) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb_term) + sol = solve(jprob, Tsit5(); rng) @test sol.t[end] ≈ 3.0 # Should terminate at t=3 @@ -57,8 +57,8 @@ rng = StableRNG(12345) affect_mod!(integrator) = (integrator.u[1] *= 2.0) cb_mod = ContinuousCallback(condition_mod, affect_mod!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb_mod) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb_mod) + sol = solve(jprob, Tsit5(); rng) # Check that state was modified at t=5 idx = findfirst(t -> t >= 5.0, sol.t) @@ -95,8 +95,8 @@ end # Test 1: Both callbacks should fire (default merge_callbacks = true) cb1_count[] = 0 cb2_count[] = 0 - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb1) - sol = solve(jprob, Tsit5(); callback = cb2) + jprob = JumpProblem(prob, Direct(), jump; callback = cb1) + sol = solve(jprob, Tsit5(); callback = cb2, rng) @test cb1_count[] > 0 @test cb2_count[] > 0 @@ -110,8 +110,8 @@ end # Test 2: Only solve callback should fire (merge_callbacks = false) cb1_count[] = 0 cb2_count[] = 0 - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb1) - sol = solve(jprob, Tsit5(); callback = cb2, merge_callbacks = false) + jprob = JumpProblem(prob, Direct(), jump; callback = cb1) + sol = solve(jprob, Tsit5(); callback = cb2, merge_callbacks = false, rng) @test cb1_count[] == 0 # Should not fire @test cb2_count[] == 1 # Should fire exactly once @@ -139,8 +139,8 @@ end cb = ContinuousCallback(condition, affect_cb!) # Callback in JumpProblem constructor - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb) - integrator = init(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb) + integrator = init(jprob, Tsit5(); rng) solve!(integrator) @test cb_called[] @@ -175,8 +175,8 @@ end # Create CallbackSet with both types cbset = CallbackSet(ccb, dcb) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cbset) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cbset) + sol = solve(jprob, Tsit5(); rng) @test ccb_called[] @test dcb_called[] > 0 @@ -202,8 +202,8 @@ end affect_term!(integrator) = (cb_called[] = true; terminate!(integrator)) dcb_term = DiscreteCallback(condition_term, affect_term!) - jprob = JumpProblem(dprob, Direct(), jump1; rng, callback = dcb_term) - sol = solve(jprob, SSAStepper()) + jprob = JumpProblem(dprob, Direct(), jump1; callback = dcb_term) + sol = solve(jprob, SSAStepper(); rng) @test cb_called[] # Should have fired @test sol.u[end][1] >= 5 # Should have reached threshold @@ -215,8 +215,8 @@ end affect_count!(integrator) = (dcb_counter[] += 1) dcb_count = DiscreteCallback(condition_count, affect_count!) - jprob2 = JumpProblem(dprob, Direct(), jump1; rng) - sol2 = solve(jprob2, SSAStepper(); callback = dcb_count) + jprob2 = JumpProblem(dprob, Direct(), jump1) + sol2 = solve(jprob2, SSAStepper(); callback = dcb_count, rng) @test dcb_counter[] > 0 # Should have fired at least once @@ -232,8 +232,8 @@ end affect_cb2!(integrator) = (cb2_count[] += 1) dcb2 = DiscreteCallback(condition2, affect_cb2!) - jprob3 = JumpProblem(dprob, Direct(), jump1; rng, callback = dcb1) - sol3 = solve(jprob3, SSAStepper(); callback = dcb2) + jprob3 = JumpProblem(dprob, Direct(), jump1; callback = dcb1) + sol3 = solve(jprob3, SSAStepper(); callback = dcb2, rng) @test cb1_count[] > 0 # First callback should fire @test cb2_count[] > 0 # Second callback should fire @@ -257,8 +257,8 @@ end affect_cb4!(integrator) = (cb4_called[] = true) dcb4 = DiscreteCallback(condition4, affect_cb4!) - jprob4 = JumpProblem(dprob, Direct(), jump1; rng, callback = dcb3) - sol4 = solve(jprob4, SSAStepper(); callback = dcb4, merge_callbacks = false) + jprob4 = JumpProblem(dprob, Direct(), jump1; callback = dcb3) + sol4 = solve(jprob4, SSAStepper(); callback = dcb4, merge_callbacks = false, rng) @test !cb3_called[] # First callback should NOT fire @test cb4_called[] # Second callback should fire @@ -288,8 +288,8 @@ end cb = ContinuousCallback(condition, affect_cb!) # This was broken in v9.17.0 - callback wouldn't fire - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cb) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), jump; callback = cb) + sol = solve(jprob, Tsit5(); rng) @test cb_called[] @test sol.t[end] ≈ 0.5 # Should terminate at 0.5, not run to 1.0 @@ -323,8 +323,8 @@ end affect_d!(integrator) = (dcb_called[] += 1) dcb = DiscreteCallback(condition_d, affect_d!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = ccb) - sol = solve(jprob, Tsit5(); callback = dcb) + jprob = JumpProblem(prob, Direct(), jump; callback = ccb) + sol = solve(jprob, Tsit5(); callback = dcb, rng) @test ccb_called[] # Continuous callback should fire @test dcb_called[] > 0 # Discrete callback should fire multiple times @@ -357,8 +357,8 @@ end affect_c!(integrator) = (ccb_called[] = true) ccb = ContinuousCallback(condition_c, affect_c!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = dcb) - sol = solve(jprob, Tsit5(); callback = ccb) + jprob = JumpProblem(prob, Direct(), jump; callback = dcb) + sol = solve(jprob, Tsit5(); callback = ccb, rng) @test dcb_called[] > 0 # Discrete callback should fire @test ccb_called[] # Continuous callback should fire @@ -398,8 +398,8 @@ end affect3!(integrator) = (cb3_called[] = true) ccb2 = ContinuousCallback(condition3, affect3!) - jprob = JumpProblem(prob, Direct(), jump; rng, callback = cbset) - sol = solve(jprob, Tsit5(); callback = ccb2) + jprob = JumpProblem(prob, Direct(), jump; callback = cbset) + sol = solve(jprob, Tsit5(); callback = ccb2, rng) @test cb1_called[] # First continuous callback should fire @test cb2_called[] > 0 # Discrete callback should fire @@ -421,12 +421,12 @@ end affect_cb!(integrator) = nothing ccb = ContinuousCallback(condition, affect_cb!) - jprob_ccb = JumpProblem(dprob, Direct(), jump; rng, callback = ccb) - @test_throws ErrorException solve(jprob_ccb, SSAStepper()) + jprob_ccb = JumpProblem(dprob, Direct(), jump; callback = ccb) + @test_throws ErrorException solve(jprob_ccb, SSAStepper(); rng) # Test 2: ContinuousCallback passed to solve should error - jprob = JumpProblem(dprob, Direct(), jump; rng) - @test_throws ErrorException solve(jprob, SSAStepper(); callback = ccb) + jprob = JumpProblem(dprob, Direct(), jump) + @test_throws ErrorException solve(jprob, SSAStepper(); callback = ccb, rng) # Test 3: CallbackSet with continuous callbacks passed to JumpProblem should error on solve condition_d(u, t, integrator) = true @@ -434,19 +434,19 @@ end dcb = DiscreteCallback(condition_d, affect_dcb!) cbset_with_continuous = CallbackSet(ccb, dcb) - jprob_cbset = JumpProblem(dprob, Direct(), jump; rng, callback = cbset_with_continuous) - @test_throws ErrorException solve(jprob_cbset, SSAStepper()) + jprob_cbset = JumpProblem(dprob, Direct(), jump; callback = cbset_with_continuous) + @test_throws ErrorException solve(jprob_cbset, SSAStepper(); rng) # Test 4: CallbackSet with continuous callbacks passed to solve should error - @test_throws ErrorException solve(jprob, SSAStepper(); callback = cbset_with_continuous) + @test_throws ErrorException solve(jprob, SSAStepper(); callback = cbset_with_continuous, rng) # Test 5: CallbackSet with multiple continuous callbacks should error with correct count ccb2 = ContinuousCallback(condition, affect_cb!) cbset_multi = CallbackSet(ccb, ccb2, dcb) - jprob_multi = JumpProblem(dprob, Direct(), jump; rng, callback = cbset_multi) + jprob_multi = JumpProblem(dprob, Direct(), jump; callback = cbset_multi) err = try - solve(jprob_multi, SSAStepper()) + solve(jprob_multi, SSAStepper(); rng) nothing catch e e @@ -457,18 +457,18 @@ end # Test 6: DiscreteCallbacks should work fine (no error) dcb_only = DiscreteCallback(condition_d, affect_dcb!) - jprob_dcb = JumpProblem(dprob, Direct(), jump; rng, callback = dcb_only) - sol = solve(jprob_dcb, SSAStepper()) + jprob_dcb = JumpProblem(dprob, Direct(), jump; callback = dcb_only) + sol = solve(jprob_dcb, SSAStepper(); rng) @test sol.retcode == ReturnCode.Success # Test 7: CallbackSet with only discrete callbacks should work dcb2 = DiscreteCallback(condition_d, affect_dcb!) cbset_discrete = CallbackSet(dcb_only, dcb2) - jprob_dcb2 = JumpProblem(dprob, Direct(), jump; rng, callback = cbset_discrete) - sol2 = solve(jprob_dcb2, SSAStepper()) + jprob_dcb2 = JumpProblem(dprob, Direct(), jump; callback = cbset_discrete) + sol2 = solve(jprob_dcb2, SSAStepper(); rng) @test sol2.retcode == ReturnCode.Success # Test 8: Error should also be thrown with init - @test_throws ErrorException init(jprob_ccb, SSAStepper()) - @test_throws ErrorException init(jprob, SSAStepper(); callback = ccb) + @test_throws ErrorException init(jprob_ccb, SSAStepper(); rng) + @test_throws ErrorException init(jprob, SSAStepper(); callback = ccb, rng) end diff --git a/test/constant_rate.jl b/test/constant_rate.jl index 86c237c0..635be26f 100644 --- a/test/constant_rate.jl +++ b/test/constant_rate.jl @@ -16,37 +16,37 @@ end jump2 = ConstantRateJump(rate, affect!) prob = DiscreteProblem(1.0, (0.0, 3.0)) -jump_prob = JumpProblem(prob, Direct(), jump; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump) -sol = solve(jump_prob, FunctionMap()) +sol = solve(jump_prob, FunctionMap(); rng) # using Plots; plot(sol) prob = DiscreteProblem(10.0, (0.0, 3.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump, jump2) -sol = solve(jump_prob, FunctionMap()) +sol = solve(jump_prob, FunctionMap(); rng) # plot(sol) nums = Int[] @time for i in 1:10000 - local jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) - local sol = solve(jump_prob, FunctionMap()) + local jump_prob = JumpProblem(prob, Direct(), jump, jump2) + local sol = solve(jump_prob, FunctionMap(); rng) push!(nums, sol.u[end]) end @test mean(nums) - 45 < 1 prob = DiscreteProblem(1.0, (0.0, 3.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump, jump2) -sol = solve(jump_prob, FunctionMap()) +sol = solve(jump_prob, FunctionMap(); rng) nums = Int[] @time for i in 1:10000 - local jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) - local sol = solve(jump_prob, FunctionMap()) + local jump_prob = JumpProblem(prob, Direct(), jump, jump2) + local sol = solve(jump_prob, FunctionMap(); rng) push!(nums, sol.u[2]) end diff --git a/test/ensemble_problems.jl b/test/ensemble_problems.jl index 541e7cf1..7cfc7066 100644 --- a/test/ensemble_problems.jl +++ b/test/ensemble_problems.jl @@ -206,12 +206,13 @@ end # 7. JumpProblem rng kwarg forwarded to solver # ========================================================================== -@testset "JumpProblem rng kwarg forwarded to solver" begin +@testset "JumpProblem rng kwarg throws ArgumentError" begin j1 = ConstantRateJump((u, p, t) -> 1.0, integrator -> (integrator.u[1] += 1)) dprob = DiscreteProblem([10], (0.0, 10.0)) - jprob = JumpProblem(dprob, Direct(), j1; rng = StableRNG(1)) - @test haskey(jprob.kwargs, :rng) - @test jprob.kwargs[:rng] isa StableRNG - sol = solve(jprob, SSAStepper()) + j1_local = ConstantRateJump((u, p, t) -> 1.0, integrator -> (integrator.u[1] += 1)) + dprob_local = DiscreteProblem([10], (0.0, 10.0)) + jprob = JumpProblem(dprob_local, Direct(), j1_local) + @test_throws ArgumentError JumpProblem(dprob_local, Direct(), j1_local; rng = StableRNG(1)) + sol = solve(jprob, SSAStepper(); rng = StableRNG(1)) @test sol.retcode == ReturnCode.Success end diff --git a/test/ensemble_uniqueness.jl b/test/ensemble_uniqueness.jl index adb97df3..9861231c 100644 --- a/test/ensemble_uniqueness.jl +++ b/test/ensemble_uniqueness.jl @@ -7,35 +7,19 @@ u0 = [0] dprob = DiscreteProblem(u0, (0.0, 100.0)) -# For EnsembleProblems, use prob_func to create a new JumpProblem with unique RNG per trajectory. -# This ensures different trajectories while maintaining reproducibility. -# Generate seeds from a seeded RNG for reproducibility of ensemble results. -function make_seeded_prob_func(dprob, aggregator, jumps, base_rng) - return function prob_func(prob, i, repeat) - seed = rand(base_rng, UInt64) - JumpProblem(dprob, aggregator, jumps...; rng = StableRNG(seed)) - end -end - -# Test with FunctionMap - use prob_func to create JumpProblems with unique RNGs -rng1 = StableRNG(12345) -jump_prob = JumpProblem(dprob, Direct(), j1, j2; rng = rng1) -ensemble_rng = StableRNG(99999) # separate RNG for generating trajectory seeds -ensemble_prob = EnsembleProblem(jump_prob; - prob_func = make_seeded_prob_func(dprob, Direct(), (j1, j2), ensemble_rng)) -sol = solve(ensemble_prob, FunctionMap(), trajectories = 3) +# Test with FunctionMap - pass rng to solve so trajectories get unique sequences +jump_prob = JumpProblem(dprob, Direct(), j1, j2) +ensemble_prob = EnsembleProblem(jump_prob) +sol = solve(ensemble_prob, FunctionMap(), trajectories = 3; rng = StableRNG(12345)) @test Array(sol.u[1]) !== Array(sol.u[2]) @test Array(sol.u[1]) !== Array(sol.u[3]) @test Array(sol.u[2]) !== Array(sol.u[3]) @test eltype(sol.u[1].u[1]) == Int -# Test with SSAStepper - use prob_func to create JumpProblems with unique RNGs -rng2 = StableRNG(12345) -jump_prob = JumpProblem(dprob, Direct(), j1, j2; rng = rng2) -ensemble_rng2 = StableRNG(99999) # separate RNG for generating trajectory seeds -ensemble_prob2 = EnsembleProblem(jump_prob; - prob_func = make_seeded_prob_func(dprob, Direct(), (j1, j2), ensemble_rng2)) -sol = solve(ensemble_prob2, SSAStepper(), trajectories = 3) +# Test with SSAStepper - pass rng to solve so trajectories get unique sequences +jump_prob = JumpProblem(dprob, Direct(), j1, j2) +ensemble_prob2 = EnsembleProblem(jump_prob) +sol = solve(ensemble_prob2, SSAStepper(), trajectories = 3; rng = StableRNG(12345)) @test Array(sol.u[1]) !== Array(sol.u[2]) @test Array(sol.u[1]) !== Array(sol.u[3]) @test Array(sol.u[2]) !== Array(sol.u[3]) diff --git a/test/extinction_test.jl b/test/extinction_test.jl index 880254ba..468c9674 100644 --- a/test/extinction_test.jl +++ b/test/extinction_test.jl @@ -21,9 +21,8 @@ algs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator()) for n in 1:Nsims for ssa in algs - local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false), - rng = rng) - local sol = solve(jprob, SSAStepper()) + local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false)) + local sol = solve(jprob, SSAStepper(); rng) @test sol[1, end] == 0 @test sol.t[end] < Inf end @@ -33,9 +32,8 @@ u0 = SA[10] dprob = DiscreteProblem(u0, (0.0, 100.0), rates) for ssa in algs - local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false), - rng = rng) - local sol = solve(jprob, SSAStepper(), saveat = 100.0) + local jprob = JumpProblem(dprob, ssa, majump, save_positions = (false, false)) + local sol = solve(jprob, SSAStepper(); saveat = 100.0, rng) @test sol[1, end] == 0 @test sol.t[end] < Inf end @@ -57,8 +55,8 @@ end et = ExtinctionTest() cb = DiscreteCallback(et, et, save_positions = (false, false)) dprob = DiscreteProblem(u0, (0.0, 1000.0), rates) -jprob = JumpProblem(dprob, Direct(), majump; save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), callback = cb, save_end = false) +jprob = JumpProblem(dprob, Direct(), majump; save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); callback = cb, save_end = false, rng) @test sol.t[end] < 1000.0 # test terminate @@ -73,8 +71,8 @@ end cb = DiscreteCallback(extinction_condition2, extinction_affect!2, save_positions = (false, false)) dprob = DiscreteProblem(u0, (0.0, 1000.0), rates) -jprob = JumpProblem(dprob, majump; save_positions = (false, false), rng) -sol = solve(jprob; callback = cb, save_end = false) +jprob = JumpProblem(dprob, majump; save_positions = (false, false)) +sol = solve(jprob; callback = cb, save_end = false, rng) @test sol[1, end] == 1 @test sol.retcode == ReturnCode.Terminated @test sol.t[end] < 1000.0 diff --git a/test/fp_unknowns.jl b/test/fp_unknowns.jl index 6a8b5f6d..8e1ecd9d 100644 --- a/test/fp_unknowns.jl +++ b/test/fp_unknowns.jl @@ -34,11 +34,11 @@ function test(rng) Xmeans = zeros(length(SSAalgs)) Ymeans = zeros(length(SSAalgs)) for (j, agg) in enumerate(SSAalgs) - jprob = JumpProblem(dprob, agg, maj; save_positions = (false, false), rng, + jprob = JumpProblem(dprob, agg, maj; save_positions = (false, false), vartojumps_map = vtoj, jumptovars_map = jtov, dep_graph = dg, scale_rates = false) for i in 1:Nsims - sol = solve(jprob, SSAStepper()) + sol = solve(jprob, SSAStepper(); rng) Xmeans[j] += sol[1, end] Ymeans[j] += sol[2, end] end diff --git a/test/functionwrappers.jl b/test/functionwrappers.jl index 2f009ead..0c1bc0a0 100644 --- a/test/functionwrappers.jl +++ b/test/functionwrappers.jl @@ -12,18 +12,18 @@ let rateinterval = (u, p, t) -> 0.1) prob = DiscreteProblem([0.0], (0.0, 2.0), [1.0]) - jprob = JumpProblem(prob, Coevolve(), jump; dep_graph = [[1]], rng) + jprob = JumpProblem(prob, Coevolve(), jump; dep_graph = [[1]]) agg = jprob.discrete_jump_aggregation @test agg.affects! isa Vector{Any} - integ = init(jprob, SSAStepper()) + integ = init(jprob, SSAStepper(); rng) T = Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{typeof(integ)}}} @test agg.affects! isa T affs = agg.affects! sol_c = solve!(integ) # check the affects vector is unchanged from a second call - integ = init(jprob, SSAStepper()) + integ = init(jprob, SSAStepper(); rng) sol_c = solve!(integ) @test affs === agg.affects! @@ -31,7 +31,7 @@ let terminate_condition(u, t, integrator) = (return u[1] >= 1) terminate_affect!(integrator) = terminate!(integrator) terminate_cb = DiscreteCallback(terminate_condition, terminate_affect!) - integ2 = init(jprob, SSAStepper(); callback = terminate_cb) + integ2 = init(jprob, SSAStepper(); rng, callback = terminate_cb) T2 = Vector{FunctionWrappers.FunctionWrapper{Nothing, Tuple{typeof(integ2)}}} @test T2 !== T @test agg.affects! isa T2 @@ -42,7 +42,7 @@ let solve!(integ2) # check affs2 is unchanged when solving again now - integ2 = init(jprob, SSAStepper(); callback = terminate_cb) + integ2 = init(jprob, SSAStepper(); rng, callback = terminate_cb) solve!(integ2) @test affs2 === agg.affects! end diff --git a/test/geneexpr_test.jl b/test/geneexpr_test.jl index ea55c82f..d26d9b8a 100644 --- a/test/geneexpr_test.jl +++ b/test/geneexpr_test.jl @@ -22,19 +22,28 @@ expected_avg = 5.926553750000000e+02 reltol = 0.01 # average number of proteins in a simulation -function runSSAs(jump_prob; use_stepper = true) +function runSSAs(jump_prob; use_stepper = true, rng = nothing) Psamp = zeros(Int, Nsims) for i in 1:Nsims - sol = use_stepper ? solve(jump_prob, SSAStepper()) : solve(jump_prob) + sol = if use_stepper + isnothing(rng) ? solve(jump_prob, SSAStepper()) : + solve(jump_prob, SSAStepper(); rng) + else + isnothing(rng) ? solve(jump_prob) : solve(jump_prob; rng) + end Psamp[i] = sol[3, end] end mean(Psamp) end -function runSSAs_ode(vrjprob) +function runSSAs_ode(vrjprob; rng = nothing) Psamp = zeros(Float64, Nsims) tsave = vrjprob.prob.tspan[2] - integrator = init(vrjprob, Tsit5(); saveat = tsave) + integrator = if isnothing(rng) + init(vrjprob, Tsit5(); saveat = tsave) + else + init(vrjprob, Tsit5(); saveat = tsave, rng) + end solve!(integrator) Psamp[1] = integrator.sol[3, end] for i in 2:Nsims @@ -94,8 +103,8 @@ if doplot for alg in SSAalgs local jump_prob = JumpProblem(prob, alg, majumps, vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - local sol = solve(jump_prob, SSAStepper()) + jumptovars_map = jump_to_dep_specs) + local sol = solve(jump_prob, SSAStepper(); rng) plot!(plothand, sol.t, sol[3, :], seriestype = :steppost) end display(plothand) @@ -106,8 +115,8 @@ if dotestmean for (i, alg) in enumerate(SSAalgs) local jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - means = runSSAs(jump_prob) + jumptovars_map = jump_to_dep_specs) + means = runSSAs(jump_prob; rng) relerr = abs(means - expected_avg) / expected_avg doprintmeans && println("Mean from method: ", typeof(alg), " is = ", means, ", rel err = ", relerr) @@ -118,8 +127,8 @@ if dotestmean let alg = Direct() jump_prob = JumpProblem(prob, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - means = runSSAs(jump_prob; use_stepper = false) + jumptovars_map = jump_to_dep_specs) + means = runSSAs(jump_prob; use_stepper = false, rng) @test abs(means - expected_avg) < reltol * expected_avg end @@ -128,8 +137,8 @@ if dotestmean for alg in (Direct(), RSSA()) jump_probf = JumpProblem(probf, alg, majumps, save_positions = (false, false), vartojumps_map = spec_to_dep_jumps, - jumptovars_map = jump_to_dep_specs, rng = rng) - means = runSSAs(jump_probf) + jumptovars_map = jump_to_dep_specs) + means = runSSAs(jump_probf; rng) relerr = abs(means - expected_avg) / expected_avg doprintmeans && println("Mean from method (Float64 u0): ", typeof(alg), " is = ", means, ", rel err = ", relerr) @@ -139,11 +148,11 @@ end # no-aggregator tests jump_prob = JumpProblem(prob, majumps; save_positions = (false, false), - vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs, rng) -@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg + vartojumps_map = spec_to_dep_jumps, jumptovars_map = jump_to_dep_specs) +@test abs(runSSAs(jump_prob; rng) - expected_avg) < reltol * expected_avg -jump_prob = JumpProblem(prob, majumps, save_positions = (false, false), rng = rng) -@test abs(runSSAs(jump_prob) - expected_avg) < reltol * expected_avg +jump_prob = JumpProblem(prob, majumps, save_positions = (false, false)) +@test abs(runSSAs(jump_prob; rng) - expected_avg) < reltol * expected_avg # crj/vrj accuracy test # k1, DNA --> mRNA + DNA @@ -187,20 +196,20 @@ let VariableRateJump(r6, a6!, save_positions = (false, false))) prob = DiscreteProblem(u0, (0.0, tf), rates) - crjprob = JumpProblem(prob, crjs; save_positions = (false, false), rng) - @test abs(runSSAs(crjprob) - expected_avg) < reltol * expected_avg + crjprob = JumpProblem(prob, crjs; save_positions = (false, false)) + @test abs(runSSAs(crjprob; rng) - expected_avg) < reltol * expected_avg # vrjs are very slow so test on a shorter time span and compare to the crjs prob = DiscreteProblem(u0, (0.0, tf / 5), rates) - crjprob = JumpProblem(prob, crjs; save_positions = (false, false), rng) - crjmean = runSSAs(crjprob) + crjprob = JumpProblem(prob, crjs; save_positions = (false, false)) + crjmean = runSSAs(crjprob; rng) f(du, u, p, t) = (du .= 0; nothing) oprob = ODEProblem(f, u0f, (0.0, tf / 5), rates) for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) vrjprob = JumpProblem( - oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false), rng) - vrjmean = runSSAs_ode(vrjprob) + oprob, vrjs; vr_aggregator = vr_agg, save_positions = (false, false)) + vrjmean = runSSAs_ode(vrjprob; rng) @test abs(vrjmean - crjmean) < reltol * crjmean end end diff --git a/test/gpu/regular_jumps.jl b/test/gpu/regular_jumps.jl index 5e60ee2d..b35276cf 100644 --- a/test/gpu/regular_jumps.jl +++ b/test/gpu/regular_jumps.jl @@ -72,7 +72,7 @@ let # Create JumpProblem prob_disc = DiscreteProblem(u0, tspan, p) rj = RegularJump(regular_rate, regular_c, 3) - jump_prob = JumpProblem(prob_disc, PureLeaping(), rj; rng = StableRNG(12345)) + jump_prob = JumpProblem(prob_disc, PureLeaping(), rj) sol = solve(EnsembleProblem(jump_prob), SimpleTauLeaping(), EnsembleGPUKernel(CUDABackend()); trajectories = Nsims, dt = 1.0) diff --git a/test/hawkes_test.jl b/test/hawkes_test.jl index adf5a83d..bae9ff2c 100644 --- a/test/hawkes_test.jl +++ b/test/hawkes_test.jl @@ -65,7 +65,7 @@ function hawkes_problem(p, agg::Coevolve; u = [0.0], tspan = (0.0, 50.0), save_positions = (false, true), g = [[1]], h = [[]], uselrate = true, kwargs...) dprob = DiscreteProblem(u, tspan, p) jumps = hawkes_jump(u, g, h; uselrate) - jprob = JumpProblem(dprob, agg, jumps...; dep_graph = g, save_positions, rng) + jprob = JumpProblem(dprob, agg, jumps...; dep_graph = g, save_positions) return jprob end @@ -78,7 +78,7 @@ function hawkes_problem(p, agg; u = [0.0], tspan = (0.0, 50.0), save_positions = (false, true), g = [[1]], h = [[]], vr_aggregator = VR_FRM(), kwargs...) oprob = ODEProblem(f!, u, tspan, p) jumps = hawkes_jump(u, g, h) - jprob = JumpProblem(oprob, agg, jumps...; vr_aggregator, save_positions, rng) + jprob = JumpProblem(oprob, agg, jumps...; vr_aggregator, save_positions) return jprob end @@ -119,7 +119,7 @@ for (i, alg) in enumerate(algs) sols = Vector{ODESolution}(undef, Nsims) for n in 1:Nsims reset_history!(h) - sols[n] = solve(jump_prob, stepper) + sols[n] = solve(jump_prob, stepper; rng) end if alg isa Coevolve @@ -137,12 +137,12 @@ let alg = Coevolve() for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW()) oprob = ODEProblem(f!, u0, tspan, p) jumps = hawkes_jump(u0, g, h) - jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng) + jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g) @test ((jprob.variable_jumps === nothing) || isempty(jprob.variable_jumps)) sols = Vector{ODESolution}(undef, Nsims) for n in 1:Nsims reset_history!(h) - sols[n] = solve(jprob, Tsit5()) + sols[n] = solve(jprob, Tsit5(); rng) end λs = permutedims(mapreduce((sol) -> empirical_rate(sol), hcat, sols)) @test isapprox(mean(λs), Eλ; atol = 0.01) @@ -156,12 +156,12 @@ let alg = Coevolve() for vr_aggregator in (VR_FRM(), VR_Direct(), VR_DirectFW()) oprob = ODEProblem(f!, u0, tspan, p) jumps = hawkes_jump(u0, g, h) - jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, rng, use_vrj_bounds = false) + jprob = JumpProblem(oprob, alg, jumps...; vr_aggregator, dep_graph = g, use_vrj_bounds = false) @test length(jprob.variable_jumps) == 1 sols = Vector{ODESolution}(undef, Nsims) for n in 1:Nsims reset_history!(h) - sols[n] = solve(jprob, Tsit5()) + sols[n] = solve(jprob, Tsit5(); rng) end cols = length(u0) diff --git a/test/linearreaction_test.jl b/test/linearreaction_test.jl index 657982a6..67e13b15 100644 --- a/test/linearreaction_test.jl +++ b/test/linearreaction_test.jl @@ -29,7 +29,7 @@ exactmeanval = exactmean(tf, rates) function runSSAs(jump_prob) Asamp = zeros(Int, Nsims) for i in 1:Nsims - sol = solve(jump_prob, SSAStepper()) + sol = solve(jump_prob, SSAStepper(); rng) Asamp[i] = sol[1, end] end mean(Asamp) @@ -52,7 +52,7 @@ function A_to_B_tuple(N, method) jumps = ((jump for jump in jumpvec)...,) jset = JumpSet((), jumps, nothing, nothing) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -74,7 +74,7 @@ function A_to_B_vec(N, method) # convert jumpvec to tuple to send to JumpProblem... jset = JumpSet((), jumps, nothing, nothing) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -92,7 +92,7 @@ function A_to_B_ma(N, method) majumps = MassActionJump(rates, reactstoch, netstoch) jset = JumpSet((), (), nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -126,7 +126,7 @@ function A_to_B_hybrid(N, method) majumps = MassActionJump(rates[1:switchidx], reactstoch, netstoch) jset = JumpSet((), jumps, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -161,7 +161,7 @@ function A_to_B_hybrid_nojset(N, method) jumps = (constjumps..., majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) jump_prob = JumpProblem(prob, method, jumps...; save_positions = (false, false), - rng, namedpars...) + namedpars...) jump_prob end @@ -190,7 +190,7 @@ function A_to_B_hybrid_vecs(N, method) end jset = JumpSet((), jumpvec, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -220,7 +220,7 @@ function A_to_B_hybrid_vecs_scalars(N, method) end jset = JumpSet((), jumpvec, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob @@ -252,7 +252,7 @@ function A_to_B_hybrid_tups_scalars(N, method) jumps = ((maj for maj in majumpsv)..., (jump for jump in jumpvec)...) prob = DiscreteProblem([A0, 0], (0.0, tf)) jump_prob = JumpProblem(prob, method, jumps...; save_positions = (false, false), - rng, namedpars...) + namedpars...) jump_prob end @@ -282,7 +282,7 @@ function A_to_B_hybrid_tups(N, method) jumps = ((jump for jump in jumpvec)...,) jset = JumpSet((), jumps, nothing, majumps) prob = DiscreteProblem([A0, 0], (0.0, tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), rng, + jump_prob = JumpProblem(prob, method, jset; save_positions = (false, false), namedpars...) jump_prob diff --git a/test/longtimes_test.jl b/test/longtimes_test.jl index 7c787b4f..4a63422e 100644 --- a/test/longtimes_test.jl +++ b/test/longtimes_test.jl @@ -11,6 +11,6 @@ u0 = [5] tspan = (0.0, 2e6) dt = tspan[2] / 1000 dprob = DiscreteProblem(u0, tspan, p) -jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng) -sol = solve(jprob, SSAStepper(), saveat = tspan[1]:dt:tspan[2]) +jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false)) +sol = solve(jprob, SSAStepper(); saveat = tspan[1]:dt:tspan[2], rng) @test length(unique(sol.u[(end - 10):end][:])) > 1 diff --git a/test/monte_carlo_test.jl b/test/monte_carlo_test.jl index c197df9d..f675dc9f 100644 --- a/test/monte_carlo_test.jl +++ b/test/monte_carlo_test.jl @@ -8,27 +8,27 @@ prob = SDEProblem(f, g, [1.0], (0.0, 1.0)) rate = (u, p, t) -> 200.0 affect! = integrator -> (integrator.u[1] = integrator.u[1] / 2) jump = VariableRateJump(rate, affect!, save_positions = (false, true)) -jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM(), rng) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM()) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(); trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false, rng) @test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2] -jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_Direct(), rng) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_Direct()) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(); trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false, rng) @test allunique(sol.u[1].t) -jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_DirectFW(), rng) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_DirectFW()) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(); trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false, rng) @test allunique(sol.u[1].t) jump = ConstantRateJump(rate, affect!) -jump_prob = JumpProblem(prob, Direct(), jump; save_positions = (true, false), rng) +jump_prob = JumpProblem(prob, Direct(), jump; save_positions = (true, false)) monte_prob = EnsembleProblem(jump_prob) -sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3, - save_everystep = false, dt = 0.001, adaptive = false) +sol = solve(monte_prob, SRIW1(), EnsembleSerial(); trajectories = 3, + save_everystep = false, dt = 0.001, adaptive = false, rng) @test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2] diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index f4d009ea..99c3375c 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -36,11 +36,11 @@ end u0 = [999.0, 10.0, 0.0] # S, I, R tspan = (0.0, 250.0) prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng, save_positions = (false, false)) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; save_positions = (false, false)) # Solve with SSAStepper (save only at t_compare times) sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); - trajectories = Nsims, saveat = t_compare) + trajectories = Nsims, saveat = t_compare, rng) # RegularJump formulation for SimpleTauLeaping regular_rate = (out, u, p, t) -> begin @@ -55,22 +55,22 @@ end dc[3] = counts[2] end rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj) # Solve with SimpleTauLeaping (save only at t_compare times) sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); - trajectories = Nsims, dt = 0.1, saveat = t_compare) + trajectories = Nsims, dt = 0.1, saveat = t_compare, rng) # MassActionJump formulation for SimpleExplicitTauLeaping reactant_stoich = [[1 => 1, 2 => 1], [2 => 1], Pair{Int, Int}[]] net_stoich = [[1 => -1, 2 => 1], [2 => -1, 3 => 1], [1 => 1]] param_idxs = [1, 2, 3] maj = MassActionJump(reactant_stoich, net_stoich; param_idxs) - jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj) # Solve with SimpleExplicitTauLeaping (save only at t_compare times) sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); - trajectories = Nsims, saveat = t_compare) + trajectories = Nsims, saveat = t_compare, rng) # Compute mean I trajectories via direct indexing (I is index 2 in SIR) mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 2) @@ -101,11 +101,11 @@ end u0 = [999.0, 0.0, 10.0, 0.0] # S, E, I, R tspan = (0.0, 250.0) prob_disc = DiscreteProblem(u0, tspan, p) - jump_prob = JumpProblem(prob_disc, Direct(), jumps...; rng, save_positions = (false, false)) + jump_prob = JumpProblem(prob_disc, Direct(), jumps...; save_positions = (false, false)) # Solve with SSAStepper (save only at t_compare times) sol_direct = solve(EnsembleProblem(jump_prob), SSAStepper(), EnsembleSerial(); - trajectories = Nsims, saveat = t_compare) + trajectories = Nsims, saveat = t_compare, rng) # RegularJump formulation for SimpleTauLeaping regular_rate = (out, u, p, t) -> begin @@ -121,22 +121,22 @@ end dc[4] = counts[3] end rj = RegularJump(regular_rate, regular_c, 3) - jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj; rng) + jump_prob_tau = JumpProblem(prob_disc, PureLeaping(), rj) # Solve with SimpleTauLeaping (save only at t_compare times) sol_simple = solve(EnsembleProblem(jump_prob_tau), SimpleTauLeaping(), EnsembleSerial(); - trajectories = Nsims, dt = 0.1, saveat = t_compare) + trajectories = Nsims, dt = 0.1, saveat = t_compare, rng) # MassActionJump formulation for SimpleExplicitTauLeaping reactant_stoich = [[1 => 1, 3 => 1], [2 => 1], [3 => 1]] net_stoich = [[1 => -1, 2 => 1], [2 => -1, 3 => 1], [3 => -1, 4 => 1]] param_idxs = [1, 2, 3] maj = MassActionJump(reactant_stoich, net_stoich; param_idxs) - jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj; rng) + jump_prob_maj = JumpProblem(prob_disc, PureLeaping(), maj) # Solve with SimpleExplicitTauLeaping (save only at t_compare times) sol_adaptive = solve(EnsembleProblem(jump_prob_maj), SimpleExplicitTauLeaping(), EnsembleSerial(); - trajectories = Nsims, saveat = t_compare) + trajectories = Nsims, saveat = t_compare, rng) # Compute mean I trajectories via direct indexing (I is index 3 in SEIR) mean_I_direct = compute_mean_at_saves(sol_direct, Nsims, npts, 3) @@ -183,7 +183,7 @@ end maj = MassActionJump(rates, reactant_stoich, net_stoich) # Test PureLeaping JumpProblem creation - jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj); rng) + jp_pure = JumpProblem(prob, PureLeaping(), JumpSet(maj)) @test jp_pure.aggregator isa PureLeaping @test jp_pure.discrete_jump_aggregation === nothing @test jp_pure.massaction_jump !== nothing @@ -194,7 +194,7 @@ end affect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) crj = ConstantRateJump(rate, affect!) - jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj); rng) + jp_pure_crj = JumpProblem(prob, PureLeaping(), JumpSet(crj)) @test jp_pure_crj.aggregator isa PureLeaping @test jp_pure_crj.discrete_jump_aggregation === nothing @test length(jp_pure_crj.constant_jumps) == 1 @@ -204,7 +204,7 @@ end vaffect!(integrator) = (integrator.u[1] -= 1; integrator.u[3] += 1) vrj = VariableRateJump(vrate, vaffect!) - jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj); rng) + jp_pure_vrj = JumpProblem(prob, PureLeaping(), JumpSet(vrj)) @test jp_pure_vrj.aggregator isa PureLeaping @test jp_pure_vrj.discrete_jump_aggregation === nothing @test length(jp_pure_vrj.variable_jumps) == 1 @@ -224,7 +224,7 @@ end regj = RegularJump(rj_rate, rj_c, 1) - jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj); rng) + jp_pure_regj = JumpProblem(prob, PureLeaping(), JumpSet(regj)) @test jp_pure_regj.aggregator isa PureLeaping @test jp_pure_regj.discrete_jump_aggregation === nothing @test jp_pure_regj.regular_jump !== nothing @@ -232,7 +232,7 @@ end # Test mixed jump types mixed_jumps = JumpSet(; massaction_jumps = maj, constant_jumps = (crj,), variable_jumps = (vrj,), regular_jumps = regj) - jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps; rng) + jp_pure_mixed = JumpProblem(prob, PureLeaping(), mixed_jumps) @test jp_pure_mixed.aggregator isa PureLeaping @test jp_pure_mixed.discrete_jump_aggregation === nothing @test jp_pure_mixed.massaction_jump !== nothing @@ -243,14 +243,14 @@ end # Test spatial system error spatial_sys = CartesianGrid((2, 2)) hopping_consts = [1.0] - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, + @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); spatial_system = spatial_sys) - @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); rng, + @test_throws ErrorException JumpProblem(prob, PureLeaping(), JumpSet(maj); hopping_constants = hopping_consts) # Test MassActionJump with parameter mapping maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) - jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params); rng) + jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params)) scaled_rates = [p[1], p[2]/2] @test jp_params.massaction_jump.scaled_rates == scaled_rates end @@ -266,23 +266,23 @@ end prob = DiscreteProblem(u0, tspan) # SSAStepper with save_positions=(false,false) + saveat: only saveat times stored - jp = JumpProblem(prob, Direct(), crj; rng, save_positions = (false, false)) - sol = solve(jp, SSAStepper(); saveat = 1.0) + jp = JumpProblem(prob, Direct(), crj; save_positions = (false, false)) + sol = solve(jp, SSAStepper(); saveat = 1.0, rng) @test sol.t == collect(0.0:1.0:10.0) # SSAStepper with default save_positions + saveat: jump times stored too - jp2 = JumpProblem(prob, Direct(), crj; rng) - sol2 = solve(jp2, SSAStepper(); saveat = 1.0) + jp2 = JumpProblem(prob, Direct(), crj) + sol2 = solve(jp2, SSAStepper(); saveat = 1.0, rng) @test length(sol2.t) > length(sol.t) # --- SimpleTauLeaping save_start/save_end/saveat tests --- regular_rate = (out, u, p, t) -> (out[1] = 1.0) regular_c = (dc, u, p, t, counts, mark) -> (dc[1] = counts[1]) rj = RegularJump(regular_rate, regular_c, 1) - jp_tau = JumpProblem(prob, PureLeaping(), rj; rng) + jp_tau = JumpProblem(prob, PureLeaping(), rj) # No saveat: stores every dt step (save_start=true, save_end=true by default) - sol_tau = solve(jp_tau, SimpleTauLeaping(); dt = 1.0) + sol_tau = solve(jp_tau, SimpleTauLeaping(); dt = 1.0, rng) @test sol_tau.t == collect(0.0:1.0:10.0) # saveat as Number: defaults save_start=true, save_end=true @@ -334,7 +334,7 @@ end reactant_stoich = [[1 => 1]] net_stoich = [[1 => -1]] maj = MassActionJump([0.1], reactant_stoich, net_stoich) - jp_explicit = JumpProblem(prob_decay, PureLeaping(), maj; rng) + jp_explicit = JumpProblem(prob_decay, PureLeaping(), maj) # saveat as Number: defaults save_start=true, save_end=true sol = solve(jp_explicit, SimpleExplicitTauLeaping(); saveat = 2.0) diff --git a/test/reversible_binding.jl b/test/reversible_binding.jl index f872d92a..456a96d2 100644 --- a/test/reversible_binding.jl +++ b/test/reversible_binding.jl @@ -20,10 +20,10 @@ tspan = (0.0, 5.0) prob = DiscreteProblem(u0, tspan, rates) majumps = MassActionJump(rates, reactstoch, netstoch) -function getmean(jprob, Nsims) +function getmean(jprob, Nsims; rng = nothing) Amean = 0 for i in 1:Nsims - sol = solve(jprob, SSAStepper()) + sol = isnothing(rng) ? solve(jprob, SSAStepper()) : solve(jprob, SSAStepper(); rng) Amean += sol[1, end] end Amean /= Nsims @@ -48,8 +48,7 @@ mastereq_mean = mastereqmean(u0, rates) algs = (JumpProcesses.JUMP_AGGREGATORS..., JumpProcesses.NullAggregator()) relative_tolerance = 0.01 for alg in algs - local jprob = JumpProblem(prob, alg, majumps, save_positions = (false, false), - rng = rng) - local Amean = getmean(jprob, Nsims) + local jprob = JumpProblem(prob, alg, majumps, save_positions = (false, false)) + local Amean = getmean(jprob, Nsims; rng) @test abs(Amean - mastereq_mean) / mastereq_mean < relative_tolerance end diff --git a/test/rng_kwarg_tests.jl b/test/rng_kwarg_tests.jl index a619edb7..a677a911 100644 --- a/test/rng_kwarg_tests.jl +++ b/test/rng_kwarg_tests.jl @@ -2,191 +2,159 @@ using JumpProcesses, OrdinaryDiffEq, StochasticDiffEq, Test using StableRNGs, Random # ========================================================================== -# Test that rng can be passed via JumpProblem kwargs OR solve kwargs, -# and that solve-level rng takes precedence over JumpProblem-level rng. -# -# Strategy: use different RNG *types* to verify which one the integrator -# receives. StableRNG is passed at one level and Xoshiro at another, -# then we check the type on the integrator. +# Test that rng/seed can be passed via solve/init kwargs for all pathways, +# and that JumpProblem(; rng=...) throws an error. # ========================================================================== # -------------------------------------------------------------------------- # Problem constructors # -------------------------------------------------------------------------- -function make_ssa_jump_prob(; kwargs...) +function make_ssa_jump_prob() j1 = ConstantRateJump((u, p, t) -> 10.0, integrator -> (integrator.u[1] += 1)) j2 = ConstantRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1)) dprob = DiscreteProblem([10], (0.0, 20.0)) - JumpProblem(dprob, Direct(), j1, j2; kwargs...) + JumpProblem(dprob, Direct(), j1, j2) end -function make_ode_vr_jump_prob(; kwargs...) +function make_ode_vr_jump_prob() f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) oprob = ODEProblem(f!, [100.0], (0.0, 10.0)) vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1.0)) - JumpProblem(oprob, Direct(), vrj; kwargs...) + JumpProblem(oprob, Direct(), vrj) end -function make_sde_vr_jump_prob(; kwargs...) +function make_sde_vr_jump_prob() f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) g!(du, u, p, t) = (du[1] = 0.1 * u[1]; nothing) sprob = SDEProblem(f!, g!, [100.0], (0.0, 10.0)) vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], integrator -> (integrator.u[1] -= 1.0)) - JumpProblem(sprob, Direct(), vrj; kwargs...) + JumpProblem(sprob, Direct(), vrj) end # ========================================================================== -# 1. SSAStepper: rng via JumpProblem +# 1. JumpProblem(; rng=...) throws ArgumentError # ========================================================================== -@testset "SSAStepper: rng via JumpProblem kwargs" begin - xrng = Xoshiro(42) - jprob = make_ssa_jump_prob(; rng = xrng) - integrator = init(jprob, SSAStepper()) - @test SciMLBase.get_rng(integrator) isa Xoshiro - sol = solve(jprob, SSAStepper()) - @test sol.retcode == ReturnCode.Success +@testset "JumpProblem(; rng=...) throws ArgumentError" begin + dprob = DiscreteProblem([10], (0.0, 10.0)) + j1 = ConstantRateJump((u, p, t) -> 1.0, integrator -> (integrator.u[1] += 1)) + @test_throws ArgumentError JumpProblem(dprob, Direct(), j1; rng = StableRNG(42)) end # ========================================================================== -# 2. SSAStepper: rng via solve overrides JumpProblem +# 2. SSAStepper: rng via solve/init # ========================================================================== -@testset "SSAStepper: solve rng overrides JumpProblem rng" begin - jprob = make_ssa_jump_prob(; rng = Xoshiro(42)) - integrator = init(jprob, SSAStepper(); rng = StableRNG(99)) - @test SciMLBase.get_rng(integrator) isa StableRNG +@testset "SSAStepper: rng via solve kwargs" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); rng = Xoshiro(42)) + @test SciMLBase.get_rng(integrator) isa Xoshiro + sol = solve(jprob, SSAStepper(); rng = Xoshiro(42)) + @test sol.retcode == ReturnCode.Success end # ========================================================================== -# 3. SSAStepper: reproducibility via JumpProblem rng +# 3. SSAStepper: reproducibility via solve rng # ========================================================================== -@testset "SSAStepper: JumpProblem rng reproducibility" begin - jprob1 = make_ssa_jump_prob(; rng = StableRNG(123)) - jprob2 = make_ssa_jump_prob(; rng = StableRNG(123)) - sol1 = solve(jprob1, SSAStepper()) - sol2 = solve(jprob2, SSAStepper()) +@testset "SSAStepper: solve rng reproducibility" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); rng = StableRNG(123)) + sol2 = solve(jprob, SSAStepper(); rng = StableRNG(123)) @test sol1.t == sol2.t @test sol1.u == sol2.u end # ========================================================================== -# 4. SSAStepper: solve rng overrides for reproducibility +# 4. SSAStepper: different seeds → different trajectories # ========================================================================== -@testset "SSAStepper: solve rng override reproducibility" begin - jprob = make_ssa_jump_prob(; rng = Xoshiro(1)) - sol1 = solve(jprob, SSAStepper(); rng = StableRNG(42)) - sol2 = solve(jprob, SSAStepper(); rng = StableRNG(42)) - @test sol1.t == sol2.t - @test sol1.u == sol2.u +@testset "SSAStepper: different seeds → different trajectories" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); rng = StableRNG(100)) + sol2 = solve(jprob, SSAStepper(); rng = StableRNG(200)) + sol3 = solve(jprob, SSAStepper(); rng = StableRNG(300)) + times = [sol1.t[2], sol2.t[2], sol3.t[2]] + @test allunique(times) end # ========================================================================== -# 5. ODE + VR: rng via JumpProblem +# 5. ODE + VR: rng via solve/init # ========================================================================== -@testset "ODE + VR: rng via JumpProblem kwargs" begin - jprob = make_ode_vr_jump_prob(; rng = Xoshiro(42)) - integrator = init(jprob, Tsit5()) +@testset "ODE + VR: rng via solve kwargs" begin + jprob = make_ode_vr_jump_prob() + integrator = init(jprob, Tsit5(); rng = Xoshiro(42)) @test SciMLBase.get_rng(integrator) isa Xoshiro end # ========================================================================== -# 6. ODE + VR: solve rng overrides JumpProblem rng +# 6. ODE + VR: reproducibility via solve rng # ========================================================================== -@testset "ODE + VR: solve rng overrides JumpProblem rng" begin - jprob = make_ode_vr_jump_prob(; rng = Xoshiro(42)) - integrator = init(jprob, Tsit5(); rng = StableRNG(99)) - @test SciMLBase.get_rng(integrator) isa StableRNG -end - -# ========================================================================== -# 7. ODE + VR: reproducibility via JumpProblem rng -# ========================================================================== -@testset "ODE + VR: JumpProblem rng reproducibility" begin - jprob1 = make_ode_vr_jump_prob(; rng = StableRNG(123)) - jprob2 = make_ode_vr_jump_prob(; rng = StableRNG(123)) - sol1 = solve(jprob1, Tsit5()) - sol2 = solve(jprob2, Tsit5()) +@testset "ODE + VR: solve rng reproducibility" begin + jprob = make_ode_vr_jump_prob() + sol1 = solve(jprob, Tsit5(); rng = StableRNG(123)) + sol2 = solve(jprob, Tsit5(); rng = StableRNG(123)) @test sol1.t ≈ sol2.t @test sol1.u[end] ≈ sol2.u[end] end # ========================================================================== -# 8. ODE + VR: solve rng overrides for reproducibility +# 7. ODE + VR: different seeds → different trajectories # ========================================================================== -@testset "ODE + VR: solve rng override reproducibility" begin - jprob = make_ode_vr_jump_prob(; rng = Xoshiro(1)) - sol1 = solve(jprob, Tsit5(); rng = StableRNG(42)) - sol2 = solve(jprob, Tsit5(); rng = StableRNG(42)) - @test sol1.t ≈ sol2.t - @test sol1.u[end] ≈ sol2.u[end] +@testset "ODE + VR: different seeds → different trajectories" begin + jprob = make_ode_vr_jump_prob() + sols = [solve(jprob, Tsit5(); rng = StableRNG(s)) for s in (100, 200, 300)] + finals = [s.u[end][1] for s in sols] + @test allunique(finals) end # ========================================================================== -# 9. SDE + VR: rng via JumpProblem +# 8. SDE + VR: rng via solve/init # ========================================================================== -@testset "SDE + VR: rng via JumpProblem kwargs" begin - jprob = make_sde_vr_jump_prob(; rng = Xoshiro(42)) - integrator = init(jprob, EM(); dt = 0.01) +@testset "SDE + VR: rng via solve kwargs" begin + jprob = make_sde_vr_jump_prob() + integrator = init(jprob, EM(); dt = 0.01, rng = Xoshiro(42)) @test SciMLBase.get_rng(integrator) isa Xoshiro end # ========================================================================== -# 10. SDE + VR: solve rng overrides JumpProblem rng -# ========================================================================== -@testset "SDE + VR: solve rng overrides JumpProblem rng" begin - jprob = make_sde_vr_jump_prob(; rng = Xoshiro(42)) - integrator = init(jprob, EM(); dt = 0.01, rng = StableRNG(99)) - @test SciMLBase.get_rng(integrator) isa StableRNG -end - -# ========================================================================== -# 11. SDE + VR: reproducibility via JumpProblem rng +# 9. SDE + VR: reproducibility via solve rng # ========================================================================== -@testset "SDE + VR: JumpProblem rng reproducibility" begin - jprob1 = make_sde_vr_jump_prob(; rng = StableRNG(123)) - jprob2 = make_sde_vr_jump_prob(; rng = StableRNG(123)) - sol1 = solve(jprob1, EM(); dt = 0.01, save_everystep = false) - sol2 = solve(jprob2, EM(); dt = 0.01, save_everystep = false) +@testset "SDE + VR: solve rng reproducibility" begin + jprob = make_sde_vr_jump_prob() + sol1 = solve(jprob, EM(); dt = 0.01, save_everystep = false, rng = StableRNG(123)) + sol2 = solve(jprob, EM(); dt = 0.01, save_everystep = false, rng = StableRNG(123)) @test sol1.u[end] ≈ sol2.u[end] end # ========================================================================== -# 12. Tau-leaping: rng via JumpProblem +# 10. SimpleTauLeaping: rng via solve kwargs # ========================================================================== -@testset "SimpleTauLeaping: rng via JumpProblem kwargs" begin +@testset "SimpleTauLeaping: rng via solve kwargs" begin rate(out, u, p, t) = (out .= max.(u, 0); nothing) c(du, u, p, t, counts, mark) = (du .= counts; nothing) rj = RegularJump(rate, c, 2) dprob = DiscreteProblem([100, 100], (0.0, 1.0)) - jprob = JumpProblem(dprob, PureLeaping(), rj; rng = StableRNG(42)) - sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01) - jprob2 = JumpProblem(dprob, PureLeaping(), rj; rng = StableRNG(42)) - sol2 = solve(jprob2, SimpleTauLeaping(); dt = 0.01) + jprob = JumpProblem(dprob, PureLeaping(), rj) + sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(42)) + sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(42)) @test sol1.u == sol2.u end # ========================================================================== -# 13. Tau-leaping: solve rng overrides JumpProblem rng +# 11. SimpleTauLeaping: different seeds → different trajectories # ========================================================================== -@testset "SimpleTauLeaping: solve rng overrides JumpProblem rng" begin +@testset "SimpleTauLeaping: different seeds → different trajectories" begin rate(out, u, p, t) = (out .= max.(u, 0); nothing) c(du, u, p, t, counts, mark) = (du .= counts; nothing) rj = RegularJump(rate, c, 2) dprob = DiscreteProblem([100, 100], (0.0, 1.0)) - # JumpProblem has Xoshiro, solve has StableRNG - jprob = JumpProblem(dprob, PureLeaping(), rj; rng = Xoshiro(1)) + jprob = JumpProblem(dprob, PureLeaping(), rj) sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(42)) - sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(42)) - @test sol1.u == sol2.u - # Different from using the JumpProblem rng - jprob2 = JumpProblem(dprob, PureLeaping(), rj; rng = Xoshiro(1)) - sol3 = solve(jprob2, SimpleTauLeaping(); dt = 0.01) - @test sol1.u != sol3.u + sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, rng = StableRNG(99)) + @test sol1.u != sol2.u end # ========================================================================== -# 14. has_rng / get_rng / set_rng! interface on SSAIntegrator +# 12. has_rng / get_rng / set_rng! interface on SSAIntegrator # ========================================================================== @testset "SSAIntegrator RNG interface" begin jprob = make_ssa_jump_prob() @@ -202,7 +170,7 @@ end end # ========================================================================== -# 15. No rng kwarg: uses default_rng (non-reproducible but functional) +# 13. No rng kwarg: uses default_rng (non-reproducible but functional) # ========================================================================== @testset "No rng kwarg: functional solve" begin @testset "SSAStepper" begin @@ -217,3 +185,21 @@ end @test sol.retcode == ReturnCode.Success end end + +# ========================================================================== +# 14. seed kwarg: creates Xoshiro from integer seed +# ========================================================================== +@testset "seed kwarg creates Xoshiro" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); seed = 42) + @test SciMLBase.get_rng(integrator) isa Xoshiro +end + +# ========================================================================== +# 15. rng takes priority over seed +# ========================================================================== +@testset "rng takes priority over seed" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); rng = StableRNG(42), seed = 99) + @test SciMLBase.get_rng(integrator) isa StableRNG +end diff --git a/test/save_positions.jl b/test/save_positions.jl index 13413b5b..5d0a683f 100644 --- a/test/save_positions.jl +++ b/test/save_positions.jl @@ -15,8 +15,8 @@ let jump = VariableRateJump((u, p, t) -> 0, (integrator) -> integrator.u[1] += 1; urate = (u, p, t) -> 1.0, rateinterval = (u, p, t) -> 5.0) jumpproblem = JumpProblem(dprob, alg, jump; dep_graph = [[1]], - save_positions = (false, true), rng) - sol = solve(jumpproblem, SSAStepper()) + save_positions = (false, true)) + sol = solve(jumpproblem, SSAStepper(); rng) @test sol.t == [0.0, 30.0] oprob = ODEProblem((du, u, p, t) -> 0, u0, tspan) @@ -26,8 +26,8 @@ let for vr_agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jumpproblem = JumpProblem( oprob, alg, jump; vr_aggregator = vr_agg, dep_graph = [[1]], - save_positions = (false, true), rng) - sol = solve(jumpproblem, Tsit5(); save_everystep = false) + save_positions = (false, true)) + sol = solve(jumpproblem, Tsit5(); rng, save_everystep = false) @test sol.t == [0.0, 30.0] end end @@ -46,20 +46,20 @@ let # for pure jump problems dense = save_everystep vals = (true, true, true, false) for (sp, val) in zip(sps, vals) - jprob = JumpProblem(dprob, Direct(), crj; save_positions = sp, rng) - sol = solve(jprob, SSAStepper()) + jprob = JumpProblem(dprob, Direct(), crj; save_positions = sp) + sol = solve(jprob, SSAStepper(); rng) @test SciMLBase.isdenseplot(sol) == val end # for mixed problems sol.dense currently ignores save_positions oprob = ODEProblem((du, u, p, t) -> du[1] = 0.1, u0, tspan) for sp in sps - jprob = JumpProblem(oprob, Direct(), crj; save_positions = sp, rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(oprob, Direct(), crj; save_positions = sp) + sol = solve(jprob, Tsit5(); rng) @test sol.dense == true @test SciMLBase.isdenseplot(sol) == true - sol = solve(jprob, Tsit5(); dense = false) + sol = solve(jprob, Tsit5(); rng, dense = false) @test sol.dense == false @test SciMLBase.isdenseplot(sol) == false end diff --git a/test/saveat_regression.jl b/test/saveat_regression.jl index 03665d76..108fde97 100644 --- a/test/saveat_regression.jl +++ b/test/saveat_regression.jl @@ -10,12 +10,12 @@ maj = MassActionJump(rate_consts, reactant_stoich, net_stoich) n0 = [1, 1, 0] tspan = (0, 0.2) dprob = DiscreteProblem(n0, tspan) -jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng) +jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false)) ts = collect(0:0.002:tspan[2]) NA = zeros(length(ts)) Nsims = 10_000 -sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(), saveat = ts, - trajectories = Nsims) +sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(); saveat = ts, + trajectories = Nsims, rng) for i in 1:length(sol) NA .+= sol.u[i][1, :] @@ -26,10 +26,10 @@ for i in 1:length(ts) end NA = zeros(length(ts)) -jprob = JumpProblem(dprob, Direct(), maj; rng = rng) +jprob = JumpProblem(dprob, Direct(), maj) sol = nothing; GC.gc(); -sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(), trajectories = Nsims) +sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(); trajectories = Nsims, rng) for i in 1:Nsims for n in 1:length(ts) diff --git a/test/sir_model.jl b/test/sir_model.jl index e8cea455..8f4cc1cc 100644 --- a/test/sir_model.jl +++ b/test/sir_model.jl @@ -18,8 +18,8 @@ end jump2 = ConstantRateJump(rate, affect!) prob = DiscreteProblem([999.0, 1.0, 0.0], (0.0, 250.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) -integrator = init(jump_prob, FunctionMap()) +jump_prob = JumpProblem(prob, Direct(), jump, jump2) +integrator = init(jump_prob, FunctionMap(); rng) condition(u, t, integrator) = t == 100 function purge_affect!(integrator) @@ -27,8 +27,8 @@ function purge_affect!(integrator) reset_aggregated_jumps!(integrator) end cb = DiscreteCallback(condition, purge_affect!, save_positions = (false, false)) -sol = solve(jump_prob, FunctionMap(), callback = cb, tstops = [100]) -sol = solve(jump_prob, SSAStepper(), callback = cb, tstops = [100]) +sol = solve(jump_prob, FunctionMap(); callback = cb, tstops = [100], rng) +sol = solve(jump_prob, SSAStepper(); callback = cb, tstops = [100], rng) # test README example using the auto-solver selection runs let diff --git a/test/spatial/ABC.jl b/test/spatial/ABC.jl index 8d7230a7..3935ffca 100644 --- a/test/spatial/ABC.jl +++ b/test/spatial/ABC.jl @@ -38,10 +38,10 @@ prob = DiscreteProblem(starting_state, tspan, rates) hopping_constants = [hopping_rate for i in starting_state] # algs = [NSM(), DirectCRDirect()] -function get_mean_end_state(jump_prob, Nsims) +function get_mean_end_state(jump_prob, Nsims; rng = nothing) end_state = zeros(size(jump_prob.prob.u0)) for i in 1:Nsims - sol = solve(jump_prob, SSAStepper()) + sol = solve(jump_prob, SSAStepper(); rng) end_state .+= sol.u[end] end end_state / Nsims @@ -52,19 +52,19 @@ grids = [CartesianGridRej(dims), Graphs.grid(dims)] jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, spatial_system = grid, - save_positions = (false, false), rng = rng) + save_positions = (false, false)) for grid in grids] push!(jump_problems, JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) # setup flattenned jump prob push!(jump_problems, JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) # test for spatial_jump_prob in jump_problems - solution = solve(spatial_jump_prob, SSAStepper()) - mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) + solution = solve(spatial_jump_prob, SSAStepper(); rng) + mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims; rng) mean_end_state = reshape(mean_end_state, num_species, num_nodes) diff = sum(mean_end_state, dims = 2) - non_spatial_mean for (i, d) in enumerate(diff) diff --git a/test/spatial/diffusion.jl b/test/spatial/diffusion.jl index c014b5c5..e501ab07 100644 --- a/test/spatial/diffusion.jl +++ b/test/spatial/diffusion.jl @@ -5,10 +5,10 @@ using Graphs using StableRNGs rng = StableRNG(12345) -function get_mean_sol(jump_prob, Nsims, saveat) - sol = solve(jump_prob, SSAStepper(), saveat = saveat).u +function get_mean_sol(jump_prob, Nsims, saveat; rng = nothing) + sol = solve(jump_prob, SSAStepper(); saveat, rng).u for i in 1:(Nsims - 1) - sol += solve(jump_prob, SSAStepper(), saveat = saveat).u + sol += solve(jump_prob, SSAStepper(); saveat, rng).u end sol / Nsims end @@ -66,24 +66,24 @@ grids = [CartesianGridRej(dims), Graphs.grid(dims)] jump_problems = JumpProblem[JumpProblem(prob, algs[2], majumps, hopping_constants = hopping_constants, spatial_system = grid, - save_positions = (false, false), rng = rng) + save_positions = (false, false)) for grid in grids] sizehint!(jump_problems, 15 + length(jump_problems)) # flattenned jump prob push!(jump_problems, JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) # hop rates of form D_s hop_constants = [hopping_rate] for alg in algs push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false))) end # hop rates of form L_{s,i,j} @@ -96,10 +96,10 @@ end for alg in algs push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = hop_constants, - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false))) end # hop rates of form D_s * L_{i,j} @@ -112,11 +112,11 @@ for alg in algs push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false))) end # hop rates of form D_{s,i} * L_{i,j} @@ -129,16 +129,16 @@ for alg in algs push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[1], save_positions = (false, false), rng = rng)) + spatial_system = grids[1], save_positions = (false, false))) push!(jump_problems, JumpProblem(prob, alg, majumps, hopping_constants = Pair(species_hop_constants, site_hop_constants), - spatial_system = grids[end], save_positions = (false, false), rng = rng)) + spatial_system = grids[end], save_positions = (false, false))) end # testing for (j, spatial_jump_prob) in enumerate(jump_problems) - mean_sol = get_mean_sol(spatial_jump_prob, Nsims, tf / num_time_points) + mean_sol = get_mean_sol(spatial_jump_prob, Nsims, tf / num_time_points; rng) for (i, t) in enumerate(times) local diff = analytic_solution(t) - reshape(mean_sol[i], num_nodes, 1) @test abs(sum(diff[1:center_node]) / sum(analytic_solution(t)[1:center_node])) < @@ -165,7 +165,7 @@ tspan = (0.0, 10.0) prob = DiscreteProblem(starting_state, tspan) jp = JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants, - spatial_system = grid, save_positions = (false, false), rng = rng) -sol = solve(jp, SSAStepper()) + spatial_system = grid, save_positions = (false, false)) +sol = solve(jp, SSAStepper(); rng) @test sol.u[end][1, 1] == sum(sol.u[end]) diff --git a/test/spatial/spatial_majump.jl b/test/spatial/spatial_majump.jl index 6afcd367..6018b68f 100644 --- a/test/spatial/spatial_majump.jl +++ b/test/spatial/spatial_majump.jl @@ -61,26 +61,26 @@ non_uniform_majumps = [non_uniform_majumps_1, non_uniform_majumps_2, non_uniform uniform_jump_problems = JumpProblem[JumpProblem(prob, NSM(), majump, hopping_constants = hopping_constants, spatial_system = grid, - save_positions = (false, false), rng = rng) + save_positions = (false, false)) for majump in uniform_majumps] # flattenned append!(uniform_jump_problems, JumpProblem[JumpProblem(prob, NRM(), majump, hopping_constants = hopping_constants, - spatial_system = grid, save_positions = (false, false), rng = rng) + spatial_system = grid, save_positions = (false, false)) for majump in uniform_majumps]) # non-uniform non_uniform_jump_problems = JumpProblem[JumpProblem(prob, NSM(), majump, hopping_constants = hopping_constants, spatial_system = grid, - save_positions = (false, false), rng = rng) + save_positions = (false, false)) for majump in non_uniform_majumps] # testing -function get_mean_end_state(jump_prob, Nsims) +function get_mean_end_state(jump_prob, Nsims; rng = nothing) end_state = zeros(size(jump_prob.prob.u0)) for i in 1:Nsims - sol = solve(jump_prob, SSAStepper()) + sol = solve(jump_prob, SSAStepper(); rng) end_state .+= sol.u[end] end end_state / Nsims @@ -106,8 +106,8 @@ ode_prob = ODEProblem(f, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) for spatial_jump_prob in uniform_jump_problems - solution = solve(spatial_jump_prob, SSAStepper()) - mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims) + solution = solve(spatial_jump_prob, SSAStepper(); rng) + mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims; rng) mean_end_state = reshape(mean_end_state, num_nodes) diff = mean_end_state - sol.u[end] for (i, d) in enumerate(diff) @@ -122,8 +122,8 @@ end ode_prob = ODEProblem(f2, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) -solution = solve(non_uniform_jump_problems[1], SSAStepper()) -mean_end_state = get_mean_end_state(non_uniform_jump_problems[1], Nsims) +solution = solve(non_uniform_jump_problems[1], SSAStepper(); rng) +mean_end_state = get_mean_end_state(non_uniform_jump_problems[1], Nsims; rng) mean_end_state = reshape(mean_end_state, num_nodes) diff = mean_end_state - sol.u[end] for (i, d) in enumerate(diff) @@ -135,8 +135,8 @@ f3(u, p, t) = L * u - diagm([0.0, 0.0, death_rate, 0.0, 0.0]) * u + ones(num_nod ode_prob = ODEProblem(f3, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) -solution = solve(non_uniform_jump_problems[2], SSAStepper()) -mean_end_state = get_mean_end_state(non_uniform_jump_problems[2], Nsims) +solution = solve(non_uniform_jump_problems[2], SSAStepper(); rng) +mean_end_state = get_mean_end_state(non_uniform_jump_problems[2], Nsims; rng) mean_end_state = reshape(mean_end_state, num_nodes) diff = mean_end_state - sol.u[end] for (i, d) in enumerate(diff) @@ -150,8 +150,8 @@ end ode_prob = ODEProblem(f4, zeros(num_nodes), tspan) sol = solve(ode_prob, Tsit5()) -solution = solve(non_uniform_jump_problems[3], SSAStepper()) -mean_end_state = get_mean_end_state(non_uniform_jump_problems[3], Nsims) +solution = solve(non_uniform_jump_problems[3], SSAStepper(); rng) +mean_end_state = get_mean_end_state(non_uniform_jump_problems[3], Nsims; rng) mean_end_state = reshape(mean_end_state, num_nodes) diff = mean_end_state - sol.u[end] for (i, d) in enumerate(diff) diff --git a/test/splitcoupled.jl b/test/splitcoupled.jl index e871e2c0..1030ec88 100644 --- a/test/splitcoupled.jl +++ b/test/splitcoupled.jl @@ -12,16 +12,15 @@ jump1 = ConstantRateJump(rate, affect!) prob = DiscreteProblem([10], (0.0, 50.0)) prob_control = DiscreteProblem([10], (0.0, 50.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump1) coupling_map = [(1, 1)] coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) + jump_prob, jump_prob_control, Direct(), coupling_map) -@time sol = solve(coupled_prob, FunctionMap()) -@time solve(jump_prob, FunctionMap()) +@time sol = solve(coupled_prob, FunctionMap(); rng) +@time solve(jump_prob, FunctionMap(); rng) @test [s[1] - s[2] for s in sol.u] == zeros(length(sol.t)) # coupling two copies of the same process should give zero rate = (u, p, t) -> 1.0 @@ -42,34 +41,31 @@ end # Jump ODE to jump ODE prob = ODEProblem(f, [1.0], (0.0, 1.0)) prob_control = ODEProblem(f, [1.0], (0.0, 1.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump2; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump2) coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) -sol = solve(coupled_prob, Tsit5()) + jump_prob, jump_prob_control, Direct(), coupling_map) +sol = solve(coupled_prob, Tsit5(); rng) @test mean([abs(s[1] - s[2]) for s in sol.u]) <= 5.0 # Jump SDE to Jump SDE prob = SDEProblem(f, g, [1.0], (0.0, 1.0)) prob_control = SDEProblem(f, g, [1.0], (0.0, 1.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump1) coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) -sol = solve(coupled_prob, SRIW1()) + jump_prob, jump_prob_control, Direct(), coupling_map) +sol = solve(coupled_prob, SRIW1(); rng) @test mean([abs(s[1] - s[2]) for s in sol.u]) <= 5.0 # Jump SDE to Jump ODE prob = ODEProblem(f, [1.0], (0.0, 1.0)) prob_control = SDEProblem(f, g, [1.0], (0.0, 1.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump1) coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) -sol = solve(coupled_prob, SRIW1()) + jump_prob, jump_prob_control, Direct(), coupling_map) +sol = solve(coupled_prob, SRIW1(); rng) @test mean([abs(s[1] - s[2]) for s in sol.u]) <= 5.0 # Jump SDE to Discrete @@ -79,12 +75,11 @@ affect! = function (integrator) end prob = DiscreteProblem([1.0], (0.0, 1.0)) prob_control = SDEProblem(f, g, [1.0], (0.0, 1.0)) -jump_prob = JumpProblem(prob, Direct(), jump1; rng = rng) -jump_prob_control = JumpProblem(prob_control, Direct(), jump1; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump1) +jump_prob_control = JumpProblem(prob_control, Direct(), jump1) coupled_prob = SplitCoupledJumpProblem( - jump_prob, jump_prob_control, Direct(), coupling_map; - rng = rng) -sol = solve(coupled_prob, SRIW1()) + jump_prob, jump_prob_control, Direct(), coupling_map) +sol = solve(coupled_prob, SRIW1(); rng) # test mass action jumps coupled to ODE # 0 -> A (stochasic) and A -> 0 (ODE) @@ -96,13 +91,12 @@ f = function (du, u, p, t) du[1] = -1.0 * u[1] end odeprob = ODEProblem(f, [10.0], (0.0, 10.0)) -jump_prob = JumpProblem(odeprob, Direct(), majumps, save_positions = (false, false); - rng = rng) +jump_prob = JumpProblem(odeprob, Direct(), majumps, save_positions = (false, false)) Nsims = 8000 Amean = 0.0 for i in 1:Nsims global Amean - local sol = solve(jump_prob, Tsit5(), saveat = 10.0) + local sol = solve(jump_prob, Tsit5(); saveat = 10.0, rng) Amean += sol[1, end] end Amean /= Nsims diff --git a/test/ssa_tests.jl b/test/ssa_tests.jl index e82c5095..79bdbe15 100644 --- a/test/ssa_tests.jl +++ b/test/ssa_tests.jl @@ -16,31 +16,30 @@ end jump2 = ConstantRateJump(rate, affect!) prob = DiscreteProblem([10.0], (0.0, 3.0)) -jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng) +jump_prob = JumpProblem(prob, Direct(), jump, jump2) -integrator = init(jump_prob, SSAStepper()) +integrator = init(jump_prob, SSAStepper(); rng) step!(integrator) integrator.u[1] # test different saving behaviors -sol = solve(jump_prob, SSAStepper()) +sol = solve(jump_prob, SSAStepper(); rng) @test SciMLBase.successful_retcode(sol) @test sol.t[begin] == 0.0 @test sol.t[end] == 3.0 -sol = solve(jump_prob, SSAStepper(), save_end = false) +sol = solve(jump_prob, SSAStepper(); save_end = false, rng) @test sol.t[begin] == 0.0 @test sol.t[end] < 3.0 -sol = solve(jump_prob, SSAStepper(), save_start = false) +sol = solve(jump_prob, SSAStepper(); save_start = false, rng) @test sol.t[begin] > 0.0 @test sol.t[end] == 3.0 -jump_prob = JumpProblem(prob, Direct(), jump, jump2, save_positions = (false, false); - rng = rng) -sol = solve(jump_prob, SSAStepper(), save_start = false, save_end = false) +jump_prob = JumpProblem(prob, Direct(), jump, jump2, save_positions = (false, false)) +sol = solve(jump_prob, SSAStepper(); save_start = false, save_end = false, rng) @test isempty(sol.t) && isempty(sol.u) -sol = solve(jump_prob, SSAStepper(), saveat = 0.0:0.1:2.9) +sol = solve(jump_prob, SSAStepper(); saveat = 0.0:0.1:2.9, rng) @test sol.t == collect(0.0:0.1:3.0) diff --git a/test/thread_safety.jl b/test/thread_safety.jl index 197ed72c..3a7c2043 100644 --- a/test/thread_safety.jl +++ b/test/thread_safety.jl @@ -9,10 +9,10 @@ params = (1.0, 2.0, 50.0) tspan = (0.0, 4.0) u0 = [5] dprob = DiscreteProblem(u0, tspan, params) -jprob = JumpProblem(dprob, Direct(), maj; rng = rng) -solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); trajectories = 10) +jprob = JumpProblem(dprob, Direct(), maj) +solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); trajectories = 10, rng) solve(EnsembleProblem(jprob; safetycopy = true), SSAStepper(), EnsembleThreads(); - trajectories = 10) + trajectories = 10, rng) # test for https://github.com/SciML/JumpProcesses.jl/issues/472 let diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 2669b70e..63a6567f 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -30,15 +30,15 @@ f = function (du, u, p, t) end prob = ODEProblem(f, [0.2], (0.0, 10.0)) -jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) -integrator = init(jump_prob, Tsit5()) -sol = solve(jump_prob, Tsit5()) -sol = solve(jump_prob, Rosenbrock23(autodiff = AutoFiniteDiff())) -sol = solve(jump_prob, Rosenbrock23()) - -jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng) -integrator = init(jump_prob_gill, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM()) +integrator = init(jump_prob, Tsit5(); rng) +sol = solve(jump_prob, Tsit5(); rng) +sol = solve(jump_prob, Rosenbrock23(autodiff = AutoFiniteDiff()); rng) +sol = solve(jump_prob, Rosenbrock23(); rng) + +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct()) +integrator = init(jump_prob_gill, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) sol_gill = solve(jump_prob, Rosenbrock23(autodiff = AutoFiniteDiff())) sol_gill = solve(jump_prob, Rosenbrock23()) @test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12 @@ -48,10 +48,10 @@ g = function (du, u, p, t) du[1] = u[1] end prob = SDEProblem(f, g, [0.2], (0.0, 10.0)) -jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng = rng) -sol = solve(jump_prob, SRIW1()) -jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng = rng) -sol_gill = solve(jump_prob_gill, SRIW1()) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM()) +sol = solve(jump_prob, SRIW1(); rng) +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct()) +sol_gill = solve(jump_prob_gill, SRIW1(); rng) @test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12 @test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1e-12 @@ -74,10 +74,10 @@ function affect_switch!(integrator) end jump_switch = VariableRateJump(rate_switch, affect_switch!) prob = SDEProblem(ff, gg, ones(2), (0.0, 1.0), 0, noise_rate_prototype = zeros(2, 2)) -jump_prob = JumpProblem(prob, jump_switch; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump_switch; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, SRA1(), dt = 1.0) -sol_gill = solve(jump_prob_gill, SRA1(), dt = 1.0) +jump_prob = JumpProblem(prob, jump_switch; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump_switch; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, SRA1(), dt = 1.0; rng) +sol_gill = solve(jump_prob_gill, SRA1(), dt = 1.0; rng) ## Some integration tests @@ -88,10 +88,10 @@ prob = ODEProblem(f2, [0.2], (0.0, 10.0)) rate2(u, p, t) = 2 affect2!(integrator) = (integrator.u[1] = integrator.u[1] / 2) jump = ConstantRateJump(rate2, affect2!) -jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) sol(4.0) sol.u[4] @@ -99,10 +99,10 @@ rate2b(u, p, t) = u[1] affect2b!(integrator) = (integrator.u[1] = integrator.u[1] / 2) jump = VariableRateJump(rate2b, affect2b!) jump2 = deepcopy(jump) -jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) sol(4.0) sol.u[4] @@ -110,10 +110,10 @@ function g2(du, u, p, t) du[1] = u[1] end prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0)) -jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, SRIW1()) -sol_gill = solve(jump_prob_gill, SRIW1()) +jump_prob = JumpProblem(prob, jump, jump2; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump, jump2; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, SRIW1(); rng) +sol_gill = solve(jump_prob_gill, SRIW1(); rng) sol(4.0) sol.u[4] @@ -129,10 +129,10 @@ function affect3!(integrator) integrator.u[4] = 1) end jump = VariableRateJump(rate3, affect3!) -jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) # test for https://discourse.julialang.org/t/differentialequations-jl-package-variable-rate-jumps-with-complex-variables/80366/2 function f4(dx, x, p, t) @@ -146,10 +146,10 @@ jump = VariableRateJump(rate4, affect4!) x₀ = 1.0 + 0.0im Δt = (0.0, 6.0) prob = ODEProblem(f4, [x₀], Δt) -jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM(), rng) -jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct(), rng) -sol = solve(jump_prob, Tsit5()) -sol_gill = solve(jump_prob_gill, Tsit5()) +jump_prob = JumpProblem(prob, jump; vr_aggregator = VR_FRM()) +jump_prob_gill = JumpProblem(prob, jump; vr_aggregator = VR_Direct()) +sol = solve(jump_prob, Tsit5(); rng) +sol_gill = solve(jump_prob_gill, Tsit5(); rng) # Out of place test drift(x, p, t) = p * x @@ -158,7 +158,7 @@ affect!2(integrator) = (integrator.u ./= 2; nothing) x0 = rand(2) prob = ODEProblem(drift, x0, (0.0, 10.0), 2.0) jump = VariableRateJump(rate2c, affect!2) -jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM(), rng) +jump_prob = JumpProblem(prob, Direct(), jump; vr_aggregator = VR_FRM()) # test to check lack of dependency graphs is caught in Coevolve for systems with non-maj # jumps @@ -256,12 +256,12 @@ let d_jump = VariableRateJump(d_rate, death!) ode_prob = ODEProblem(ode_fxn, u0, tspan, p) - sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator = VR_FRM(), rng) + sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator = VR_FRM()) # After callback initialize, integrator.u.jump_u should have unique thresholds # that differ between sequential solves (RNG advances each time). jump_u_old = zeros(length(sjm_prob.prob.u0.jump_u)) for i in 1:Nsims - integrator = init(sjm_prob, Tsit5(); saveat = tspan[2]) + integrator = init(sjm_prob, Tsit5(); saveat = tspan[2], rng) @test allunique(integrator.u.jump_u) @test integrator.u.jump_u != jump_u_old jump_u_old .= integrator.u.jump_u @@ -339,10 +339,10 @@ end # Function to run ensemble and compute statistics function run_ensemble(prob, alg, jumps...; vr_aggregator = VR_FRM(), Nsims = 8000) rng = StableRNG(12345) - jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator, rng) + jump_prob = JumpProblem(prob, Direct(), jumps...; vr_aggregator) total = 0.0 for i in 1:Nsims - sol = solve(jump_prob, alg; save_everystep = false) + sol = solve(jump_prob, alg; save_everystep = false, rng) total += sol.u[end][1] end return total / Nsims @@ -438,10 +438,10 @@ let jump_counts = zeros(Int, Nsims) p = [0.0, 0.0, 0] prob = ODEProblem(f, u0, tspan, p) - jump_prob = JumpProblem(prob, Direct(), birth_jump, death_jump; vr_aggregator, rng) + jump_prob = JumpProblem(prob, Direct(), birth_jump, death_jump; vr_aggregator) for i in 1:Nsims - sol = solve(jump_prob, Tsit5(); save_everystep = false) + sol = solve(jump_prob, Tsit5(); save_everystep = false, rng) jump_counts[i] = jump_prob.prob.p[3] jump_prob.prob.p[3] = 0 end @@ -496,8 +496,8 @@ let ] for agg in aggregators - local jprob = JumpProblem(prob, agg, maj, vrj; rng) - local sol = solve(jprob, Tsit5()) + local jprob = JumpProblem(prob, agg, maj, vrj) + local sol = solve(jprob, Tsit5(); rng) @test SciMLBase.successful_retcode(sol) # Verify conservation: total population should be conserved @test sol.u[end][1] + sol.u[end][2] ≈ u0[1] + u0[2] @@ -527,8 +527,8 @@ let prob = ODEProblem(f!, u0, tspan) # Test with Direct aggregator (most common case) - jprob = JumpProblem(prob, Direct(), crj, vrj; rng) - sol = solve(jprob, Tsit5()) + jprob = JumpProblem(prob, Direct(), crj, vrj) + sol = solve(jprob, Tsit5(); rng) @test SciMLBase.successful_retcode(sol) @test sol.u[end][1] + sol.u[end][2] ≈ u0[1] + u0[2] end @@ -559,8 +559,8 @@ let # Test RSSA and RSSACR specifically (the aggregators that had the bug) for agg in [RSSA(), RSSACR()] - local jprob = JumpProblem(prob, agg, maj, vrj1, vrj2; rng) - local sol = solve(jprob, Tsit5()) + local jprob = JumpProblem(prob, agg, maj, vrj1, vrj2) + local sol = solve(jprob, Tsit5(); rng) @test SciMLBase.successful_retcode(sol) @test sol.u[end][1] + sol.u[end][2] ≈ u0[1] + u0[2] end From a7de080958638d377afc28c35b6e16d8eb175527 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 06:29:10 -0500 Subject: [PATCH 04/20] Clean up code quality: revert unnecessary diffs, add set_rng! type checking - Revert <:Any back to named type params in ssajump.jl (minimize diff from master) - Revert solkwargs to simpler ternary form in problem.jl (rng removal made filter unnecessary) - Add type checking to SSAIntegrator set_rng! to match StochasticDiffEq pattern Co-Authored-By: Claude Opus 4.5 --- src/SSA_stepper.jl | 8 ++++++++ src/aggregators/ssajump.jl | 8 ++++---- src/problem.jl | 4 ++-- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index c4f72906..913bcd74 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -118,6 +118,14 @@ end SciMLBase.has_rng(::SSAIntegrator) = true SciMLBase.get_rng(integrator::SSAIntegrator) = integrator.rng function SciMLBase.set_rng!(integrator::SSAIntegrator, rng) + R = typeof(integrator.rng) + if !isa(rng, R) + throw(ArgumentError( + "Cannot set RNG of type $(typeof(rng)) on an integrator " * + "whose RNG type parameter is $R. " * + "Construct a new integrator via `init(prob, alg; rng = your_rng)` instead." + )) + end integrator.rng = rng nothing end diff --git a/src/aggregators/ssajump.jl b/src/aggregators/ssajump.jl index 0e72b234..fec06c18 100644 --- a/src/aggregators/ssajump.jl +++ b/src/aggregators/ssajump.jl @@ -53,8 +53,8 @@ end nothing end -@inline function concretize_affects!(p::AbstractSSAJumpAggregator{<:Any, <:Any, <:Any, F2}, - ::I) where {F2 <: Tuple, I <: SciMLBase.DEIntegrator} +@inline function concretize_affects!(p::AbstractSSAJumpAggregator{T, S, F1, F2}, + ::I) where {T, S, F1, F2 <: Tuple, I <: SciMLBase.DEIntegrator} nothing end @@ -86,8 +86,8 @@ function (p::AbstractSSAJumpAggregator)(integrator::I) where {I <: SciMLBase.DEI end function (p::AbstractSSAJumpAggregator{ - <:Any, <:Any, <:Any, F2})(integrator::SciMLBase.DEIntegrator) where - {F2 <: Union{Tuple, Nothing}} + T, S, F1, F2})(integrator::SciMLBase.DEIntegrator) where + {T, S, F1, F2 <: Union{Tuple, Nothing}} execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t, p.affects!) generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) register_next_jump_time!(integrator, p, integrator.t) diff --git a/src/problem.jl b/src/problem.jl index 718179f7..3ce038de 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -309,7 +309,7 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS jump_cbs = CallbackSet(constant_jump_callback, variable_jump_callback) iip = isinplace_jump(prob, jumps.regular_jump) - solkwargs = make_kwarg(; filter(!isnothing, (; callback, tstops))...) + solkwargs = tstops === nothing ? make_kwarg(; callback) : make_kwarg(; callback, tstops) JumpProblem{iip, typeof(new_prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(cvrjs), typeof(jumps.regular_jump), @@ -355,7 +355,7 @@ function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; vrjs = jumps.variable_jumps iip = isinplace_jump(prob, jumps.regular_jump) - solkwargs = make_kwarg(; filter(!isnothing, (; callback, tstops))...) + solkwargs = tstops === nothing ? make_kwarg(; callback) : make_kwarg(; callback, tstops) JumpProblem{iip, typeof(prob), typeof(aggregator), typeof(jump_cbs), typeof(disc_agg), typeof(crjs), typeof(vrjs), typeof(jumps.regular_jump), From 0ded5246c803ba755afcbc229134be561e4fa173 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 06:38:07 -0500 Subject: [PATCH 05/20] Add tests for set_rng! type mismatch error on SSAIntegrator Co-Authored-By: Claude Opus 4.5 --- test/ensemble_problems.jl | 3 +++ test/rng_kwarg_tests.jl | 3 +++ 2 files changed, 6 insertions(+) diff --git a/test/ensemble_problems.jl b/test/ensemble_problems.jl index 7cfc7066..27e44650 100644 --- a/test/ensemble_problems.jl +++ b/test/ensemble_problems.jl @@ -181,6 +181,9 @@ end new_rng = StableRNG(99) SciMLBase.set_rng!(integrator, new_rng) @test SciMLBase.get_rng(integrator) === new_rng + + # mismatched RNG type should throw + @test_throws ArgumentError SciMLBase.set_rng!(integrator, Random.Xoshiro(123)) end # ========================================================================== diff --git a/test/rng_kwarg_tests.jl b/test/rng_kwarg_tests.jl index a677a911..010b96ce 100644 --- a/test/rng_kwarg_tests.jl +++ b/test/rng_kwarg_tests.jl @@ -167,6 +167,9 @@ end new_rng = StableRNG(99) SciMLBase.set_rng!(integrator, new_rng) @test SciMLBase.get_rng(integrator) === new_rng + + # mismatched RNG type should throw + @test_throws ArgumentError SciMLBase.set_rng!(integrator, Random.Xoshiro(123)) end # ========================================================================== From 20376a9ece795299d021d9f67b8f9d8b1e58c6f5 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 06:45:31 -0500 Subject: [PATCH 06/20] Remove unused StableRNG import and variable from GPU test The GPU tau-leaping kernel uses PassthroughRNG internally; the StableRNG was never consumed by the EnsembleGPUKernel path. Co-Authored-By: Claude Opus 4.5 --- test/gpu/regular_jumps.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/gpu/regular_jumps.jl b/test/gpu/regular_jumps.jl index b35276cf..bb136dd1 100644 --- a/test/gpu/regular_jumps.jl +++ b/test/gpu/regular_jumps.jl @@ -1,9 +1,6 @@ using JumpProcesses, DiffEqBase using Test, LinearAlgebra, Statistics using KernelAbstractions, Adapt, CUDA -using StableRNGs -rng = StableRNG(12345) - Nsims = 100_000 # SIR model with influx From 0c32cceea42d1b7636f5270be5d8023525b381d2 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 06:59:52 -0500 Subject: [PATCH 07/20] Simplify test solve calls: remove unnecessary isnothing(rng) checks solve(; rng = nothing) falls through to default_rng() via resolve_rng, so the ternary guards are redundant. Co-Authored-By: Claude Opus 4.5 --- test/bimolerx_test.jl | 4 ++-- test/geneexpr_test.jl | 3 +-- test/reversible_binding.jl | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/test/bimolerx_test.jl b/test/bimolerx_test.jl index 832968b7..05917bf5 100644 --- a/test/bimolerx_test.jl +++ b/test/bimolerx_test.jl @@ -59,9 +59,9 @@ function runSSAs(jump_prob; use_stepper = true, rng = nothing) Psamp = zeros(Int, Nsims) for i in 1:Nsims sol = if use_stepper - isnothing(rng) ? solve(jump_prob, SSAStepper()) : solve(jump_prob, SSAStepper(); rng) + solve(jump_prob, SSAStepper(); rng) else - isnothing(rng) ? solve(jump_prob) : solve(jump_prob; rng) + solve(jump_prob; rng) end Psamp[i] = sol[1, end] end diff --git a/test/geneexpr_test.jl b/test/geneexpr_test.jl index d26d9b8a..5212e071 100644 --- a/test/geneexpr_test.jl +++ b/test/geneexpr_test.jl @@ -26,10 +26,9 @@ function runSSAs(jump_prob; use_stepper = true, rng = nothing) Psamp = zeros(Int, Nsims) for i in 1:Nsims sol = if use_stepper - isnothing(rng) ? solve(jump_prob, SSAStepper()) : solve(jump_prob, SSAStepper(); rng) else - isnothing(rng) ? solve(jump_prob) : solve(jump_prob; rng) + solve(jump_prob; rng) end Psamp[i] = sol[3, end] end diff --git a/test/reversible_binding.jl b/test/reversible_binding.jl index 456a96d2..d0e9b9a6 100644 --- a/test/reversible_binding.jl +++ b/test/reversible_binding.jl @@ -23,7 +23,7 @@ majumps = MassActionJump(rates, reactstoch, netstoch) function getmean(jprob, Nsims; rng = nothing) Amean = 0 for i in 1:Nsims - sol = isnothing(rng) ? solve(jprob, SSAStepper()) : solve(jprob, SSAStepper(); rng) + sol = solve(jprob, SSAStepper(); rng) Amean += sol[1, end] end Amean /= Nsims From 33b86e9cb520947309e01890a4ce74b6d78eb95e Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 07:14:08 -0500 Subject: [PATCH 08/20] Use jump times instead of final state for trajectory uniqueness tests Comparing continuous-valued jump times is more robust than discrete final state values, which could collide by chance. Co-Authored-By: Claude Opus 4.5 --- test/ensemble_problems.jl | 8 ++------ test/rng_kwarg_tests.jl | 4 ++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/test/ensemble_problems.jl b/test/ensemble_problems.jl index 27e44650..61c82e31 100644 --- a/test/ensemble_problems.jl +++ b/test/ensemble_problems.jl @@ -54,8 +54,6 @@ first_jump_time(traj) = traj.t[2] trajectories = 3, rng = StableRNG(12345)) times = [first_jump_time(sol.u[i]) for i in 1:3] @test allunique(times) - finals = [sol.u[i].u[end][1] for i in 1:3] - @test allunique(finals) end # EM() uses a fixed time grid so jump event times aren't directly visible @@ -87,8 +85,6 @@ end sols = [solve(jprob, Tsit5(); rng) for _ in 1:3] times = [first_jump_time(s) for s in sols] @test allunique(times) - finals = [s.u[end][1] for s in sols] - @test allunique(finals) end end @@ -161,8 +157,8 @@ end @testset "ODE + VR ($agg): different seeds → different trajectories" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jprob = make_vr_jump_prob(agg) sols = [solve(jprob, Tsit5(); rng = StableRNG(s)) for s in (100, 200, 300)] - finals = [s.u[end][1] for s in sols] - @test allunique(finals) + times = [first_jump_time(s) for s in sols] + @test allunique(times) end end diff --git a/test/rng_kwarg_tests.jl b/test/rng_kwarg_tests.jl index 010b96ce..09f6ab1e 100644 --- a/test/rng_kwarg_tests.jl +++ b/test/rng_kwarg_tests.jl @@ -102,8 +102,8 @@ end @testset "ODE + VR: different seeds → different trajectories" begin jprob = make_ode_vr_jump_prob() sols = [solve(jprob, Tsit5(); rng = StableRNG(s)) for s in (100, 200, 300)] - finals = [s.u[end][1] for s in sols] - @test allunique(finals) + times = [s.t[2] for s in sols] + @test allunique(times) end # ========================================================================== From 2b6d3c6287279efdb66388f60a9bd2de0d3d4421 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 07:54:40 -0500 Subject: [PATCH 09/20] Strengthen VR_FRM threshold uniqueness test Check jump_u thresholds directly via init, and verify both event times and post-event thresholds from ensemble solve. Co-Authored-By: Claude Opus 4.5 --- test/ensemble_problems.jl | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/test/ensemble_problems.jl b/test/ensemble_problems.jl index 61c82e31..f08ea524 100644 --- a/test/ensemble_problems.jl +++ b/test/ensemble_problems.jl @@ -187,18 +187,29 @@ end # # For VR_FRM, each trajectory's first jump time is determined by the initial # jump_u threshold (set to -randexp() by the VR_FRMEventCallback initialize). -# Distinct thresholds → distinct first event times, so we verify by checking -# that the second time point (first event) differs across serial trajectories. +# We verify both the thresholds (via init) and the resulting event times. # ========================================================================== @testset "VR_FRM: jump_u thresholds unique per trajectory (EnsembleSerial)" begin jprob = make_vr_jump_prob(VR_FRM()) + + # Check jump_u thresholds directly via init (callback sets them during initialization) + rng = StableRNG(12345) + thresholds = [begin + integrator = init(jprob, Tsit5(); rng) + integrator.u.jump_u[1] + end for _ in 1:3] + @test allunique(thresholds) + + # From a full ensemble solve, check both first event times and the + # post-event jump_u thresholds (u[3] is the post-event save where + # jump_u has been reset to a new -randexp() value). sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleSerial(); trajectories = 3, rng = StableRNG(12345)) - # The second time point is when the first variable-rate jump fires, - # directly reflecting the initial -randexp() threshold. event_times = [sol.u[i].t[2] for i in 1:3] @test allunique(event_times) + post_event_thresholds = [sol.u[i].u[3].jump_u[1] for i in 1:3] + @test allunique(post_event_thresholds) end # ========================================================================== From 44ced847b40dc25d07821139a750ea1cfab6fcdd Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 09:33:08 -0500 Subject: [PATCH 10/20] Move EnsembleThreads tests to thread_safety.jl, fix Project.toml - Remove StochasticDiffEq from [deps] (test-only, belongs in [extras]) - Move EnsembleThreads data race tests from ensemble_problems.jl to thread_safety.jl with stronger assertions (400 trajectories, allunique on first reaction times, all 3 VR aggregators, adaptive SDE solver) - Renumber remaining ensemble_problems.jl sections Co-Authored-By: Claude Opus 4.5 --- Project.toml | 1 - test/ensemble_problems.jl | 43 ++++----------------------------------- test/thread_safety.jl | 21 ++++++++++++++++++- 3 files changed, 24 insertions(+), 41 deletions(-) diff --git a/Project.toml b/Project.toml index 3a528ceb..0d264333 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,6 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" [weakdeps] diff --git a/test/ensemble_problems.jl b/test/ensemble_problems.jl index aa619eab..efff24f6 100644 --- a/test/ensemble_problems.jl +++ b/test/ensemble_problems.jl @@ -90,42 +90,7 @@ end end # ========================================================================== -# 3. Threaded ensemble: no data race on the shared JumpProblem -# -# With integrator-owned RNGs, each thread's integrator gets its own -# default_rng(). We only assert completion here — uniqueness is tested -# via explicit rng kwarg in section 4. -# ========================================================================== - -@testset "EnsembleThreads: no data race" begin - @testset "SSAStepper" begin - jprob = make_ssa_jump_prob() - sol = solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); - trajectories = 4) - @test length(sol) == 4 - end - - @testset "ODE + VR ($agg)" for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) - jprob = make_vr_jump_prob(agg) - sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleThreads(); - trajectories = 4, save_everystep = false) - @test length(sol) == 4 - end - - @testset "SDE + VR (VR_FRM): unique trajectories" begin - jprob = make_sde_vr_jump_prob(VR_FRM()) - # StochasticDiffEq generates per-trajectory seeds, so trajectories - # should be distinct. - sol = solve(EnsembleProblem(jprob), EM(), EnsembleThreads(); - trajectories = 4, dt = 0.01, save_everystep = false) - @test length(sol) == 4 - finals = [sol.u[i].u[end][1] for i in 1:4] - @test length(unique(finals)) > 1 - end -end - -# ========================================================================== -# 4. rng kwarg reproducibility: same rng seed → identical trajectory, +# 3. rng kwarg reproducibility: same rng seed → identical trajectory, # different rng seeds → different trajectories # ========================================================================== @@ -164,7 +129,7 @@ end end # ========================================================================== -# 5. has_rng / get_rng / set_rng! interface on SSAIntegrator +# 4. has_rng / get_rng / set_rng! interface on SSAIntegrator # ========================================================================== @testset "SSAIntegrator RNG interface" begin @@ -184,7 +149,7 @@ end end # ========================================================================== -# 6. Variable-rate: jump_u thresholds are unique per trajectory +# 5. Variable-rate: jump_u thresholds are unique per trajectory # # For VR_FRM, each trajectory's first jump time is determined by the initial # jump_u threshold (set to -randexp() by the VR_FRMEventCallback initialize). @@ -214,7 +179,7 @@ end end # ========================================================================== -# 7. JumpProblem rng kwarg forwarded to solver +# 6. JumpProblem rng kwarg throws ArgumentError # ========================================================================== @testset "JumpProblem rng kwarg throws ArgumentError" begin diff --git a/test/thread_safety.jl b/test/thread_safety.jl index 349346db..2f80d5e8 100644 --- a/test/thread_safety.jl +++ b/test/thread_safety.jl @@ -1,5 +1,5 @@ using DiffEqBase, Test -using JumpProcesses, OrdinaryDiffEq +using JumpProcesses, OrdinaryDiffEq, StochasticDiffEq using StableRNGs rng = StableRNG(12345) @@ -34,3 +34,22 @@ let @test allunique(firstrx_time) end end + +# SDE + variable-rate jumps with EnsembleThreads +let + f!(du, u, p, t) = (du[1] = -0.1 * u[1]; nothing) + g!(du, u, p, t) = (du[1] = 0.1 * u[1]; nothing) + sde_prob = SDEProblem(f!, g!, [100.0], (0.0, 10.0)) + vrj = VariableRateJump((u, p, t) -> 0.5 * u[1], + integrator -> (integrator.u[1] -= 1.0)) + + for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) + jump_prob = JumpProblem(sde_prob, Direct(), vrj; vr_aggregator = agg) + prob_func(prob, i, repeat) = deepcopy(prob) + prob = EnsembleProblem(jump_prob; prob_func) + sol = solve(prob, SRIW1(), EnsembleThreads(); + trajectories = 400, save_everystep = false) + firstrx_time = [sol.u[i].t[findfirst(>(sol.u[i].t[1]), sol.u[i].t)] for i in 1:length(sol)] + @test allunique(firstrx_time) + end +end From 55e130b71d306a7a40742187cdb3e6f476021e47 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 11:54:25 -0500 Subject: [PATCH 11/20] Address PR review: strengthen tests, add seed coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ensemble_problems.jl: check initial jump_u threshold (u[2]) instead of post-event (u[3]) now that u_modified!(true) produces init save - extended_jump_array_remake.jl: rework to solve → remake → init → compare jump_u, verifying fresh thresholds after remake - geneexpr_test.jl: remove unnecessary isnothing(rng) branch - linearreaction_test.jl: make rng const for type stability - variable_rate.jl: switch getmean from seed integer to StableRNG - rng_kwarg_tests.jl: add SDE+VR different seeds test, SDE+VR no-rng functional test, seed kwarg reproducibility and different-seeds tests for all pathways (SSAStepper, ODE+VR, SDE+VR), use first_jump_time helper for robust event time extraction Co-Authored-By: Claude Opus 4.5 --- test/ensemble_problems.jl | 8 +-- test/extended_jump_array_remake.jl | 46 +++++++----- test/geneexpr_test.jl | 6 +- test/linearreaction_test.jl | 2 +- test/rng_kwarg_tests.jl | 111 ++++++++++++++++++++++++++--- test/variable_rate.jl | 10 +-- 6 files changed, 139 insertions(+), 44 deletions(-) diff --git a/test/ensemble_problems.jl b/test/ensemble_problems.jl index efff24f6..da11b321 100644 --- a/test/ensemble_problems.jl +++ b/test/ensemble_problems.jl @@ -168,14 +168,14 @@ end @test allunique(thresholds) # From a full ensemble solve, check both first event times and the - # post-event jump_u thresholds (u[3] is the post-event save where - # jump_u has been reset to a new -randexp() value). + # initial jump_u thresholds (u[2] is the initialization save where + # jump_u has been set to -randexp() by the callback). sol = solve(EnsembleProblem(jprob), Tsit5(), EnsembleSerial(); trajectories = 3, rng = StableRNG(12345)) event_times = [first_jump_time(sol.u[i]) for i in 1:3] @test allunique(event_times) - post_event_thresholds = [sol.u[i].u[3].jump_u[1] for i in 1:3] - @test allunique(post_event_thresholds) + init_thresholds = [sol.u[i].u[2].jump_u[1] for i in 1:3] + @test allunique(init_thresholds) end # ========================================================================== diff --git a/test/extended_jump_array_remake.jl b/test/extended_jump_array_remake.jl index 224064ae..7a32c64d 100644 --- a/test/extended_jump_array_remake.jl +++ b/test/extended_jump_array_remake.jl @@ -5,8 +5,6 @@ using JumpProcesses, OrdinaryDiffEq, Test, SymbolicIndexingInterface using StableRNGs @testset "remake JumpProblem with VariableRateJumps (ExtendedJumpArray)" begin - rng = StableRNG(12345) - # Setup: Create an ODEProblem with SymbolCache for symbolic indexing f(du, u, p, t) = (du .= 0; nothing) g = ODEFunction(f; sys = SymbolCache([:X, :Y], [:k1, :k2], :t)) @@ -23,15 +21,19 @@ using StableRNGs @test jprob.prob.u0 isa ExtendedJumpArray @test jprob.prob.u0.u == [10.0, 5.0] + # Solve original problem and capture jump_u after initialization + orig_integrator = init(jprob, Tsit5(); rng = StableRNG(42)) + orig_jump_u = copy(orig_integrator.u.jump_u) + @testset "remake with numeric Vector{Float64}" begin prob2 = remake(jprob; u0 = [20.0, 10.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u == [20.0, 10.0] - # jump_u is zeroed at construction; callback initializes it at solve time @test all(iszero, prob2.prob.u0.jump_u) - # After init the callback should set jump_u to non-zero thresholds - integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + # After init, callback sets fresh jump_u thresholds (different RNG seed) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake with ExtendedJumpArray (no resample)" begin @@ -48,34 +50,34 @@ using StableRNGs end @testset "remake with Symbol pairs" begin - # This was the FAILING case - should work after fix prob2 = remake(jprob; u0 = [:X => 25.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u[1] == 25.0 - # jump_u zeroed at construction, initialized by callback at solve time @test all(iszero, prob2.prob.u0.jump_u) - integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + # After init, callback sets fresh jump_u thresholds (different RNG seed) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake with multiple Symbol pairs" begin prob2 = remake(jprob; u0 = [:X => 35.0, :Y => 15.0]) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u == [35.0, 15.0] - # jump_u zeroed at construction, initialized by callback at solve time @test all(iszero, prob2.prob.u0.jump_u) - integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake with Dict" begin prob2 = remake(jprob; u0 = Dict(:X => 40.0)) @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u[1] == 40.0 - # jump_u zeroed at construction, initialized by callback at solve time @test all(iszero, prob2.prob.u0.jump_u) - integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake with parameters only (u0 unchanged)" begin @@ -93,21 +95,27 @@ using StableRNGs @test prob2.prob.u0 isa ExtendedJumpArray @test prob2.prob.u0.u[1] == 50.0 @test prob2.prob.p[1] == 3.0 - # jump_u zeroed at construction, initialized by callback at solve time @test all(iszero, prob2.prob.u0.jump_u) - integrator = init(prob2, Tsit5(); rng = StableRNG(42)) + integrator = init(prob2, Tsit5(); rng = StableRNG(99)) @test any(!iszero, integrator.u.jump_u) + @test integrator.u.jump_u != orig_jump_u end @testset "remake preserves problem solvability" begin - # Ensure remade problems can actually be solved + # Solve original, then remake and solve again — jump_u should differ + sol1 = solve(jprob, Tsit5(); rng = StableRNG(42)) + @test SciMLBase.successful_retcode(sol1) + prob2 = remake(jprob; u0 = [5.0, 2.0]) - sol = solve(prob2, Tsit5()) - @test SciMLBase.successful_retcode(sol) + sol2 = solve(prob2, Tsit5(); rng = StableRNG(99)) + @test SciMLBase.successful_retcode(sol2) + # Different RNG seeds → different jump_u thresholds after init + @test sol1.u[2].jump_u != sol2.u[2].jump_u - # With symbolic map (after fix) + # With symbolic map prob3 = remake(jprob; u0 = [:X => 8.0]) - sol3 = solve(prob3, Tsit5()) + sol3 = solve(prob3, Tsit5(); rng = StableRNG(77)) @test SciMLBase.successful_retcode(sol3) + @test sol1.u[2].jump_u != sol3.u[2].jump_u end end diff --git a/test/geneexpr_test.jl b/test/geneexpr_test.jl index 5212e071..45e9a02b 100644 --- a/test/geneexpr_test.jl +++ b/test/geneexpr_test.jl @@ -38,11 +38,7 @@ end function runSSAs_ode(vrjprob; rng = nothing) Psamp = zeros(Float64, Nsims) tsave = vrjprob.prob.tspan[2] - integrator = if isnothing(rng) - init(vrjprob, Tsit5(); saveat = tsave) - else - init(vrjprob, Tsit5(); saveat = tsave, rng) - end + integrator = init(vrjprob, Tsit5(); saveat = tsave, rng) solve!(integrator) Psamp[1] = integrator.sol[3, end] for i in 2:Nsims diff --git a/test/linearreaction_test.jl b/test/linearreaction_test.jl index 67e13b15..7883883c 100644 --- a/test/linearreaction_test.jl +++ b/test/linearreaction_test.jl @@ -3,7 +3,7 @@ using DiffEqBase, JumpProcesses, Statistics using Test using StableRNGs -rng = StableRNG(12345) +const rng = StableRNG(12345) # using BenchmarkTools # dobenchmark = true diff --git a/test/rng_kwarg_tests.jl b/test/rng_kwarg_tests.jl index 09f6ab1e..f3284f3b 100644 --- a/test/rng_kwarg_tests.jl +++ b/test/rng_kwarg_tests.jl @@ -33,6 +33,9 @@ function make_sde_vr_jump_prob() JumpProblem(sprob, Direct(), vrj) end +# Helpers +first_jump_time(traj) = traj.t[findfirst(>(traj.t[1]), traj.t)] + # ========================================================================== # 1. JumpProblem(; rng=...) throws ArgumentError # ========================================================================== @@ -102,7 +105,7 @@ end @testset "ODE + VR: different seeds → different trajectories" begin jprob = make_ode_vr_jump_prob() sols = [solve(jprob, Tsit5(); rng = StableRNG(s)) for s in (100, 200, 300)] - times = [s.t[2] for s in sols] + times = [first_jump_time(s) for s in sols] @test allunique(times) end @@ -126,7 +129,18 @@ end end # ========================================================================== -# 10. SimpleTauLeaping: rng via solve kwargs +# 10. SDE + VR: different seeds → different trajectories +# ========================================================================== +@testset "SDE + VR: different seeds → different trajectories" begin + jprob = make_sde_vr_jump_prob() + sols = [solve(jprob, SRIW1(); save_everystep = false, + rng = StableRNG(s)) for s in (100, 200, 300)] + times = [first_jump_time(s) for s in sols] + @test allunique(times) +end + +# ========================================================================== +# 11. SimpleTauLeaping: rng via solve kwargs # ========================================================================== @testset "SimpleTauLeaping: rng via solve kwargs" begin rate(out, u, p, t) = (out .= max.(u, 0); nothing) @@ -140,7 +154,7 @@ end end # ========================================================================== -# 11. SimpleTauLeaping: different seeds → different trajectories +# 12. SimpleTauLeaping: different seeds → different trajectories # ========================================================================== @testset "SimpleTauLeaping: different seeds → different trajectories" begin rate(out, u, p, t) = (out .= max.(u, 0); nothing) @@ -154,7 +168,7 @@ end end # ========================================================================== -# 12. has_rng / get_rng / set_rng! interface on SSAIntegrator +# 13. has_rng / get_rng / set_rng! interface on SSAIntegrator # ========================================================================== @testset "SSAIntegrator RNG interface" begin jprob = make_ssa_jump_prob() @@ -173,7 +187,7 @@ end end # ========================================================================== -# 13. No rng kwarg: uses default_rng (non-reproducible but functional) +# 14. No rng kwarg: uses default_rng (non-reproducible but functional) # ========================================================================== @testset "No rng kwarg: functional solve" begin @testset "SSAStepper" begin @@ -187,19 +201,96 @@ end sol = solve(jprob, Tsit5()) @test sol.retcode == ReturnCode.Success end + + @testset "SDE + VR" begin + jprob = make_sde_vr_jump_prob() + sol = solve(jprob, EM(); dt = 0.01) + @test sol.retcode == ReturnCode.Success + end end # ========================================================================== -# 14. seed kwarg: creates Xoshiro from integer seed +# 15. seed kwarg: creates Xoshiro from integer seed # ========================================================================== @testset "seed kwarg creates Xoshiro" begin - jprob = make_ssa_jump_prob() - integrator = init(jprob, SSAStepper(); seed = 42) - @test SciMLBase.get_rng(integrator) isa Xoshiro + @testset "SSAStepper" begin + jprob = make_ssa_jump_prob() + integrator = init(jprob, SSAStepper(); seed = 42) + @test SciMLBase.get_rng(integrator) isa Xoshiro + end + + @testset "ODE + VR" begin + jprob = make_ode_vr_jump_prob() + integrator = init(jprob, Tsit5(); seed = 42) + @test SciMLBase.get_rng(integrator) isa Xoshiro + end + + @testset "SDE + VR" begin + jprob = make_sde_vr_jump_prob() + integrator = init(jprob, EM(); dt = 0.01, seed = 42) + @test SciMLBase.get_rng(integrator) isa Xoshiro + end +end + +# ========================================================================== +# 16. seed kwarg reproducibility: same seed → same trajectory +# ========================================================================== +@testset "seed kwarg reproducibility" begin + @testset "SSAStepper" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); seed = 42) + sol2 = solve(jprob, SSAStepper(); seed = 42) + @test sol1.t == sol2.t + @test sol1.u == sol2.u + end + + @testset "ODE + VR" begin + jprob = make_ode_vr_jump_prob() + sol1 = solve(jprob, Tsit5(); seed = 42) + sol2 = solve(jprob, Tsit5(); seed = 42) + @test sol1.t ≈ sol2.t + @test sol1.u[end] ≈ sol2.u[end] + end + + @testset "SDE + VR" begin + jprob = make_sde_vr_jump_prob() + sol1 = solve(jprob, EM(); dt = 0.01, save_everystep = false, seed = 42) + sol2 = solve(jprob, EM(); dt = 0.01, save_everystep = false, seed = 42) + @test sol1.u[end] ≈ sol2.u[end] + end +end + +# ========================================================================== +# 17. seed kwarg: different seeds → different trajectories +# ========================================================================== +@testset "seed kwarg: different seeds → different trajectories" begin + @testset "SSAStepper" begin + jprob = make_ssa_jump_prob() + sol1 = solve(jprob, SSAStepper(); seed = 100) + sol2 = solve(jprob, SSAStepper(); seed = 200) + sol3 = solve(jprob, SSAStepper(); seed = 300) + times = [sol1.t[2], sol2.t[2], sol3.t[2]] + @test allunique(times) + end + + @testset "ODE + VR" begin + jprob = make_ode_vr_jump_prob() + sols = [solve(jprob, Tsit5(); seed = s) for s in (100, 200, 300)] + times = [first_jump_time(s) for s in sols] + @test allunique(times) + end + + @testset "SDE + VR" begin + jprob = make_sde_vr_jump_prob() + sols = [solve(jprob, SRIW1(); save_everystep = false, + seed = s) for s in (100, 200, 300)] + times = [first_jump_time(s) for s in sols] + @test allunique(times) + end end # ========================================================================== -# 15. rng takes priority over seed +# 18. rng takes priority over seed # ========================================================================== @testset "rng takes priority over seed" begin jprob = make_ssa_jump_prob() diff --git a/test/variable_rate.jl b/test/variable_rate.jl index 63a6567f..195479ec 100644 --- a/test/variable_rate.jl +++ b/test/variable_rate.jl @@ -273,9 +273,9 @@ end # https://github.com/SciML/JumpProcesses.jl/issues/320 # note that even with the seeded StableRNG this test is not # deterministic for some reason. -function getmean(Nsims, prob, alg, tsave, seed) +function getmean(Nsims, prob, alg, tsave, rng) umean = zeros(length(tsave)) - integrator = init(prob, alg; saveat = tsave, seed) + integrator = init(prob, alg; saveat = tsave, rng) solve!(integrator) for j in eachindex(umean) umean[j] += integrator.sol.u[j][1] @@ -292,7 +292,7 @@ function getmean(Nsims, prob, alg, tsave, seed) end let - seed = 12345 + rng_seed = 12345 b = 2.0 d = 1.0 n0 = 1.0 @@ -327,9 +327,9 @@ let sjm_prob = JumpProblem(ode_prob, b_jump, d_jump; vr_aggregator) for alg in (Tsit5(), Rodas5P(linsolve = QRFactorization())) - umean = getmean(Nsims, sjm_prob, alg, tsave, seed) + umean = getmean(Nsims, sjm_prob, alg, tsave, StableRNG(rng_seed)) @test all(abs.(umean .- n.(tsave)) .< 0.05 * n.(tsave)) - seed += Nsims + rng_seed += Nsims end end end From 96864366137e68c4406fade5f85af108c9f8ba32 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 14:21:14 -0500 Subject: [PATCH 12/20] Address review round 3: fix docs example, add tau-leaping seed tests, strengthen thread_safety assertions - docs/advanced_point_process.md: move rng from JumpProblem to solve - rng_kwarg_tests.jl: add seed kwarg tests for SimpleTauLeaping (tests 13-14), renumber 15-20 - thread_safety.jl: SSAStepper EnsembleThreads now runs 400 trajectories and asserts allunique first jump times Co-Authored-By: Claude Opus 4.5 --- .../applications/advanced_point_process.md | 4 +- test/rng_kwarg_tests.jl | 40 ++++++++++++++++--- test/thread_safety.jl | 16 ++++++-- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/docs/src/applications/advanced_point_process.md b/docs/src/applications/advanced_point_process.md index 97677001..87e397d8 100644 --- a/docs/src/applications/advanced_point_process.md +++ b/docs/src/applications/advanced_point_process.md @@ -387,10 +387,10 @@ function Base.rand(rng::AbstractRNG, out = Array{History, 1}(undef, n) p = params(pp) dprob = DiscreteProblem([0], tspan, p) - jprob = JumpProblem(dprob, Coevolve(), jumps...; dep_graph = pp.g, save_positions, rng) + jprob = JumpProblem(dprob, Coevolve(), jumps...; dep_graph = pp.g, save_positions) for i in 1:n params!(pp, p) - solve(jprob, SSAStepper()) + solve(jprob, SSAStepper(); rng) out[i] = deepcopy(p.h) end return out diff --git a/test/rng_kwarg_tests.jl b/test/rng_kwarg_tests.jl index f3284f3b..a237cecd 100644 --- a/test/rng_kwarg_tests.jl +++ b/test/rng_kwarg_tests.jl @@ -168,7 +168,35 @@ end end # ========================================================================== -# 13. has_rng / get_rng / set_rng! interface on SSAIntegrator +# 13. SimpleTauLeaping: seed kwarg reproducibility +# ========================================================================== +@testset "SimpleTauLeaping: seed kwarg reproducibility" begin + rate(out, u, p, t) = (out .= max.(u, 0); nothing) + c(du, u, p, t, counts, mark) = (du .= counts; nothing) + rj = RegularJump(rate, c, 2) + dprob = DiscreteProblem([100, 100], (0.0, 1.0)) + jprob = JumpProblem(dprob, PureLeaping(), rj) + sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01, seed = 42) + sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, seed = 42) + @test sol1.u == sol2.u +end + +# ========================================================================== +# 14. SimpleTauLeaping: seed different seeds → different trajectories +# ========================================================================== +@testset "SimpleTauLeaping: seed different seeds → different trajectories" begin + rate(out, u, p, t) = (out .= max.(u, 0); nothing) + c(du, u, p, t, counts, mark) = (du .= counts; nothing) + rj = RegularJump(rate, c, 2) + dprob = DiscreteProblem([100, 100], (0.0, 1.0)) + jprob = JumpProblem(dprob, PureLeaping(), rj) + sol1 = solve(jprob, SimpleTauLeaping(); dt = 0.01, seed = 42) + sol2 = solve(jprob, SimpleTauLeaping(); dt = 0.01, seed = 99) + @test sol1.u != sol2.u +end + +# ========================================================================== +# 15. has_rng / get_rng / set_rng! interface on SSAIntegrator # ========================================================================== @testset "SSAIntegrator RNG interface" begin jprob = make_ssa_jump_prob() @@ -187,7 +215,7 @@ end end # ========================================================================== -# 14. No rng kwarg: uses default_rng (non-reproducible but functional) +# 16. No rng kwarg: uses default_rng (non-reproducible but functional) # ========================================================================== @testset "No rng kwarg: functional solve" begin @testset "SSAStepper" begin @@ -210,7 +238,7 @@ end end # ========================================================================== -# 15. seed kwarg: creates Xoshiro from integer seed +# 17. seed kwarg: creates Xoshiro from integer seed # ========================================================================== @testset "seed kwarg creates Xoshiro" begin @testset "SSAStepper" begin @@ -233,7 +261,7 @@ end end # ========================================================================== -# 16. seed kwarg reproducibility: same seed → same trajectory +# 18. seed kwarg reproducibility: same seed → same trajectory # ========================================================================== @testset "seed kwarg reproducibility" begin @testset "SSAStepper" begin @@ -261,7 +289,7 @@ end end # ========================================================================== -# 17. seed kwarg: different seeds → different trajectories +# 19. seed kwarg: different seeds → different trajectories # ========================================================================== @testset "seed kwarg: different seeds → different trajectories" begin @testset "SSAStepper" begin @@ -290,7 +318,7 @@ end end # ========================================================================== -# 18. rng takes priority over seed +# 20. rng takes priority over seed # ========================================================================== @testset "rng takes priority over seed" begin jprob = make_ssa_jump_prob() diff --git a/test/thread_safety.jl b/test/thread_safety.jl index 2f80d5e8..4c4cfcfd 100644 --- a/test/thread_safety.jl +++ b/test/thread_safety.jl @@ -10,9 +10,19 @@ tspan = (0.0, 4.0) u0 = [5] dprob = DiscreteProblem(u0, tspan, params) jprob = JumpProblem(dprob, Direct(), maj) -solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); trajectories = 10, rng) -solve(EnsembleProblem(jprob; safetycopy = true), SSAStepper(), EnsembleThreads(); - trajectories = 10, rng) + +# Verify threaded solves complete and produce distinct trajectories +sol = solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); + trajectories = 400, rng) +@test length(sol) == 400 +firstrx_time = [sol.u[i].t[findfirst(>(sol.u[i].t[1]), sol.u[i].t)] for i in 1:length(sol)] +@test allunique(firstrx_time) + +sol2 = solve(EnsembleProblem(jprob; safetycopy = true), SSAStepper(), EnsembleThreads(); + trajectories = 400, rng) +@test length(sol2) == 400 +firstrx_time2 = [sol2.u[i].t[findfirst(>(sol2.u[i].t[1]), sol2.u[i].t)] for i in 1:length(sol2)] +@test allunique(firstrx_time2) # test for https://github.com/SciML/JumpProcesses.jl/issues/472 let From ec3b6610227a1cd4f78c1b6f30f62f64a0da6d0a Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Tue, 24 Feb 2026 14:41:49 -0500 Subject: [PATCH 13/20] Add dedicated threaded CI for thread_safety tests - Create ThreadSafety.yml workflow: runs thread_safety.jl with --threads=4 - Add ThreadSafety group to Tests.yml matrix (serial) and runtests.jl - Move thread_safety.jl from InterfaceII to ThreadSafety group - Fix shared rng data race: SSAStepper ensemble tests now use task-local default_rng() instead of a shared StableRNG across threads Co-Authored-By: Claude Opus 4.5 --- .github/workflows/Tests.yml | 1 + .github/workflows/ThreadSafety.yml | 65 ++++++++++++++++++++++++++++++ test/runtests.jl | 5 ++- test/thread_safety.jl | 13 +++--- 4 files changed, 78 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/ThreadSafety.yml diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 3f223b3d..dd8ce04f 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -29,6 +29,7 @@ jobs: group: - InterfaceI - InterfaceII + - ThreadSafety - QA exclude: - version: "pre" diff --git a/.github/workflows/ThreadSafety.yml b/.github/workflows/ThreadSafety.yml new file mode 100644 index 00000000..b82eda2d --- /dev/null +++ b/.github/workflows/ThreadSafety.yml @@ -0,0 +1,65 @@ +name: "Thread Safety Tests" + +on: + push: + branches: + - master + paths-ignore: + - 'docs/**' + pull_request: + branches: + - master + paths-ignore: + - 'docs/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref_name != github.event.repository.default_branch || github.ref != 'refs/tags/v*' }} + +jobs: + thread-safety: + name: "Thread Safety" + strategy: + fail-fast: false + matrix: + version: + - "1" + - "lts" + - "pre" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: "Setup Julia ${{ matrix.version }}" + uses: julia-actions/setup-julia@v2 + with: + version: "${{ matrix.version }}" + + - name: "Cache Julia packages" + uses: julia-actions/cache@v2 + with: + token: "${{ secrets.GITHUB_TOKEN }}" + + - name: "Build package" + uses: julia-actions/julia-buildpkg@v1 + + - name: "Run thread safety tests (4 threads)" + run: | + julia --threads=4 --code-coverage=user --check-bounds=yes --compiled-modules=yes \ + --project=@. --color=yes -e ' + using Pkg + Pkg.test() + ' + env: + GROUP: ThreadSafety + + - name: "Process Coverage" + uses: julia-actions/julia-processcoverage@v1 + + - name: "Report Coverage" + uses: codecov/codecov-action@v5 + if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} + with: + files: lcov.info + token: "${{ secrets.CODECOV_TOKEN }}" + fail_ci_if_error: false diff --git a/test/runtests.jl b/test/runtests.jl index be0236c3..e4b097c4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,7 +41,6 @@ end @time @safetestset "Save_positions test" begin include("save_positions.jl") end @time @safetestset "RNG kwarg tests" begin include("rng_kwarg_tests.jl") end @time @safetestset "Ensemble Uniqueness test" begin include("ensemble_uniqueness.jl") end - @time @safetestset "Thread Safety test" begin include("thread_safety.jl") end @time @safetestset "Ensemble Problem Tests" begin include("ensemble_problems.jl") end @time @safetestset "A + B <--> C" begin include("reversible_binding.jl") end @time @safetestset "Remake tests" begin include("remake_test.jl") end @@ -58,6 +57,10 @@ end @time @safetestset "Pure diffusion" begin include("spatial/diffusion.jl") end end + if GROUP == "All" || GROUP == "ThreadSafety" + @time @safetestset "Thread Safety test (threaded)" begin include("thread_safety.jl") end + end + if GROUP == "CUDA" activate_gpu_env() @time @safetestset "GPU Tau Leaping test" begin include("gpu/regular_jumps.jl") end diff --git a/test/thread_safety.jl b/test/thread_safety.jl index 4c4cfcfd..9f7c4491 100644 --- a/test/thread_safety.jl +++ b/test/thread_safety.jl @@ -1,7 +1,5 @@ using DiffEqBase, Test using JumpProcesses, OrdinaryDiffEq, StochasticDiffEq -using StableRNGs -rng = StableRNG(12345) sr = [1.0, 2.0, 50.0] maj = MassActionJump(sr, [[1 => 1], [1 => 1], [0 => 1]], [[1 => 1], [1 => -1], [1 => 1]]) @@ -11,15 +9,20 @@ u0 = [5] dprob = DiscreteProblem(u0, tspan, params) jprob = JumpProblem(dprob, Direct(), maj) -# Verify threaded solves complete and produce distinct trajectories +# Verify threaded solves complete and produce distinct trajectories. +# NOTE: We intentionally do NOT pass `rng` here. In threaded ensembles, passing a +# shared rng object via `solve(...; rng=...)` does not yet provide correct +# per-trajectory stream handling. Until SciMLBase's ensemble RNG updates land +# (master rng -> per-trajectory rng), correctness in threaded contexts relies on +# task-local `Random.default_rng()`. sol = solve(EnsembleProblem(jprob), SSAStepper(), EnsembleThreads(); - trajectories = 400, rng) + trajectories = 400) @test length(sol) == 400 firstrx_time = [sol.u[i].t[findfirst(>(sol.u[i].t[1]), sol.u[i].t)] for i in 1:length(sol)] @test allunique(firstrx_time) sol2 = solve(EnsembleProblem(jprob; safetycopy = true), SSAStepper(), EnsembleThreads(); - trajectories = 400, rng) + trajectories = 400) @test length(sol2) == 400 firstrx_time2 = [sol2.u[i].t[findfirst(>(sol2.u[i].t[1]), sol2.u[i].t)] for i in 1:length(sol2)] @test allunique(firstrx_time2) From 201e70e8410d4a2cf480ecbe2551845eae3e0213 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Wed, 25 Feb 2026 13:03:12 -0500 Subject: [PATCH 14/20] Add supports_solve_rng trait for JumpProblem solve paths Co-Authored-By: Claude Opus 4.5 --- src/SSA_stepper.jl | 1 + src/simple_regular_solve.jl | 3 +++ src/solve.jl | 6 ++++++ 3 files changed, 10 insertions(+) diff --git a/src/SSA_stepper.jl b/src/SSA_stepper.jl index 913bcd74..a3e4096e 100644 --- a/src/SSA_stepper.jl +++ b/src/SSA_stepper.jl @@ -61,6 +61,7 @@ for details. """ struct SSAStepper <: DiffEqBase.DEAlgorithm end SciMLBase.allows_late_binding_tstops(::SSAStepper) = true +SciMLBase.supports_solve_rng(::JumpProblem, ::SSAStepper) = true """ $(TYPEDEF) diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index 7c8df9fb..f71351ef 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -6,6 +6,9 @@ end SimpleExplicitTauLeaping(; epsilon = 0.05) = SimpleExplicitTauLeaping(epsilon) +SciMLBase.supports_solve_rng(::JumpProblem, ::SimpleTauLeaping) = true +SciMLBase.supports_solve_rng(::JumpProblem, ::SimpleExplicitTauLeaping) = true + function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg) if !(jump_prob.aggregator isa PureLeaping) @warn "When using $alg, please pass PureLeaping() as the aggregator to the \ diff --git a/src/solve.jl b/src/solve.jl index db58e3e4..ffe3177d 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -15,6 +15,9 @@ function resolve_rng(rng, seed) end end +SciMLBase.supports_solve_rng(jprob::JumpProblem, alg::DiffEqBase.DEAlgorithm) = + SciMLBase.supports_solve_rng(jprob.prob, alg) + function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}, alg::DiffEqBase.DEAlgorithm; merge_callbacks = true, kwargs...) where {P} @@ -38,6 +41,9 @@ function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}, integrator.sol end +SciMLBase.supports_solve_rng(jprob::JumpProblem, ::Nothing) = + jprob.prob isa DiffEqBase.DiscreteProblem + # if passed a JumpProblem over a DiscreteProblem, and no aggregator is selected use # SSAStepper function DiffEqBase.__solve(jump_prob::DiffEqBase.AbstractJumpProblem{P}; From 613611e505b39b621737337030d71aac7578c9ce Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Sat, 28 Feb 2026 07:24:12 -0500 Subject: [PATCH 15/20] Update package versions in Project.toml --- Project.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 0d264333..35ea0c11 100644 --- a/Project.toml +++ b/Project.toml @@ -45,18 +45,18 @@ KernelAbstractions = "0.9" LinearAlgebra = "1" LinearSolve = "3" OrdinaryDiffEq = "6" -OrdinaryDiffEqCore = "3.9" +OrdinaryDiffEqCore = "3.11" Pkg = "1" PoissonRandom = "0.4" Random = "1" RecursiveArrayTools = "3.35" Reexport = "1.2" SafeTestsets = "0.1" -SciMLBase = "2.144" +SciMLBase = "2.147" StableRNGs = "1" StaticArrays = "1.9.8" Statistics = "1" -StochasticDiffEq = "6.82" +StochasticDiffEq = "6.95" SymbolicIndexingInterface = "0.3.36" Test = "1" julia = "1.10" From a434959ece71f37fe1393737453af3dcf026b6b1 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Thu, 26 Feb 2026 17:00:48 -0500 Subject: [PATCH 16/20] Add rescale_rates_on_update field to MassActionJump to prevent double-scaling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When MTKBase/Catalyst constructs MassActionJumps, the symbolic rate expressions already include stoichiometric scaling (e.g. k/3! for 3X → Y), so they pass scale_rates=false. However, update_parameters! defaulted to scale_rates=true, causing every subsequent parameter update path (reset_aggregated_jumps!, remake, finalize_parameters_hook!) to double-scale the rates. The fix stores the intended scaling behavior on the struct as rescale_rates_on_update, which update_parameters! now reads as its default. This field is propagated through all merge/combine paths and validated for consistency when merging multiple MassActionJumps. Also fixes SpatialMassActionJump conversion constructor to default scale_rates=false since it receives already-scaled rates. Co-Authored-By: Claude Opus 4.5 --- src/jumps.jl | 94 +++++++++------ src/spatial/spatial_massaction_jump.jl | 4 +- test/runtests.jl | 1 + test/scale_rates_field_test.jl | 156 +++++++++++++++++++++++++ 4 files changed, 216 insertions(+), 39 deletions(-) create mode 100644 test/scale_rates_field_test.jl diff --git a/src/jumps.jl b/src/jumps.jl index 63704ef9..07185d95 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -333,10 +333,13 @@ struct MassActionJump{T, S, U, V} <: AbstractMassActionJump net_stoch::U """Parameter mapping functor to identify reaction rate constants with parameters in `p` vectors.""" param_mapper::V + """Whether `update_parameters!` should apply stoichiometric scaling to rates.""" + rescale_rates_on_update::Bool function MassActionJump{T, S, U, V}(rates::T, rs_in::S, ns::U, pmapper::V, scale_rates::Bool, useiszero::Bool, - nocopy::Bool) where {T <: AbstractVector, S, U, V} + nocopy::Bool, + rescale_rates_on_update::Bool = scale_rates) where {T <: AbstractVector, S, U, V} sr = nocopy ? rates : copy(rates) rs = nocopy ? rs_in : copy(rs_in) for i in eachindex(rs) @@ -348,14 +351,15 @@ struct MassActionJump{T, S, U, V} <: AbstractMassActionJump if scale_rates && !isempty(sr) scalerates!(sr, rs) end - new(sr, rs, ns, pmapper) + new(sr, rs, ns, pmapper, rescale_rates_on_update) end function MassActionJump{Nothing, Vector{S}, Vector{U}, V}(::Nothing, rs_in::Vector{S}, ns::Vector{U}, pmapper::V, scale_rates::Bool, useiszero::Bool, - nocopy::Bool) where {S <: AbstractVector, + nocopy::Bool, + rescale_rates_on_update::Bool = scale_rates) where {S <: AbstractVector, U <: AbstractVector, V} rs = nocopy ? rs_in : copy(rs_in) for i in eachindex(rs) @@ -363,46 +367,51 @@ struct MassActionJump{T, S, U, V} <: AbstractMassActionJump rs[i] = typeof(rs[i])() end end - new(nothing, rs, ns, pmapper) + new(nothing, rs, ns, pmapper, rescale_rates_on_update) end function MassActionJump{T, S, U, V}(rate::T, rs_in::S, ns::U, pmapper::V, scale_rates::Bool, useiszero::Bool, - nocopy::Bool) where {T <: Number, S, U, V} + nocopy::Bool, + rescale_rates_on_update::Bool = scale_rates) where {T <: Number, S, U, V} rs = rs_in if useiszero && (length(rs) == 1) && iszero(rs[1][1]) rs = typeof(rs)() end sr = scale_rates ? scalerate(rate, rs) : rate - new(sr, rs, ns, pmapper) + new(sr, rs, ns, pmapper, rescale_rates_on_update) end function MassActionJump{Nothing, S, U, V}(::Nothing, rs_in::S, ns::U, pmapper::V, scale_rates::Bool, useiszero::Bool, - nocopy::Bool) where {S, U, V} + nocopy::Bool, + rescale_rates_on_update::Bool = scale_rates) where {S, U, V} rs = rs_in if useiszero && (length(rs) == 1) && iszero(rs[1][1]) rs = typeof(rs)() end - new(nothing, rs, ns, pmapper) + new(nothing, rs, ns, pmapper, rescale_rates_on_update) end end function MassActionJump(usr::T, rs::S, ns::U, pmapper::V; scale_rates = true, - useiszero = true, nocopy = false) where {T, S, U, V} - MassActionJump{T, S, U, V}(usr, rs, ns, pmapper, scale_rates, useiszero, nocopy) + useiszero = true, nocopy = false, + rescale_rates_on_update = scale_rates) where {T, S, U, V} + MassActionJump{T, S, U, V}(usr, rs, ns, pmapper, scale_rates, useiszero, nocopy, + rescale_rates_on_update) end function MassActionJump(usr::T, rs, ns; scale_rates = true, useiszero = true, - nocopy = false) where {T <: AbstractVector} - MassActionJump(usr, rs, ns, nothing; scale_rates = scale_rates, useiszero = useiszero, - nocopy = nocopy) + nocopy = false, rescale_rates_on_update = scale_rates) where {T <: AbstractVector} + MassActionJump(usr, rs, ns, nothing; scale_rates, useiszero, nocopy, + rescale_rates_on_update) end function MassActionJump(usr::T, rs, ns; scale_rates = true, useiszero = true, - nocopy = false) where {T <: Number} - MassActionJump(usr, rs, ns, nothing; scale_rates = scale_rates, useiszero = useiszero, - nocopy = nocopy) + nocopy = false, rescale_rates_on_update = scale_rates) where {T <: Number} + MassActionJump(usr, rs, ns, nothing; scale_rates, useiszero, nocopy, + rescale_rates_on_update) end # with parameter indices or mapping, multiple jump case function MassActionJump(rs, ns; param_idxs = nothing, param_mapper = nothing, - scale_rates = true, useiszero = true, nocopy = false) + scale_rates = true, useiszero = true, nocopy = false, + rescale_rates_on_update = scale_rates) if param_mapper === nothing (param_idxs === nothing) && error("If no parameter indices are given via param_idxs, an explicit parameter mapping must be passed in via param_mapper.") @@ -413,8 +422,8 @@ function MassActionJump(rs, ns; param_idxs = nothing, param_mapper = nothing, pmapper = param_mapper end - MassActionJump(nothing, nocopy ? rs : copy(rs), ns, pmapper; scale_rates = scale_rates, - useiszero = useiszero, nocopy = true) + MassActionJump(nothing, nocopy ? rs : copy(rs), ns, pmapper; scale_rates, useiszero, + nocopy = true, rescale_rates_on_update) end using_params(maj::MassActionJump{T, S, U, Nothing}) where {T, S, U} = false @@ -470,18 +479,19 @@ function Base.merge(pmap1::MassActionJumpParamMapper{Int}, end """ -update_parameters!(maj::MassActionJump, newparams; scale_rates=true) +update_parameters!(maj::MassActionJump, newparams; scale_rates=maj.rescale_rates_on_update) Updates the passed in MassActionJump with the parameter values in `newparams`. Notes: - Requires the jump to have been constructed with a user-passed `param_idxs` or `param_mapper`. - - `scale_rates=true` will scale the parameter representing the jump rate by an - appropriate combinatoric factor. i.e for 3A --> B at rate k it will scale - k --> k/3!. + - `scale_rates` defaults to `maj.rescale_rates_on_update`, which itself defaults to the + `scale_rates` value used when constructing the jump. When `true`, the parameter + representing the jump rate will be scaled by an appropriate combinatoric factor, i.e. + for 3A --> B at rate k it will scale k --> k/3!. """ -function update_parameters!(maj::MassActionJump, newparams; scale_rates = true, kwargs...) +function update_parameters!(maj::MassActionJump, newparams; scale_rates = maj.rescale_rates_on_update, kwargs...) (maj.param_mapper === nothing) && error("MassActionJumps must be constructed with param_idxs or a param_mapper to be updateable.") maj.param_mapper(maj, newparams; scale_rates, kwargs) @@ -556,8 +566,12 @@ function JumpSet(vjs, cjs, rj, majv::Vector{T}) where {T <: MassActionJump} error("JumpSets do not accept empty mass action jump collections; use \"nothing\" instead.") end + sr_val = majv[1].rescale_rates_on_update + if !all(m -> m.rescale_rates_on_update == sr_val, majv) + error("Cannot merge MassActionJumps with different rescale_rates_on_update settings.") + end maj = setup_majump_to_merge(majv[1].scaled_rates, majv[1].reactant_stoch, - majv[1].net_stoch, majv[1].param_mapper) + majv[1].net_stoch, majv[1].param_mapper, sr_val) for i in 2:length(majv) massaction_jump_combine(maj, majv[i]) end @@ -618,35 +632,35 @@ end # functionality to merge two mass action jumps together function check_majump_type(maj::MassActionJump{S, T, U, V}) where {S <: Number, T, U, V} setup_majump_to_merge(maj.scaled_rates, maj.reactant_stoch, maj.net_stoch, - maj.param_mapper) + maj.param_mapper, maj.rescale_rates_on_update) end function check_majump_type(maj::MassActionJump{Nothing, T, U, V}) where {T, U, V} setup_majump_to_merge(maj.scaled_rates, maj.reactant_stoch, maj.net_stoch, - maj.param_mapper) + maj.param_mapper, maj.rescale_rates_on_update) end # if given containers of rates and stoichiometry directly create a jump -function setup_majump_to_merge(sr::T, rs::AbstractVector{S}, ns::AbstractVector{U}, - pmapper) where {T <: AbstractVector, S <: AbstractArray, +function setup_majump_to_merge(sr::T, rs::AbstractVector{S}, ns::AbstractVector{U}, pmapper, + rescale_rates_on_update::Bool) where {T <: AbstractVector, S <: AbstractArray, U <: AbstractArray} - MassActionJump(sr, rs, ns, pmapper; scale_rates = false) + MassActionJump(sr, rs, ns, pmapper; scale_rates = false, rescale_rates_on_update) end # if just given the data for one jump (and not in a container) wrap in a vector -function setup_majump_to_merge(sr::S, rs::T, ns::U, - pmapper) where {S <: Number, T <: AbstractArray, +function setup_majump_to_merge(sr::S, rs::T, ns::U, pmapper, + rescale_rates_on_update::Bool) where {S <: Number, T <: AbstractArray, U <: AbstractArray} MassActionJump([sr], [rs], [ns], (pmapper === nothing) ? pmapper : to_collection(pmapper); - scale_rates = false) + scale_rates = false, rescale_rates_on_update) end # if no rate field setup yet -function setup_majump_to_merge(::Nothing, rs::T, ns::U, - pmapper) where {T <: AbstractArray, U <: AbstractArray} +function setup_majump_to_merge(::Nothing, rs::T, ns::U, pmapper, + rescale_rates_on_update::Bool) where {T <: AbstractArray, U <: AbstractArray} MassActionJump(nothing, [rs], [ns], (pmapper === nothing) ? pmapper : to_collection(pmapper); - scale_rates = false) + scale_rates = false, rescale_rates_on_update) end # when given a collection of reactions to add to maj @@ -697,10 +711,12 @@ function majump_merge!(maj::MassActionJump{T, S, U, V}, sr::T, rs::S, ns::U, (param_mapper === nothing) || error("Error, trying to merge a MassActionJump with a parameter mapping to one without a parameter mapping.") return MassActionJump(rates, [maj.reactant_stoch, rs], [maj.net_stoch, ns], - param_mapper; scale_rates = false) + param_mapper; scale_rates = false, + rescale_rates_on_update = maj.rescale_rates_on_update) else return MassActionJump(rates, [maj.reactant_stoch, rs], [maj.net_stoch, ns], - merge(maj.param_mapper, param_mapper); scale_rates = false) + merge(maj.param_mapper, param_mapper); scale_rates = false, + rescale_rates_on_update = maj.rescale_rates_on_update) end end @@ -708,6 +724,8 @@ massaction_jump_combine(maj1::MassActionJump, maj2::Nothing) = maj1 massaction_jump_combine(maj1::Nothing, maj2::MassActionJump) = maj2 massaction_jump_combine(maj1::Nothing, maj2::Nothing) = maj1 function massaction_jump_combine(maj1::MassActionJump, maj2::MassActionJump) + (maj1.rescale_rates_on_update == maj2.rescale_rates_on_update) || + error("Cannot merge MassActionJumps with different rescale_rates_on_update settings.") majump_merge!(maj1, maj2.scaled_rates, maj2.reactant_stoch, maj2.net_stoch, maj2.param_mapper) end diff --git a/src/spatial/spatial_massaction_jump.jl b/src/spatial/spatial_massaction_jump.jl index f1a4fea5..c9af9d1e 100644 --- a/src/spatial/spatial_massaction_jump.jl +++ b/src/spatial/spatial_massaction_jump.jl @@ -82,7 +82,9 @@ function SpatialMassActionJump(urates::A, rs, ns; scale_rates = true, useiszero useiszero = useiszero, nocopy = nocopy) end -function SpatialMassActionJump(ma_jumps::MassActionJump{T, S, U, V}; scale_rates = true, +# scale_rates defaults to false since ma_jumps.scaled_rates are already scaled; +# passing true would double-scale. +function SpatialMassActionJump(ma_jumps::MassActionJump{T, S, U, V}; scale_rates = false, useiszero = true, nocopy = false) where {T, S, U, V} SpatialMassActionJump(ma_jumps.scaled_rates, ma_jumps.reactant_stoch, ma_jumps.net_stoch, ma_jumps.param_mapper; diff --git a/test/runtests.jl b/test/runtests.jl index e4b097c4..6df569d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,7 @@ end @time @safetestset "Mass Action Jump Tests; Nonlinear Rx Model" begin include("bimolerx_test.jl") end @time @safetestset "Mass Action Jump Tests; Special Cases" begin include("degenerate_rx_cases.jl") end @time @safetestset "Mass Action Jump Tests; Floating Point Inputs" begin include("fp_unknowns.jl") end + @time @safetestset "scale_rates Field Tests" begin include("scale_rates_field_test.jl") end @time @safetestset "Direct allocations test" begin include("allocations.jl") end @time @safetestset "Bracketing Tests" begin include("bracketing.jl") end @time @safetestset "Composition-Rejection Table Tests" begin include("table_test.jl") end diff --git a/test/scale_rates_field_test.jl b/test/scale_rates_field_test.jl new file mode 100644 index 00000000..f07ce8d5 --- /dev/null +++ b/test/scale_rates_field_test.jl @@ -0,0 +1,156 @@ +using JumpProcesses, OrdinaryDiffEq, Test + +# Reaction: 3X → Y (third-order, factorial(3) = 6) +# reactant_stoch: species 1 consumed with stoichiometry 3 +# net_stoch: species 1 loses 3, species 2 gains 1 +reactant_stoch = [[1 => 3]] +net_stoch = [[1 => -3, 2 => 1]] + +# Custom mapper: returns pre-scaled rates (like MTKBase's JumpSysMajParamMapper) +struct PreScaledMapper + param_idxs::Vector{Int} + reactant_stoch::Vector{Vector{Pair{Int, Int}}} +end +function (m::PreScaledMapper)(params) + rates = [params[i] for i in m.param_idxs] + JumpProcesses.scalerates!(rates, m.reactant_stoch) + rates +end +function (m::PreScaledMapper)(maj::MassActionJump, newparams; scale_rates, kwargs...) + for i in 1:JumpProcesses.get_num_majumps(maj) + maj.scaled_rates[i] = newparams[m.param_idxs[i]] + end + JumpProcesses.scalerates!(maj.scaled_rates, maj.reactant_stoch) + nothing +end +JumpProcesses.to_collection(m::PreScaledMapper) = m +function Base.merge!(m1::PreScaledMapper, m2::PreScaledMapper) + append!(m1.param_idxs, m2.param_idxs) + append!(m1.reactant_stoch, m2.reactant_stoch) +end + +# Test 1: rescale_rates_on_update field is stored correctly +@testset "rescale_rates_on_update field storage" begin + # Default: scale_rates = true → rescale_rates_on_update = true + maj = MassActionJump([6.0], reactant_stoch, net_stoch) + @test maj.rescale_rates_on_update == true + @test maj.scaled_rates[1] ≈ 1.0 # 6.0 / 3! = 1.0 + + # Explicit: scale_rates = false → rescale_rates_on_update = false + maj = MassActionJump([6.0], reactant_stoch, net_stoch; scale_rates = false) + @test maj.rescale_rates_on_update == false + @test maj.scaled_rates[1] ≈ 6.0 # no scaling + + # Parameterized + maj = MassActionJump(reactant_stoch, net_stoch; param_idxs = [1]) + @test maj.rescale_rates_on_update == true +end + +# Test 2: update_parameters! respects stored rescale_rates_on_update +@testset "update_parameters! respects rescale_rates_on_update" begin + # With param_idxs and scale_rates = true (default) — built-in mapper path + p = [6.0] + maj = MassActionJump(reactant_stoch, net_stoch; param_idxs = [1]) + dprob = DiscreteProblem([100, 0], (0.0, 1.0), p) + jprob = JumpProblem(dprob, Direct(), maj) + @test jprob.massaction_jump.rescale_rates_on_update == true + @test jprob.massaction_jump.scaled_rates[1] ≈ 1.0 # 6.0 / 3! + + # update_parameters! should scale (rescale_rates_on_update = true from struct) + JumpProcesses.update_parameters!(jprob.massaction_jump, [12.0]) + @test jprob.massaction_jump.scaled_rates[1] ≈ 2.0 # 12.0 / 3! +end + +# Test 3: Custom pre-scaled mapper with scale_rates = false — the bug reproducer +@testset "Pre-scaled mapper with scale_rates=false (bug reproducer)" begin + mapper = PreScaledMapper([1], reactant_stoch) + k = 6.0 + p = [k] + + maj = MassActionJump(reactant_stoch, net_stoch; param_mapper = mapper, scale_rates = false) + dprob = DiscreteProblem([100, 0], (0.0, 100.0), p) + jprob = JumpProblem(dprob, Direct(), maj; scale_rates = false) + + expected_scaled = k / factorial(3) # 1.0 + @test jprob.massaction_jump.scaled_rates[1] ≈ expected_scaled + @test jprob.massaction_jump.rescale_rates_on_update == false + + # Test reset_aggregated_jumps! does NOT double-scale + integ = init(jprob, SSAStepper()) + integ.p[1] = 18.0 + reset_aggregated_jumps!(integ) + new_expected = 18.0 / factorial(3) # 3.0, NOT 3.0/6 = 0.5 + @test jprob.massaction_jump.scaled_rates[1] ≈ new_expected + + # Test remake does NOT double-scale + jprob2 = remake(jprob; p = [24.0]) + remake_expected = 24.0 / factorial(3) # 4.0, NOT 4.0/6 ≈ 0.667 + @test jprob2.massaction_jump.scaled_rates[1] ≈ remake_expected + + # Test remake round-trip + jprob3 = remake(jprob2; p = [k]) + @test jprob3.massaction_jump.scaled_rates[1] ≈ expected_scaled +end + +# Test 4: Callback parameter changes with built-in mapper (rescale_rates_on_update = true) +@testset "Callback with built-in mapper" begin + p = [6.0] + maj = MassActionJump(reactant_stoch, net_stoch; param_idxs = [1]) + dprob = DiscreteProblem([100, 0], (0.0, 2000.0), p) + jprob = JumpProblem(dprob, Direct(), maj; save_positions = (false, false)) + + condit(u, t, integrator) = t == 1000.0 + function affect!(integrator) + integrator.p[1] = 24.0 + reset_aggregated_jumps!(integrator) + end + cb = DiscreteCallback(condit, affect!) + sol = solve(jprob, SSAStepper(); tstops = [1000.0], callback = cb) + @test jprob.massaction_jump.scaled_rates[1] ≈ 4.0 # 24.0 / 3! +end + +# Test 5: rescale_rates_on_update propagated through JumpSet merge and JumpProblem varargs +# Use explicit-rate MAJs since the JumpSet vector merge path doesn't support +# parameterized (Nothing-rated) MAJs. +@testset "rescale_rates_on_update propagated through merge paths" begin + reactant_stoch2 = [[2 => 3]] + net_stoch2 = [[2 => -3, 1 => 1]] + + # --- JumpSet merge path --- + + # Two MAJs with matching rescale_rates_on_update = false (rates already scaled) + maj1 = MassActionJump([1.0], reactant_stoch, net_stoch; scale_rates = false) + maj2 = MassActionJump([2.0], reactant_stoch2, net_stoch2; scale_rates = false) + jset = JumpSet(; massaction_jumps = [maj1, maj2]) + @test jset.massaction_jump.rescale_rates_on_update == false + + # Two MAJs with matching rescale_rates_on_update = true (rates get scaled) + maj3 = MassActionJump([6.0], reactant_stoch, net_stoch) # scale_rates = true (default) + maj4 = MassActionJump([12.0], reactant_stoch2, net_stoch2) + jset2 = JumpSet(; massaction_jumps = [maj3, maj4]) + @test jset2.massaction_jump.rescale_rates_on_update == true + + # Mismatched rescale_rates_on_update via JumpSet — should error + maj_true = MassActionJump([6.0], reactant_stoch, net_stoch) # rescale = true + maj_false = MassActionJump([1.0], reactant_stoch2, net_stoch2; scale_rates = false) # rescale = false + @test_throws ErrorException JumpSet(; massaction_jumps = [maj_true, maj_false]) + + # --- JumpProblem varargs path (split_jumps → massaction_jump_combine) --- + + dprob = DiscreteProblem([100, 100], (0.0, 1.0), [1.0, 1.0]) + + # Two MAJs with matching rescale_rates_on_update = false via JumpProblem varargs + maj_f1 = MassActionJump([1.0], reactant_stoch, net_stoch; scale_rates = false) + maj_f2 = MassActionJump([2.0], reactant_stoch2, net_stoch2; scale_rates = false) + jprob_f = JumpProblem(dprob, Direct(), maj_f1, maj_f2; scale_rates = false) + @test jprob_f.massaction_jump.rescale_rates_on_update == false + + # Two MAJs with matching rescale_rates_on_update = true via JumpProblem varargs + maj_t1 = MassActionJump([6.0], reactant_stoch, net_stoch) + maj_t2 = MassActionJump([12.0], reactant_stoch2, net_stoch2) + jprob_t = JumpProblem(dprob, Direct(), maj_t1, maj_t2) + @test jprob_t.massaction_jump.rescale_rates_on_update == true + + # Mismatched rescale_rates_on_update via JumpProblem varargs — should error + @test_throws ErrorException JumpProblem(dprob, Direct(), maj_true, maj_false) +end From 83352c59934099fb3a84193ef96e40e9a4eca1a6 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Thu, 26 Feb 2026 18:20:47 -0500 Subject: [PATCH 17/20] Fix PreScaledMapper test to actually catch double-scaling regression The mapper now conditionally applies scalerates! based on scale_rates, matching MTKBase JumpSysMajParamMapper behavior. Previously it ignored scale_rates entirely, so the test would pass even with the old buggy default of scale_rates=true. Co-Authored-By: Claude Opus 4.5 --- test/scale_rates_field_test.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/scale_rates_field_test.jl b/test/scale_rates_field_test.jl index f07ce8d5..3f06ed88 100644 --- a/test/scale_rates_field_test.jl +++ b/test/scale_rates_field_test.jl @@ -6,7 +6,13 @@ using JumpProcesses, OrdinaryDiffEq, Test reactant_stoch = [[1 => 3]] net_stoch = [[1 => -3, 2 => 1]] -# Custom mapper: returns pre-scaled rates (like MTKBase's JumpSysMajParamMapper) +# Custom mapper mimicking MTKBase's JumpSysMajParamMapper. +# The initial rates callable always pre-scales (simulating symbolic expressions that +# already contain the combinatoric factor, e.g. k/3! for 3X → Y). +# The update callable also pre-scales, then conditionally applies scalerates! based on +# scale_rates — matching MTKBase's behavior. This is what makes the test an actual +# regression test: with the old default of scale_rates=true, the second scalerates! +# would double-scale. struct PreScaledMapper param_idxs::Vector{Int} reactant_stoch::Vector{Vector{Pair{Int, Int}}} @@ -20,7 +26,8 @@ function (m::PreScaledMapper)(maj::MassActionJump, newparams; scale_rates, kwarg for i in 1:JumpProcesses.get_num_majumps(maj) maj.scaled_rates[i] = newparams[m.param_idxs[i]] end - JumpProcesses.scalerates!(maj.scaled_rates, maj.reactant_stoch) + JumpProcesses.scalerates!(maj.scaled_rates, m.reactant_stoch) + scale_rates && JumpProcesses.scalerates!(maj.scaled_rates, maj.reactant_stoch) nothing end JumpProcesses.to_collection(m::PreScaledMapper) = m From 4a52abe0448f26a0fdc21a1c116e7d2ce801f69a Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Thu, 26 Feb 2026 21:11:45 -0500 Subject: [PATCH 18/20] Update project version to 9.23 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 35ea0c11..1c59737f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "JumpProcesses" uuid = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" authors = ["Chris Rackauckas "] -version = "9.22.2" +version = "9.23" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From d1d6beb928d1f66cd0be8d5a56164a83eb9c4dc3 Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Sat, 28 Feb 2026 12:31:43 -0500 Subject: [PATCH 19/20] add thread-safety warning to JumpProblem docstring Document that JumpProblem contains mutable state and is not thread-safe for concurrent solve calls. Note that EnsembleProblem handles isolation automatically via SciMLBase per-task deepcopy and per-trajectory RNG. Co-Authored-By: Claude Opus 4.5 --- src/problem.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/problem.jl b/src/problem.jl index 3ce038de..56d58da6 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -67,6 +67,15 @@ Please see the [tutorial page](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/) in the DifferentialEquations.jl [docs](https://docs.sciml.ai/JumpProcesses/stable/) for usage examples and commonly asked questions. + +!!! warning "Thread Safety" + `JumpProblem` contains mutable state (aggregator data, callbacks) and is **not + thread-safe**. A single `JumpProblem` instance must not be solved concurrently from + multiple threads or tasks without first creating independent copies via `deepcopy`. + When running ensemble simulations via `EnsembleProblem`, this is handled automatically + — the `SciMLBase` ensemble layer provides per-task isolation and per-trajectory RNG + seeding. This warning only applies to manually parallelized `solve` calls outside the + ensemble interface. """ mutable struct JumpProblem{iip, P, A, C, J <: Union{Nothing, AbstractJumpAggregator}, J1, J2, J3, J4, K} <: DiffEqBase.AbstractJumpProblem{P, J} From a07a527ec48e5875fdbc8378c889b7e5f07593ca Mon Sep 17 00:00:00 2001 From: Sam Isaacson Date: Sun, 1 Mar 2026 17:21:51 -0500 Subject: [PATCH 20/20] Temporarily use SciMLBase fork branch for ensemble_rng_redesign Co-Authored-By: Claude Opus 4.5 --- Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index 63f3ea26..242da5ae 100644 --- a/Project.toml +++ b/Project.toml @@ -61,6 +61,9 @@ SymbolicIndexingInterface = "0.3.36" Test = "1" julia = "1.10" +[sources] +SciMLBase = {url = "https://github.com/isaacsas/SciMLBase.jl.git", rev = "ensemble_rng_redesign"} + [extras] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"