diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 4e7e0999f2..3f8cb3839c 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -1142,6 +1142,7 @@ def get_rate(symbol: sp.Symbol): return 0 # replace rateOf-instances in xdot by xdot symbols + made_substitutions = False for i_state in range(len(self.eq("xdot"))): if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func): self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs( @@ -1151,9 +1152,13 @@ def get_rate(symbol: sp.Symbol): for rate_of in rate_ofs } ) - # substitute in topological order - subs = toposort_symbols(dict(zip(self.sym("xdot"), self.eq("xdot")))) - self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs) + made_substitutions = True + if made_substitutions: + # substitute in topological order + subs = toposort_symbols( + dict(zip(self.sym("xdot"), self.eq("xdot"))) + ) + self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs) # replace rateOf-instances in x0 by xdot equation for i_state in range(len(self.eq("x0"))): @@ -1165,9 +1170,25 @@ def get_rate(symbol: sp.Symbol): } ) + # replace rateOf-instances in w by xdot equation + # here we may need toposort, as xdot may depend on w + made_substitutions = False + for i_state in range(len(self.eq("w"))): + if rate_ofs := self._eqs["w"][i_state].find(rate_of_func): + self._eqs["w"][i_state] = self._eqs["w"][i_state].subs( + { + rate_of: get_rate(rate_of.args[0]) + for rate_of in rate_ofs + } + ) + made_substitutions = True + if made_substitutions: + # sort in topological order + subs = toposort_symbols(dict(zip(self.sym("w"), self.eq("w")))) + self._eqs["w"] = smart_subs_dict(self.eq("w"), subs) + for component in chain( self.observables(), - self.expressions(), self.events(), self._algebraic_equations, ): diff --git a/python/tests/test_sbml_import_special_functions.py b/python/tests/test_sbml_import_special_functions.py index 9d8f447511..3f8383ce94 100644 --- a/python/tests/test_sbml_import_special_functions.py +++ b/python/tests/test_sbml_import_special_functions.py @@ -12,7 +12,11 @@ from amici.antimony_import import antimony2amici from amici.gradient_check import check_derivatives from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind -from numpy.testing import assert_approx_equal, assert_array_almost_equal_nulp +from numpy.testing import ( + assert_approx_equal, + assert_array_almost_equal_nulp, + assert_allclose, +) from scipy.special import loggamma @@ -222,3 +226,51 @@ def test_rateof(): assert_array_almost_equal_nulp( rdata.by_id("p2"), 1 + rdata.by_id("S1") ) + + +@skip_on_valgrind +def test_rateof_with_expression_dependent_rate(): + """Test rateOf, where the rateOf argument depends on `w` and requires + toposorting.""" + ant_model = """ + model test_rateof_with_expression_dependent_rate + S1 = 0; + S2 = 0; + S1' = rate; + S2' = 2 * rateOf(S1); + # the id of the following expression must be alphabetically before + # `rate`, so that toposort is required to evaluate the expressions + # in the correct order + e1 := 2 * rateOf(S1); + rate := time + end + """ + module_name = "test_rateof_with_expression_dependent_rate" + with TemporaryDirectoryWinSafe(prefix=module_name) as outdir: + antimony2amici( + ant_model, + model_name=module_name, + output_dir=outdir, + ) + model_module = amici.import_model_module( + module_name=module_name, module_path=outdir + ) + amici_model = model_module.getModel() + t = np.linspace(0, 10, 11) + amici_model.setTimepoints(t) + amici_solver = amici_model.getSolver() + rdata = amici.runAmiciSimulation(amici_model, amici_solver) + + state_ids_solver = amici_model.getStateIdsSolver() + + assert_array_almost_equal_nulp(rdata.by_id("e1"), 2 * t, 1) + + i_S1 = state_ids_solver.index("S1") + i_S2 = state_ids_solver.index("S2") + assert_approx_equal(rdata["xdot"][i_S1], t[-1]) + assert_approx_equal(rdata["xdot"][i_S2], 2 * t[-1]) + + assert_allclose(np.diff(rdata.by_id("S1")), t[:-1] + 0.5, atol=1e-9) + assert_array_almost_equal_nulp( + rdata.by_id("S2"), 2 * rdata.by_id("S1"), 10 + )