diff --git a/.github/workflows/Downgrade.yml b/.github/workflows/Downgrade.yml index 54db9bad..3fe5b7a7 100644 --- a/.github/workflows/Downgrade.yml +++ b/.github/workflows/Downgrade.yml @@ -3,6 +3,7 @@ on: pull_request: branches: - master + - v10_working paths-ignore: - 'docs/**' push: diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index dd8ce04f..3d0b7ef3 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -8,6 +8,7 @@ on: pull_request: branches: - master + - v10_working paths-ignore: - 'docs/**' diff --git a/.github/workflows/ThreadSafety.yml b/.github/workflows/ThreadSafety.yml index b82eda2d..2dfbb464 100644 --- a/.github/workflows/ThreadSafety.yml +++ b/.github/workflows/ThreadSafety.yml @@ -9,6 +9,7 @@ on: pull_request: branches: - master + - v10_working paths-ignore: - 'docs/**' diff --git a/HISTORY.md b/HISTORY.md index d010778a..b12999b0 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -21,6 +21,31 @@ solver pathways (SSAStepper, ODE, SDE, tau-leaping). - `SSAIntegrator` now supports the `SciMLBase` RNG interface (`has_rng`, `get_rng`, `set_rng!`). + - **Breaking**: The `scale_rates` and `useiszero` keyword arguments have been + removed from `JumpProblem`. Set them on the `MassActionJump` directly: + ```julia + # Before (no longer works): + jprob = JumpProblem(dprob, Direct(), maj; scale_rates = false) + + # After: + maj = MassActionJump(rates, reactant_stoch, net_stoch; scale_rates = false) + jprob = JumpProblem(dprob, Direct(), maj) + ``` + - **Breaking**: Parameterized `MassActionJump`s (those constructed with + `param_idxs` or a custom `param_mapper`) are now immutable — rates are + computed from parameters at aggregator initialization rather than being + materialized into the jump at `JumpProblem` construction time. This means: + - `update_parameters!` has been removed. Mass action rates are now + automatically recomputed from the current parameter values whenever the + aggregator reinitializes. After modifying parameters (e.g. in a + callback), call `reset_aggregated_jumps!(integrator)` to trigger + reinitialization with the updated parameter values. + - Custom parameter mappers (e.g. ModelingToolkitBase's + `JumpSysMajParamMapper`) must implement the 3-arg callable API: + `(mapper)(dest::AbstractVector, maj::MassActionJump, params)`. + See [`MassActionJumpParamMapper`](@ref) for details. + - Scalar `param_idxs` (e.g. `param_idxs = 1`) is now internally converted to + a one-element vector. The scalar form continues to work as before. ## 9.14 diff --git a/Project.toml b/Project.toml index 242da5ae..7a3e6693 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "JumpProcesses" uuid = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" authors = ["Chris Rackauckas "] -version = "9.23.1" +version = "9.23.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -24,9 +24,11 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" [extensions] JumpProcessesKernelAbstractionsExt = ["Adapt", "KernelAbstractions"] +JumpProcessesOrdinaryDiffEqCoreExt = "OrdinaryDiffEqCore" [compat] ADTypes = "1" @@ -45,7 +47,7 @@ KernelAbstractions = "0.9" LinearAlgebra = "1" LinearSolve = "3" OrdinaryDiffEq = "6" -OrdinaryDiffEqCore = "3.11" +OrdinaryDiffEqCore = "3" Pkg = "1" PoissonRandom = "0.4" Random = "1" @@ -72,7 +74,6 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" -OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/docs/Project.toml b/docs/Project.toml index 5ef13280..c69a2805 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -13,7 +13,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" [compat] -Catalyst = "14.0, 15" +Catalyst = "16" DensityInterface = "0.4" DifferentialEquations = "7.11" Distributions = "0.25" diff --git a/docs/src/assets/Project.toml b/docs/src/assets/Project.toml index 5ef13280..c69a2805 100644 --- a/docs/src/assets/Project.toml +++ b/docs/src/assets/Project.toml @@ -13,7 +13,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" [compat] -Catalyst = "14.0, 15" +Catalyst = "16" DensityInterface = "0.4" DifferentialEquations = "7.11" Distributions = "0.25" diff --git a/docs/src/tutorials/discrete_stochastic_example.md b/docs/src/tutorials/discrete_stochastic_example.md index b6274b4b..b603f11b 100644 --- a/docs/src/tutorials/discrete_stochastic_example.md +++ b/docs/src/tutorials/discrete_stochastic_example.md @@ -190,36 +190,25 @@ sir_model = @reaction_network begin end ``` -To build a pure jump process model of the reaction system, where the state is -constant between jumps, we will use a -[`DiscreteProblem`](https://docs.sciml.ai/DiffEqDocs/stable/types/discrete_types/). -This encodes that the state only changes at the jump times. We do this by giving -the constructor `u₀`, the initial condition, and `tspan`, the timespan. Here, we -will start with ``990`` susceptible people, ``10`` infected person, and `0` recovered -people, and solve the problem from `t=0.0` to `t=250.0`. We use the parameters -`β = 0.1/1000` and `ν = 0.01`. Thus, we build the problem via: +To build a pure jump process model of the reaction system we construct a +[`JumpProblem`](@ref) directly from the Catalyst `ReactionSystem`. We specify +`u₀`, the initial condition, `tspan`, the timespan, and `p`, the parameters. +Here, we will start with ``990`` susceptible people, ``10`` infected person, and +`0` recovered people, and solve the problem from `t=0.0` to `t=250.0`. We use +the parameters `β = 0.1/1000` and `ν = 0.01`. ```@example tut2 p = (:β => 0.1 / 1000, :ν => 0.01) u₀ = [:S => 990, :I => 10, :R => 0] tspan = (0.0, 250.0) -prob = DiscreteProblem(sir_model, u₀, tspan, p) +jump_prob = JumpProblem(sir_model, u₀, tspan, p) ``` *Notice, the initial populations are integers, since we want the exact number of people in the different states.* -The Catalyst reaction network can be converted into various -DifferentialEquations.jl problem types, including `JumpProblem`s, `ODEProblem`s, -or `SDEProblem`s. To turn it into a [`JumpProblem`](@ref) representing the SIR jump -process model, we simply write - -```@example tut2 -jump_prob = JumpProblem(sir_model, prob, Direct()) -``` - -Here `Direct()` indicates that we will determine the random times and types of -reactions using [Gillespie's Direct stochastic simulation algorithm +Here `Direct()` is the default aggregator, which determines the random times and +types of reactions using [Gillespie's Direct stochastic simulation algorithm (SSA)](https://doi.org/10.1016/0021-9991(76)90041-3), also known as Doob's method or Kinetic Monte Carlo. See [Jump Aggregators for Exact Simulation](@ref) for other supported SSAs. diff --git a/ext/JumpProcessesOrdinaryDiffEqCoreExt.jl b/ext/JumpProcessesOrdinaryDiffEqCoreExt.jl new file mode 100644 index 00000000..28d769a2 --- /dev/null +++ b/ext/JumpProcessesOrdinaryDiffEqCoreExt.jl @@ -0,0 +1,26 @@ +module JumpProcessesOrdinaryDiffEqCoreExt + +using JumpProcesses +import DiffEqBase +import OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, DAEAlgorithm, + StochasticDiffEqAlgorithm, StochasticDiffEqRODEAlgorithm + +# Ambiguity fix: OrdinaryDiffEqCore defines +# __init(::Union{..., AbstractJumpProblem}, ::Union{OrdinaryDiffEqAlgorithm, +# DAEAlgorithm, StochasticDiffEqAlgorithm, StochasticDiffEqRODEAlgorithm}) +# which is ambiguous with JumpProcesses' +# __init(::AbstractJumpProblem{P}, ::DEAlgorithm) +# +# This method resolves the ambiguity by being more specific in the problem type +# (AbstractJumpProblem vs Union{..., AbstractJumpProblem, ...}) while matching +# the exact algorithm union from OrdinaryDiffEqCore. +function DiffEqBase.__init( + _jump_prob::DiffEqBase.AbstractJumpProblem{P}, + alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm, + StochasticDiffEqAlgorithm, StochasticDiffEqRODEAlgorithm}; + merge_callbacks = true, kwargs...) where {P} + kwargs = DiffEqBase.merge_problem_kwargs(_jump_prob; merge_callbacks, kwargs...) + JumpProcesses.__jump_init(_jump_prob, alg; kwargs...) +end + +end diff --git a/src/JumpProcesses.jl b/src/JumpProcesses.jl index e32c4fb8..2c1ea6e9 100644 --- a/src/JumpProcesses.jl +++ b/src/JumpProcesses.jl @@ -5,7 +5,7 @@ using Reexport: Reexport, @reexport # Explicit imports from standard libraries using LinearAlgebra: LinearAlgebra, mul! -using Random: Random, randexp, seed! +using Random: Random, randexp # Explicit imports from external packages using DocStringExtensions: DocStringExtensions, FIELDS, TYPEDEF diff --git a/src/aggregators/aggregated_api.jl b/src/aggregators/aggregated_api.jl index a65c4ad8..ec3376d0 100644 --- a/src/aggregators/aggregated_api.jl +++ b/src/aggregators/aggregated_api.jl @@ -1,55 +1,41 @@ """ - reset_aggregated_jumps!(integrator, uprev = nothing; update_jump_params=true) + reset_aggregated_jumps!(integrator, uprev = nothing) Reset the state of jump processes and associated solvers following a change -in parameters or such. - -Notes - - - `update_jump_params=true` will recalculate the rates stored within any - MassActionJump that was built from the parameter vector. If the parameter - vector is unchanged, this can safely be set to false to improve performance. +in parameters or such. Rate updates are handled automatically by `initialize!` +via `fill_scaled_rates!`. """ -function reset_aggregated_jumps!(integrator, uprev = nothing; update_jump_params = true, - kwargs...) - reset_aggregated_jumps!(integrator, uprev, integrator.opts.callback, - update_jump_params = update_jump_params, kwargs...) +function reset_aggregated_jumps!(integrator, uprev = nothing; kwargs...) + if haskey(kwargs, :update_jump_params) + throw(ArgumentError("`update_jump_params` keyword argument has been removed. " * + "Rate updates are now handled automatically by `initialize!` " * + "via `fill_scaled_rates!`.")) + end + reset_aggregated_jumps!(integrator, uprev, integrator.opts.callback) nothing end -function reset_aggregated_jumps!(integrator, uprev, callback::Nothing; - update_jump_params = true, kwargs...) +function reset_aggregated_jumps!(integrator, uprev, callback::Nothing) nothing end -function reset_aggregated_jumps!(integrator, uprev, callback::CallbackSet; - update_jump_params = true, kwargs...) +function reset_aggregated_jumps!(integrator, uprev, callback::CallbackSet) if !isempty(callback.discrete_callbacks) - reset_aggregated_jumps!(integrator, uprev, callback.discrete_callbacks..., - update_jump_params = update_jump_params, kwargs...) + reset_aggregated_jumps!(integrator, uprev, callback.discrete_callbacks...) end nothing end -function reset_aggregated_jumps!(integrator, uprev, cb::DiscreteCallback, cbs...; - update_jump_params = true, kwargs...) +function reset_aggregated_jumps!(integrator, uprev, cb::DiscreteCallback, cbs...) if cb.condition isa AbstractSSAJumpAggregator - maj = cb.condition.ma_jumps - update_jump_params && using_params(maj) && - update_parameters!(cb.condition.ma_jumps, integrator.p; kwargs...) cb.condition(cb, integrator.u, integrator.t, integrator) end - reset_aggregated_jumps!(integrator, uprev, cbs...; - update_jump_params = update_jump_params, kwargs...) + reset_aggregated_jumps!(integrator, uprev, cbs...) nothing end -function reset_aggregated_jumps!(integrator, uprev, cb::DiscreteCallback; - update_jump_params = true, kwargs...) +function reset_aggregated_jumps!(integrator, uprev, cb::DiscreteCallback) if cb.condition isa AbstractSSAJumpAggregator - maj = cb.condition.ma_jumps - update_jump_params && using_params(maj) && - update_parameters!(cb.condition.ma_jumps, integrator.p; kwargs...) cb.condition(cb, integrator.u, integrator.t, integrator) end nothing diff --git a/src/aggregators/bracketing.jl b/src/aggregators/bracketing.jl index 18bb9ece..b66092c4 100644 --- a/src/aggregators/bracketing.jl +++ b/src/aggregators/bracketing.jl @@ -47,8 +47,8 @@ end @inline get_spec_brackets(bd, i, u::AbstractVector) = get_spec_brackets(bd, i, u[i]) # Get propensity brackets of massaction jump k. -@inline function get_majump_brackets(ulow, uhigh, k, majumps) - evalrxrate(ulow, k, majumps), evalrxrate(uhigh, k, majumps) +@inline function get_majump_brackets(ulow, uhigh, k, majumps, maj_rates) + evalrxrate(ulow, k, majumps, maj_rates), evalrxrate(uhigh, k, majumps, maj_rates) end # for constant rate jumps we must check the ordering of the bracket values @@ -66,7 +66,7 @@ get brackets for the rate of reaction rx by first checking if the reaction is a ma_jumps = p.ma_jumps num_majumps = get_num_majumps(ma_jumps) if rx <= num_majumps - return get_majump_brackets(p.ulow, p.uhigh, rx, ma_jumps) + return get_majump_brackets(p.ulow, p.uhigh, rx, ma_jumps, p.maj_rates) else @inbounds return get_cjump_brackets(p.ulow, p.uhigh, p.rates[rx - num_majumps], params, t) @@ -108,8 +108,9 @@ function set_bracketing!(p::AbstractSSAJumpAggregator, u, params, t) majumps = p.ma_jumps crlow = p.cur_rate_low crhigh = p.cur_rate_high + maj_rates = p.maj_rates @inbounds for k in 1:get_num_majumps(majumps) - crlow[k], crhigh[k] = get_majump_brackets(p.ulow, p.uhigh, k, majumps) + crlow[k], crhigh[k] = get_majump_brackets(p.ulow, p.uhigh, k, majumps, maj_rates) sum_rate += crhigh[k] end diff --git a/src/aggregators/ccnrm.jl b/src/aggregators/ccnrm.jl index 720868c5..4788df88 100644 --- a/src/aggregators/ccnrm.jl +++ b/src/aggregators/ccnrm.jl @@ -12,6 +12,7 @@ mutable struct CCNRMJumpAggregation{T, S, F1, F2, DEPGR, PT} <: cur_rates::Vector{T} sum_rate::T ma_jumps::S + maj_rates::Vector{T} rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} @@ -21,6 +22,7 @@ end function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{T}(undef, get_num_majumps(maj)), num_specs, dep_graph = nothing, kwargs...) where {T, S, F1, F2} @@ -46,7 +48,7 @@ function CCNRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, affecttype = F2 <: Tuple ? F2 : Any CCNRMJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(ptt)}( nj, nj, njt, et, - crs, sr, maj, + crs, sr, maj, maj_rates, rs, affs!, sps, dg, ptt) end @@ -67,6 +69,7 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::CCNRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) rng = get_rng(integrator) initialize_rates_and_times!(p, u, params, t, rng) generate_jumps!(p, integrator, u, params, t) @@ -115,7 +118,7 @@ function update_dependent_rates!(p::CCNRMJumpAggregation, u, params, t, rng) # update the jump rate @inbounds cur_rates[rx] = calculate_jump_rate(ma_jumps, num_majumps, rates, u, - params, t, rx) + params, t, rx, p.maj_rates) # Calculate new jump times for dependent jumps if rx != p.next_jump && oldrate > zero(oldrate) @@ -141,8 +144,9 @@ function initialize_rates_and_times!(p::CCNRMJumpAggregation, u, params, t, rng) majumps = p.ma_jumps cur_rates = p.cur_rates pttdata = Vector{typeof(t)}(undef, length(cur_rates)) + maj_rates = p.maj_rates @inbounds for i in 1:get_num_majumps(majumps) - cur_rates[i] = evalrxrate(u, i, majumps) + cur_rates[i] = evalrxrate(u, i, majumps, maj_rates) pttdata[i] = t + randexp(rng) / cur_rates[i] end diff --git a/src/aggregators/coevolve.jl b/src/aggregators/coevolve.jl index ab5dc5f4..651a7012 100644 --- a/src/aggregators/coevolve.jl +++ b/src/aggregators/coevolve.jl @@ -10,6 +10,7 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, GR, PQ} <: cur_rates::Vector{T} # the last computed upper bound for each rate sum_rate::Nothing # not used ma_jumps::S # MassActionJumps + maj_rates::Vector{T} # working copy of mass action jump rates rates::F1 # vector of rate functions affects!::F2 # vector of affect functions for VariableRateJumps save_positions::Tuple{Bool, Bool} # tuple for whether to save the jumps before and/or after event @@ -24,6 +25,7 @@ end function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Nothing, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{T}(undef, get_num_majumps(maj)), u::U, dep_graph = nothing, lrates, urates, rateintervals, haslratevec, cur_lrates::Vector{T}) where {T, S, F1, F2, U} @@ -49,7 +51,7 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not pq = MutableBinaryMinHeap{T}() affecttype = F2 <: Tuple ? F2 : Any CoevolveJumpAggregation{T, S, F1, affecttype, typeof(dg), - typeof(pq)}(nj, nj, njt, et, crs, sr, maj, + typeof(pq)}(nj, nj, njt, et, crs, sr, maj, maj_rates, rs, affs!, sps, dg, pq, lrates, urates, rateintervals, haslratevec, cur_lrates) @@ -135,19 +137,21 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, num_jumps = get_num_majumps(ma_jumps) + nrjs cur_rates = Vector{typeof(t)}(undef, num_jumps) + maj_rates = Vector{typeof(t)}(undef, get_num_majumps(ma_jumps)) cur_lrates = zeros(typeof(t), nvrjs) sum_rate = nothing next_jump = 0 next_jump_time = typemax(t) CoevolveJumpAggregation(next_jump, next_jump_time, end_time, cur_rates, sum_rate, ma_jumps, rates, affects!, save_positions; - u, dep_graph, lrates, urates, rateintervals, haslratevec, + maj_rates, u, dep_graph, lrates, urates, rateintervals, haslratevec, cur_lrates) end # set up a new simulation and calculate the first jump / jump time function initialize!(p::CoevolveJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) rng = get_rng(integrator) fill_rates_and_get_times!(p, u, params, t, rng) generate_jumps!(p, integrator, u, params, t) @@ -239,7 +243,7 @@ function update_dependent_rates!(p::CoevolveJumpAggregation, u, params, t, rng) end @inline function get_ma_urate(p::CoevolveJumpAggregation, i, u, params, t) - return evalrxrate(u, i, p.ma_jumps) + return evalrxrate(u, i, p.ma_jumps, p.maj_rates) end @inline function get_urate(p::CoevolveJumpAggregation, uidx, u, params, t) diff --git a/src/aggregators/direct.jl b/src/aggregators/direct.jl index b51e18f7..f976a4cb 100644 --- a/src/aggregators/direct.jl +++ b/src/aggregators/direct.jl @@ -7,16 +7,18 @@ mutable struct DirectJumpAggregation{T, S, F1, F2} <: cur_rates::Vector{T} sum_rate::T ma_jumps::S + maj_rates::Vector{T} rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} end function DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{T}(undef, get_num_majumps(maj)), kwargs...) where {T, S, F1, F2} affecttype = F2 <: Tuple ? F2 : Any - DirectJumpAggregation{T, S, F1, affecttype}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps) + DirectJumpAggregation{T, S, F1, affecttype}(nj, nj, njt, et, crs, sr, maj, maj_rates, + rs, affs!, sps) end ############################# Required Functions ############################# @@ -46,6 +48,7 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::DirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) generate_jumps!(p, integrator, u, params, t) nothing end @@ -77,9 +80,10 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, # mass action rates majumps = p.ma_jumps + maj_rates = p.maj_rates idx = get_num_majumps(majumps) @inbounds for i in 1:idx - new_rate = evalrxrate(u, i, majumps) + new_rate = evalrxrate(u, i, majumps, maj_rates) cur_rates[i] = add_fast(new_rate, prev_rate) prev_rate = cur_rates[i] end @@ -119,9 +123,10 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params, # mass action rates majumps = p.ma_jumps + maj_rates = p.maj_rates idx = get_num_majumps(majumps) @inbounds for i in 1:idx - new_rate = evalrxrate(u, i, majumps) + new_rate = evalrxrate(u, i, majumps, maj_rates) cur_rates[i] = add_fast(new_rate, prev_rate) prev_rate = cur_rates[i] end diff --git a/src/aggregators/directcr.jl b/src/aggregators/directcr.jl index 636300bb..ae4ebf87 100644 --- a/src/aggregators/directcr.jl +++ b/src/aggregators/directcr.jl @@ -20,6 +20,7 @@ mutable struct DirectCRJumpAggregation{T, S, F1, F2, DEPGR, U <: PriorityTable, cur_rates::Vector{T} sum_rate::T ma_jumps::S + maj_rates::Vector{T} rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} @@ -32,6 +33,7 @@ end function DirectCRJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{T}(undef, get_num_majumps(maj)), num_specs, dep_graph = nothing, minrate = convert(T, MINJUMPRATE), maxrate = convert(T, Inf), @@ -64,7 +66,7 @@ function DirectCRJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, affecttype = F2 <: Tuple ? F2 : Any DirectCRJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, crs, sr, maj, - rs, affs!, sps, dg, + maj_rates, rs, affs!, sps, dg, minrate, maxrate, rt, ratetogroup) end @@ -86,6 +88,7 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::DirectCRJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) # initialize rates fill_rates_and_sum!(p, u, params, t) @@ -134,7 +137,8 @@ function update_dependent_rates!(p::DirectCRJumpAggregation, u, params, t) oldrate = cur_rates[rx] # update rate - cur_rates[rx] = calculate_jump_rate(ma_jumps, num_majumps, rates, u, params, t, rx) + cur_rates[rx] = calculate_jump_rate(ma_jumps, num_majumps, rates, u, params, t, rx, + p.maj_rates) # update table update!(rt, rx, oldrate, cur_rates[rx]) diff --git a/src/aggregators/frm.jl b/src/aggregators/frm.jl index 0db30113..128b93db 100644 --- a/src/aggregators/frm.jl +++ b/src/aggregators/frm.jl @@ -7,16 +7,18 @@ mutable struct FRMJumpAggregation{T, S, F1, F2} <: cur_rates::Vector{T} sum_rate::T ma_jumps::S + maj_rates::Vector{T} rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} end function FRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{T}(undef, get_num_majumps(maj)), kwargs...) where {T, S, F1, F2} affecttype = F2 <: Tuple ? F2 : Any - FRMJumpAggregation{T, S, F1, affecttype}(nj, nj, njt, et, crs, sr, maj, rs, - affs!, sps) + FRMJumpAggregation{T, S, F1, affecttype}(nj, nj, njt, et, crs, sr, maj, maj_rates, + rs, affs!, sps) end ############################# Required Functions ############################# @@ -48,6 +50,7 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::FRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) generate_jumps!(p, integrator, u, params, t) nothing end @@ -83,8 +86,9 @@ function next_ma_jump(p::FRMJumpAggregation, u, params, t, rng) ttnj = typemax(typeof(t)) nextrx = zero(Int) majumps = p.ma_jumps + maj_rates = p.maj_rates @inbounds for i in 1:get_num_majumps(majumps) - p.cur_rates[i] = evalrxrate(u, i, majumps) + p.cur_rates[i] = evalrxrate(u, i, majumps, maj_rates) dt = randexp(rng) / p.cur_rates[i] if dt < ttnj ttnj = dt diff --git a/src/aggregators/nrm.jl b/src/aggregators/nrm.jl index 93454044..ef1a6b74 100644 --- a/src/aggregators/nrm.jl +++ b/src/aggregators/nrm.jl @@ -10,6 +10,7 @@ mutable struct NRMJumpAggregation{T, S, F1, F2, DEPGR, PQ} <: cur_rates::Vector{T} sum_rate::T ma_jumps::S + maj_rates::Vector{T} rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} @@ -19,6 +20,7 @@ end function NRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{T}(undef, get_num_majumps(maj)), num_specs, dep_graph = nothing, kwargs...) where {T, S, F1, F2} @@ -40,7 +42,7 @@ function NRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, affecttype = F2 <: Tuple ? F2 : Any NRMJumpAggregation{T, S, F1, affecttype, typeof(dg), typeof(pq)}(nj, nj, njt, et, - crs, sr, maj, + crs, sr, maj, maj_rates, rs, affs!, sps, dg, pq) end @@ -61,6 +63,7 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::NRMJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) rng = get_rng(integrator) fill_rates_and_get_times!(p, u, params, t, rng) generate_jumps!(p, integrator, u, params, t) @@ -98,7 +101,7 @@ function update_dependent_rates!(p::NRMJumpAggregation, u, params, t, rng) # update the jump rate @inbounds cur_rates[rx] = calculate_jump_rate(ma_jumps, num_majumps, rates, u, - params, t, rx) + params, t, rx, p.maj_rates) # calculate new jump times for dependent jumps if rx != p.next_jump && oldrate > zero(oldrate) @@ -125,8 +128,9 @@ function fill_rates_and_get_times!(p::NRMJumpAggregation, u, params, t, rng) majumps = p.ma_jumps cur_rates = p.cur_rates pqdata = Vector{typeof(t)}(undef, length(cur_rates)) + maj_rates = p.maj_rates @inbounds for i in 1:get_num_majumps(majumps) - cur_rates[i] = evalrxrate(u, i, majumps) + cur_rates[i] = evalrxrate(u, i, majumps, maj_rates) pqdata[i] = t + randexp(rng) / cur_rates[i] end diff --git a/src/aggregators/rdirect.jl b/src/aggregators/rdirect.jl index eea28316..8d8b34c3 100644 --- a/src/aggregators/rdirect.jl +++ b/src/aggregators/rdirect.jl @@ -11,6 +11,7 @@ mutable struct RDirectJumpAggregation{T, S, F1, F2, DEPGR} <: cur_rates::Vector{T} sum_rate::T ma_jumps::S + maj_rates::Vector{T} rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} @@ -22,6 +23,7 @@ end function RDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{T}(undef, get_num_majumps(maj)), num_specs, counter_threshold = length(crs), dep_graph = nothing, kwargs...) where {T, S, F1, F2} @@ -42,7 +44,7 @@ function RDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, m max_rate = maximum(crs) affecttype = F2 <: Tuple ? F2 : Any return RDirectJumpAggregation{T, S, F1, affecttype, typeof(dg)}(nj, nj, njt, et, - crs, sr, maj, rs, + crs, sr, maj, maj_rates, rs, affs!, sps, dg, max_rate, 0, counter_threshold) @@ -65,6 +67,7 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::RDirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) fill_rates_and_sum!(p, u, params, t) p.max_rate = maximum(p.cur_rates) generate_jumps!(p, integrator, u, params, t) @@ -118,7 +121,7 @@ function update_dependent_rates!(p::RDirectJumpAggregation, u, params, t) @inbounds for rx in dep_rxs @inbounds new_rate = calculate_jump_rate( ma_jumps, num_majumps, rates, u, params, t, - rx) + rx, p.maj_rates) sum_rate += new_rate - cur_rates[rx] if new_rate > p.max_rate p.max_rate = new_rate diff --git a/src/aggregators/rssa.jl b/src/aggregators/rssa.jl index f86e0b44..1e19259e 100644 --- a/src/aggregators/rssa.jl +++ b/src/aggregators/rssa.jl @@ -14,6 +14,7 @@ mutable struct RSSAJumpAggregation{T, S, F1, F2, VJMAP, JVMAP, BD, U} <: cur_rate_high::Vector{T} sum_rate::T ma_jumps::S + maj_rates::Vector{T} rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} @@ -26,6 +27,7 @@ end function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{T}(undef, get_num_majumps(maj)), u::U, vartojumps_map = nothing, jumptovars_map = nothing, bracket_data = nothing, kwargs...) where {T, S, F1, F2, U} @@ -64,7 +66,7 @@ function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, affecttype = F2 <: Tuple ? F2 : Any RSSAJumpAggregation{T, S, F1, affecttype, typeof(vtoj_map), typeof(jtov_map), typeof(bd), U}(nj, nj, njt, et, crl_bnds, - crh_bnds, sr, maj, rs, affs!, sps, + crh_bnds, sr, maj, maj_rates, rs, affs!, sps, vtoj_map, jtov_map, bd, ulow, uhigh) end @@ -86,6 +88,7 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::RSSAJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) set_bracketing!(p, u, params, t) generate_jumps!(p, integrator, u, params, t) nothing @@ -120,7 +123,7 @@ function generate_jumps!(p::RSSAJumpAggregation, integrator, u, params, t) end rerl += randexp(rng) @inbounds while rejectrx(ma_jumps, num_majumps, rates, cur_rate_high, - cur_rate_low, rng, u, jidx, params, t) + cur_rate_low, rng, u, jidx, params, t, p.maj_rates) # sample candidate reaction r = rand(rng) * sum_rate jidx = linear_search(cur_rate_high, r) diff --git a/src/aggregators/rssacr.jl b/src/aggregators/rssacr.jl index 9cc47c3b..887c718e 100644 --- a/src/aggregators/rssacr.jl +++ b/src/aggregators/rssacr.jl @@ -15,6 +15,7 @@ mutable struct RSSACRJumpAggregation{F, S, F1, F2, U, VJMAP, JVMAP, BD, cur_rate_high::Vector{F} sum_rate::F ma_jumps::S + maj_rates::Vector{F} rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} @@ -30,8 +31,9 @@ mutable struct RSSACRJumpAggregation{F, S, F1, F2, U, VJMAP, JVMAP, BD, end function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate::F, maj::S, - rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; u::U, - vartojumps_map = nothing, jumptovars_map = nothing, + rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{F}(undef, get_num_majumps(maj)), + u::U, vartojumps_map = nothing, jumptovars_map = nothing, bracket_data = nothing, minrate = convert(F, MINJUMPRATE), maxrate = convert(F, Inf), kwargs...) where {F, S, F1, F2, U} @@ -82,7 +84,7 @@ function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate: RSSACRJumpAggregation{typeof(njt), S, F1, affecttype, U, typeof(vtoj_map), typeof(jtov_map), typeof(bd), typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, crl_bnds, crh_bnds, - sum_rate, maj, rs, affs!, sps, vtoj_map, + sum_rate, maj, maj_rates, rs, affs!, sps, vtoj_map, jtov_map, bd, ulow, uhigh, minrate, maxrate, rt, ratetogroup) end @@ -103,6 +105,7 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::RSSACRJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) set_bracketing!(p, u, params, t) # setup PriorityTable @@ -145,7 +148,7 @@ function generate_jumps!(p::RSSACRJumpAggregation, integrator, u, params, t) end rerl += randexp(rng) while rejectrx(ma_jumps, num_majumps, rates, cur_rate_high, cur_rate_low, rng, u, jidx, - params, t) + params, t, p.maj_rates) # sample candidate reaction jidx = sample(rt, cur_rate_high, rng) rerl += randexp(rng) diff --git a/src/aggregators/sortingdirect.jl b/src/aggregators/sortingdirect.jl index c20ea1ae..6555a757 100644 --- a/src/aggregators/sortingdirect.jl +++ b/src/aggregators/sortingdirect.jl @@ -11,6 +11,7 @@ mutable struct SortingDirectJumpAggregation{T, S, F1, F2, DEPGR} <: cur_rates::Vector{T} sum_rate::T ma_jumps::S + maj_rates::Vector{T} rates::F1 affects!::F2 save_positions::Tuple{Bool, Bool} @@ -21,6 +22,7 @@ end function SortingDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool, Bool}; + maj_rates = Vector{T}(undef, get_num_majumps(maj)), num_specs, dep_graph = nothing, kwargs...) where {T, S, F1, F2} @@ -42,7 +44,7 @@ function SortingDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr jtoidx = collect(1:length(crs)) affecttype = F2 <: Tuple ? F2 : Any SortingDirectJumpAggregation{T, S, F1, affecttype, typeof(dg)}(nj, nj, njt, et, - crs, sr, maj, rs, + crs, sr, maj, maj_rates, rs, affs!, sps, dg, jtoidx, zero(Int)) @@ -65,6 +67,7 @@ end # set up a new simulation and calculate the first jump / jump time function initialize!(p::SortingDirectJumpAggregation, integrator, u, params, t) p.end_time = integrator.sol.prob.tspan[2] + fill_scaled_rates!(p.maj_rates, p.ma_jumps, params) fill_rates_and_sum!(p, u, params, t) generate_jumps!(p, integrator, u, params, t) nothing diff --git a/src/aggregators/ssajump.jl b/src/aggregators/ssajump.jl index fec06c18..c68979d7 100644 --- a/src/aggregators/ssajump.jl +++ b/src/aggregators/ssajump.jl @@ -127,12 +127,13 @@ function build_jump_aggregation(jump_agg_type, u, p, t, end_time, ma_jumps, rate # current jump rates, allows mass action rates and constant jumps cur_rates = Vector{typeof(t)}(undef, get_num_majumps(majumps) + length(rates)) + maj_rates = Vector{typeof(t)}(undef, get_num_majumps(majumps)) sum_rate = zero(typeof(t)) next_jump = 0 next_jump_time = typemax(typeof(t)) jump_agg_type(next_jump, next_jump_time, end_time, cur_rates, sum_rate, - majumps, rates, affects!, save_positions; kwargs...) + majumps, rates, affects!, save_positions; maj_rates, kwargs...) end """ @@ -146,8 +147,9 @@ function fill_rates_and_sum!(p::AbstractSSAJumpAggregator, u, params, t) # mass action jumps majumps = p.ma_jumps cur_rates = p.cur_rates + maj_rates = p.maj_rates @inbounds for i in 1:get_num_majumps(majumps) - cur_rates[i] = evalrxrate(u, i, majumps) + cur_rates[i] = evalrxrate(u, i, majumps, maj_rates) sum_rate += cur_rates[i] end @@ -169,9 +171,10 @@ end Recalculate the rate for the jump with index `rx`. """ -@inline function calculate_jump_rate(ma_jumps, num_majumps, rates, u, params, t, rx) +@inline function calculate_jump_rate(ma_jumps, num_majumps, rates, u, params, t, rx, + maj_rates) if rx <= num_majumps - return evalrxrate(u, rx, ma_jumps) + return evalrxrate(u, rx, ma_jumps, maj_rates) else @inbounds return rates[rx - num_majumps](u, params, t) end @@ -194,7 +197,7 @@ function update_dependent_rates!(p::AbstractSSAJumpAggregator, u, params, t) @inbounds for rx in dep_rxs sum_rate -= cur_rates[rx] @inbounds cur_rates[rx] = calculate_jump_rate(p.ma_jumps, num_majumps, p.rates, u, - params, t, rx) + params, t, rx, p.maj_rates) sum_rate += cur_rates[rx] end @@ -296,7 +299,7 @@ Perform rejection sampling test (used in RSSA methods). """ @inline function rejectrx( ma_jumps, num_majumps, rates, cur_rate_high, cur_rate_low, rng, u, - jidx, params, t) + jidx, params, t, maj_rates) # rejection test @inbounds r2 = rand(rng) * cur_rate_high[jidx] @inbounds crlow = cur_rate_low[jidx] @@ -305,7 +308,8 @@ Perform rejection sampling test (used in RSSA methods). return false else # calculate actual propensity, split up for type stability - crate = calculate_jump_rate(ma_jumps, num_majumps, rates, u, params, t, jidx) + crate = calculate_jump_rate(ma_jumps, num_majumps, rates, u, params, t, jidx, + maj_rates) if crate > zero(crate) && r2 <= crate return false end diff --git a/src/extended_jump_array.jl b/src/extended_jump_array.jl index 6ba64588..0eacede3 100644 --- a/src/extended_jump_array.jl +++ b/src/extended_jump_array.jl @@ -121,8 +121,23 @@ function ArrayInterface.zeromatrix(A::ExtendedJumpArray) u = [vec(A.u); vec(A.jump_u)] u .* u' .* false end + +# Helper: concatenate fields into a flat vector, apply op, scatter back +function _eja_flat_apply_and_scatter!(op!, A, b::ExtendedJumpArray) + N = length(b.u) + tmp = [vec(b.u); vec(b.jump_u)] + op!(A, tmp) + copyto!(vec(b.u), 1, tmp, 1, N) + copyto!(vec(b.jump_u), 1, tmp, N + 1, length(b.jump_u)) + b +end + function LinearAlgebra.ldiv!(A::LinearAlgebra.LU, b::ExtendedJumpArray) - LinearAlgebra.ldiv!(A, [vec(b.u); vec(b.jump_u)]) + _eja_flat_apply_and_scatter!(LinearAlgebra.ldiv!, A, b) +end + +function LinearAlgebra.lmul!(A::LinearAlgebra.AbstractQ, b::ExtendedJumpArray) + _eja_flat_apply_and_scatter!(LinearAlgebra.lmul!, A, b) end function recursivecopy!(dest::T, src::T) where {T <: ExtendedJumpArray} diff --git a/src/jumps.jl b/src/jumps.jl index 07185d95..f1e1bf5a 100644 --- a/src/jumps.jl +++ b/src/jumps.jl @@ -325,7 +325,13 @@ jprob = JumpProblem(prob, Direct(), maj) """ struct MassActionJump{T, S, U, V} <: AbstractMassActionJump - """The (scaled) reaction rate constants.""" + """ + The (scaled) reaction rate constants. When stored within a `JumpProblem`, this vector is + shared across remade problems. Users are responsible for maintaining consistent + stoichiometric scaling and thread safety if mutating directly. In general, prefer using + the parameter index or mapping form for rates that need to change, rather than mutating + `scaled_rates` directly. + """ scaled_rates::T """The reactant stoichiometry vectors.""" reactant_stoch::S @@ -333,7 +339,7 @@ struct MassActionJump{T, S, U, V} <: AbstractMassActionJump net_stoch::U """Parameter mapping functor to identify reaction rate constants with parameters in `p` vectors.""" param_mapper::V - """Whether `update_parameters!` should apply stoichiometric scaling to rates.""" + """Whether the built-in mapper callable should apply stoichiometric scaling to rates.""" rescale_rates_on_update::Bool function MassActionJump{T, S, U, V}(rates::T, rs_in::S, ns::U, pmapper::V, @@ -415,7 +421,7 @@ function MassActionJump(rs, ns; param_idxs = nothing, param_mapper = nothing, if param_mapper === nothing (param_idxs === nothing) && error("If no parameter indices are given via param_idxs, an explicit parameter mapping must be passed in via param_mapper.") - pmapper = MassActionJumpParamMapper(param_idxs) + pmapper = MassActionJumpParamMapper(param_idxs isa Integer ? [param_idxs] : param_idxs) else (param_idxs !== nothing) && error("Only one of param_idxs and param_mapper should be passed.") @@ -432,71 +438,69 @@ using_params(maj::Nothing) = false @inline get_num_majumps(maj::MassActionJump) = length(maj.net_stoch) @inline get_num_majumps(maj::Nothing) = 0 -struct MassActionJumpParamMapper{U} - param_idxs::U -end +""" + MassActionJumpParamMapper{U} -# create the initial parameter vector for use in a MassActionJump -# Note these are unscaled -function (ratemap::MassActionJumpParamMapper{U})(params) where {U <: AbstractArray} - [params[pidx] for pidx in ratemap.param_idxs] -end +Default parameter mapper for parameterized `MassActionJump`s (those constructed +with `param_idxs` instead of explicit `scaled_rates`). Stores the indices into +the parameter vector `p` that correspond to each reaction's rate constant. + +Implements the in-place mapper callable API: + + (mapper)(dest::AbstractVector, maj::MassActionJump, params) -# Note this is unscaled -function (ratemap::MassActionJumpParamMapper{U})(params) where {U <: Int} - params[ratemap.param_idxs] +which is called by [`fill_scaled_rates!`](@ref) during aggregator initialization +and reinitialization to populate the working rate vector from parameters. + +`dest` should be filled with the current rates for each reaction. If +`maj.rescale_rates_on_update` is `true`, the mapper should also apply +stoichiometric scaling via [`scalerates!`](@ref). + +Custom mappers (e.g. ModelingToolkitBase's `JumpSysMajParamMapper`) should +implement this 3-arg callable to support parameterized `MassActionJump`s. +""" +struct MassActionJumpParamMapper{U} + param_idxs::U end -# update a maj with parameter vectors -function (ratemap::MassActionJumpParamMapper{U})(maj::MassActionJump, newparams; - scale_rates, - kwargs...) where {U <: AbstractArray} - for i in 1:get_num_majumps(maj) - maj.scaled_rates[i] = newparams[ratemap.param_idxs[i]] +function (mapper::MassActionJumpParamMapper{U})(dest::AbstractVector, + maj::MassActionJump, params) where {U <: AbstractArray} + @inbounds for i in eachindex(dest) + dest[i] = params[mapper.param_idxs[i]] end - scale_rates && scalerates!(maj.scaled_rates, maj.reactant_stoch) + maj.rescale_rates_on_update && scalerates!(dest, maj.reactant_stoch) nothing end -function to_collection(ratemap::MassActionJumpParamMapper{Int}) - MassActionJumpParamMapper([ratemap.param_idxs]) -end +to_collection(ratemap::MassActionJumpParamMapper) = MassActionJumpParamMapper(copy(ratemap.param_idxs)) function Base.merge!(pmap1::MassActionJumpParamMapper{U}, pmap2::MassActionJumpParamMapper{U}) where {U <: AbstractVector} append!(pmap1.param_idxs, pmap2.param_idxs) end -function Base.merge!(pmap1::MassActionJumpParamMapper{U}, - pmap2::MassActionJumpParamMapper{V}) where {U <: AbstractVector, - V <: Int} - push!(pmap1.param_idxs, pmap2.param_idxs) -end - -function Base.merge(pmap1::MassActionJumpParamMapper{Int}, - pmap2::MassActionJumpParamMapper{Int}) - MassActionJumpParamMapper([pmap1.param_idxs, pmap2.param_idxs]) -end - """ -update_parameters!(maj::MassActionJump, newparams; scale_rates=maj.rescale_rates_on_update) + fill_scaled_rates!(dest::AbstractVector, maj::MassActionJump, params) -Updates the passed in MassActionJump with the parameter values in `newparams`. +Fill `dest` with the current scaled rates for the mass action jump. Dispatches on the +`scaled_rates` type parameter of `maj`: -Notes: - - - Requires the jump to have been constructed with a user-passed `param_idxs` or `param_mapper`. - - `scale_rates` defaults to `maj.rescale_rates_on_update`, which itself defaults to the - `scale_rates` value used when constructing the jump. When `true`, the parameter - representing the jump rate will be scaled by an appropriate combinatoric factor, i.e. - for 3A --> B at rate k it will scale k --> k/3!. + - `T = Nothing` (parameterized MAJ): delegates to `maj.param_mapper(dest, maj, params)`. + - `T <: AbstractVector` (non-parameterized MAJ): copies `maj.scaled_rates` into `dest`. """ -function update_parameters!(maj::MassActionJump, newparams; scale_rates = maj.rescale_rates_on_update, kwargs...) - (maj.param_mapper === nothing) && - error("MassActionJumps must be constructed with param_idxs or a param_mapper to be updateable.") - maj.param_mapper(maj, newparams; scale_rates, kwargs) +function fill_scaled_rates!(dest::AbstractVector, maj::MassActionJump{Nothing}, params) + maj.param_mapper(dest, maj, params) + nothing end +function fill_scaled_rates!(dest::AbstractVector, maj::MassActionJump{T}, + params) where {T <: AbstractVector} + dest .= maj.scaled_rates + nothing +end + +fill_scaled_rates!(dest, maj::Nothing, params) = nothing + """ $(TYPEDEF) @@ -544,6 +548,11 @@ end function JumpSet(vj, cj, rj, maj::MassActionJump{S, T, U, V}) where {S <: Number, T, U, V} JumpSet(vj, cj, rj, check_majump_type(maj)) end +function JumpSet(vj, cj, rj, + maj::MassActionJump{Nothing, T, U, V}) where { + T <: AbstractArray{<:Pair}, U <: AbstractArray{<:Pair}, V} + JumpSet(vj, cj, rj, check_majump_type(maj)) +end JumpSet(jump::ConstantRateJump) = JumpSet((), (jump,), nothing, nothing) JumpSet(jump::VariableRateJump) = JumpSet((jump,), (), nothing, nothing) @@ -566,6 +575,11 @@ function JumpSet(vjs, cjs, rj, majv::Vector{T}) where {T <: MassActionJump} error("JumpSets do not accept empty mass action jump collections; use \"nothing\" instead.") end + if any(m -> using_params(m) && !(m.param_mapper isa MassActionJumpParamMapper), majv) + error("Cannot merge MassActionJumps backed by custom parameter mappers. " * + "Construct a single MassActionJump with all reactions instead.") + end + sr_val = majv[1].rescale_rates_on_update if !all(m -> m.rescale_rates_on_update == sr_val, majv) error("Cannot merge MassActionJumps with different rescale_rates_on_update settings.") @@ -640,10 +654,12 @@ function check_majump_type(maj::MassActionJump{Nothing, T, U, V}) where {T, U, V end # if given containers of rates and stoichiometry directly create a jump +# copy arrays since majump_merge! will mutate them in-place function setup_majump_to_merge(sr::T, rs::AbstractVector{S}, ns::AbstractVector{U}, pmapper, rescale_rates_on_update::Bool) where {T <: AbstractVector, S <: AbstractArray, U <: AbstractArray} - MassActionJump(sr, rs, ns, pmapper; scale_rates = false, rescale_rates_on_update) + MassActionJump(copy(sr), copy(rs), copy(ns), pmapper; + scale_rates = false, nocopy = true, rescale_rates_on_update) end # if just given the data for one jump (and not in a container) wrap in a vector @@ -655,7 +671,17 @@ function setup_majump_to_merge(sr::S, rs::T, ns::U, pmapper, scale_rates = false, rescale_rates_on_update) end -# if no rate field setup yet +# if no rate field setup yet — collection case (rs is already Vector{<:AbstractArray}) +# copy arrays and mapper since majump_merge! will mutate them in-place +function setup_majump_to_merge(::Nothing, rs::AbstractVector{S}, ns::AbstractVector{U}, + pmapper, + rescale_rates_on_update::Bool) where {S <: AbstractArray, U <: AbstractArray} + pm = (pmapper === nothing) ? nothing : deepcopy(pmapper) + MassActionJump(nothing, copy(rs), copy(ns), pm; + scale_rates = false, nocopy = true, rescale_rates_on_update) +end + +# if no rate field setup yet — single reaction case (wrap in a vector) function setup_majump_to_merge(::Nothing, rs::T, ns::U, pmapper, rescale_rates_on_update::Bool) where {T <: AbstractArray, U <: AbstractArray} MassActionJump(nothing, [rs], [ns], @@ -721,9 +747,18 @@ function majump_merge!(maj::MassActionJump{T, S, U, V}, sr::T, rs::S, ns::U, end massaction_jump_combine(maj1::MassActionJump, maj2::Nothing) = maj1 -massaction_jump_combine(maj1::Nothing, maj2::MassActionJump) = maj2 massaction_jump_combine(maj1::Nothing, maj2::Nothing) = maj1 +# copy the MAJ so it is safe to mutate during subsequent merges +function massaction_jump_combine(maj1::Nothing, maj2::MassActionJump) + setup_majump_to_merge(maj2.scaled_rates, maj2.reactant_stoch, maj2.net_stoch, + maj2.param_mapper, maj2.rescale_rates_on_update) +end function massaction_jump_combine(maj1::MassActionJump, maj2::MassActionJump) + for m in (maj1, maj2) + (using_params(m) && !(m.param_mapper isa MassActionJumpParamMapper)) && + error("Cannot merge MassActionJumps backed by custom parameter mappers. " * + "Construct a single MassActionJump with all reactions instead.") + end (maj1.rescale_rates_on_update == maj2.rescale_rates_on_update) || error("Cannot merge MassActionJumps with different rescale_rates_on_update settings.") majump_merge!(maj1, maj2.scaled_rates, maj2.reactant_stoch, maj2.net_stoch, diff --git a/src/massaction_rates.jl b/src/massaction_rates.jl index 06d4dfa8..31e639f6 100644 --- a/src/massaction_rates.jl +++ b/src/massaction_rates.jl @@ -4,7 +4,7 @@ ############################################################################### @inline function evalrxrate(speciesvec::AbstractVector{T}, rxidx, - majump::MassActionJump{U})::R where {T <: Integer, R, U <: AbstractVector{R}} + majump::MassActionJump, maj_rates::AbstractVector{R})::R where {T <: Integer, R} val = one(T) @inbounds for specstoch in majump.reactant_stoch[rxidx] specpop = speciesvec[specstoch[1]] @@ -15,11 +15,11 @@ end end - @inbounds return val * majump.scaled_rates[rxidx] + @inbounds return val * maj_rates[rxidx] end @inline function evalrxrate(speciesvec::AbstractVector{T}, rxidx, - majump::MassActionJump{U})::R where {T <: Real, R, U <: AbstractVector{R}} + majump::MassActionJump, maj_rates::AbstractVector{R})::R where {T <: Real, R} val = one(T) @inbounds for specstoch in majump.reactant_stoch[rxidx] specpop = speciesvec[specstoch[1]] @@ -29,11 +29,11 @@ end val *= specpop end # we need to check the smallest rate law term is positive - # i.e. for an order k reaction: x - k + 1 > 0 + # i.e. for an order k reaction: x - k + 1 > 0 (specpop <= 0) && return zero(R) end - @inbounds return val * majump.scaled_rates[rxidx] + @inbounds return val * maj_rates[rxidx] end @inline function executerx!(speciesvec::AbstractVector{T}, rxidx::S, diff --git a/src/problem.jl b/src/problem.jl index 56d58da6..05838c46 100644 --- a/src/problem.jl +++ b/src/problem.jl @@ -157,10 +157,6 @@ function DiffEqBase.remake(jprob::JumpProblem; u0 = missing, p = missing, newprob = DiffEqBase.remake(prob; u0, p, interpret_symbolicmap, use_defaults, kwargs...) end - # if the parameters were changed we must remake the MassActionJump too - if (p !== missing) && using_params(jprob.massaction_jump) - update_parameters!(jprob.massaction_jump, newprob.p; kwargs...) - end else ((u0 !== missing) || (p !== missing) || (:tspan ∈ keys(kwargs))) && error("If remaking a JumpProblem you can not pass both prob and any of u0, p, or tspan.") @@ -169,11 +165,6 @@ function DiffEqBase.remake(jprob::JumpProblem; u0 = missing, p = missing, # when passing a new wrapped problem directly we require u0 has the correct type (typeof(newprob.u0) == typeof(jprob.prob.u0)) || error("The new u0 within the passed prob does not have the same type as the existing u0. Please pass a u0 of type $(typeof(jprob.prob.u0)).") - - # we can't know if p was changed, so we must remake the MassActionJump - if using_params(jprob.massaction_jump) - update_parameters!(jprob.massaction_jump, newprob.p; kwargs...) - end end T(newprob, jprob.aggregator, jprob.discrete_jump_aggregation, jprob.jump_callback, @@ -181,14 +172,6 @@ function DiffEqBase.remake(jprob::JumpProblem; u0 = missing, p = missing, jprob.massaction_jump, jprob.kwargs) end -# for updating parameters in JumpProblems to update MassActionJumps -function SII.finalize_parameters_hook!(prob::JumpProblem, p) - if using_params(prob.massaction_jump) - update_parameters!(prob.massaction_jump, SII.parameter_values(prob)) - end - nothing -end - DiffEqBase.isinplace(::JumpProblem{iip}) where {iip} = iip JumpProblem(prob::JumpProblem) = prob @@ -247,26 +230,23 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS vr_aggregator::VariableRateAggregator = VR_FRM(), save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true), - scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, callback = nothing, tstops = nothing, use_vrj_bounds = true, kwargs...) if haskey(kwargs, :rng) throw(ArgumentError("`rng` is no longer a keyword argument for `JumpProblem`. Pass `rng` to `solve` or `init` instead, e.g. `solve(jprob, SSAStepper(); rng = my_rng)`.")) end - - # initialize the MassActionJump rate constants with the user parameters - if using_params(jumps.massaction_jump) - rates = jumps.massaction_jump.param_mapper(prob.p) - maj = MassActionJump(rates, jumps.massaction_jump.reactant_stoch, - jumps.massaction_jump.net_stoch, - jumps.massaction_jump.param_mapper; scale_rates = scale_rates, - useiszero = useiszero, - nocopy = true) - else - maj = jumps.massaction_jump + if haskey(kwargs, :scale_rates) + throw(ArgumentError("`scale_rates` is no longer a keyword argument for `JumpProblem`. Set `scale_rates` on the `MassActionJump` directly instead.")) + end + if haskey(kwargs, :useiszero) + throw(ArgumentError("`useiszero` is no longer a keyword argument for `JumpProblem`. Set `useiszero` on the `MassActionJump` directly instead.")) end + # keep the original MAJ (including {Nothing} parameterized MAJs); + # fill_scaled_rates! in each aggregator's initialize! handles rate setup + maj = jumps.massaction_jump + ## Spatial jumps handling if spatial_system !== nothing && hopping_constants !== nothing (num_crjs(jumps) == num_vrjs(jumps) == 0) || @@ -330,7 +310,6 @@ end function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; save_positions = prob isa DiffEqBase.AbstractDiscreteProblem ? (false, true) : (true, true), - scale_rates = true, useiszero = true, spatial_system = nothing, hopping_constants = nothing, callback = nothing, tstops = nothing, kwargs...) @@ -342,17 +321,9 @@ function JumpProblem(prob, aggregator::PureLeaping, jumps::JumpSet; (spatial_system !== nothing || hopping_constants !== nothing) && error("PureLeaping does not currently support spatial problems.") - # Initialize the MassActionJump rate constants with the user parameters - if using_params(jumps.massaction_jump) - rates = jumps.massaction_jump.param_mapper(prob.p) - maj = MassActionJump(rates, jumps.massaction_jump.reactant_stoch, - jumps.massaction_jump.net_stoch, - jumps.massaction_jump.param_mapper; scale_rates = scale_rates, - useiszero = useiszero, - nocopy = true) - else - maj = jumps.massaction_jump - end + # keep the original MAJ (including {Nothing} parameterized MAJs); + # fill_scaled_rates! in the tau-leaping solver handles rate setup + maj = jumps.massaction_jump # For PureLeaping, all jumps are handled by the tau-leaping solver # No discrete jump aggregation or variable rate callbacks are created diff --git a/src/simple_regular_solve.jl b/src/simple_regular_solve.jl index f71351ef..78b70428 100644 --- a/src/simple_regular_solve.jl +++ b/src/simple_regular_solve.jl @@ -262,10 +262,10 @@ function compute_tau( end # Function to generate a mass action rate function -function massaction_rate(maj, numjumps) +function massaction_rate(maj, maj_rates, numjumps) return (out, u, p, t) -> begin for j in 1:numjumps - out[j] = evalrxrate(u, j, maj) + out[j] = evalrxrate(u, j, maj, maj_rates) end end end @@ -355,9 +355,11 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping; maj = jump_prob.massaction_jump numjumps = get_num_majumps(maj) + maj_rates = Vector{typeof(tspan[2])}(undef, numjumps) + fill_scaled_rates!(maj_rates, maj, prob.p) rj = jump_prob.regular_jump # Extract rates - rate = rj !== nothing ? rj.rate : massaction_rate(maj, numjumps) + rate = rj !== nothing ? rj.rate : massaction_rate(maj, maj_rates, numjumps) c = rj !== nothing ? rj.c : nothing u0 = copy(prob.u0) p = prob.p diff --git a/src/spatial/directcrdirect.jl b/src/spatial/directcrdirect.jl index 0961ec25..04e77334 100644 --- a/src/spatial/directcrdirect.jl +++ b/src/spatial/directcrdirect.jl @@ -142,6 +142,7 @@ function fill_rates_and_get_times!( (; spatial_system, rx_rates, hop_rates, site_rates, rt) = aggregation u = integrator.u + fill_scaled_rates!(rx_rates.maj_rates, rx_rates.ma_jumps, integrator.p) reset!(rx_rates) reset!(hop_rates) site_rates .= zero(typeof(t)) diff --git a/src/spatial/flatten.jl b/src/spatial/flatten.jl index aed6279a..2b0dbc67 100644 --- a/src/spatial/flatten.jl +++ b/src/spatial/flatten.jl @@ -15,7 +15,9 @@ function flatten(ma_jump, prob::DiscreteProblem, spatial_system, hopping_constan netstoch = ma_jump.net_stoch reactstoch = ma_jump.reactant_stoch rx_rates = if isa(ma_jump, MassActionJump) - ma_jump.scaled_rates + rates_init = Vector{typeof(tspan[1])}(undef, get_num_majumps(ma_jump)) + fill_scaled_rates!(rates_init, ma_jump, prob.p) + rates_init elseif isa(ma_jump, SpatialMassActionJump) num_nodes = num_sites(spatial_system) if isnothing(ma_jump.uniform_rates) && isnothing(ma_jump.spatial_rates) diff --git a/src/spatial/nsm.jl b/src/spatial/nsm.jl index acf6b56e..e661cefa 100644 --- a/src/spatial/nsm.jl +++ b/src/spatial/nsm.jl @@ -129,6 +129,7 @@ function fill_rates_and_get_times!(aggregation::NSMJumpAggregation, integrator, rng = get_rng(integrator) u = integrator.u + fill_scaled_rates!(rx_rates.maj_rates, rx_rates.ma_jumps, integrator.p) reset!(rx_rates) reset!(hop_rates) diff --git a/src/spatial/reaction_rates.jl b/src/spatial/reaction_rates.jl index 737cc5c9..c0685851 100644 --- a/src/spatial/reaction_rates.jl +++ b/src/spatial/reaction_rates.jl @@ -12,6 +12,9 @@ struct RxRates{F, M} "AbstractMassActionJump" ma_jumps::M + + "working copy of scaled mass action jump rates" + maj_rates::Vector{F} end """ @@ -22,7 +25,8 @@ initializes RxRates with zero rates function RxRates(num_sites::Int, ma_jumps::M) where {M} numrxjumps = get_num_majumps(ma_jumps) rates = zeros(Float64, numrxjumps, num_sites) - RxRates{Float64, M}(rates, vec(sum(rates, dims = 1)), ma_jumps) + maj_rates = Vector{Float64}(undef, numrxjumps) + RxRates{Float64, M}(rates, vec(sum(rates, dims = 1)), ma_jumps, maj_rates) end num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) @@ -55,8 +59,9 @@ update rates of all reactions in rxs at site function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator, site) ma_jumps = rx_rates.ma_jumps + maj_rates = rx_rates.maj_rates @inbounds for rx in rxs - rate = eval_massaction_rate(u, rx, ma_jumps, site) + rate = eval_massaction_rate(u, rx, ma_jumps, site, maj_rates) set_rx_rate_at_site!(rx_rates, site, rx, rate) end end @@ -90,9 +95,9 @@ function Base.show(io::IO, ::MIME"text/plain", rx_rates::RxRates) println(io, "RxRates with $num_rxs reactions and $num_sites sites") end -function eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: SpatialMassActionJump} +function eval_massaction_rate(u, rx, ma_jumps::M, site, maj_rates) where {M <: SpatialMassActionJump} evalrxrate(u, rx, ma_jumps, site) end -function eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: MassActionJump} - evalrxrate((@view u[:, site]), rx, ma_jumps) +function eval_massaction_rate(u, rx, ma_jumps::M, site, maj_rates) where {M <: MassActionJump} + evalrxrate((@view u[:, site]), rx, ma_jumps, maj_rates) end diff --git a/src/spatial/spatial_massaction_jump.jl b/src/spatial/spatial_massaction_jump.jl index c9af9d1e..9706a182 100644 --- a/src/spatial/spatial_massaction_jump.jl +++ b/src/spatial/spatial_massaction_jump.jl @@ -82,6 +82,11 @@ function SpatialMassActionJump(urates::A, rs, ns; scale_rates = true, useiszero useiszero = useiszero, nocopy = nocopy) end +function SpatialMassActionJump(ma_jumps::MassActionJump{Nothing}; kwargs...) + error("Cannot construct SpatialMassActionJump from a parameter-mapped MassActionJump. " * + "Provide explicit rate vectors or use a SpatialMassActionJump directly.") +end + # scale_rates defaults to false since ma_jumps.scaled_rates are already scaled; # passing true would double-scale. function SpatialMassActionJump(ma_jumps::MassActionJump{T, S, U, V}; scale_rates = false, @@ -91,6 +96,10 @@ function SpatialMassActionJump(ma_jumps::MassActionJump{T, S, U, V}; scale_rates scale_rates = scale_rates, useiszero = useiszero, nocopy = nocopy) end +# SpatialMassActionJump stores rates in the struct itself (uniform_rates/spatial_rates), +# so fill_scaled_rates! is a no-op. +fill_scaled_rates!(dest, maj::SpatialMassActionJump, params) = nothing + ############################################## function get_num_majumps(smaj::SpatialMassActionJump{ diff --git a/test/bracketing.jl b/test/bracketing.jl index 18268943..6138e971 100644 --- a/test/bracketing.jl +++ b/test/bracketing.jl @@ -36,9 +36,11 @@ reactstoch = [[1 => 1]] netstoch = [[1 => -1]] majump = MassActionJump(majump_rates, reactstoch, netstoch) +maj_rates_work = Vector{Float64}(undef, JP.get_num_majumps(majump)) +JP.fill_scaled_rates!(maj_rates_work, majump, nothing) reaction_index = 1 -@test JP.get_majump_brackets(ulow, uhigh, reaction_index, majump)[1] == majump_rates[1] * ulow[1] # low -@test JP.get_majump_brackets(ulow, uhigh, reaction_index, majump)[2] == majump_rates[1] * uhigh[1] # high +@test JP.get_majump_brackets(ulow, uhigh, reaction_index, majump, maj_rates_work)[1] == majump_rates[1] * ulow[1] # low +@test JP.get_majump_brackets(ulow, uhigh, reaction_index, majump, maj_rates_work)[2] == majump_rates[1] * uhigh[1] # high # constant rate rate(u, params, t) = 1 / u[1] @@ -56,6 +58,7 @@ mutable struct DummyAggregator{T, M, R, BD} <: cur_rate_high::Vector{T} sum_rate::T ma_jumps::M + maj_rates::Vector{T} rates::R bracket_data::BD end @@ -63,7 +66,7 @@ end cur_rate_low = [0.0, 0.0] cur_rate_high = [0.0, 0.0] sum_rate = 0.0 -p = DummyAggregator([0], [0], cur_rate_low, cur_rate_high, sum_rate, majump, [rate], bd) +p = DummyAggregator([0], [0], cur_rate_low, cur_rate_high, sum_rate, majump, maj_rates_work, [rate], bd) u = [100] JP.update_u_brackets!(p, u) @@ -77,7 +80,7 @@ reaction_index = 2 @test JP.get_jump_brackets(reaction_index, p, params, t)[1] == rate(p.uhigh, params, t) @test JP.get_jump_brackets(reaction_index, p, params, t)[2] == rate(p.ulow, params, t) -p = DummyAggregator([0], [0], cur_rate_low, cur_rate_high, sum_rate, majump, [rate], bd) +p = DummyAggregator([0], [0], cur_rate_low, cur_rate_high, sum_rate, majump, maj_rates_work, [rate], bd) JP.set_bracketing!(p, u, params, t) @test p.ulow[1]≈u[1] * (1 - fluctuation_rate) atol=1 @test p.uhigh[1]≈u[1] * (1 + fluctuation_rate) atol=1 diff --git a/test/callbacks.jl b/test/callbacks.jl index ce45d006..4fe72732 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -472,3 +472,46 @@ end @test_throws ErrorException init(jprob_ccb, SSAStepper(); rng) @test_throws ErrorException init(jprob, SSAStepper(); callback = ccb, rng) end + +@testset "SDE + jump callback not duplicated" begin + # Regression test for PR #567: verify that when using an SDE algorithm with a + # JumpProblem, the jump callback is added exactly once (not duplicated by both + # JumpProcesses and StochasticDiffEq). See: + # https://github.com/SciML/JumpProcesses.jl/pull/567#issuecomment-4092662794 + using StochasticDiffEq + + # SDE: dX = -X dt + 0.1 dW + f(u, p, t) = -u + g(u, p, t) = 0.1 + + # Constant rate jump + rate(u, p, t) = 1.0 + affect!(integrator) = (integrator.u += 0.5) + crj = ConstantRateJump(rate, affect!) + + sde_prob = SDEProblem(f, g, 1.0, (0.0, 1.0)) + jprob = JumpProblem(sde_prob, Direct(), crj) + + integrator = init(jprob, EM(); dt = 0.01, rng) + + # Count discrete callbacks — the jump should appear exactly once. + # If both JumpProcesses and StochasticDiffEq add it, we'd see duplicates. + n_discrete = length(integrator.opts.callback.discrete_callbacks) + @test n_discrete == 1 + + # Also test with a VariableRateJump (needs array u0 for ExtendedJumpArray) + f_vr(du, u, p, t) = (du[1] = -u[1]) + g_vr(du, u, p, t) = (du[1] = 0.1) + sde_prob_vr = SDEProblem(f_vr, g_vr, [1.0], (0.0, 1.0)) + + vrate(u, p, t) = 1.0 + vaffect!(integrator) = (integrator.u[1] += 0.5) + vrj = VariableRateJump(vrate, vaffect!) + + jprob_vr = JumpProblem(sde_prob_vr, Direct(), vrj; vr_aggregator = VR_FRM()) + integrator_vr = init(jprob_vr, EM(); dt = 0.01, rng) + + # VariableRateJumps produce continuous callbacks + n_continuous = length(integrator_vr.opts.callback.continuous_callbacks) + @test n_continuous == 1 +end diff --git a/test/extended_jump_array.jl b/test/extended_jump_array.jl index 349b5bb3..b0789da1 100644 --- a/test/extended_jump_array.jl +++ b/test/extended_jump_array.jl @@ -1,4 +1,4 @@ -using Test, JumpProcesses, DiffEqBase, OrdinaryDiffEq, SciMLBase +using Test, JumpProcesses, DiffEqBase, OrdinaryDiffEq, SciMLBase, LinearAlgebra, LinearSolve using FastBroadcast using StableRNGs @@ -118,3 +118,44 @@ let @test eltype(sol.u) <: ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}} @test SciMLBase.plottable_indices(sol.u[1]) == 1:length(u₀) end + +# Test ldiv! and lmul! for stiff solver support +let rng = StableRNG(456) + u = rand(rng, 3) + jump_u = rand(rng, 2) + flat = [u; jump_u] + + # ldiv! with LU should modify eja in place and match plain vector result + eja = ExtendedJumpArray(copy(u), copy(jump_u)) + A = rand(rng, 5, 5) + 5I + F = lu(A) + expected = F \ flat + ldiv!(F, eja) + @test vcat(eja.u, eja.jump_u) ≈ expected + + # lmul! with Q from QR + eja2 = ExtendedJumpArray(copy(u), copy(jump_u)) + Q = qr(rand(rng, 5, 5)).Q + expected_q = Q * flat + lmul!(Q, eja2) + @test vcat(eja2.u, eja2.jump_u) ≈ expected_q + + # lmul! with AdjointQ from QR (the actual CI failure case) + eja3 = ExtendedJumpArray(copy(u), copy(jump_u)) + expected_qt = Q' * flat + lmul!(Q', eja3) + @test vcat(eja3.u, eja3.jump_u) ≈ expected_qt +end + +# Integration test: stiff solver with QRFactorization and ExtendedJumpArray +let + f!(du, u, p, t) = (du .= 0; nothing) + rate(u, p, t) = 0.5 + affect!(integrator) = (integrator.u[1] += 1; nothing) + vrj = VariableRateJump(rate, affect!) + oprob = ODEProblem(f!, [0.0], (0.0, 1.0)) + jprob = JumpProblem(oprob, Direct(), vrj; vr_aggregator = VR_FRM()) + sol = solve(jprob, Rodas5P(linsolve = QRFactorization()); + rng = StableRNG(789)) + @test sol.retcode == ReturnCode.Success +end diff --git a/test/jprob_symbol_indexing.jl b/test/jprob_symbol_indexing.jl index 4cad845c..9419c3c5 100644 --- a/test/jprob_symbol_indexing.jl +++ b/test/jprob_symbol_indexing.jl @@ -24,16 +24,25 @@ jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj) jprob[:a] = 20 @test jprob[:a] == 20 -# test mass action jumps update with parameter mutation in problems -@test jprob.massaction_jump.scaled_rates[1] == 1.0 +# test parameterized MAJ stores nothing for scaled_rates +@test jprob.massaction_jump.scaled_rates === nothing + +# test mass action jump rates update correctly after init +integ = init(jprob, SSAStepper()) +@test jprob.discrete_jump_aggregation.maj_rates == [1.0, 2.0] + +# test parameter mutation updates problem params jprob.ps[:p1] = 3.0 @test jprob.ps[:p1] == 3.0 -@test jprob.massaction_jump.scaled_rates[1] == 3.0 +# rates update after re-initialization +integ = init(jprob, SSAStepper()) +@test jprob.discrete_jump_aggregation.maj_rates[1] == 3.0 p1setter = setp(jprob, [:p1, :p2]) p1setter(jprob, [4.0, 10.0]) @test jprob.ps[:p1] == 4.0 @test jprob.ps[:p2] == 10.0 -@test jprob.massaction_jump.scaled_rates == [4.0, 10.0] +integ = init(jprob, SSAStepper()) +@test jprob.discrete_jump_aggregation.maj_rates == [4.0, 10.0] # integrator tests # note that `setu` is not currently supported as `set_u!` is not implemented for SSAStepper @@ -44,9 +53,8 @@ integ[[:b, :a]] = [40, 5] @test getp(integ, :p2)(integ) == 10.0 setp(integ, :p2)(integ, 15.0) @test getp(integ, :p2)(integ) == 15.0 -@test jprob.massaction_jump.scaled_rates[2] == 10.0 # jump rate not updated reset_aggregated_jumps!(integ) -@test jprob.massaction_jump.scaled_rates[2] == 15.0 # jump rate now updated +@test jprob.discrete_jump_aggregation.maj_rates[2] == 15.0 # jump rate now updated # remake tests dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0]) @@ -54,19 +62,22 @@ jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj) jprob = remake(jprob; u0 = [:a => -10, :b => 100], p = [:p2 => 3.5, :p1 => 0.5]) @test jprob.prob.u0 == [-10, 100] @test jprob.prob.p == [0.5, 3.5] -@test jprob.massaction_jump.scaled_rates == [0.5, 3.5] +integ = init(jprob, SSAStepper()) +@test jprob.discrete_jump_aggregation.maj_rates == [0.5, 3.5] jprob = remake(jprob; u0 = [:b => 10], p = [:p2 => 4.5]) @test jprob.prob.u0 == [-10, 10] @test jprob.prob.p == [0.5, 4.5] -@test jprob.massaction_jump.scaled_rates == [0.5, 4.5] +integ = init(jprob, SSAStepper()) +@test jprob.discrete_jump_aggregation.maj_rates == [0.5, 4.5] -# test updating problems via regular indexing still updates the mass action jump +# test updating problems via regular indexing dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0]) jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj) -@test jprob.massaction_jump.scaled_rates[1] == 1.0 +@test jprob.massaction_jump.scaled_rates === nothing jprob.ps[1] = 3.0 @test jprob.ps[1] == 3.0 -@test jprob.massaction_jump.scaled_rates[1] == 3.0 +integ = init(jprob, SSAStepper()) +@test jprob.discrete_jump_aggregation.maj_rates[1] == 3.0 # test updating integrators via regular indexing dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0]) @@ -76,12 +87,12 @@ integ.u .= [40, 5] @test getu(integ, [1, 2])(integ) == [40, 5] @test getp(integ, 2)(integ) == 2.0 @test integ.p[2] == 2.0 -@test jprob.massaction_jump.scaled_rates[2] == 2.0 +@test jprob.discrete_jump_aggregation.maj_rates[2] == 2.0 setp(integ, 2)(integ, 15.0) @test integ.p[2] == 15.0 @test getp(integ, 2)(integ) == 15.0 reset_aggregated_jumps!(integ) -@test jprob.massaction_jump.scaled_rates[2] == 15.0 # jump rate now updated +@test jprob.discrete_jump_aggregation.maj_rates[2] == 15.0 # jump rate now updated # remake tests for regular indexing dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0]) @@ -89,8 +100,10 @@ jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj) jprob = remake(jprob; u0 = [-10, 100], p = [0.5, 3.5]) @test jprob.prob.u0 == [-10, 100] @test jprob.prob.p == [0.5, 3.5] -@test jprob.massaction_jump.scaled_rates == [0.5, 3.5] +integ = init(jprob, SSAStepper()) +@test jprob.discrete_jump_aggregation.maj_rates == [0.5, 3.5] jprob = remake(jprob; u0 = [2 => 10], p = [2 => 4.5]) @test jprob.prob.u0 == [-10, 10] @test jprob.prob.p == [0.5, 4.5] -@test jprob.massaction_jump.scaled_rates == [0.5, 4.5] +integ = init(jprob, SSAStepper()) +@test jprob.discrete_jump_aggregation.maj_rates == [0.5, 4.5] diff --git a/test/regular_jumps.jl b/test/regular_jumps.jl index 99c3375c..6cbc2804 100644 --- a/test/regular_jumps.jl +++ b/test/regular_jumps.jl @@ -251,8 +251,10 @@ end # Test MassActionJump with parameter mapping maj_params = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1, 2]) jp_params = JumpProblem(prob, PureLeaping(), JumpSet(maj_params)) - scaled_rates = [p[1], p[2]/2] - @test jp_params.massaction_jump.scaled_rates == scaled_rates + @test jp_params.massaction_jump.scaled_rates === nothing + dest = zeros(2) + JumpProcesses.fill_scaled_rates!(dest, jp_params.massaction_jump, p) + @test dest == [p[1], p[2] / 2] end # Test that saveat/save_start/save_end control which times are stored in solutions diff --git a/test/scale_rates_field_test.jl b/test/scale_rates_field_test.jl index 3f06ed88..7044659e 100644 --- a/test/scale_rates_field_test.jl +++ b/test/scale_rates_field_test.jl @@ -7,12 +7,10 @@ reactant_stoch = [[1 => 3]] net_stoch = [[1 => -3, 2 => 1]] # Custom mapper mimicking MTKBase's JumpSysMajParamMapper. -# The initial rates callable always pre-scales (simulating symbolic expressions that -# already contain the combinatoric factor, e.g. k/3! for 3X → Y). -# The update callable also pre-scales, then conditionally applies scalerates! based on -# scale_rates — matching MTKBase's behavior. This is what makes the test an actual -# regression test: with the old default of scale_rates=true, the second scalerates! -# would double-scale. +# The 1-arg callable pre-scales rates (simulating symbolic expressions that already +# contain the combinatoric factor, e.g. k/3! for 3X → Y). +# The 3-arg in-place callable also pre-scales, then conditionally applies standard +# stoichiometric scaling based on maj.rescale_rates_on_update. struct PreScaledMapper param_idxs::Vector{Int} reactant_stoch::Vector{Vector{Pair{Int, Int}}} @@ -22,12 +20,12 @@ function (m::PreScaledMapper)(params) JumpProcesses.scalerates!(rates, m.reactant_stoch) rates end -function (m::PreScaledMapper)(maj::MassActionJump, newparams; scale_rates, kwargs...) - for i in 1:JumpProcesses.get_num_majumps(maj) - maj.scaled_rates[i] = newparams[m.param_idxs[i]] +function (m::PreScaledMapper)(dest::AbstractVector, maj::MassActionJump, params) + @inbounds for i in eachindex(dest) + dest[i] = params[m.param_idxs[i]] end - JumpProcesses.scalerates!(maj.scaled_rates, m.reactant_stoch) - scale_rates && JumpProcesses.scalerates!(maj.scaled_rates, maj.reactant_stoch) + JumpProcesses.scalerates!(dest, m.reactant_stoch) + maj.rescale_rates_on_update && JumpProcesses.scalerates!(dest, maj.reactant_stoch) nothing end JumpProcesses.to_collection(m::PreScaledMapper) = m @@ -53,19 +51,25 @@ end @test maj.rescale_rates_on_update == true end -# Test 2: update_parameters! respects stored rescale_rates_on_update -@testset "update_parameters! respects rescale_rates_on_update" begin +# Test 2: fill_scaled_rates! respects rescale_rates_on_update +@testset "fill_scaled_rates! respects rescale_rates_on_update" begin # With param_idxs and scale_rates = true (default) — built-in mapper path p = [6.0] maj = MassActionJump(reactant_stoch, net_stoch; param_idxs = [1]) - dprob = DiscreteProblem([100, 0], (0.0, 1.0), p) - jprob = JumpProblem(dprob, Direct(), maj) - @test jprob.massaction_jump.rescale_rates_on_update == true - @test jprob.massaction_jump.scaled_rates[1] ≈ 1.0 # 6.0 / 3! + @test maj.rescale_rates_on_update == true + + dest = zeros(1) + JumpProcesses.fill_scaled_rates!(dest, maj, p) + @test dest[1] ≈ 1.0 # 6.0 / 3! - # update_parameters! should scale (rescale_rates_on_update = true from struct) - JumpProcesses.update_parameters!(jprob.massaction_jump, [12.0]) - @test jprob.massaction_jump.scaled_rates[1] ≈ 2.0 # 12.0 / 3! + # fill_scaled_rates! with new params + JumpProcesses.fill_scaled_rates!(dest, maj, [12.0]) + @test dest[1] ≈ 2.0 # 12.0 / 3! + + # Non-parameterized MAJ: fill_scaled_rates! copies stored rates + maj_explicit = MassActionJump([6.0], reactant_stoch, net_stoch) + JumpProcesses.fill_scaled_rates!(dest, maj_explicit, p) + @test dest[1] ≈ 1.0 # copies maj.scaled_rates which is 6.0/3! = 1.0 end # Test 3: Custom pre-scaled mapper with scale_rates = false — the bug reproducer @@ -76,27 +80,34 @@ end maj = MassActionJump(reactant_stoch, net_stoch; param_mapper = mapper, scale_rates = false) dprob = DiscreteProblem([100, 0], (0.0, 100.0), p) - jprob = JumpProblem(dprob, Direct(), maj; scale_rates = false) + jprob = JumpProblem(dprob, Direct(), maj) expected_scaled = k / factorial(3) # 1.0 - @test jprob.massaction_jump.scaled_rates[1] ≈ expected_scaled + + # parameterized MAJ stores nothing for scaled_rates + @test jprob.massaction_jump.scaled_rates === nothing @test jprob.massaction_jump.rescale_rates_on_update == false - # Test reset_aggregated_jumps! does NOT double-scale + # rates are materialized after init triggers initialize! integ = init(jprob, SSAStepper()) + @test jprob.discrete_jump_aggregation.maj_rates[1] ≈ expected_scaled + + # Test reset_aggregated_jumps! does NOT double-scale integ.p[1] = 18.0 reset_aggregated_jumps!(integ) new_expected = 18.0 / factorial(3) # 3.0, NOT 3.0/6 = 0.5 - @test jprob.massaction_jump.scaled_rates[1] ≈ new_expected + @test jprob.discrete_jump_aggregation.maj_rates[1] ≈ new_expected - # Test remake does NOT double-scale + # Test remake + init materializes rates correctly jprob2 = remake(jprob; p = [24.0]) - remake_expected = 24.0 / factorial(3) # 4.0, NOT 4.0/6 ≈ 0.667 - @test jprob2.massaction_jump.scaled_rates[1] ≈ remake_expected + init(jprob2, SSAStepper()) + remake_expected = 24.0 / factorial(3) # 4.0 + @test jprob2.discrete_jump_aggregation.maj_rates[1] ≈ remake_expected # Test remake round-trip jprob3 = remake(jprob2; p = [k]) - @test jprob3.massaction_jump.scaled_rates[1] ≈ expected_scaled + init(jprob3, SSAStepper()) + @test jprob3.discrete_jump_aggregation.maj_rates[1] ≈ expected_scaled end # Test 4: Callback parameter changes with built-in mapper (rescale_rates_on_update = true) @@ -113,12 +124,11 @@ end end cb = DiscreteCallback(condit, affect!) sol = solve(jprob, SSAStepper(); tstops = [1000.0], callback = cb) - @test jprob.massaction_jump.scaled_rates[1] ≈ 4.0 # 24.0 / 3! + @test jprob.discrete_jump_aggregation.maj_rates[1] ≈ 4.0 # 24.0 / 3! end # Test 5: rescale_rates_on_update propagated through JumpSet merge and JumpProblem varargs -# Use explicit-rate MAJs since the JumpSet vector merge path doesn't support -# parameterized (Nothing-rated) MAJs. +# Uses explicit-rate MAJs to test rescale_rates_on_update propagation and mismatch errors. @testset "rescale_rates_on_update propagated through merge paths" begin reactant_stoch2 = [[2 => 3]] net_stoch2 = [[2 => -3, 1 => 1]] @@ -149,7 +159,7 @@ end # Two MAJs with matching rescale_rates_on_update = false via JumpProblem varargs maj_f1 = MassActionJump([1.0], reactant_stoch, net_stoch; scale_rates = false) maj_f2 = MassActionJump([2.0], reactant_stoch2, net_stoch2; scale_rates = false) - jprob_f = JumpProblem(dprob, Direct(), maj_f1, maj_f2; scale_rates = false) + jprob_f = JumpProblem(dprob, Direct(), maj_f1, maj_f2) @test jprob_f.massaction_jump.rescale_rates_on_update == false # Two MAJs with matching rescale_rates_on_update = true via JumpProblem varargs @@ -161,3 +171,44 @@ end # Mismatched rescale_rates_on_update via JumpProblem varargs — should error @test_throws ErrorException JumpProblem(dprob, Direct(), maj_true, maj_false) end + +# Test 6: Custom mapper-backed MAJ merge raises error; built-in mapper merge works +@testset "MAJ merge behavior" begin + # Built-in MassActionJumpParamMapper merges should succeed + maj_p1 = MassActionJump(reactant_stoch, net_stoch; param_idxs = [1]) + maj_p2 = MassActionJump(reactant_stoch, net_stoch; param_idxs = [1]) + merged = JumpSet(; massaction_jumps = [maj_p1, maj_p2]) + @test JumpProcesses.get_num_majumps(merged.massaction_jump) == 2 + + dprob = DiscreteProblem([100, 0], (0.0, 1.0), [1.0]) + maj_p3 = MassActionJump(reactant_stoch, net_stoch; param_idxs = [1]) + maj_p4 = MassActionJump(reactant_stoch, net_stoch; param_idxs = [1]) + jprob = JumpProblem(dprob, Direct(), maj_p3, maj_p4) + @test JumpProcesses.get_num_majumps(jprob.massaction_jump) == 2 + + # Custom mapper merge should error + mapper1 = PreScaledMapper([1], reactant_stoch) + mapper2 = PreScaledMapper([1], reactant_stoch) + maj_c1 = MassActionJump(reactant_stoch, net_stoch; param_mapper = mapper1, scale_rates = false) + maj_c2 = MassActionJump(reactant_stoch, net_stoch; param_mapper = mapper2, scale_rates = false) + @test_throws ErrorException JumpSet(; massaction_jumps = [maj_c1, maj_c2]) + @test_throws ErrorException JumpProblem(dprob, Direct(), maj_c1, maj_c2) +end + +# Test 7: Immutability and aliasing +@testset "Immutability and aliasing after remake" begin + p = [6.0] + maj = MassActionJump(reactant_stoch, net_stoch; param_idxs = [1]) + dprob = DiscreteProblem([100, 0], (0.0, 100.0), p) + jprob = JumpProblem(dprob, Direct(), maj) + init(jprob, SSAStepper()) + rates_before = copy(jprob.discrete_jump_aggregation.maj_rates) + + # remake with new p, then init — original MAJ is not mutated + jprob2 = remake(jprob; p = [12.0]) + @test jprob.massaction_jump === jprob2.massaction_jump # shared MAJ + @test jprob.massaction_jump.scaled_rates === nothing # still nothing + init(jprob2, SSAStepper()) + # aggregation is shared, so maj_rates reflects latest init + @test jprob2.discrete_jump_aggregation.maj_rates[1] ≈ 2.0 # 12.0 / 3! +end diff --git a/test/spatial/bracketing.jl b/test/spatial/bracketing.jl index 31c1e23b..7afcaba8 100644 --- a/test/spatial/bracketing.jl +++ b/test/spatial/bracketing.jl @@ -19,6 +19,8 @@ netstoch = [[1 => -1]] majump = MassActionJump(majump_rates, reactstoch, netstoch) rx_rates = JP.LowHigh(JP.RxRates(n, majump)) +JP.fill_scaled_rates!(rx_rates.low.maj_rates, majump, nothing) +JP.fill_scaled_rates!(rx_rates.high.maj_rates, majump, nothing) # set up hop rates hop_constants = [1.0] diff --git a/test/spatial/reaction_rates.jl b/test/spatial/reaction_rates.jl index 8199dd39..2e6b5ac3 100644 --- a/test/spatial/reaction_rates.jl +++ b/test/spatial/reaction_rates.jl @@ -36,10 +36,13 @@ rx_rates_list = [JP.RxRates(num_nodes, ma_jumps), JP.RxRates(num_nodes, spatial_ for rx_rates in rx_rates_list @test JP.num_rxs(rx_rates) == length(rates) show(io, "text/plain", rx_rates) + if rx_rates.ma_jumps isa MassActionJump + JP.fill_scaled_rates!(rx_rates.maj_rates, rx_rates.ma_jumps, nothing) + end for site in 1:num_nodes JP.update_rx_rates!(rx_rates, 1:num_rxs, integrator, site) @test JP.total_site_rx_rate(rx_rates, site) == 1.1 - rx_props = [JP.evalrxrate(u[:, site], rx, ma_jumps) for rx in 1:num_rxs] + rx_props = [JP.eval_massaction_rate(u, rx, rx_rates.ma_jumps, site, rx_rates.maj_rates) for rx in 1:num_rxs] rx_probs = rx_props / sum(rx_props) d = Dict{Int, Int}() for i in 1:num_samples diff --git a/test/ssa_callback_test.jl b/test/ssa_callback_test.jl index 1cbf678b..212ce9f4 100644 --- a/test/ssa_callback_test.jl +++ b/test/ssa_callback_test.jl @@ -105,14 +105,19 @@ sol = solve(jprob, SSAStepper(); rng, tstops = [1000.0], callback = DiscreteCall @test all(p .== [0.0, 1.0]) @test sol[1, end] == 100 -# test scale_rates kwarg +# test scale_rates / rescale_rates_on_update behavior p .= [1.0] dprob = DiscreteProblem(u₀, tspan, p) maj5 = MassActionJump([[1 => 2]], [[1 => -1, 2 => 1]]; param_idxs = [1]) jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false)) -@test all(jprob.massaction_jump.scaled_rates .== [0.5]) -jprob = JumpProblem(dprob, Direct(), maj5, save_positions = (false, false), scale_rates = false) -@test all(jprob.massaction_jump.scaled_rates .== [1.0]) +@test jprob.massaction_jump.scaled_rates === nothing # parameterized MAJ +integ = init(jprob, SSAStepper()) +@test all(jprob.discrete_jump_aggregation.maj_rates .== [0.5]) +# with scale_rates = false on the MAJ itself +maj5_noscale = MassActionJump([[1 => 2]], [[1 => -1, 2 => 1]]; param_idxs = [1], scale_rates = false) +jprob = JumpProblem(dprob, Direct(), maj5_noscale, save_positions = (false, false)) +integ = init(jprob, SSAStepper()) +@test all(jprob.discrete_jump_aggregation.maj_rates .== [1.0]) # test for https://github.com/SciML/JumpProcesses.jl/issues/239 maj6 = MassActionJump([[1 => 1], [2 => 1]], [[1 => -1, 2 => 1], [1 => 1, 2 => -1]]; @@ -226,3 +231,28 @@ let @test 3.0 ∈ sol5.t @test 6.0 ∈ sol5.t end + +# test that a single parameterized MAJ with scalar param_idxs works without merging +@testset "Single parameterized MAJ with scalar param_idxs" begin + maj = MassActionJump([1 => 1, 2 => 1], [1 => -1, 2 => -1, 3 => 1]; + param_idxs = 1) + p = [0.1] + u0 = [100, 100, 0] + dprob = DiscreteProblem(u0, (0.0, 10.0), p) + jprob = JumpProblem(dprob, Direct(), maj) + sol = solve(jprob, SSAStepper(); rng = StableRNG(12345)) + @test sol.retcode == ReturnCode.Success + @test sol[3, end] > 0 +end + +# test that merging MAJs does not mutate the input MAJs' param_mapper +@testset "Merging MAJs does not mutate inputs" begin + maj1 = MassActionJump([1 => 1], [1 => -1]; param_idxs = [1]) + maj2 = MassActionJump([2 => 1], [2 => -1]; param_idxs = [2]) + orig1 = copy(maj1.param_mapper.param_idxs) + orig2 = copy(maj2.param_mapper.param_idxs) + dprob = DiscreteProblem([100, 100], (0.0, 1.0), [0.1, 0.2]) + jprob = JumpProblem(dprob, Direct(), maj1, maj2) + @test maj1.param_mapper.param_idxs == orig1 + @test maj2.param_mapper.param_idxs == orig2 +end diff --git a/test/thread_safety.jl b/test/thread_safety.jl index 9f7c4491..460dec73 100644 --- a/test/thread_safety.jl +++ b/test/thread_safety.jl @@ -39,8 +39,7 @@ let for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jump_prob = JumpProblem(ode_prob, Direct(), vrj; vr_aggregator = agg) - prob_func(prob, i, repeat) = deepcopy(prob) - prob = EnsembleProblem(jump_prob, prob_func = prob_func) + prob = EnsembleProblem(jump_prob) sol = solve(prob, Tsit5(), EnsembleThreads(), trajectories = 400, save_everystep = false) firstrx_time = [sol.u[i].t[findfirst(>(sol.u[i].t[1]), sol.u[i].t)] for i in 1:length(sol)] @@ -58,8 +57,7 @@ let for agg in (VR_FRM(), VR_Direct(), VR_DirectFW()) jump_prob = JumpProblem(sde_prob, Direct(), vrj; vr_aggregator = agg) - prob_func(prob, i, repeat) = deepcopy(prob) - prob = EnsembleProblem(jump_prob; prob_func) + prob = EnsembleProblem(jump_prob) sol = solve(prob, SRIW1(), EnsembleThreads(); trajectories = 400, save_everystep = false) firstrx_time = [sol.u[i].t[findfirst(>(sol.u[i].t[1]), sol.u[i].t)] for i in 1:length(sol)]