Skip to content

Commit

Permalink
Fix missing toposort after rateOf-substitutions in w
Browse files Browse the repository at this point in the history
Fixes potentially incorrect simulation results when using rateOf in `w`
where the rates depend on `w`.

Fixes AMICI-dev#2290
  • Loading branch information
dweindl committed Feb 19, 2024
1 parent 5bd921f commit 3c802c7
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 5 deletions.
29 changes: 25 additions & 4 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,7 @@ def get_rate(symbol: sp.Symbol):
return 0

# replace rateOf-instances in xdot by xdot symbols
made_substitutions = False
for i_state in range(len(self.eq("xdot"))):
if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func):
self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs(
Expand All @@ -1151,9 +1152,13 @@ def get_rate(symbol: sp.Symbol):
for rate_of in rate_ofs
}
)
# substitute in topological order
subs = toposort_symbols(dict(zip(self.sym("xdot"), self.eq("xdot"))))
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)
made_substitutions = True
if made_substitutions:
# substitute in topological order
subs = toposort_symbols(
dict(zip(self.sym("xdot"), self.eq("xdot")))
)
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)

# replace rateOf-instances in x0 by xdot equation
for i_state in range(len(self.eq("x0"))):
Expand All @@ -1165,9 +1170,25 @@ def get_rate(symbol: sp.Symbol):
}
)

# replace rateOf-instances in w by xdot equation
# here we may need toposort, as xdot may depend on w
made_substitutions = False
for i_state in range(len(self.eq("w"))):
if rate_ofs := self._eqs["w"][i_state].find(rate_of_func):
self._eqs["w"][i_state] = self._eqs["w"][i_state].subs(
{
rate_of: get_rate(rate_of.args[0])
for rate_of in rate_ofs
}
)
made_substitutions = True
if made_substitutions:
# sort in topological order
subs = toposort_symbols(dict(zip(self.sym("w"), self.eq("w"))))
self._eqs["w"] = smart_subs_dict(self.eq("w"), subs)

for component in chain(
self.observables(),
self.expressions(),
self.events(),
self._algebraic_equations,
):
Expand Down
54 changes: 53 additions & 1 deletion python/tests/test_sbml_import_special_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from amici.antimony_import import antimony2amici
from amici.gradient_check import check_derivatives
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from numpy.testing import assert_approx_equal, assert_array_almost_equal_nulp
from numpy.testing import (
assert_approx_equal,
assert_array_almost_equal_nulp,
assert_allclose,
)
from scipy.special import loggamma


Expand Down Expand Up @@ -222,3 +226,51 @@ def test_rateof():
assert_array_almost_equal_nulp(
rdata.by_id("p2"), 1 + rdata.by_id("S1")
)


@skip_on_valgrind
def test_rateof_with_expression_dependent_rate():
"""Test rateOf, where the rateOf argument depends on `w` and requires
toposorting."""
ant_model = """
model test_rateof_with_expression_dependent_rate
S1 = 0;
S2 = 0;
S1' = rate;
S2' = 2 * rateOf(S1);
# the id of the following expression must be alphabetically before
# `rate`, so that toposort is required to evaluate the expressions
# in the correct order
e1 := 2 * rateOf(S1);
rate := time
end
"""
module_name = "test_rateof_with_expression_dependent_rate"
with TemporaryDirectoryWinSafe(prefix=module_name) as outdir:
antimony2amici(
ant_model,
model_name=module_name,
output_dir=outdir,
)
model_module = amici.import_model_module(
module_name=module_name, module_path=outdir
)
amici_model = model_module.getModel()
t = np.linspace(0, 10, 11)
amici_model.setTimepoints(t)
amici_solver = amici_model.getSolver()
rdata = amici.runAmiciSimulation(amici_model, amici_solver)

state_ids_solver = amici_model.getStateIdsSolver()

assert_array_almost_equal_nulp(rdata.by_id("e1"), 2 * t, 1)

i_S1 = state_ids_solver.index("S1")
i_S2 = state_ids_solver.index("S2")
assert_approx_equal(rdata["xdot"][i_S1], t[-1])
assert_approx_equal(rdata["xdot"][i_S2], 2 * t[-1])

assert_allclose(np.diff(rdata.by_id("S1")), t[:-1] + 0.5, atol=1e-9)
assert_array_almost_equal_nulp(
rdata.by_id("S2"), 2 * rdata.by_id("S1"), 10
)

0 comments on commit 3c802c7

Please sign in to comment.