From bbd1c9c9965deb8f7ce0e130cc30280618c33e46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Mon, 2 Dec 2024 22:38:29 +0000 Subject: [PATCH] fix tests --- python/sdist/amici/jax/jax.template.py | 34 ++++++++++++------------ python/sdist/amici/jax/ode_export.py | 2 +- python/sdist/amici/petab/petab_import.py | 2 -- python/sdist/amici/petab/sbml_import.py | 1 + python/tests/test_jax.py | 28 ++++++++++++------- 5 files changed, 37 insertions(+), 30 deletions(-) diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index ddddb8a64b..59b3ca2ecc 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -14,28 +14,28 @@ def __init__(self): super().__init__() def _xdot(self, t, x, args): - pk, tcl = args + p, tcl = args TPL_X_SYMS = x - TPL_PK_SYMS = pk + TPL_P_SYMS = p TPL_TCL_SYMS = tcl - TPL_W_SYMS = self._w(t, x, pk, tcl) + TPL_W_SYMS = self._w(t, x, p, tcl) TPL_XDOT_EQ return TPL_XDOT_RET - def _w(self, t, x, pk, tcl): + def _w(self, t, x, p, tcl): TPL_X_SYMS = x - TPL_PK_SYMS = pk + TPL_P_SYMS = p TPL_TCL_SYMS = tcl TPL_W_EQ return TPL_W_RET - def _x0(self, pk): - TPL_PK_SYMS = pk + def _x0(self, p): + TPL_P_SYMS = p TPL_X0_EQ @@ -56,25 +56,25 @@ def _x_rdata(self, x, tcl): return TPL_X_RDATA_RET - def _tcl(self, x, pk): + def _tcl(self, x, p): TPL_X_RDATA_SYMS = x - TPL_PK_SYMS = pk + TPL_P_SYMS = p TPL_TOTAL_CL_EQ return TPL_TOTAL_CL_RET - def _y(self, t, x, pk, tcl): + def _y(self, t, x, p, tcl): TPL_X_SYMS = x - TPL_PK_SYMS = pk - TPL_W_SYMS = self._w(t, x, pk, tcl) + TPL_P_SYMS = p + TPL_W_SYMS = self._w(t, x, p, tcl) TPL_Y_EQ return TPL_Y_RET - def _sigmay(self, y, pk): - TPL_PK_SYMS = pk + def _sigmay(self, y, p): + TPL_P_SYMS = p TPL_Y_SYMS = y @@ -82,10 +82,10 @@ def _sigmay(self, y, pk): return TPL_SIGMAY_RET - def _nllh(self, t, x, pk, tcl, my, iy): - y = self._y(t, x, pk, tcl) + def _nllh(self, t, x, p, tcl, my, iy): + y = self._y(t, x, p, tcl) TPL_Y_SYMS = y - TPL_SIGMAY_SYMS = self._sigmay(y, pk) + TPL_SIGMAY_SYMS = self._sigmay(y, p) TPL_JY_EQ diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 6ef7c2b9c1..5398aeb235 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -245,7 +245,7 @@ def _generate_jax_code(self) -> None: outdir.mkdir(parents=True, exist_ok=True) apply_template( - Path(amiciModulePath) / "jax.template.py", + Path(amiciModulePath) / "jax" / "jax.template.py", outdir / "__init__.py", tpl_data, ) diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 8a9c907439..ca69687270 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -146,7 +146,6 @@ def import_petab_problem( petab_problem, model_name=model_name, model_output_dir=model_output_dir, - compile=kwargs.pop("compile", not jax), jax=jax, **kwargs, ) @@ -156,7 +155,6 @@ def import_petab_problem( model_name=model_name, model_output_dir=model_output_dir, non_estimated_parameters_as_constants=non_estimated_parameters_as_constants, - compile=kwargs.pop("compile", not jax), jax=jax, **kwargs, ) diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index e157864176..f013156725 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -379,6 +379,7 @@ def import_model_sbml( verbose=verbose, **kwargs, ) + return sbml_importer else: sbml_importer.sbml2amici( model_name=model_name, diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 30e205ca26..c92b35e570 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -12,7 +12,7 @@ import numpy as np from beartype import beartype -from amici.pysb_import import pysb2amici +from amici.pysb_import import pysb2amici, pysb2jax from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind from amici.petab.petab_import import import_petab_problem from amici.jax import JAXProblem @@ -39,17 +39,21 @@ def test_conversion(): pysb.Rule("conv", a(s="a") >> a(s="b"), pysb.Parameter("kcat", 0.05)) pysb.Observable("ab", a(s="b")) - with TemporaryDirectoryWinSafe(prefix=model.name) as outdir: + with TemporaryDirectoryWinSafe() as outdir: pysb2amici(model, outdir, verbose=True, observables=["ab"]) + pysb2jax(model, outdir, verbose=True, observables=["ab"]) - model_module = amici.import_model_module( + amici_module = amici.import_model_module( module_name=model.name, module_path=outdir ) + jax_module = amici.import_model_module( + module_name=model.name + "_jax", module_path=outdir + ) ts = tuple(np.linspace(0, 1, 10)) p = jnp.stack((1.0, 0.1), axis=-1) k = tuple() - _test_model(model_module, ts, p, k) + _test_model(amici_module, jax_module, ts, p, k) @skip_on_valgrind @@ -86,7 +90,7 @@ def test_dimerization(): pysb.Observable("a_obs", a()) pysb.Observable("b_obs", b()) - with TemporaryDirectoryWinSafe(prefix=model.name) as outdir: + with TemporaryDirectoryWinSafe() as outdir: pysb2amici( model, outdir, @@ -94,26 +98,30 @@ def test_dimerization(): observables=["a_obs", "b_obs"], constant_parameters=["ksyn_a", "ksyn_b"], ) + pysb2jax(model, outdir, verbose=True, observables=["ab"]) - model_module = amici.import_model_module( + amici_module = amici.import_model_module( module_name=model.name, module_path=outdir ) + jax_module = amici.import_model_module( + module_name=model.name + "_jax", module_path=outdir + ) ts = tuple(np.linspace(0, 1, 10)) p = jnp.stack((5, 0.5, 0.5, 0.5), axis=-1) k = (0.5, 5) - _test_model(model_module, ts, p, k) + _test_model(amici_module, jax_module, ts, p, k) -def _test_model(model_module, ts, p, k): - amici_model = model_module.getModel() +def _test_model(amici_module, jax_module, ts, p, k): + amici_model = amici_module.getModel() amici_model.setTimepoints(np.asarray(ts, dtype=np.float64)) sol_amici_ref = amici.runAmiciSimulation( amici_model, amici_model.getSolver() ) - jax_model = model_module.get_jax_model() + jax_model = jax_module.Model() amici_model.setParameters(np.asarray(p, dtype=np.float64)) amici_model.setFixedParameters(np.asarray(k, dtype=np.float64))