diff --git a/include/amici/defines.h b/include/amici/defines.h index b49ebc6a83..44fa4cadc5 100644 --- a/include/amici/defines.h +++ b/include/amici/defines.h @@ -244,6 +244,7 @@ enum class RDataReporting { full, residuals, likelihood, + observables_likelihood, }; /** boundary conditions for splines */ diff --git a/include/amici/rdata.h b/include/amici/rdata.h index 793be9435a..9f7378751a 100644 --- a/include/amici/rdata.h +++ b/include/amici/rdata.h @@ -476,6 +476,13 @@ class ReturnData : public ModelDimensions { */ void initializeLikelihoodReporting(bool quadratic_llh); + /** + * @brief initializes storage for observables + likelihood reporting mode + * @param quadratic_llh whether model defines a quadratic nllh and computing + * res, sres and FIM makes sense. + */ + void initializeObservablesLikelihoodReporting(bool quadratic_llh); + /** * @brief initializes storage for residual reporting mode * @param enable_res whether residuals are to be computed diff --git a/pytest.ini b/pytest.ini index 29463d5b09..b24e565354 100644 --- a/pytest.ini +++ b/pytest.ini @@ -25,5 +25,7 @@ filterwarnings = ignore:.*PyDevIPCompleter6.*:DeprecationWarning # ignore numpy log(0) warnings (np.log(0) = -inf) ignore:divide by zero encountered in log:RuntimeWarning + # ignore jax deprecation warnings + ignore:jax.* is deprecated:DeprecationWarning norecursedirs = .git amici_models build doc documentation matlab models ThirdParty amici sdist examples diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 6d645a1451..2acd82bdbe 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -365,7 +365,7 @@ "simulation_condition = (\"model1_data1\",)\n", "\n", "# Load condition-specific data\n", - "ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n", + "ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n", " simulation_condition\n", "]\n", "\n", @@ -378,7 +378,6 @@ "def grad_ts_dyn(tt):\n", " return jax_problem.model.simulate_condition(\n", " p=p,\n", - " ts_init=ts_init,\n", " ts_dyn=tt,\n", " ts_posteq=ts_posteq,\n", " my=jnp.array(my),\n", @@ -386,6 +385,7 @@ " iy_trafos=jnp.array(iy_trafos),\n", " solver=diffrax.Kvaerno5(),\n", " controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n", + " steady_state_event=diffrax.steady_state_event(),\n", " max_steps=2**10,\n", " adjoint=diffrax.DirectAdjoint(),\n", " ret=ReturnValue.y, # Return observables\n", diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 51923fd517..8b820498d6 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -12,6 +12,8 @@ import jax import jaxtyping as jt +from collections.abc import Callable + class ReturnValue(enum.Enum): llh = "log-likelihood" @@ -32,6 +34,13 @@ class JAXModel(eqx.Module): JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements routines for simulation and evaluation of derived quantities, model specific implementations need to be provided by classes inheriting from JAXModel. + + :ivar api_version: + API version of the derived class. Needs to match the API version of the base class (MODEL_API_VERSION). + :ivar MODEL_API_VERSION: + API version of the base class. + :ivar jax_py_file: + Path to the JAX model file. """ MODEL_API_VERSION = "0.0.2" @@ -249,6 +258,9 @@ def _eq( x0: jt.Float[jt.Array, "nxs"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "1 nxs"], dict]: """ @@ -279,10 +291,20 @@ def _eq( stepsize_controller=controller, max_steps=max_steps, adjoint=diffrax.DirectAdjoint(), - event=diffrax.Event(cond_fn=diffrax.steady_state_event()), + event=diffrax.Event( + cond_fn=steady_state_event, + ), throw=False, ) - return sol.ys[-1, :], sol.stats + # If the event was triggered, the event mask is True and the solution is the steady state. Otherwise, the + # solution is the last state and the event mask is False. In the latter case, we return inf for the steady + # state. + ys = jnp.where( + sol.event_mask, + sol.ys[-1, :], + jnp.inf * jnp.ones_like(sol.ys[-1, :]), + ) + return ys, sol.stats def _solve( self, @@ -443,7 +465,6 @@ def _sigmays( def simulate_condition( self, p: jt.Float[jt.Array, "np"], - ts_init: jt.Float[jt.Array, "nt_preeq"], ts_dyn: jt.Float[jt.Array, "nt_dyn"], ts_posteq: jt.Float[jt.Array, "nt_posteq"], my: jt.Float[jt.Array, "nt"], @@ -452,6 +473,9 @@ def simulate_condition( solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: int | jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), mask_reinit: jt.Bool[jt.Array, "*nx"] = jnp.array([]), @@ -463,13 +487,9 @@ def simulate_condition( :param p: parameters for simulation ordered according to ids in :ivar parameter_ids: - :param ts_init: - time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to - the number of observables that are evaluated before dynamic simulation. :param ts_dyn: - time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order. - Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time - points. + time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are + allowed to facilitate the evaluation of multiple observables at specific time points. :param ts_posteq: time points for post-equilibration. Usually valued \Infty, but needs to be shaped according to the number of observables that are evaluated after post-equilibration. @@ -509,8 +529,6 @@ def simulate_condition( x_solver = self._x_solver(x) tcl = self._tcl(x, p) - x_preq = jnp.repeat(x_solver.reshape(1, -1), ts_init.shape[0], axis=0) - # Dynamic simulation if ts_dyn.shape[0]: x_dyn, stats_dyn = self._solve( @@ -533,7 +551,13 @@ def simulate_condition( # Post-equilibration if ts_posteq.shape[0]: x_solver, stats_posteq = self._eq( - p, tcl, x_solver, solver, controller, max_steps + p, + tcl, + x_solver, + solver, + controller, + steady_state_event, + max_steps, ) else: stats_posteq = None @@ -542,8 +566,8 @@ def simulate_condition( x_solver.reshape(1, -1), ts_posteq.shape[0], axis=0 ) - ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0) - x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0) + ts = jnp.concatenate((ts_dyn, ts_posteq), axis=0) + x = jnp.concatenate((x_dyn, x_posteq), axis=0) nllhs = self._nllhs(ts, x, p, tcl, my, iys) llh = -jnp.sum(nllhs) @@ -604,6 +628,9 @@ def preequilibrate_condition( mask_reinit: jt.Bool[jt.Array, "*nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: int | jnp.int_, ) -> tuple[jt.Float[jt.Array, "nx"], dict]: r""" @@ -611,6 +638,10 @@ def preequilibrate_condition( :param p: parameters for simulation ordered according to ids in :ivar parameter_ids: + :param x_reinit: + re-initialized state vector. If not provided, the state vector is not re-initialized. + :param mask_reinit: + mask for re-initialization. If `True`, the corresponding state variable is re-initialized. :param solver: ODE solver :param controller: @@ -627,7 +658,13 @@ def preequilibrate_condition( tcl = self._tcl(x0, p) current_x = self._x_solver(x0) current_x, stats_preeq = self._eq( - p, tcl, current_x, solver, controller, max_steps + p, + tcl, + current_x, + solver, + controller, + steady_state_event, + max_steps, ) return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 16774a6e3f..d23c5a1b5e 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -3,6 +3,7 @@ from numbers import Number from collections.abc import Iterable from pathlib import Path +from collections.abc import Callable import diffrax @@ -71,7 +72,7 @@ class JAXProblem(eqx.Module): :ivar _parameter_mappings: :class:`ParameterMappingForCondition` instances for each simulation condition. :ivar _measurements: - Subset measurement dataframes for each simulation condition. + Preprocessed arrays for each simulation condition. :ivar _petab_problem: PEtab problem to simulate. """ @@ -87,7 +88,6 @@ class JAXProblem(eqx.Module): np.ndarray, np.ndarray, np.ndarray, - np.ndarray, ], ] _inputs: dict[str, dict[str, np.ndarray]] @@ -188,7 +188,6 @@ def _get_measurements( np.ndarray, np.ndarray, np.ndarray, - np.ndarray, ], ], dict[tuple[str, ...], tuple[int, ...]], @@ -214,11 +213,9 @@ def _get_measurements( ) ts = m[petab.TIME] - ts_preeq = ts[np.isfinite(ts) & (ts == 0)] - ts_dyn = ts[np.isfinite(ts) & (ts > 0)] + ts_dyn = ts[np.isfinite(ts)] ts_posteq = ts[np.logical_not(np.isfinite(ts))] - index = pd.concat([ts_preeq, ts_dyn, ts_posteq]).index - ts_preeq = ts_preeq.values + index = pd.concat([ts_dyn, ts_posteq]).index ts_dyn = ts_dyn.values ts_posteq = ts_posteq.values my = m[petab.MEASUREMENT].values @@ -246,7 +243,6 @@ def _get_measurements( iy_trafos = np.zeros_like(iys) measurements[tuple(simulation_condition)] = ( - ts_preeq, ts_dyn, ts_posteq, my, @@ -600,6 +596,9 @@ def run_simulation( simulation_condition: tuple[str, ...], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 ret: ReturnValue = ReturnValue.llh, @@ -622,7 +621,7 @@ def run_simulation( :return: Tuple of output value and simulation statistics """ - ts_preeq, ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[ + ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[ simulation_condition ] p = self.load_model_parameters(simulation_condition[0]) @@ -630,8 +629,7 @@ def run_simulation( simulation_condition[0], p ) return self.model.simulate_condition( - p=eqx.debug.backward_nan(p), - ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)), + p=p, ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), my=jax.lax.stop_gradient(jnp.array(my)), @@ -643,6 +641,7 @@ def run_simulation( solver=solver, controller=controller, max_steps=max_steps, + steady_state_event=steady_state_event, adjoint=diffrax.RecursiveCheckpointAdjoint() if ret in (ReturnValue.llh, ReturnValue.chi2) else diffrax.DirectAdjoint(), @@ -654,6 +653,9 @@ def run_preequilibration( simulation_condition: str, solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ], max_steps: jnp.int_, ) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821 """ @@ -675,12 +677,13 @@ def run_preequilibration( simulation_condition, p ) return self.model.preequilibrate_condition( - p=eqx.debug.backward_nan(p), + p=p, mask_reinit=mask_reinit, x_reinit=x_reinit, solver=solver, controller=controller, max_steps=max_steps, + steady_state_event=steady_state_event, ) @@ -691,6 +694,9 @@ def run_simulations( controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS ), + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ] = diffrax.steady_state_event(), max_steps: int = 2**10, ret: ReturnValue | str = ReturnValue.llh, ): @@ -705,6 +711,9 @@ def run_simulations( ODE solver to use for simulation. :param controller: Step size controller to use for simulation. + :param steady_state_event: + Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state + condition, see :func:`diffrax.steady_state_event` for details. :param max_steps: Maximum number of steps to take during simulation. :param ret: @@ -719,7 +728,9 @@ def run_simulations( simulation_conditions = problem.get_all_simulation_conditions() preeqs = { - sc: problem.run_preequilibration(sc, solver, controller, max_steps) + sc: problem.run_preequilibration( + sc, solver, controller, steady_state_event, max_steps + ) # only run preequilibration once per condition for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1} } @@ -729,6 +740,7 @@ def run_simulations( sc, solver, controller, + steady_state_event, max_steps, preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]), ret=ret, @@ -753,6 +765,9 @@ def petab_simulate( controller: diffrax.AbstractStepSizeController = diffrax.PIDController( **DEFAULT_CONTROLLER_SETTINGS ), + steady_state_event: Callable[ + ..., diffrax._custom_types.BoolScalarLike + ] = diffrax.steady_state_event(), max_steps: int = 2**10, ): """ @@ -773,6 +788,7 @@ def petab_simulate( problem, solver=solver, controller=controller, + steady_state_event=steady_state_event, max_steps=max_steps, ret=ReturnValue.y, ) @@ -780,7 +796,7 @@ def petab_simulate( for sc, ys in y.items(): obs = [ problem.model.observable_ids[io] - for io in problem._measurements[sc][4] + for io in problem._measurements[sc][3] ] t = jnp.concat(problem._measurements[sc][:2]) df_sc = pd.DataFrame( diff --git a/python/sdist/pyproject.toml b/python/sdist/pyproject.toml index b62903240e..db48bd6766 100644 --- a/python/sdist/pyproject.toml +++ b/python/sdist/pyproject.toml @@ -83,11 +83,11 @@ examples = [ "scipy", ] jax = [ - "jax>=0.4.34,<0.4.36", - "jaxlib>=0.4.34", - "diffrax>=0.6.0", + "jax>=0.4.36", + "jaxlib>=0.4.36", + "diffrax>=0.6.1", "jaxtyping>=0.2.34", - "equinox>=0.11.8", + "equinox>=0.11.10", "optimistix>=0.0.9", "interpax>=0.3.3", ] diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index ef9cbde576..78fa026cfc 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -11,11 +11,12 @@ import diffrax import numpy as np from beartype import beartype +from petab.v1.C import PREEQUILIBRATION_CONDITION_ID, SIMULATION_CONDITION_ID from amici.pysb_import import pysb2amici, pysb2jax from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind from amici.petab.petab_import import import_petab_problem -from amici.jax import JAXProblem, ReturnValue +from amici.jax import JAXProblem, ReturnValue, run_simulations from numpy.testing import assert_allclose from test_petab_objective import lotka_volterra # noqa: F401 @@ -179,8 +180,7 @@ def check_fields_jax( iys = iys.flatten() iy_trafos = np.zeros_like(iys) - ts_init = ts[ts == 0] - ts_dyn = ts[ts > 0] + ts_dyn = ts ts_posteq = np.array([]) par_dict = { @@ -190,7 +190,6 @@ def check_fields_jax( p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids]) kwargs = { - "ts_init": jnp.array(ts_init), "ts_dyn": jnp.array(ts_dyn), "ts_posteq": jnp.array(ts_posteq), "my": jnp.array(my), @@ -200,6 +199,7 @@ def check_fields_jax( "solver": diffrax.Kvaerno5(), "controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), "adjoint": diffrax.RecursiveCheckpointAdjoint(), + "steady_state_event": diffrax.steady_state_event(), "max_steps": 2**8, # max_steps } fun = beartype(jax_model.simulate_condition) @@ -268,6 +268,28 @@ def check_fields_jax( ) +def test_preequilibration_failure(lotka_volterra): # noqa: F811 + petab_problem = lotka_volterra + # oscillating system, preequilibation should fail when interaction is active + with TemporaryDirectoryWinSafe(prefix="normal") as model_dir: + jax_model = import_petab_problem( + petab_problem, jax=True, model_output_dir=model_dir + ) + jax_problem = JAXProblem(jax_model, petab_problem) + r = run_simulations(jax_problem) + assert not np.isinf(r[0].item()) + petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = ( + petab_problem.measurement_df[SIMULATION_CONDITION_ID] + ) + with TemporaryDirectoryWinSafe(prefix="failure") as model_dir: + jax_model = import_petab_problem( + petab_problem, jax=True, model_output_dir=model_dir + ) + jax_problem = JAXProblem(jax_model, petab_problem) + r = run_simulations(jax_problem) + assert np.isinf(r[0].item()) + + @skip_on_valgrind def test_serialisation(lotka_volterra): # noqa: F811 petab_problem = lotka_volterra diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index 9686a25d94..4ead568c9c 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -586,3 +586,44 @@ def test_python_exceptions(sbml_example_presimulation_module): ): # rethrow=True runAmiciSimulation(solver, None, model.get(), True) + + +def test_reporting_mode_obs_llh(sbml_example_presimulation_module): + model_module = sbml_example_presimulation_module + model = model_module.getModel() + solver = model.getSolver() + + solver.setReturnDataReportingMode( + amici.RDataReporting.observables_likelihood + ) + solver.setSensitivityOrder(amici.SensitivityOrder.first) + + for sens_method in ( + amici.SensitivityMethod.none, + amici.SensitivityMethod.forward, + amici.SensitivityMethod.adjoint, + ): + solver.setSensitivityMethod(sens_method) + rdata = amici.runAmiciSimulation( + model, solver, amici.ExpData(1, 1, 1, [1]) + ) + assert ( + rdata.rdata_reporting + == amici.RDataReporting.observables_likelihood + ) + + assert rdata.y.size > 0 + assert rdata.sigmay.size > 0 + assert rdata.J is None + + match solver.getSensitivityMethod(): + case amici.SensitivityMethod.none: + assert rdata.sllh is None + case amici.SensitivityMethod.forward: + assert rdata.sy.size > 0 + assert rdata.ssigmay.size > 0 + assert rdata.sllh.size > 0 + case amici.SensitivityMethod.adjoint: + assert rdata.sy is None + assert rdata.ssigmay is None + assert rdata.sllh.size > 0 diff --git a/src/rdata.cpp b/src/rdata.cpp index 4ec983af2b..c724d29954 100644 --- a/src/rdata.cpp +++ b/src/rdata.cpp @@ -60,13 +60,17 @@ ReturnData::ReturnData( case RDataReporting::likelihood: initializeLikelihoodReporting(quadratic_llh); break; + + case RDataReporting::observables_likelihood: + initializeObservablesLikelihoodReporting(quadratic_llh); + break; } } void ReturnData::initializeLikelihoodReporting(bool enable_fim) { llh = getNaN(); chi2 = getNaN(); - if (sensi >= SensitivityOrder::first) { + if (sensi >= SensitivityOrder::first && sensi_meth != SensitivityMethod::none) { sllh.resize(nplist, getNaN()); if (sensi >= SensitivityOrder::second) s2llh.resize(nplist * (nJ - 1), getNaN()); @@ -78,6 +82,21 @@ void ReturnData::initializeLikelihoodReporting(bool enable_fim) { } } +void ReturnData::initializeObservablesLikelihoodReporting(bool enable_fim) { + initializeLikelihoodReporting(enable_fim); + + y.resize(nt * ny, 0.0); + sigmay.resize(nt * ny, 0.0); + + if ((sensi_meth == SensitivityMethod::forward + && sensi >= SensitivityOrder::first) + || sensi >= SensitivityOrder::second) { + + sy.resize(nt * ny * nplist, 0.0); + ssigmay.resize(nt * ny * nplist, 0.0); + } +} + void ReturnData::initializeResidualReporting(bool enable_res) { y.resize(nt * ny, 0.0); sigmay.resize(nt * ny, 0.0); diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 4a63d8bfda..d9f836b0b4 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -7,6 +7,7 @@ from functools import partial from pathlib import Path + import fiddy import amici import numpy as np @@ -342,8 +343,9 @@ def test_jax_llh(benchmark_problem): [problem_parameters[pid] for pid in jax_problem.parameter_ids] ), ) - llh_jax, _ = beartype(run_simulations)(jax_problem) + if problem_id in problems_for_gradient_check: + beartype(run_simulations)(jax_problem) (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( run_simulations, has_aux=True )(jax_problem) diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index 5fe61adcf2..4fcbe0b631 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -4,6 +4,8 @@ import logging import sys +import diffrax + import amici import pandas as pd import petab.v1 as petab @@ -68,10 +70,17 @@ def _test_case(case, model_type, version, jax): if jax: from amici.jax import JAXProblem, run_simulations, petab_simulate + steady_state_event = diffrax.steady_state_event(rtol=1e-6, atol=1e-6) jax_problem = JAXProblem(model, problem) - llh, ret = run_simulations(jax_problem) - chi2, _ = run_simulations(jax_problem, ret="chi2") - simulation_df = petab_simulate(jax_problem) + llh, ret = run_simulations( + jax_problem, steady_state_event=steady_state_event + ) + chi2, _ = run_simulations( + jax_problem, ret="chi2", steady_state_event=steady_state_event + ) + simulation_df = petab_simulate( + jax_problem, steady_state_event=steady_state_event + ) simulation_df.rename( columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True )