From aeb5f3478ddb85372c3ae69268aafa6f2fdef91c Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 26 Feb 2024 13:58:02 +0100 Subject: [PATCH] Refactor de_export.py, extract sympy_utils.py (#2307) No changes in functionality. Related to #2306. --- python/sdist/amici/de_export.py | 205 ++--------------------------- python/sdist/amici/import_utils.py | 7 + python/sdist/amici/pysb_import.py | 2 +- python/sdist/amici/sbml_import.py | 4 +- python/sdist/amici/sympy_utils.py | 196 +++++++++++++++++++++++++++ python/tests/test_misc.py | 21 --- python/tests/test_sympy_utils.py | 24 ++++ 7 files changed, 241 insertions(+), 218 deletions(-) create mode 100644 python/sdist/amici/sympy_utils.py create mode 100644 python/tests/test_sympy_utils.py diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index 0a6813a6ca..3f8696a86b 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -1,7 +1,7 @@ """ C++ Export ---------- -This module provides all necessary functionality specify an DE model and +This module provides all necessary functionality specify a DE model and generate executable C++ simulation code. The user generally won't have to directly call any function from this module as this will be done by :py:func:`amici.pysb_import.pysb2amici`, @@ -18,12 +18,11 @@ import subprocess import sys from dataclasses import dataclass -from itertools import chain, starmap +from itertools import chain from pathlib import Path from string import Template from typing import ( TYPE_CHECKING, - Any, Callable, Literal, Optional, @@ -59,8 +58,17 @@ strip_pysb, toposort_symbols, unique_preserve_order, + _default_simplify, ) from .logging import get_logger, log_execution_time, set_log_level +from .sympy_utils import ( + _custom_pow_eval_derivative, + _monkeypatched, + smart_jacobian, + smart_multiply, + smart_is_zero_matrix, + _parallel_applyfunc, +) if TYPE_CHECKING: from . import sbml_import @@ -509,109 +517,6 @@ def var_in_function_signature(name: str, varname: str, ode: bool) -> bool: } -@log_execution_time("running smart_jacobian", logger) -def smart_jacobian( - eq: sp.MutableDenseMatrix, sym_var: sp.MutableDenseMatrix -) -> sp.MutableSparseMatrix: - """ - Wrapper around symbolic jacobian with some additional checks that reduce - computation time for large matrices - - :param eq: - equation - :param sym_var: - differentiation variable - :return: - jacobian of eq wrt sym_var - """ - nrow = eq.shape[0] - ncol = sym_var.shape[0] - if ( - not min(eq.shape) - or not min(sym_var.shape) - or smart_is_zero_matrix(eq) - or smart_is_zero_matrix(sym_var) - ): - return sp.MutableSparseMatrix(nrow, ncol, dict()) - - # preprocess sparsity pattern - elements = ( - (i, j, a, b) - for i, a in enumerate(eq) - for j, b in enumerate(sym_var) - if a.has(b) - ) - - if (n_procs := int(os.environ.get("AMICI_IMPORT_NPROCS", 1))) == 1: - # serial - return sp.MutableSparseMatrix( - nrow, ncol, dict(starmap(_jacobian_element, elements)) - ) - - # parallel - from multiprocessing import get_context - - # "spawn" should avoid potential deadlocks occurring with fork - # see e.g. https://stackoverflow.com/a/66113051 - ctx = get_context("spawn") - with ctx.Pool(n_procs) as p: - mapped = p.starmap(_jacobian_element, elements) - return sp.MutableSparseMatrix(nrow, ncol, dict(mapped)) - - -@log_execution_time("running smart_multiply", logger) -def smart_multiply( - x: Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix], - y: sp.MutableDenseMatrix, -) -> Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix]: - """ - Wrapper around symbolic multiplication with some additional checks that - reduce computation time for large matrices - - :param x: - educt 1 - :param y: - educt 2 - :return: - product - """ - if ( - not x.shape[0] - or not y.shape[1] - or smart_is_zero_matrix(x) - or smart_is_zero_matrix(y) - ): - return sp.zeros(x.shape[0], y.shape[1]) - return x.multiply(y) - - -def smart_is_zero_matrix( - x: Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix], -) -> bool: - """A faster implementation of sympy's is_zero_matrix - - Avoids repeated indexer type checks and double iteration to distinguish - False/None. Found to be about 100x faster for large matrices. - - :param x: Matrix to check - """ - - if isinstance(x, sp.MutableDenseMatrix): - return all(xx.is_zero is True for xx in x.flat()) - - if isinstance(x, list): - return all(smart_is_zero_matrix(xx) for xx in x) - - return x.nnz() == 0 - - -def _default_simplify(x): - """Default simplification applied in DEModel""" - # We need this as a free function instead of a lambda to have it picklable - # for parallel simplification - return sp.powsimp(x, deep=True) - - class DEModel: """ Defines a Differential Equation as set of ModelQuantities. @@ -4304,94 +4209,6 @@ def is_valid_identifier(x: str) -> bool: return IDENTIFIER_PATTERN.match(x) is not None -@contextlib.contextmanager -def _monkeypatched(obj: object, name: str, patch: Any): - """ - Temporarily monkeypatches an object. - - :param obj: - object to be patched - - :param name: - name of the attribute to be patched - - :param patch: - patched value - """ - pre_patched_value = getattr(obj, name) - setattr(obj, name, patch) - try: - yield object - finally: - setattr(obj, name, pre_patched_value) - - -def _custom_pow_eval_derivative(self, s): - """ - Custom Pow derivative that removes a removable singularity for - ``self.base == 0`` and ``self.base.diff(s) == 0``. This function is - intended to be monkeypatched into :py:method:`sympy.Pow._eval_derivative`. - - :param self: - sp.Pow class - - :param s: - variable with respect to which the derivative will be computed - """ - dbase = self.base.diff(s) - dexp = self.exp.diff(s) - part1 = sp.Pow(self.base, self.exp - 1) * self.exp * dbase - part2 = self * dexp * sp.log(self.base) - if self.base.is_nonzero or dbase.is_nonzero or part2.is_zero: - # first piece never applies or is zero anyways - return part1 + part2 - - return part1 + sp.Piecewise( - (self.base, sp.And(sp.Eq(self.base, 0), sp.Eq(dbase, 0))), - (part2, True), - ) - - -def _jacobian_element(i, j, eq_i, sym_var_j): - """Compute a single element of a jacobian""" - return (i, j), eq_i.diff(sym_var_j) - - -def _parallel_applyfunc(obj: sp.Matrix, func: Callable) -> sp.Matrix: - """Parallel implementation of sympy's Matrix.applyfunc""" - if (n_procs := int(os.environ.get("AMICI_IMPORT_NPROCS", 1))) == 1: - # serial - return obj.applyfunc(func) - - # parallel - from multiprocessing import get_context - from pickle import PicklingError - - from sympy.matrices.dense import DenseMatrix - - # "spawn" should avoid potential deadlocks occurring with fork - # see e.g. https://stackoverflow.com/a/66113051 - ctx = get_context("spawn") - with ctx.Pool(n_procs) as p: - try: - if isinstance(obj, DenseMatrix): - return obj._new(obj.rows, obj.cols, p.map(func, obj)) - elif isinstance(obj, sp.SparseMatrix): - dok = obj.todok() - mapped = p.map(func, dok.values()) - dok = {k: v for k, v in zip(dok.keys(), mapped) if v != 0} - return obj._new(obj.rows, obj.cols, dok) - else: - raise ValueError(f"Unsupported matrix type {type(obj)}") - except PicklingError as e: - raise ValueError( - f"Couldn't pickle {func}. This is likely because the argument " - "was not a module-level function. Either rewrite the argument " - "to a module-level function or disable parallelization by " - "setting `AMICI_IMPORT_NPROCS=1`." - ) from e - - def _write_gitignore(dest_dir: Path) -> None: """Write .gitignore file. diff --git a/python/sdist/amici/import_utils.py b/python/sdist/amici/import_utils.py index 63a160c1de..029c2cc6de 100644 --- a/python/sdist/amici/import_utils.py +++ b/python/sdist/amici/import_utils.py @@ -748,3 +748,10 @@ def unique_preserve_order(seq: Sequence) -> list: sbml_time_symbol = symbol_with_assumptions("time") amici_time_symbol = symbol_with_assumptions("t") + + +def _default_simplify(x): + """Default simplification applied in DEModel""" + # We need this as a free function instead of a lambda to have it picklable + # for parallel simplification + return sp.powsimp(x, deep=True) diff --git a/python/sdist/amici/pysb_import.py b/python/sdist/amici/pysb_import.py index c79a8c50f9..05f9ed9b28 100644 --- a/python/sdist/amici/pysb_import.py +++ b/python/sdist/amici/pysb_import.py @@ -34,7 +34,6 @@ Observable, Parameter, SigmaY, - _default_simplify, ) from .import_utils import ( _get_str_symbol_identifiers, @@ -42,6 +41,7 @@ generate_measurement_symbol, noise_distribution_to_cost_function, noise_distribution_to_observable_transformation, + _default_simplify, ) from .logging import get_logger, log_execution_time, set_log_level diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 9d6b7229c5..8fc2ab9fd9 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -31,9 +31,8 @@ from .de_export import ( DEExporter, DEModel, - _default_simplify, - smart_is_zero_matrix, ) +from .sympy_utils import smart_is_zero_matrix from .import_utils import ( RESERVED_SYMBOLS, _check_unsupported_functions, @@ -50,6 +49,7 @@ smart_subs_dict, symbol_with_assumptions, toposort_symbols, + _default_simplify, ) from .logging import get_logger, log_execution_time, set_log_level from .sbml_utils import SBMLException, _parse_logical_operators diff --git a/python/sdist/amici/sympy_utils.py b/python/sdist/amici/sympy_utils.py new file mode 100644 index 0000000000..863fc57f14 --- /dev/null +++ b/python/sdist/amici/sympy_utils.py @@ -0,0 +1,196 @@ +"""Functionality for working with sympy objects.""" +import os +from itertools import starmap +from typing import Union, Any, Callable +import contextlib +import sympy as sp +import logging +from amici.de_export import get_logger +from amici.logging import log_execution_time + + +logger = get_logger(__name__, logging.ERROR) + + +def _custom_pow_eval_derivative(self, s): + """ + Custom Pow derivative that removes a removable singularity for + ``self.base == 0`` and ``self.base.diff(s) == 0``. This function is + intended to be monkeypatched into :py:method:`sympy.Pow._eval_derivative`. + + :param self: + sp.Pow class + + :param s: + variable with respect to which the derivative will be computed + """ + dbase = self.base.diff(s) + dexp = self.exp.diff(s) + part1 = sp.Pow(self.base, self.exp - 1) * self.exp * dbase + part2 = self * dexp * sp.log(self.base) + if self.base.is_nonzero or dbase.is_nonzero or part2.is_zero: + # first piece never applies or is zero anyway + return part1 + part2 + + return part1 + sp.Piecewise( + (self.base, sp.And(sp.Eq(self.base, 0), sp.Eq(dbase, 0))), + (part2, True), + ) + + +@contextlib.contextmanager +def _monkeypatched(obj: object, name: str, patch: Any): + """ + Temporarily monkeypatches an object. + + :param obj: + object to be patched + + :param name: + name of the attribute to be patched + + :param patch: + patched value + """ + pre_patched_value = getattr(obj, name) + setattr(obj, name, patch) + try: + yield object + finally: + setattr(obj, name, pre_patched_value) + + +@log_execution_time("running smart_jacobian", logger) +def smart_jacobian( + eq: sp.MutableDenseMatrix, sym_var: sp.MutableDenseMatrix +) -> sp.MutableSparseMatrix: + """ + Wrapper around symbolic jacobian with some additional checks that reduce + computation time for large matrices + + :param eq: + equation + :param sym_var: + differentiation variable + :return: + jacobian of eq wrt sym_var + """ + nrow = eq.shape[0] + ncol = sym_var.shape[0] + if ( + not min(eq.shape) + or not min(sym_var.shape) + or smart_is_zero_matrix(eq) + or smart_is_zero_matrix(sym_var) + ): + return sp.MutableSparseMatrix(nrow, ncol, dict()) + + # preprocess sparsity pattern + elements = ( + (i, j, a, b) + for i, a in enumerate(eq) + for j, b in enumerate(sym_var) + if a.has(b) + ) + + if (n_procs := int(os.environ.get("AMICI_IMPORT_NPROCS", 1))) == 1: + # serial + return sp.MutableSparseMatrix( + nrow, ncol, dict(starmap(_jacobian_element, elements)) + ) + + # parallel + from multiprocessing import get_context + + # "spawn" should avoid potential deadlocks occurring with fork + # see e.g. https://stackoverflow.com/a/66113051 + ctx = get_context("spawn") + with ctx.Pool(n_procs) as p: + mapped = p.starmap(_jacobian_element, elements) + return sp.MutableSparseMatrix(nrow, ncol, dict(mapped)) + + +@log_execution_time("running smart_multiply", logger) +def smart_multiply( + x: Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix], + y: sp.MutableDenseMatrix, +) -> Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix]: + """ + Wrapper around symbolic multiplication with some additional checks that + reduce computation time for large matrices + + :param x: + educt 1 + :param y: + educt 2 + :return: + product + """ + if ( + not x.shape[0] + or not y.shape[1] + or smart_is_zero_matrix(x) + or smart_is_zero_matrix(y) + ): + return sp.zeros(x.shape[0], y.shape[1]) + return x.multiply(y) + + +def smart_is_zero_matrix( + x: Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix], +) -> bool: + """A faster implementation of sympy's is_zero_matrix + + Avoids repeated indexer type checks and double iteration to distinguish + False/None. Found to be about 100x faster for large matrices. + + :param x: Matrix to check + """ + + if isinstance(x, sp.MutableDenseMatrix): + return all(xx.is_zero is True for xx in x.flat()) + + if isinstance(x, list): + return all(smart_is_zero_matrix(xx) for xx in x) + + return x.nnz() == 0 + + +def _jacobian_element(i, j, eq_i, sym_var_j): + """Compute a single element of a jacobian""" + return (i, j), eq_i.diff(sym_var_j) + + +def _parallel_applyfunc(obj: sp.Matrix, func: Callable) -> sp.Matrix: + """Parallel implementation of sympy's Matrix.applyfunc""" + if (n_procs := int(os.environ.get("AMICI_IMPORT_NPROCS", 1))) == 1: + # serial + return obj.applyfunc(func) + + # parallel + from multiprocessing import get_context + from pickle import PicklingError + + from sympy.matrices.dense import DenseMatrix + + # "spawn" should avoid potential deadlocks occurring with fork + # see e.g. https://stackoverflow.com/a/66113051 + ctx = get_context("spawn") + with ctx.Pool(n_procs) as p: + try: + if isinstance(obj, DenseMatrix): + return obj._new(obj.rows, obj.cols, p.map(func, obj)) + elif isinstance(obj, sp.SparseMatrix): + dok = obj.todok() + mapped = p.map(func, dok.values()) + dok = {k: v for k, v in zip(dok.keys(), mapped) if v != 0} + return obj._new(obj.rows, obj.cols, dok) + else: + raise ValueError(f"Unsupported matrix type {type(obj)}") + except PicklingError as e: + raise ValueError( + f"Couldn't pickle {func}. This is likely because the argument " + "was not a module-level function. Either rewrite the argument " + "to a module-level function or disable parallelization by " + "setting `AMICI_IMPORT_NPROCS=1`." + ) from e diff --git a/python/tests/test_misc.py b/python/tests/test_misc.py index 24bba79888..1ddeb3f760 100644 --- a/python/tests/test_misc.py +++ b/python/tests/test_misc.py @@ -8,8 +8,6 @@ import pytest import sympy as sp from amici.de_export import ( - _custom_pow_eval_derivative, - _monkeypatched, smart_subs_dict, ) from amici.testing import skip_on_valgrind @@ -108,25 +106,6 @@ def test_smart_subs_dict(): assert sp.simplify(result_reverse - expected_reverse).is_zero -@skip_on_valgrind -def test_monkeypatch(): - t = sp.Symbol("t") - n = sp.Symbol("n") - vals = [(t, 0), (n, 1)] - - # check that the removable singularity still exists - assert (t**n).diff(t).subs(vals) is sp.nan - - # check that we can monkeypatch it out - with _monkeypatched( - sp.Pow, "_eval_derivative", _custom_pow_eval_derivative - ): - assert (t**n).diff(t).subs(vals) is not sp.nan - - # check that the monkeypatch is transient - assert (t**n).diff(t).subs(vals) is sp.nan - - @skip_on_valgrind def test_get_default_argument(): # no default diff --git a/python/tests/test_sympy_utils.py b/python/tests/test_sympy_utils.py new file mode 100644 index 0000000000..da89741352 --- /dev/null +++ b/python/tests/test_sympy_utils.py @@ -0,0 +1,24 @@ +"""Tests related to the sympy_utils module.""" + +from amici.sympy_utils import _custom_pow_eval_derivative, _monkeypatched +import sympy as sp +from amici.testing import skip_on_valgrind + + +@skip_on_valgrind +def test_monkeypatch(): + t = sp.Symbol("t") + n = sp.Symbol("n") + vals = [(t, 0), (n, 1)] + + # check that the removable singularity still exists + assert (t**n).diff(t).subs(vals) is sp.nan + + # check that we can monkeypatch it out + with _monkeypatched( + sp.Pow, "_eval_derivative", _custom_pow_eval_derivative + ): + assert (t**n).diff(t).subs(vals) is not sp.nan + + # check that the monkeypatch is transient + assert (t**n).diff(t).subs(vals) is sp.nan