Skip to content

Commit

Permalink
reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 3, 2024
1 parent 1f0a13f commit 9465a6c
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 43 deletions.
1 change: 1 addition & 0 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@
"amici_model = import_petab_problem(\n",
" petab_problem,\n",
" verbose=False,\n",
" compile_=True,\n",
" jax=False, # load the amici model this time\n",
")\n",
"\n",
Expand Down
28 changes: 1 addition & 27 deletions python/sdist/amici/__init__.template.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
"""AMICI-generated module for model TPL_MODELNAME"""

import datetime
import os
import sys
from pathlib import Path
from typing import TYPE_CHECKING
import amici


if TYPE_CHECKING:
from amici.jax import JAXModel
pass

# Ensure we are binary-compatible, see #556
if "TPL_AMICI_VERSION" != amici.__version__:
Expand Down Expand Up @@ -38,28 +36,4 @@
# when the model package is imported via `import`
TPL_MODELNAME._model_module = sys.modules[__name__]


def get_jax_model() -> "JAXModel":
# If the model directory was meanwhile overwritten, this would load the
# new version, which would not match the previously imported extension.
# This is not allowed, as it would lead to inconsistencies.
jax_py_file = Path(__file__).parent / "jax.py"
jax_py_file = jax_py_file.resolve()
t_imported = TPL_MODELNAME._get_import_time() # noqa: protected-access
t_modified = os.path.getmtime(jax_py_file)
if t_imported < t_modified:
t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat()
t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat()
raise RuntimeError(
f"Refusing to import {jax_py_file} which was changed since "
f"TPL_MODELNAME was imported. This is to avoid inconsistencies "
"between the different model implementations.\n"
f"Imported at {t_imp_str}\nModified at {t_mod_str}.\n"
"Import the module with a different name or restart the "
"Python kernel."
)
jax = amici._module_from_path("jax", jax_py_file)
return jax.JAXModel_TPL_MODELNAME()


__version__ = "TPL_PACKAGE_VERSION"
11 changes: 2 additions & 9 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import logging
import os
from pathlib import Path
from typing import (
TYPE_CHECKING,
)

import sympy as sp

Expand All @@ -38,9 +35,6 @@
_monkeypatched,
)

if TYPE_CHECKING:
pass

#: python log manager
logger = get_logger(__name__, logging.ERROR)

Check warning on line 39 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L39

Added line #L39 was not covered by tests

Expand Down Expand Up @@ -168,8 +162,7 @@ def __init__(
@log_execution_time("generating jax code", logger)
def generate_model_code(self) -> None:

Check warning on line 163 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L162-L163

Added lines #L162 - L163 were not covered by tests
"""
Generates the native C++ code for the loaded model and a Matlab
script that can be run to compile a mex file from the C++ code
Generates the jax code for the loaded model
"""
with _monkeypatched(

Check warning on line 167 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L167

Added line #L167 was not covered by tests
sp.Pow, "_eval_derivative", _custom_pow_eval_derivative
Expand Down Expand Up @@ -221,7 +214,7 @@ def _generate_jax_code(self) -> None:
strict=True,
)
)
subs = {**subs_heaviside, **subs_observables}
subs = subs_heaviside | subs_observables

Check warning on line 217 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L217

Added line #L217 was not covered by tests

tpl_data = {

Check warning on line 219 in python/sdist/amici/jax/ode_export.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/ode_export.py#L219

Added line #L219 was not covered by tests
# assign named variable using corresponding algebraic formula (function body)
Expand Down
8 changes: 4 additions & 4 deletions python/sdist/amici/petab/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
logger = logging.getLogger(__name__)


def workaround_initial_states(
def _workaround_initial_states(
petab_problem: petab.Problem, sbml_model: libsbml.Model, **kwargs
):
# TODO: to parameterize initial states or compartment sizes, we currently
Expand Down Expand Up @@ -146,7 +146,7 @@ def workaround_initial_states(
return fixed_parameters


def workaround_observable_parameters(
def _workaround_observable_parameters(
observables, sigmas, sbml_model, output_parameter_defaults
):
# TODO: adding extra output parameters is currently not supported,
Expand Down Expand Up @@ -345,10 +345,10 @@ def import_model_sbml(
f"({len(sigmas)}) do not match."
)

workaround_observable_parameters(
_workaround_observable_parameters(
observables, sigmas, sbml_model, output_parameter_defaults
)
fixed_parameters = workaround_initial_states(
fixed_parameters = _workaround_initial_states(
petab_problem=petab_problem,
sbml_model=sbml_model,
**kwargs,
Expand Down
6 changes: 3 additions & 3 deletions python/sdist/amici/pysb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ def pysb2jax(
see :attr:`amici.DEModel._simplify`
:param cache_simplify:
see :func:`amici.DEModel.__init__`
Note that there are possible issues with PySB models:
https://github.com/AMICI-dev/AMICI/pull/1672
see :func:`amici.DEModel.__init__`
Note that there are possible issues with PySB models:
https://github.com/AMICI-dev/AMICI/pull/1672
:param model_name:
Name for the generated model module. If None, :attr:`pysb.Model.name`
Expand Down

0 comments on commit 9465a6c

Please sign in to comment.