Skip to content

Commit

Permalink
Finish events and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Apr 4, 2024
1 parent b0a9c48 commit 6e9bd40
Show file tree
Hide file tree
Showing 4 changed files with 407 additions and 150 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ julia = "1.9"

[extras]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Graphviz_jll = "3c863552-8265-54e4-a6dc-903eb78fde85"
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
Expand All @@ -80,4 +81,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["BifurcationKit", "DomainSets", "Graphviz_jll", "HomotopyContinuation", "NonlinearSolve", "OrdinaryDiffEq", "Plots", "Random", "SafeTestsets", "SciMLBase", "SciMLNLSolve", "StableRNGs", "Statistics", "SteadyStateDiffEq", "StochasticDiffEq", "StructuralIdentifiability", "Test", "Unitful"]
test = ["BifurcationKit", "DiffEqCallbacks", "DomainSets", "Graphviz_jll", "HomotopyContinuation", "NonlinearSolve", "OrdinaryDiffEq", "Plots", "Random", "SafeTestsets", "SciMLBase", "SciMLNLSolve", "StableRNGs", "Statistics", "SteadyStateDiffEq", "StochasticDiffEq", "StructuralIdentifiability", "Test", "Unitful"]
149 changes: 102 additions & 47 deletions src/reactionsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,70 +628,77 @@ function ReactionSystem(eqs, iv, unknowns, ps;
continuous_events = nothing,
discrete_events = nothing,
metadata = nothing)

name === nothing &&

# Error checks
if name === nothing &&
throw(ArgumentError("The `name` keyword must be provided. Please consider using the `@named` macro"))
end
sysnames = nameof.(systems)
(length(unique(sysnames)) == length(sysnames)) ||
throw(ArgumentError("System names must be unique."))
(length(unique(sysnames)) == length(sysnames)) || throw(ArgumentError("System names must be unique."))

# Handle defaults values provided via optional arguments.
if !(isempty(default_u0) && isempty(default_p))
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.",
:ReactionSystem, force = true)
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :ReactionSystem, force = true)
end
defaults = MT.todict(defaults)
defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults))

# Extracts independent variables (iv and sivs), dependent variables (species and variables)
# and parameters. Sorts so that species comes before variables in unknowns vector.
iv′ = value(iv)
sivs′ = if spatial_ivs === nothing
Vector{typeof(iv′)}()
else
value.(MT.scalarize(spatial_ivs))
end
unknowns′ = sort!(value.(MT.scalarize(unknowns)), by = !isspecies) # species come first
unknowns′ = sort!(value.(MT.scalarize(unknowns)), by = !isspecies)
spcs = filter(isspecies, unknowns′)
ps′ = value.(MT.scalarize(ps))

# Checks that no (by Catalyst) forbidden symbols are used.
allsyms = Iterators.flatten((ps′, unknowns′))
all(sym -> getname(sym) forbidden_symbols_error, allsyms) ||
if !all(sym -> getname(sym) forbidden_symbols_error, allsyms)
error("Catalyst reserves the symbols $forbidden_symbols_error for internal use. Please do not use these symbols as parameters or unknowns/species.")
end

# sort Reactions before Equations
# Handles reactions and equations. Sorts so that reactions are before equaions in the equations vector.
eqs′ = CatalystEqType[eq for eq in eqs]
sort!(eqs′; by = eqsortby)
rxs = Reaction[rx for rx in eqs if rx isa Reaction]

# Additional error checks.
if any(MT.isparameter, unknowns′)
psts = filter(MT.isparameter, unknowns′)
throw(ArgumentError("Found one or more parameters among the unknowns; this is not allowed. Move: $psts to be parameters."))
end

if any(isconstant, unknowns′)
csts = filter(isconstant, unknowns′)
throw(ArgumentError("Found one or more constant species among the unknowns; this is not allowed. Move: $csts to be parameters."))
end

# if there are BC species, check they are balanced in their reactions
# If there are BC species, check they are balanced in their reactions.
if balanced_bc_check && any(isbc, unknowns′)
for rx in eqs
if rx isa Reaction
isbcbalanced(rx) ||
throw(ErrorException("BC species must be balanced, appearing as a substrate and product with the same stoichiometry. Please fix reaction: $rx"))
if (rx isa Reaction) && !isbcbalanced(rx)
throw(ErrorException("BC species must be balanced, appearing as a substrate and product with the same stoichiometry. Please fix reaction: $rx"))
end
end
end

# Adds all unknowns/parameters to the `var_to_name` vector.
# Adds their (potential) default values to the defaults vector.
var_to_name = Dict()
MT.process_variables!(var_to_name, defaults, unknowns′)
MT.process_variables!(var_to_name, defaults, ps′)
MT.collect_var_to_name!(var_to_name, eq.lhs for eq in observed)

nps = if networkproperties === nothing
NetworkProperties{Int, get_speciestype(iv′, unknowns′, systems)}()
# Computes network properties.
if networkproperties === nothing
nps = NetworkProperties{Int, get_speciestype(iv′, unknowns′, systems)}()
else
networkproperties
nps = networkproperties
end

# Creates the continious and discrete callbacks.
ccallbacks = MT.SymbolicContinuousCallbacks(continuous_events)
dcallbacks = MT.SymbolicDiscreteCallbacks(discrete_events)

Expand All @@ -705,77 +712,125 @@ function ReactionSystem(rxs::Vector, iv = Catalyst.DEFAULT_IV; kwargs...)
end

# search the symbolic expression for parameters or unknowns
# and save in ps and sts respectively. vars is used to cache results
function findvars!(ps, sts, exprtosearch, ivs, vars)
# and save in ps and us respectively. vars is used to cache results
function findvars!(ps, us, exprtosearch, ivs, vars)
MT.get_variables!(vars, exprtosearch)
for var in vars
(var ivs) && continue
if MT.isparameter(var)
push!(ps, var)
else
push!(sts, var)
push!(us, var)
end
end
empty!(vars)
end

# Only used internally by the @reaction_network macro. Permits giving an initial order to
# the parameters, and then adds additional ones found in the reaction. Name could be
# changed.
function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, sts_in, ps_in;
spatial_ivs = nothing, kwargs...)
# Called internally (whether DSL-based or programmtic model creation is used).
# Creates a sorted reactions + equations vector, also ensuring reaction is first in this vector.
# Extracts potential species, variables, and parameters from the input (if not provided as part of
# the model creation) and creates the corresponding vectors.
# While species are ordered before variables in the unknowns vector, this ordering is not imposed here,
# but carried out at a later stage.
function make_ReactionSystem_internal(rxs_and_eqs::Vector, iv, us_in, ps_in; spatial_ivs = nothing,
continuous_events = [], discrete_events = [], kwargs...)

# Creates a combined iv vector (iv and sivs). This is used later in the function (so that
# independent variables can be exluded when encountered quantities are added to `us` and `ps`).
t = value(iv)
ivs = Set([t])
if (spatial_ivs !== nothing)
for siv in (MT.scalarize(spatial_ivs))
push!(ivs, value(siv))
end
end
sts = OrderedSet{eltype(sts_in)}(sts_in)

# Initialises the new unknowns and parameter vectors.
# Preallocates the `vars` set, which is used by `findvars!`
us = OrderedSet{eltype(us_in)}(us_in)
ps = OrderedSet{eltype(ps_in)}(ps_in)
vars = OrderedSet()

# Extracts the reactions and equations from the combined reactions + equations input vector.
all(eq -> eq isa Union{Reaction, Equation}, rxs_and_eqs)
rxs = Reaction[eq for eq in rxs_and_eqs if eq isa Reaction]
eqs = Equation[eq for eq in rxs_and_eqs if eq isa Equation]

# add species / parameters that are substrates / products first
for rx in rxs, reactants in (rx.substrates, rx.products)
for spec in reactants
MT.isparameter(spec) ? push!(ps, spec) : push!(sts, spec)
end
end

# Loops through all reactions, adding encountered quantities to the unknown and parameter vectors.
for rx in rxs
findvars!(ps, sts, rx.rate, ivs, vars)
for s in rx.substoich
(s isa Symbolic) && findvars!(ps, sts, s, ivs, vars)
# Loops through all reaction substrates and products, extracting these.
for reactants in (rx.substrates, rx.products), spec in reactants
MT.isparameter(spec) ? push!(ps, spec) : push!(us, spec)
end
for p in rx.prodstoich
(p isa Symbolic) && findvars!(ps, sts, p, ivs, vars)

# Adds all quantitites encountered in the reaction's rate.
findvars!(ps, us, rx.rate, ivs, vars)

# Extracts all quantitites encountered within stoichiometries.
for stoichiometry in (rx.substoich, rx.prodstoich), sym in stoichiometry
(sym isa Symbolic) && findvars!(ps, us, sym, ivs, vars)
end
end

stsv = collect(sts)
psv = collect(ps)
# Will appear here: add stuff from nosie scaling.
end

# Extracts any species, variables, and parameters that occur in (non-reaction) equations.
# Creates the new reactions + equations vector, `fulleqs` (sorted reactions first, equations next).
if !isempty(eqs)
osys = ODESystem(eqs, iv; name = gensym())
fulleqs = CatalystEqType[rxs; equations(osys)]
union!(stsv, unknowns(osys))
union!(psv, parameters(osys))
union!(us, unknowns(osys))
union!(ps, parameters(osys))
else
fulleqs = rxs
end
end

ReactionSystem(fulleqs, t, stsv, psv; spatial_ivs, kwargs...)
# Loops through all events, adding encountered quantities to the unknwon and parameter vectors.
find_event_vars!(ps, us, continuous_events, ivs, vars)
find_event_vars!(ps, us, discrete_events, ivs, vars)

# Converts the found unknowns and parameters to vectors.
usv = collect(us)
psv = collect(ps)

# Passes the processed input into the next `ReactionSystem` call.
ReactionSystem(fulleqs, t, usv, psv; spatial_ivs, continuous_events, discrete_events, kwargs...)
end

function ReactionSystem(iv; kwargs...)
ReactionSystem(Reaction[], iv, [], []; kwargs...)
end


# Loops through all events in an supplied event vector, adding all unknowns and parameters found in
# its condition and affect functions to their respective vectors (`ps` and `us`).
function find_event_vars!(ps, us, events::Vector, ivs, vars)
foreach(event -> find_event_vars!(ps, us, event, ivs, vars), events)
end
# For a single event, adds quantitites from its condition and affect expression(s) to `ps` and `us`.
function find_event_vars!(ps, us, event, ivs, vars)
conds, affects = event
# For discrete events, the condition can be a single value (for periodic events).
# If not, it is a vector of conditions and we must check each.
if conds isa Vector
for cond in conds
# For continious events the conditions are equations (with lhs and rhs).
# For discrete events, they are single expressions.
if cond isa Equation
findvars!(ps, us, cond.lhs, ivs, vars)
findvars!(ps, us, cond.rhs, ivs, vars)
else
findvars!(ps, us, cond, ivs, vars)
end
end
else
findvars!(ps, us, conds, ivs, vars)
end
# The affects is always a vector of equations. Here, we handle the lhs and rhs separately.
for affect in affects
findvars!(ps, us, affect.lhs, ivs, vars)
findvars!(ps, us, affect.rhs, ivs, vars)
end
end
"""
remake_ReactionSystem_internal(rs::ReactionSystem;
default_reaction_metadata::Vector{Pair{Symbol, T}} = Vector{Pair{Symbol, Any}}()) where {T}
Expand Down
89 changes: 0 additions & 89 deletions test/dsl/dsl_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -860,93 +860,4 @@ let
@equations X ~ p - S
(P,D), 0 <--> S
end
end

### Events ###

# Compares models with complicated events that are created programmatically/with the DSL.
# Checks that simulations are correct.
# Checks that various simulation inputs works.
# Checks continuous, discrete, preset time, and periodic events.
# Tests event affecting non-species components.

let
# Creates model via DSL.
rn_dsl = @reaction_network rn begin
@parameters thres=1.0 dY_up
@variables Z(t)
@continuous_events begin
[t - 2.5] => [p ~ p + 0.2]
[X - thres, Y - X] => [X ~ X - 0.5, Z ~ Z + 0.1]
end
@discrete_events begin
2.0 => [dX ~ dX + 0.1, dY ~ dY + dY_up]
[1.0, 5.0] => [p ~ p - 0.1]
(Z > Y) => [Z ~ Z - 0.1]
end

(p, dX), 0 <--> X
(p, dY), 0 <--> Y
end

# Creates model programmatically.
@variables t Z(t)
@species X(t) Y(t)
@parameters p dX dY thres=1.0 dY_up
rxs = [
Reaction(p, nothing, [X], nothing, [1])
Reaction(dX, [X], nothing, [1], nothing)
Reaction(p, nothing, [Y], nothing, [1])
Reaction(dY, [Y], nothing, [1], nothing)
]
continuous_events = [
t - 2.5 => p ~ p + 0.2
[X - thres, Y - X] => [X ~ X - 0.5, Z ~ Z + 0.1]
]
discrete_events = [
2.0 => [dX ~ dX + 0.1, dY ~ dY + dY_up]
[1.0, 5.0] => [p ~ p - 0.1]
(Z > Y) => [Z ~ Z - 0.1]
]
rn_prog = ReactionSystem(rxs, t; continuous_events, discrete_events, name=:rn)

# Tests that approaches yield identical results.
@test isequal(rn_dsl, rn_prog)

u0 = [X => 1.0, Y => 0.5, Z => 0.25]
tspan = (0.0, 20.0)
ps = [p => 1.0, dX => 0.5, dY => 0.5, dY_up => 0.1]

sol_dsl = solve(ODEProblem(rn_dsl, u0, tspan, ps), Tsit5())
sol_prog = solve(ODEProblem(rn_prog, u0, tspan, ps), Tsit5())
@test sol_dsl == sol_prog
end

# Compares DLS events to those given as callbacks.
# Checks that events works when given to SDEs.
let
# Creates models.
rn = @reaction_network begin
(p, d), 0 <--> X
end
rn_events = @reaction_network begin
@discrete_events begin
[5.0, 10.0] => [X ~ X + 100.0]
end
@continuous_events begin
[X ~ 90.0] => [X ~ X + 10.0]
end
(p, d), 0 <--> X
end
cb_disc = ModelingToolkit.PresetTimeCallback([5.0, 10.0], int -> (int[:X] += 100.0))
cb_cont = ContinuousCallback((u, t, int) -> (u[1] - 90.0), int -> (int[:X] += 10.0))

# Simulates models.
u0 = [:X => 100.0]
tspan = (0.0, 50.0)
ps = [:p => 100.0, :d => 1.0]
sol = solve(SDEProblem(rn, u0, tspan, ps), ImplicitEM(); seed, callback = CallbackSet(cb_disc, cb_cont))
sol_events = solve(SDEProblem(rn_events, u0, tspan, ps), ImplicitEM(); seed)

@test sol == sol_events
end
Loading

0 comments on commit 6e9bd40

Please sign in to comment.