From 16ec8b279f4bf9bf3bb0192170f0e5bd2dbd096c Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 26 Feb 2024 11:56:15 +0100 Subject: [PATCH] Refactor DEExporter/DEModel/csc_matrix (#2311) * Refactor DEExporter/DEModel/csc_matrix Reduce unnecessary coupling: * `csc_matrix` as free function - removes the need for the codeprinter in DEModel * Move the codeprinter to `DEExporter` where it's actually needed * .. --- python/sdist/amici/cxxcodeprinter.py | 165 +++++++++++++-------------- python/sdist/amici/de_export.py | 54 +++++---- python/sdist/amici/pysb_import.py | 8 +- python/tests/test_ode_export.py | 13 +-- 4 files changed, 119 insertions(+), 121 deletions(-) diff --git a/python/sdist/amici/cxxcodeprinter.py b/python/sdist/amici/cxxcodeprinter.py index 3fe5b8cd17..032089393d 100644 --- a/python/sdist/amici/cxxcodeprinter.py +++ b/python/sdist/amici/cxxcodeprinter.py @@ -207,90 +207,6 @@ def format_line(symbol: sp.Symbol): if math not in [0, 0.0] ] - def csc_matrix( - self, - matrix: sp.Matrix, - rownames: list[sp.Symbol], - colnames: list[sp.Symbol], - identifier: Optional[int] = 0, - pattern_only: Optional[bool] = False, - ) -> tuple[list[int], list[int], sp.Matrix, list[str], sp.Matrix]: - """ - Generates the sparse symbolic identifiers, symbolic identifiers, - sparse matrix, column pointers and row values for a symbolic - variable - - :param matrix: - dense matrix to be sparsified - - :param rownames: - ids of the variable of which the derivative is computed (assuming - matrix is the jacobian) - - :param colnames: - ids of the variable with respect to which the derivative is computed - (assuming matrix is the jacobian) - - :param identifier: - additional identifier that gets appended to symbol names to - ensure their uniqueness in outer loops - - :param pattern_only: - flag for computing sparsity pattern without whole matrix - - :return: - symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, - sparse_matrix - """ - idx = 0 - - nrows, ncols = matrix.shape - - if not pattern_only: - sparse_matrix = sp.zeros(nrows, ncols) - symbol_list = [] - sparse_list = [] - symbol_col_ptrs = [] - symbol_row_vals = [] - - for col in range(ncols): - symbol_col_ptrs.append(idx) - for row in range(nrows): - if matrix[row, col] == 0: - continue - - symbol_row_vals.append(row) - idx += 1 - symbol_name = ( - f"d{rownames[row].name}" f"_d{colnames[col].name}" - ) - if identifier: - symbol_name += f"_{identifier}" - symbol_list.append(symbol_name) - if pattern_only: - continue - - sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True) - sparse_list.append(matrix[row, col]) - - if idx == 0: - symbol_col_ptrs = [] # avoid bad memory access for empty matrices - else: - symbol_col_ptrs.append(idx) - - if pattern_only: - sparse_matrix = None - else: - sparse_list = sp.Matrix(sparse_list) - - return ( - symbol_col_ptrs, - symbol_row_vals, - sparse_list, - symbol_list, - sparse_matrix, - ) - @staticmethod def print_bool(expr) -> str: """Print the boolean value of the given expression""" @@ -360,3 +276,84 @@ def get_switch_statement( ), indent0 + "}", ] + + +def csc_matrix( + matrix: sp.Matrix, + rownames: list[sp.Symbol], + colnames: list[sp.Symbol], + identifier: Optional[int] = 0, + pattern_only: Optional[bool] = False, +) -> tuple[list[int], list[int], sp.Matrix, list[str], sp.Matrix]: + """ + Generates the sparse symbolic identifiers, symbolic identifiers, + sparse matrix, column pointers and row values for a symbolic + variable + + :param matrix: + dense matrix to be sparsified + + :param rownames: + ids of the variable of which the derivative is computed (assuming + matrix is the jacobian) + + :param colnames: + ids of the variable with respect to which the derivative is computed + (assuming matrix is the jacobian) + + :param identifier: + additional identifier that gets appended to symbol names to + ensure their uniqueness in outer loops + + :param pattern_only: + flag for computing sparsity pattern without whole matrix + + :return: + symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list, + sparse_matrix + """ + idx = 0 + nrows, ncols = matrix.shape + + if not pattern_only: + sparse_matrix = sp.zeros(nrows, ncols) + symbol_list = [] + sparse_list = [] + symbol_col_ptrs = [] + symbol_row_vals = [] + + for col in range(ncols): + symbol_col_ptrs.append(idx) + for row in range(nrows): + if matrix[row, col] == 0: + continue + + symbol_row_vals.append(row) + idx += 1 + symbol_name = f"d{rownames[row].name}" f"_d{colnames[col].name}" + if identifier: + symbol_name += f"_{identifier}" + symbol_list.append(symbol_name) + if pattern_only: + continue + + sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True) + sparse_list.append(matrix[row, col]) + + if idx == 0: + symbol_col_ptrs = [] # avoid bad memory access for empty matrices + else: + symbol_col_ptrs.append(idx) + + if pattern_only: + sparse_matrix = None + else: + sparse_list = sp.Matrix(sparse_list) + + return ( + symbol_col_ptrs, + symbol_row_vals, + sparse_list, + symbol_list, + sparse_matrix, + ) diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index f2badbea76..0a6813a6ca 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -44,7 +44,11 @@ splines, ) from .constants import SymbolId -from .cxxcodeprinter import AmiciCxxCodePrinter, get_switch_statement +from .cxxcodeprinter import ( + AmiciCxxCodePrinter, + get_switch_statement, + csc_matrix, +) from .de_model import * from .import_utils import ( ObservableTransformation, @@ -725,9 +729,6 @@ class DEModel: whether all observables have a gaussian noise model, i.e. whether res and FIM make sense. - :ivar _code_printer: - Code printer to generate C++ code - :ivar _z2event: list of event indices for each event observable """ @@ -869,10 +870,6 @@ def cached_simplify( self._has_quadratic_nllh: bool = True set_log_level(logger, verbose) - self._code_printer = AmiciCxxCodePrinter() - for fun in CUSTOM_FUNCTIONS: - self._code_printer.known_functions[fun["sympy"]] = fun["c++"] - def differential_states(self) -> list[DifferentialState]: """Get all differential states.""" return self._differential_states @@ -1882,7 +1879,7 @@ def _generate_sparse_symbol(self, name: str) -> None: sparse_list, symbol_list, sparse_matrix, - ) = self._code_printer.csc_matrix( + ) = csc_matrix( matrix[iy, :], rownames=rownames, colnames=colnames, @@ -1900,7 +1897,7 @@ def _generate_sparse_symbol(self, name: str) -> None: sparse_list, symbol_list, sparse_matrix, - ) = self._code_printer.csc_matrix( + ) = csc_matrix( matrix, rownames=rownames, colnames=colnames, @@ -2884,6 +2881,9 @@ class DEExporter: If the given model uses special functions, this set contains hints for model building. + :ivar _code_printer: + Code printer to generate C++ code + :ivar generate_sensitivity_code: Specifies whether code for sensitivity computation is to be generated @@ -2950,10 +2950,14 @@ def __init__( self.set_name(model_name) self.set_paths(outdir) + self._code_printer = AmiciCxxCodePrinter() + for fun in CUSTOM_FUNCTIONS: + self._code_printer.known_functions[fun["sympy"]] = fun["c++"] + # Signatures and properties of generated model functions (see # include/amici/model.h for details) self.model: DEModel = de_model - self.model._code_printer.known_functions.update( + self._code_printer.known_functions.update( splines.spline_user_functions( self.model._splines, self._get_index("p") ) @@ -3519,7 +3523,7 @@ def _get_function_body( f"reinitialization_state_idxs.cend(), {index}) != " "reinitialization_state_idxs.cend())", f" {function}[{index}] = " - f"{self.model._code_printer.doprint(formula)};", + f"{self._code_printer.doprint(formula)};", ] ) cases[ipar] = expressions @@ -3534,12 +3538,12 @@ def _get_function_body( f"reinitialization_state_idxs.cend(), {index}) != " "reinitialization_state_idxs.cend())\n " f"{function}[{index}] = " - f"{self.model._code_printer.doprint(formula)};" + f"{self._code_printer.doprint(formula)};" ) elif function in event_functions: cases = { - ie: self.model._code_printer._get_sym_lines_array( + ie: self._code_printer._get_sym_lines_array( equations[ie], function, 0 ) for ie in range(self.model.num_events()) @@ -3552,7 +3556,7 @@ def _get_function_body( for ie, inner_equations in enumerate(equations): inner_lines = [] inner_cases = { - ipar: self.model._code_printer._get_sym_lines_array( + ipar: self._code_printer._get_sym_lines_array( inner_equations[:, ipar], function, 0 ) for ipar in range(self.model.num_par()) @@ -3567,7 +3571,7 @@ def _get_function_body( and equations.shape[1] == self.model.num_par() ): cases = { - ipar: self.model._code_printer._get_sym_lines_array( + ipar: self._code_printer._get_sym_lines_array( equations[:, ipar], function, 0 ) for ipar in range(self.model.num_par()) @@ -3577,7 +3581,7 @@ def _get_function_body( elif function in multiobs_functions: if function == "dJydy": cases = { - iobs: self.model._code_printer._get_sym_lines_array( + iobs: self._code_printer._get_sym_lines_array( equations[iobs], function, 0 ) for iobs in range(self.model.num_obs()) @@ -3585,7 +3589,7 @@ def _get_function_body( } else: cases = { - iobs: self.model._code_printer._get_sym_lines_array( + iobs: self._code_printer._get_sym_lines_array( equations[:, iobs], function, 0 ) for iobs in range(equations.shape[1]) @@ -3605,12 +3609,12 @@ def _get_function_body( symbols = list(map(sp.Symbol, self.model.sparsesym(function))) else: symbols = self.model.sym(function) - lines += self.model._code_printer._get_sym_lines_symbols( + lines += self._code_printer._get_sym_lines_symbols( symbols, equations, function, 4 ) else: - lines += self.model._code_printer._get_sym_lines_array( + lines += self._code_printer._get_sym_lines_array( equations, function, 4 ) @@ -3766,10 +3770,10 @@ def _write_model_header_cpp(self) -> None: "NK": self.model.num_const(), "O2MODE": "amici::SecondOrderMode::none", # using code printer ensures proper handling of nan/inf - "PARAMETERS": self.model._code_printer.doprint( - self.model.val("p") - )[1:-1], - "FIXED_PARAMETERS": self.model._code_printer.doprint( + "PARAMETERS": self._code_printer.doprint(self.model.val("p"))[ + 1:-1 + ], + "FIXED_PARAMETERS": self._code_printer.doprint( self.model.val("k") )[1:-1], "PARAMETER_NAMES_INITIALIZER_LIST": self._get_symbol_name_initializer_list( @@ -3961,7 +3965,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str: Template initializer list of ids """ return "\n".join( - f'"{self.model._code_printer.doprint(symbol)}", // {name}[{idx}]' + f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]' for idx, symbol in enumerate(self.model.sym(name)) ) diff --git a/python/sdist/amici/pysb_import.py b/python/sdist/amici/pysb_import.py index 4f843033f1..c79a8c50f9 100644 --- a/python/sdist/amici/pysb_import.py +++ b/python/sdist/amici/pysb_import.py @@ -178,6 +178,10 @@ def pysb2amici( compiler=compiler, generate_sensitivity_code=generate_sensitivity_code, ) + # Sympy code optimizations are incompatible with PySB objects, as + # `pysb.Observable` comes with its own `.match` which overrides + # `sympy.Basic.match()`, breaking `sympy.codegen.rewriting.optimize`. + exporter._code_printer._fpoptimizer = None exporter.generate_model_code() if compile: @@ -241,10 +245,6 @@ def ode_model_from_pysb_importer( simplify=simplify, cache_simplify=cache_simplify, ) - # Sympy code optimizations are incompatible with PySB objects, as - # `pysb.Observable` comes with its own `.match` which overrides - # `sympy.Basic.match()`, breaking `sympy.codegen.rewriting.optimize`. - ode._code_printer._fpoptimizer = None if constant_parameters is None: constant_parameters = [] diff --git a/python/tests/test_ode_export.py b/python/tests/test_ode_export.py index f34d78892d..65af2935bb 100644 --- a/python/tests/test_ode_export.py +++ b/python/tests/test_ode_export.py @@ -1,14 +1,13 @@ """Miscellaneous AMICI Python interface tests""" import sympy as sp -from amici.cxxcodeprinter import AmiciCxxCodePrinter +from amici.cxxcodeprinter import csc_matrix from amici.testing import skip_on_valgrind @skip_on_valgrind def test_csc_matrix(): """Test sparse CSC matrix creation""" - printer = AmiciCxxCodePrinter() matrix = sp.Matrix([[1, 0], [2, 3]]) ( symbol_col_ptrs, @@ -16,7 +15,7 @@ def test_csc_matrix(): sparse_list, symbol_list, sparse_matrix, - ) = printer.csc_matrix( + ) = csc_matrix( matrix, rownames=[sp.Symbol("a1"), sp.Symbol("a2")], colnames=[sp.Symbol("b1"), sp.Symbol("b2")], @@ -32,7 +31,6 @@ def test_csc_matrix(): @skip_on_valgrind def test_csc_matrix_empty(): """Test sparse CSC matrix creation for empty matrix""" - printer = AmiciCxxCodePrinter() matrix = sp.Matrix() ( symbol_col_ptrs, @@ -40,7 +38,7 @@ def test_csc_matrix_empty(): sparse_list, symbol_list, sparse_matrix, - ) = printer.csc_matrix(matrix, rownames=[], colnames=[]) + ) = csc_matrix(matrix, rownames=[], colnames=[]) assert symbol_col_ptrs == [] assert symbol_row_vals == [] @@ -52,7 +50,6 @@ def test_csc_matrix_empty(): @skip_on_valgrind def test_csc_matrix_vector(): """Test sparse CSC matrix creation from matrix slice""" - printer = AmiciCxxCodePrinter() matrix = sp.Matrix([[1, 0], [2, 3]]) ( symbol_col_ptrs, @@ -60,7 +57,7 @@ def test_csc_matrix_vector(): sparse_list, symbol_list, sparse_matrix, - ) = printer.csc_matrix( + ) = csc_matrix( matrix[:, 0], colnames=[sp.Symbol("b")], rownames=[sp.Symbol("a1"), sp.Symbol("a2")], @@ -79,7 +76,7 @@ def test_csc_matrix_vector(): sparse_list, symbol_list, sparse_matrix, - ) = printer.csc_matrix( + ) = csc_matrix( matrix[:, 1], colnames=[sp.Symbol("b")], rownames=[sp.Symbol("a1"), sp.Symbol("a2")],