Skip to content

Commit

Permalink
doc
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Feb 26, 2024
1 parent 67c4ef6 commit 6cab686
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1662,13 +1662,13 @@ def static_indices(self, name: str) -> list[int]:
dwdw = self.eq("dwdw")
w = self.eq("w")

# check for indirect state and time dependency
# to avoid lengthy symbolic computations,
# we only check if we have any non-zeros in hierarchy,
# we currently neglect the case where different hierarchy
# levels may cancel out. treating a static expression as
# dynamic in such rare cases shouldn't be a problem.
nonzero_dwdx = np.asarray(
# Check for direct (via `t`) or indirect (via `x`, `h`, or splines)
# time dependency.
# To avoid lengthy symbolic computations, we only check if we have
# any non-zeros in hierarchy. We currently neglect the case where
# different hierarchy levels may cancel out. Treating a static
# expression as dynamic in such rare cases shouldn't be a problem.
dynamic_dependency = np.asarray(
dwdx.applyfunc(lambda x: int(not x.is_zero))
).astype(np.int64)
# to check for other time-dependence, we add a column to the dwdx
Expand All @@ -1679,9 +1679,9 @@ def static_indices(self, name: str) -> list[int]:
*self.sym("h"),
amici_time_symbol,
]
nonzero_dwdx = np.hstack(
dynamic_dependency = np.hstack(
(
nonzero_dwdx,
dynamic_dependency,
np.array(
[
expr.has(*dynamic_syms)
Expand All @@ -1706,7 +1706,12 @@ def static_indices(self, name: str) -> list[int]:
dwdw.applyfunc(lambda x: int(not x.is_zero))
).astype(np.int64)

tmp = nonzero_dwdx
# `w` is made up an expression hierarchy. Any given entry is only
# static if all its dependencies are static. Here, we unravel
# the hierarchical structure of `w`.
# If for an entry in `w`, the row sum of the intermediate products
# is 0 across all levels, the expression is static.
tmp = dynamic_dependency
res = np.sum(tmp, axis=1)
while np.any(tmp != 0):
tmp = nonzero_dwdw.dot(tmp)
Expand Down Expand Up @@ -1735,19 +1740,25 @@ def static_indices(self, name: str) -> list[int]:
dynamic_syms = sp.Matrix(dynamic_syms)
rowvals = self.rowvals(name)
sparseeq = self.sparseeq(name)

# collect the indices of static expressions of dwd* from the list
# of non-zeros entries of the sparse matrix
self._static_indices[name] = [
i
for i, r in enumerate(rowvals)
if r in static_indices_w
or sparseeq[i].is_Number
for i, (expr, row_idx) in enumerate(zip(sparseeq, rowvals))
# derivative of a static expression is static
if row_idx in static_indices_w
# constant expressions
or expr.is_Number
# check for dependencies on non-static entities
or (
# FIXME see spline comment above
# (check str before diff, as diff will fail on spline functions)
(
not self._splines
or "AmiciSpline" not in str(sparseeq[i])
# splines: non-static
not self._splines or "AmiciSpline" not in str(expr)
)
and sparseeq[i].diff(dynamic_syms).is_zero_matrix
and expr.diff(dynamic_syms).is_zero_matrix
)
]
return self._static_indices[name]
Expand Down

0 comments on commit 6cab686

Please sign in to comment.