Skip to content

Commit

Permalink
Faster generation of drootdt_total for models without state-dependent…
Browse files Browse the repository at this point in the history
… root functions (#2417)

Don't compute things we don't need.

For my test model, this reduces code generation time from 76s to 12s (-83%).
  • Loading branch information
dweindl authored May 2, 2024
1 parent 304c23a commit 6b05ddd
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,16 +1561,18 @@ def _compute_equation(self, name: str) -> None:
self._eqs[name] = smart_jacobian(self.eq("root"), time_symbol)

elif name == "drootdt_total":
# backsubstitution of optimized right-hand side terms into RHS
# calling subs() is costly. Due to looping over events though, the
# following lines are only evaluated if a model has events
w_sorted = toposort_symbols(
dict(zip(self.sym("w"), self.eq("w"), strict=True))
)
tmp_xdot = smart_subs_dict(self.eq("xdot"), w_sorted)
self._eqs[name] = self.eq("drootdt")
if self.num_states_solver():
self._eqs[name] += smart_multiply(self.eq("drootdx"), tmp_xdot)
# backsubstitution of optimized right-hand side terms into RHS
# calling subs() is costly. We can skip it if we don't have any
# state-dependent roots.
if self.num_states_solver() and not smart_is_zero_matrix(
drootdx := self.eq("drootdx")
):
w_sorted = toposort_symbols(
dict(zip(self.sym("w"), self.eq("w"), strict=True))
)
tmp_xdot = smart_subs_dict(self.eq("xdot"), w_sorted)
self._eqs[name] += smart_multiply(drootdx, tmp_xdot)

elif name == "deltax":
# fill boluses for Heaviside functions, as empty state updates
Expand Down

0 comments on commit 6b05ddd

Please sign in to comment.