Skip to content

Commit

Permalink
Refactor DEExporter/DEModel/csc_matrix (#2311)
Browse files Browse the repository at this point in the history
* Refactor DEExporter/DEModel/csc_matrix

Reduce unnecessary coupling:

* `csc_matrix` as free function - removes the need for the codeprinter in DEModel
* Move the codeprinter to `DEExporter` where it's actually needed

* ..
  • Loading branch information
dweindl authored Feb 26, 2024
1 parent b355ab7 commit 16ec8b2
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 121 deletions.
165 changes: 81 additions & 84 deletions python/sdist/amici/cxxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,90 +207,6 @@ def format_line(symbol: sp.Symbol):
if math not in [0, 0.0]
]

def csc_matrix(
self,
matrix: sp.Matrix,
rownames: list[sp.Symbol],
colnames: list[sp.Symbol],
identifier: Optional[int] = 0,
pattern_only: Optional[bool] = False,
) -> tuple[list[int], list[int], sp.Matrix, list[str], sp.Matrix]:
"""
Generates the sparse symbolic identifiers, symbolic identifiers,
sparse matrix, column pointers and row values for a symbolic
variable
:param matrix:
dense matrix to be sparsified
:param rownames:
ids of the variable of which the derivative is computed (assuming
matrix is the jacobian)
:param colnames:
ids of the variable with respect to which the derivative is computed
(assuming matrix is the jacobian)
:param identifier:
additional identifier that gets appended to symbol names to
ensure their uniqueness in outer loops
:param pattern_only:
flag for computing sparsity pattern without whole matrix
:return:
symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list,
sparse_matrix
"""
idx = 0

nrows, ncols = matrix.shape

if not pattern_only:
sparse_matrix = sp.zeros(nrows, ncols)
symbol_list = []
sparse_list = []
symbol_col_ptrs = []
symbol_row_vals = []

for col in range(ncols):
symbol_col_ptrs.append(idx)
for row in range(nrows):
if matrix[row, col] == 0:
continue

symbol_row_vals.append(row)
idx += 1
symbol_name = (
f"d{rownames[row].name}" f"_d{colnames[col].name}"
)
if identifier:
symbol_name += f"_{identifier}"
symbol_list.append(symbol_name)
if pattern_only:
continue

sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True)
sparse_list.append(matrix[row, col])

if idx == 0:
symbol_col_ptrs = [] # avoid bad memory access for empty matrices
else:
symbol_col_ptrs.append(idx)

if pattern_only:
sparse_matrix = None
else:
sparse_list = sp.Matrix(sparse_list)

return (
symbol_col_ptrs,
symbol_row_vals,
sparse_list,
symbol_list,
sparse_matrix,
)

@staticmethod
def print_bool(expr) -> str:
"""Print the boolean value of the given expression"""
Expand Down Expand Up @@ -360,3 +276,84 @@ def get_switch_statement(
),
indent0 + "}",
]


def csc_matrix(
matrix: sp.Matrix,
rownames: list[sp.Symbol],
colnames: list[sp.Symbol],
identifier: Optional[int] = 0,
pattern_only: Optional[bool] = False,
) -> tuple[list[int], list[int], sp.Matrix, list[str], sp.Matrix]:
"""
Generates the sparse symbolic identifiers, symbolic identifiers,
sparse matrix, column pointers and row values for a symbolic
variable
:param matrix:
dense matrix to be sparsified
:param rownames:
ids of the variable of which the derivative is computed (assuming
matrix is the jacobian)
:param colnames:
ids of the variable with respect to which the derivative is computed
(assuming matrix is the jacobian)
:param identifier:
additional identifier that gets appended to symbol names to
ensure their uniqueness in outer loops
:param pattern_only:
flag for computing sparsity pattern without whole matrix
:return:
symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list,
sparse_matrix
"""
idx = 0
nrows, ncols = matrix.shape

if not pattern_only:
sparse_matrix = sp.zeros(nrows, ncols)
symbol_list = []
sparse_list = []
symbol_col_ptrs = []
symbol_row_vals = []

for col in range(ncols):
symbol_col_ptrs.append(idx)
for row in range(nrows):
if matrix[row, col] == 0:
continue

symbol_row_vals.append(row)
idx += 1
symbol_name = f"d{rownames[row].name}" f"_d{colnames[col].name}"
if identifier:
symbol_name += f"_{identifier}"
symbol_list.append(symbol_name)
if pattern_only:
continue

sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True)
sparse_list.append(matrix[row, col])

if idx == 0:
symbol_col_ptrs = [] # avoid bad memory access for empty matrices
else:
symbol_col_ptrs.append(idx)

if pattern_only:
sparse_matrix = None
else:
sparse_list = sp.Matrix(sparse_list)

return (
symbol_col_ptrs,
symbol_row_vals,
sparse_list,
symbol_list,
sparse_matrix,
)
54 changes: 29 additions & 25 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
splines,
)
from .constants import SymbolId
from .cxxcodeprinter import AmiciCxxCodePrinter, get_switch_statement
from .cxxcodeprinter import (
AmiciCxxCodePrinter,
get_switch_statement,
csc_matrix,
)
from .de_model import *
from .import_utils import (
ObservableTransformation,
Expand Down Expand Up @@ -725,9 +729,6 @@ class DEModel:
whether all observables have a gaussian noise model, i.e. whether
res and FIM make sense.
:ivar _code_printer:
Code printer to generate C++ code
:ivar _z2event:
list of event indices for each event observable
"""
Expand Down Expand Up @@ -869,10 +870,6 @@ def cached_simplify(
self._has_quadratic_nllh: bool = True
set_log_level(logger, verbose)

self._code_printer = AmiciCxxCodePrinter()
for fun in CUSTOM_FUNCTIONS:
self._code_printer.known_functions[fun["sympy"]] = fun["c++"]

def differential_states(self) -> list[DifferentialState]:
"""Get all differential states."""
return self._differential_states
Expand Down Expand Up @@ -1882,7 +1879,7 @@ def _generate_sparse_symbol(self, name: str) -> None:
sparse_list,
symbol_list,
sparse_matrix,
) = self._code_printer.csc_matrix(
) = csc_matrix(
matrix[iy, :],
rownames=rownames,
colnames=colnames,
Expand All @@ -1900,7 +1897,7 @@ def _generate_sparse_symbol(self, name: str) -> None:
sparse_list,
symbol_list,
sparse_matrix,
) = self._code_printer.csc_matrix(
) = csc_matrix(
matrix,
rownames=rownames,
colnames=colnames,
Expand Down Expand Up @@ -2884,6 +2881,9 @@ class DEExporter:
If the given model uses special functions, this set contains hints for
model building.
:ivar _code_printer:
Code printer to generate C++ code
:ivar generate_sensitivity_code:
Specifies whether code for sensitivity computation is to be generated
Expand Down Expand Up @@ -2950,10 +2950,14 @@ def __init__(
self.set_name(model_name)
self.set_paths(outdir)

self._code_printer = AmiciCxxCodePrinter()
for fun in CUSTOM_FUNCTIONS:
self._code_printer.known_functions[fun["sympy"]] = fun["c++"]

# Signatures and properties of generated model functions (see
# include/amici/model.h for details)
self.model: DEModel = de_model
self.model._code_printer.known_functions.update(
self._code_printer.known_functions.update(
splines.spline_user_functions(
self.model._splines, self._get_index("p")
)
Expand Down Expand Up @@ -3519,7 +3523,7 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())",
f" {function}[{index}] = "
f"{self.model._code_printer.doprint(formula)};",
f"{self._code_printer.doprint(formula)};",
]
)
cases[ipar] = expressions
Expand All @@ -3534,12 +3538,12 @@ def _get_function_body(
f"reinitialization_state_idxs.cend(), {index}) != "
"reinitialization_state_idxs.cend())\n "
f"{function}[{index}] = "
f"{self.model._code_printer.doprint(formula)};"
f"{self._code_printer.doprint(formula)};"
)

elif function in event_functions:
cases = {
ie: self.model._code_printer._get_sym_lines_array(
ie: self._code_printer._get_sym_lines_array(
equations[ie], function, 0
)
for ie in range(self.model.num_events())
Expand All @@ -3552,7 +3556,7 @@ def _get_function_body(
for ie, inner_equations in enumerate(equations):
inner_lines = []
inner_cases = {
ipar: self.model._code_printer._get_sym_lines_array(
ipar: self._code_printer._get_sym_lines_array(
inner_equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -3567,7 +3571,7 @@ def _get_function_body(
and equations.shape[1] == self.model.num_par()
):
cases = {
ipar: self.model._code_printer._get_sym_lines_array(
ipar: self._code_printer._get_sym_lines_array(
equations[:, ipar], function, 0
)
for ipar in range(self.model.num_par())
Expand All @@ -3577,15 +3581,15 @@ def _get_function_body(
elif function in multiobs_functions:
if function == "dJydy":
cases = {
iobs: self.model._code_printer._get_sym_lines_array(
iobs: self._code_printer._get_sym_lines_array(
equations[iobs], function, 0
)
for iobs in range(self.model.num_obs())
if not smart_is_zero_matrix(equations[iobs])
}
else:
cases = {
iobs: self.model._code_printer._get_sym_lines_array(
iobs: self._code_printer._get_sym_lines_array(
equations[:, iobs], function, 0
)
for iobs in range(equations.shape[1])
Expand All @@ -3605,12 +3609,12 @@ def _get_function_body(
symbols = list(map(sp.Symbol, self.model.sparsesym(function)))
else:
symbols = self.model.sym(function)
lines += self.model._code_printer._get_sym_lines_symbols(
lines += self._code_printer._get_sym_lines_symbols(
symbols, equations, function, 4
)

else:
lines += self.model._code_printer._get_sym_lines_array(
lines += self._code_printer._get_sym_lines_array(
equations, function, 4
)

Expand Down Expand Up @@ -3766,10 +3770,10 @@ def _write_model_header_cpp(self) -> None:
"NK": self.model.num_const(),
"O2MODE": "amici::SecondOrderMode::none",
# using code printer ensures proper handling of nan/inf
"PARAMETERS": self.model._code_printer.doprint(
self.model.val("p")
)[1:-1],
"FIXED_PARAMETERS": self.model._code_printer.doprint(
"PARAMETERS": self._code_printer.doprint(self.model.val("p"))[
1:-1
],
"FIXED_PARAMETERS": self._code_printer.doprint(
self.model.val("k")
)[1:-1],
"PARAMETER_NAMES_INITIALIZER_LIST": self._get_symbol_name_initializer_list(
Expand Down Expand Up @@ -3961,7 +3965,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
Template initializer list of ids
"""
return "\n".join(
f'"{self.model._code_printer.doprint(symbol)}", // {name}[{idx}]'
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
for idx, symbol in enumerate(self.model.sym(name))
)

Expand Down
8 changes: 4 additions & 4 deletions python/sdist/amici/pysb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def pysb2amici(
compiler=compiler,
generate_sensitivity_code=generate_sensitivity_code,
)
# Sympy code optimizations are incompatible with PySB objects, as
# `pysb.Observable` comes with its own `.match` which overrides
# `sympy.Basic.match()`, breaking `sympy.codegen.rewriting.optimize`.
exporter._code_printer._fpoptimizer = None
exporter.generate_model_code()

if compile:
Expand Down Expand Up @@ -241,10 +245,6 @@ def ode_model_from_pysb_importer(
simplify=simplify,
cache_simplify=cache_simplify,
)
# Sympy code optimizations are incompatible with PySB objects, as
# `pysb.Observable` comes with its own `.match` which overrides
# `sympy.Basic.match()`, breaking `sympy.codegen.rewriting.optimize`.
ode._code_printer._fpoptimizer = None

if constant_parameters is None:
constant_parameters = []
Expand Down
Loading

0 comments on commit 16ec8b2

Please sign in to comment.