Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dwd[pwx]
Browse files Browse the repository at this point in the history
dweindl committed Feb 24, 2024
1 parent 3b123d4 commit 10940c6
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
@@ -1701,16 +1701,26 @@ def static_indices(self, name: str) -> list[int]:

return self._static_indices[name]

if name == "dwdw":
# TODO
...
if name == "dwdx":
# TODO
...

if name == "dwdp":
# TODO
...
if name in ("dwdw", "dwdx", "dwdp"):
static_indices_w = set(self.static_indices("w"))
w_syms = self.sym("w")
dynamic_syms = [
sym
for i, sym in enumerate(w_syms)
if i not in static_indices_w
]
dynamic_syms.append(amici_time_symbol)
dynamic_syms = sp.Matrix(dynamic_syms)
rowvals = self.rowvals(name)
sparseeq = self.sparseeq(name)
self._static_indices[name] = [
i
for i, r in enumerate(rowvals)
if r in static_indices_w
or sparseeq[i].is_Number
or sparseeq[i].diff(dynamic_syms).is_zero_matrix
]
return self._static_indices[name]

raise NotImplementedError(name)

@@ -1721,8 +1731,13 @@ def dynamic_indices(self, name: str) -> list[int]:
:param name: Name of the model entity.
:return: List of indices of dynamic expressions.
"""
static_idxs = self.static_indices(name)
return [i for i in range(len(self.sym(name))) if i not in static_idxs]
static_idxs = set(self.static_indices(name))
length = len(
self.sparsesym(name)
if name in sparse_functions
else self.sym(name)
)
return [i for i in range(length) if i not in static_idxs]

def _generate_symbol(self, name: str) -> None:
"""
@@ -3692,8 +3707,8 @@ def _get_function_body(
else:
symbols = self.model.sym(function)

if function == "w":
# Split w into a block of static and dynamic expressions.
if function in ("w", "dwdw", "dwdx", "dwdp"):
# Split into a block of static and dynamic expressions.
if len(static_idxs := self.model.static_indices(function)) > 0:
lines.append(" if (include_static) {")

0 comments on commit 10940c6

Please sign in to comment.