diff --git a/pytest.ini b/pytest.ini index adbf313922..8cc45e0fd9 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,6 +12,7 @@ filterwarnings = ignore:Conservation laws for non-constant species in models with Species-AssignmentRules are currently not supported and will be turned off.:UserWarning ignore:Conservation laws for non-constant species in combination with parameterized stoichiometric coefficients are not currently supported and will be turned off.:UserWarning ignore:Support for PEtab2.0 is experimental!:UserWarning + ignore:The JAX module is experimental and the API may change in the future.:ImportWarning # hundreds of SBML <=5.17 warnings ignore:.*inspect.getargspec\(\) is deprecated.*:DeprecationWarning # pysb warnings diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 7f541ea7e4..f6a4f10e98 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -151,7 +151,7 @@ " results (dict): Simulation results from run_simulations.\n", " \"\"\"\n", " # Extract the simulation results for the specific condition\n", - " sim_results = results[simulation_condition][1]\n", + " sim_results = results[simulation_condition]\n", "\n", " # Create a new figure for the state trajectories\n", " plt.figure(figsize=(8, 6))\n", diff --git a/python/sdist/amici/jax/__init__.py b/python/sdist/amici/jax/__init__.py index e14d231e1e..8b67abda27 100644 --- a/python/sdist/amici/jax/__init__.py +++ b/python/sdist/amici/jax/__init__.py @@ -1,6 +1,21 @@ -"""Interface to facilitate AMICI generated models using JAX""" +""" +JAX +--- + +This module provides an interface to generate and use AMICI models with JAX. Please note that this module is +experimental, the API may substantially change in the future. Use at your own risk and do not expect backward +compatibility. +""" + +from warnings import warn from amici.jax.petab import JAXProblem, run_simulations from amici.jax.model import JAXModel +warn( + "The JAX module is experimental and the API may change in the future.", + ImportWarning, + stacklevel=2, +) + __all__ = ["JAXModel", "JAXProblem", "run_simulations"] diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index d105ff9ab6..0f34af791b 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -432,6 +432,7 @@ def simulate_condition( ts_posteq: jt.Float[jt.Array, "nt_posteq"], my: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], + x_preeq: jt.Float[jt.Array, "nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 50c0154ee3..b39051390b 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -490,4 +490,7 @@ def run_simulations( ) for sc in simulation_conditions } - return sum(llh for llh, _ in results.values()), results | preeqs + return sum(llh for llh, _ in results.values()), { + sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1] + for sc, res in results.items() + }