diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 0770894884..3bc65d5a7f 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -30,7 +30,6 @@ Union, ) from collections.abc import Sequence - import numpy as np import sympy as sp from sympy.matrices.dense import MutableDenseMatrix @@ -1117,11 +1116,10 @@ def transform_dxdt_to_concentration(species_id, dxdt): for llh in si.symbols[SymbolId.LLHY].values() ) - self._process_sbml_rate_of( - symbols - ) # substitute SBML-rateOf constructs + # substitute SBML-rateOf constructs + self._process_sbml_rate_of() - def _process_sbml_rate_of(self, symbols) -> None: + def _process_sbml_rate_of(self) -> None: """Substitute any SBML-rateOf constructs in the model equations""" rate_of_func = sp.core.function.UndefinedFunction("rateOf") species_sym_to_xdot = dict(zip(self.sym("x"), self.sym("xdot"))) @@ -1129,8 +1127,6 @@ def _process_sbml_rate_of(self, symbols) -> None: def get_rate(symbol: sp.Symbol): """Get rate of change of the given symbol""" - nonlocal symbols - if symbol.find(rate_of_func): raise SBMLException("Nesting rateOf() is not allowed.") @@ -1142,6 +1138,7 @@ def get_rate(symbol: sp.Symbol): return 0 # replace rateOf-instances in xdot by xdot symbols + made_substitutions = False for i_state in range(len(self.eq("xdot"))): if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func): self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs( @@ -1151,9 +1148,14 @@ def get_rate(symbol: sp.Symbol): for rate_of in rate_ofs } ) - # substitute in topological order - subs = toposort_symbols(dict(zip(self.sym("xdot"), self.eq("xdot")))) - self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs) + made_substitutions = True + + if made_substitutions: + # substitute in topological order + subs = toposort_symbols( + dict(zip(self.sym("xdot"), self.eq("xdot"))) + ) + self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs) # replace rateOf-instances in x0 by xdot equation for i_state in range(len(self.eq("x0"))): @@ -1165,9 +1167,55 @@ def get_rate(symbol: sp.Symbol): } ) + # replace rateOf-instances in w by xdot equation + # here we may need toposort, as xdot may depend on w + made_substitutions = False + for i_expr in range(len(self.eq("w"))): + if rate_ofs := self._eqs["w"][i_expr].find(rate_of_func): + self._eqs["w"][i_expr] = self._eqs["w"][i_expr].subs( + { + rate_of: get_rate(rate_of.args[0]) + for rate_of in rate_ofs + } + ) + made_substitutions = True + + if made_substitutions: + # Sort expressions in self._expressions, w symbols, and w equations + # in topological order. Ideally, this would already happen before + # adding the expressions to the model, but at that point, we don't + # have access to xdot yet. + # NOTE: elsewhere, conservations law expressions are expected to + # occur before any other w expressions, so we must maintain their + # position + # toposort everything but conservation law expressions, + # then prepend conservation laws + w_sorted = toposort_symbols( + dict( + zip( + self.sym("w")[self.num_cons_law() :, :], + self.eq("w")[self.num_cons_law() :, :], + ) + ) + ) + w_sorted = ( + dict( + zip( + self.sym("w")[: self.num_cons_law(), :], + self.eq("w")[: self.num_cons_law(), :], + ) + ) + | w_sorted + ) + old_syms = tuple(self._syms["w"]) + topo_expr_syms = tuple(w_sorted.keys()) + new_order = [old_syms.index(s) for s in topo_expr_syms] + self._expressions = [self._expressions[i] for i in new_order] + self._syms["w"] = sp.Matrix(topo_expr_syms) + self._eqs["w"] = sp.Matrix(list(w_sorted.values())) + for component in chain( self.observables(), - self.expressions(), self.events(), self._algebraic_equations, ): @@ -2210,6 +2258,18 @@ def _compute_equation(self, name: str) -> None: self._eqs[name] = self.sym(name) elif name == "dwdx": + if ( + expected := list( + map( + ConservationLaw.get_x_rdata, + reversed(self.conservation_laws()), + ) + ) + ) != (actual := self.eq("w")[: self.num_cons_law()]): + raise AssertionError( + "Conservation laws are not at the beginning of 'w'. " + f"Got {actual}, expected {expected}." + ) x = self.sym("x") self._eqs[name] = sp.Matrix( [ diff --git a/python/tests/test_sbml_import_special_functions.py b/python/tests/test_sbml_import_special_functions.py index 9d8f447511..3f8383ce94 100644 --- a/python/tests/test_sbml_import_special_functions.py +++ b/python/tests/test_sbml_import_special_functions.py @@ -12,7 +12,11 @@ from amici.antimony_import import antimony2amici from amici.gradient_check import check_derivatives from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind -from numpy.testing import assert_approx_equal, assert_array_almost_equal_nulp +from numpy.testing import ( + assert_approx_equal, + assert_array_almost_equal_nulp, + assert_allclose, +) from scipy.special import loggamma @@ -222,3 +226,51 @@ def test_rateof(): assert_array_almost_equal_nulp( rdata.by_id("p2"), 1 + rdata.by_id("S1") ) + + +@skip_on_valgrind +def test_rateof_with_expression_dependent_rate(): + """Test rateOf, where the rateOf argument depends on `w` and requires + toposorting.""" + ant_model = """ + model test_rateof_with_expression_dependent_rate + S1 = 0; + S2 = 0; + S1' = rate; + S2' = 2 * rateOf(S1); + # the id of the following expression must be alphabetically before + # `rate`, so that toposort is required to evaluate the expressions + # in the correct order + e1 := 2 * rateOf(S1); + rate := time + end + """ + module_name = "test_rateof_with_expression_dependent_rate" + with TemporaryDirectoryWinSafe(prefix=module_name) as outdir: + antimony2amici( + ant_model, + model_name=module_name, + output_dir=outdir, + ) + model_module = amici.import_model_module( + module_name=module_name, module_path=outdir + ) + amici_model = model_module.getModel() + t = np.linspace(0, 10, 11) + amici_model.setTimepoints(t) + amici_solver = amici_model.getSolver() + rdata = amici.runAmiciSimulation(amici_model, amici_solver) + + state_ids_solver = amici_model.getStateIdsSolver() + + assert_array_almost_equal_nulp(rdata.by_id("e1"), 2 * t, 1) + + i_S1 = state_ids_solver.index("S1") + i_S2 = state_ids_solver.index("S2") + assert_approx_equal(rdata["xdot"][i_S1], t[-1]) + assert_approx_equal(rdata["xdot"][i_S2], 2 * t[-1]) + + assert_allclose(np.diff(rdata.by_id("S1")), t[:-1] + 0.5, atol=1e-9) + assert_array_almost_equal_nulp( + rdata.by_id("S2"), 2 * rdata.by_id("S1"), 10 + )