From 9465a6c73871015276dfc9ffe62d01c660bcafe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Tue, 3 Dec 2024 08:36:47 +0000 Subject: [PATCH] reviews --- .../example_jax_petab/ExampleJaxPEtab.ipynb | 1 + python/sdist/amici/__init__.template.py | 28 +------------------ python/sdist/amici/jax/ode_export.py | 11 ++------ python/sdist/amici/petab/sbml_import.py | 8 +++--- python/sdist/amici/pysb_import.py | 6 ++-- 5 files changed, 11 insertions(+), 43 deletions(-) diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index ef1c513dd0..855860e242 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -489,6 +489,7 @@ "amici_model = import_petab_problem(\n", " petab_problem,\n", " verbose=False,\n", + " compile_=True,\n", " jax=False, # load the amici model this time\n", ")\n", "\n", diff --git a/python/sdist/amici/__init__.template.py b/python/sdist/amici/__init__.template.py index efc8df0617..abd07a81ab 100644 --- a/python/sdist/amici/__init__.template.py +++ b/python/sdist/amici/__init__.template.py @@ -1,7 +1,5 @@ """AMICI-generated module for model TPL_MODELNAME""" -import datetime -import os import sys from pathlib import Path from typing import TYPE_CHECKING @@ -9,7 +7,7 @@ if TYPE_CHECKING: - from amici.jax import JAXModel + pass # Ensure we are binary-compatible, see #556 if "TPL_AMICI_VERSION" != amici.__version__: @@ -38,28 +36,4 @@ # when the model package is imported via `import` TPL_MODELNAME._model_module = sys.modules[__name__] - -def get_jax_model() -> "JAXModel": - # If the model directory was meanwhile overwritten, this would load the - # new version, which would not match the previously imported extension. - # This is not allowed, as it would lead to inconsistencies. - jax_py_file = Path(__file__).parent / "jax.py" - jax_py_file = jax_py_file.resolve() - t_imported = TPL_MODELNAME._get_import_time() # noqa: protected-access - t_modified = os.path.getmtime(jax_py_file) - if t_imported < t_modified: - t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat() - t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat() - raise RuntimeError( - f"Refusing to import {jax_py_file} which was changed since " - f"TPL_MODELNAME was imported. This is to avoid inconsistencies " - "between the different model implementations.\n" - f"Imported at {t_imp_str}\nModified at {t_mod_str}.\n" - "Import the module with a different name or restart the " - "Python kernel." - ) - jax = amici._module_from_path("jax", jax_py_file) - return jax.JAXModel_TPL_MODELNAME() - - __version__ = "TPL_PACKAGE_VERSION" diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 5398aeb235..7ea4a29d8a 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -14,9 +14,6 @@ import logging import os from pathlib import Path -from typing import ( - TYPE_CHECKING, -) import sympy as sp @@ -38,9 +35,6 @@ _monkeypatched, ) -if TYPE_CHECKING: - pass - #: python log manager logger = get_logger(__name__, logging.ERROR) @@ -168,8 +162,7 @@ def __init__( @log_execution_time("generating jax code", logger) def generate_model_code(self) -> None: """ - Generates the native C++ code for the loaded model and a Matlab - script that can be run to compile a mex file from the C++ code + Generates the jax code for the loaded model """ with _monkeypatched( sp.Pow, "_eval_derivative", _custom_pow_eval_derivative @@ -221,7 +214,7 @@ def _generate_jax_code(self) -> None: strict=True, ) ) - subs = {**subs_heaviside, **subs_observables} + subs = subs_heaviside | subs_observables tpl_data = { # assign named variable using corresponding algebraic formula (function body) diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index f013156725..02a2c4e12c 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -def workaround_initial_states( +def _workaround_initial_states( petab_problem: petab.Problem, sbml_model: libsbml.Model, **kwargs ): # TODO: to parameterize initial states or compartment sizes, we currently @@ -146,7 +146,7 @@ def workaround_initial_states( return fixed_parameters -def workaround_observable_parameters( +def _workaround_observable_parameters( observables, sigmas, sbml_model, output_parameter_defaults ): # TODO: adding extra output parameters is currently not supported, @@ -345,10 +345,10 @@ def import_model_sbml( f"({len(sigmas)}) do not match." ) - workaround_observable_parameters( + _workaround_observable_parameters( observables, sigmas, sbml_model, output_parameter_defaults ) - fixed_parameters = workaround_initial_states( + fixed_parameters = _workaround_initial_states( petab_problem=petab_problem, sbml_model=sbml_model, **kwargs, diff --git a/python/sdist/amici/pysb_import.py b/python/sdist/amici/pysb_import.py index e01a09dc65..b84fadea44 100644 --- a/python/sdist/amici/pysb_import.py +++ b/python/sdist/amici/pysb_import.py @@ -111,9 +111,9 @@ def pysb2jax( see :attr:`amici.DEModel._simplify` :param cache_simplify: - see :func:`amici.DEModel.__init__` - Note that there are possible issues with PySB models: - https://github.com/AMICI-dev/AMICI/pull/1672 + see :func:`amici.DEModel.__init__` + Note that there are possible issues with PySB models: + https://github.com/AMICI-dev/AMICI/pull/1672 :param model_name: Name for the generated model module. If None, :attr:`pysb.Model.name`