From 14c9d22df110db1a03828eb482cdf3230bcf31b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Wed, 4 Dec 2024 17:31:50 +0000 Subject: [PATCH 01/14] disentangle sim & preeq --- .../example_jax_petab/ExampleJaxPEtab.ipynb | 2 +- python/sdist/amici/jax/model.py | 64 ++++++++++++------- python/sdist/amici/jax/petab.py | 60 +++++++++++++---- .../benchmark-models/test_petab_benchmark.py | 4 +- 4 files changed, 93 insertions(+), 37 deletions(-) diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 855860e242..fe7b5d8830 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -373,7 +373,7 @@ " return jax_problem.model.simulate_condition(\n", " p=p,\n", " p_preeq=p_preeq,\n", - " ts_preeq=ts_preeq,\n", + " ts_init=ts_preeq,\n", " ts_dyn=tt,\n", " ts_posteq=ts_posteq,\n", " my=jnp.array(my),\n", diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index ac86b547a6..8f9650ef0f 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -427,12 +427,12 @@ def _sigmays( def simulate_condition( self, p: jt.Float[jt.Array, "np"], - p_preeq: jt.Float[jt.Array, "*np"], - ts_preeq: jt.Float[jt.Array, "nt_preeq"], + ts_init: jt.Float[jt.Array, "nt_preeq"], ts_dyn: jt.Float[jt.Array, "nt_dyn"], ts_posteq: jt.Float[jt.Array, "nt_posteq"], my: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], + x_preeq: jt.Float[jt.Array, "nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, @@ -444,12 +444,9 @@ def simulate_condition( :param p: parameters for simulation ordered according to ids in :ivar parameter_ids: - :param p_preeq: - parameters for pre-equilibration ordered according to ids in :ivar parameter_ids:. May be empty to - disable pre-equilibration. - :param ts_preeq: - time points for pre-equilibration. Usually valued 0.0, but needs to be shaped according to - the number of observables that are evaluated after pre-equilibration. + :param ts_init: + time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to + the number of observables that are evaluated before dynamic simulation. :param ts_dyn: time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order. Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time @@ -486,24 +483,16 @@ def simulate_condition( output according to `ret` and statistics """ # Pre-equilibration - if p_preeq.shape[0] > 0: - x0 = self._x0(p_preeq) - tcl = self._tcl(x0, p_preeq) - current_x = self._x_solver(x0) - current_x, stats_preeq = self._eq( - p_preeq, tcl, current_x, solver, controller, max_steps - ) + if x_preeq.shape[0] > 0: + current_x = self._x_solver(x_preeq) # update tcl with new parameters - tcl = self._tcl(self._x_rdata(current_x, tcl), p) + tcl = self._tcl(x_preeq, p) else: x0 = self._x0(p) current_x = self._x_solver(x0) - stats_preeq = None tcl = self._tcl(x0, p) - x_preq = jnp.repeat( - current_x.reshape(1, -1), ts_preeq.shape[0], axis=0 - ) + x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0) # Dynamic simulation if ts_dyn.shape[0] > 0: @@ -536,7 +525,7 @@ def simulate_condition( current_x.reshape(1, -1), ts_posteq.shape[0], axis=0 ) - ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0) + ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0) x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0) nllhs = self._nllhs(ts, x, p, tcl, my, iys) @@ -555,11 +544,42 @@ def simulate_condition( }[ret], dict( ts=ts, x=x, - stats_preeq=stats_preeq, stats_dyn=stats_dyn, stats_posteq=stats_posteq, ) + @eqx.filter_jit + def preequilibrate_condition( + self, + p: jt.Float[jt.Array, "np"], + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + max_steps: int | jnp.int_, + ) -> tuple[jt.Float[jt.Array, "nx"], dict]: + r""" + Simulate a condition. + + :param p: + parameters for simulation ordered according to ids in :ivar parameter_ids: + :param solver: + ODE solver + :param controller: + step size controller + :param max_steps: + maximum number of solver steps + :return: + pre-equilibrated state variables and statistics + """ + # Pre-equilibration + x0 = self._x0(p) + tcl = self._tcl(x0, p) + current_x = self._x_solver(x0) + current_x, stats_preeq = self._eq( + p, tcl, current_x, solver, controller, max_steps + ) + + return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq) + def safe_log(x: jnp.float_) -> jnp.float_: """ diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 2c823259fe..1bd42fe6c1 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -154,7 +154,7 @@ def _get_parameter_mappings( def _get_measurements( self, simulation_conditions: pd.DataFrame ) -> dict[ - tuple[str], + str, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], ]: """ @@ -307,19 +307,21 @@ def run_simulation( solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, max_steps: jnp.int_, + x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 ) -> tuple[jnp.float_, dict]: """ Run a simulation for a given simulation condition. :param simulation_condition: - Tuple of simulation conditions to run the simulation for. can be a single string (simulation only) or a - tuple of strings (pre-equilibration followed by simulation). + Simulation condition to run simulation for. :param solver: ODE solver to use for simulation :param controller: Step size controller to use for simulation :param max_steps: Maximum number of steps to take during simulation + :param preeq: + Pre-equilibration state if available :return: Tuple of log-likelihood and simulation statistics """ @@ -327,25 +329,49 @@ def run_simulation( simulation_condition ] p = self.load_parameters(simulation_condition[0]) - p_preeq = ( - self.load_parameters(simulation_condition[1]) - if len(simulation_condition) > 1 - else jnp.array([]) - ) return self.model.simulate_condition( p=p, - p_preeq=p_preeq, - ts_preeq=jax.lax.stop_gradient(jnp.array(ts_preeq)), + ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)), ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), my=jax.lax.stop_gradient(jnp.array(my)), iys=jax.lax.stop_gradient(jnp.array(iys)), + x_preeq=x_preeq, solver=solver, controller=controller, max_steps=max_steps, adjoint=diffrax.RecursiveCheckpointAdjoint(), ) + def run_preequilibration( + self, + simulation_condition: str, + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + max_steps: jnp.int_, + ) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821 + """ + Run a pre-equilibration simulation for a given simulation condition. + + :param simulation_condition: + Simulation condition to run simulation for. + :param solver: + ODE solver to use for simulation + :param controller: + Step size controller to use for simulation + :param max_steps: + Maximum number of steps to take during simulation + :return: + Pre-equilibration state + """ + p = self.load_parameters(simulation_condition) + return self.model.preequilibrate_condition( + p=p, + solver=solver, + controller=controller, + max_steps=max_steps, + ) + def run_simulations( problem: JAXProblem, @@ -379,8 +405,18 @@ def run_simulations( if simulation_conditions is None: simulation_conditions = problem.get_all_simulation_conditions() + preeqs = { + sc[1]: problem.run_preequilibration( + sc[1], solver, controller, max_steps + ) + for sc in simulation_conditions + if len(sc) > 1 + } + results = { - sc: problem.run_simulation(sc, solver, controller, max_steps) + sc[0]: problem.run_simulation( + sc, solver, controller, max_steps, preeqs.get(sc[1])[0] + ) for sc in simulation_conditions } - return sum(llh for llh, _ in results.values()), results + return sum(llh for llh, _ in results.values()), results | preeqs diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 74c84d37a9..6a388f7493 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -328,7 +328,7 @@ def test_jax_llh(benchmark_problem): jax_model = import_petab_problem( petab_problem, - model_output_dir=benchmark_outdir / problem_id, + model_output_dir=benchmark_outdir / (problem_id + "_jax"), jax=True, ) jax_problem = JAXProblem(jax_model, petab_problem) @@ -340,7 +340,7 @@ def test_jax_llh(benchmark_problem): [problem_parameters[pid] for pid in jax_problem.parameter_ids] ), ) - if problem_id in problems_for_gradient_check_jax: + if problem_id in problems_for_gradient_check: (llh_jax, _), sllh_jax = eqx.filter_jit( eqx.filter_value_and_grad(run_simulations, has_aux=True) )(jax_problem) From 42d4767047c2875ab67b5ebee70f6eebacdcb997 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Wed, 4 Dec 2024 17:32:07 +0000 Subject: [PATCH 02/14] disentangle sim & preeq --- documentation/ExampleJaxPEtab.ipynb | 674 +++++++++++++++++++++++++++- 1 file changed, 673 insertions(+), 1 deletion(-) mode change 120000 => 100644 documentation/ExampleJaxPEtab.ipynb diff --git a/documentation/ExampleJaxPEtab.ipynb b/documentation/ExampleJaxPEtab.ipynb deleted file mode 120000 index 821b14f21f..0000000000 --- a/documentation/ExampleJaxPEtab.ipynb +++ /dev/null @@ -1 +0,0 @@ -../python/examples/example_jax_petab/ExampleJaxPEtab.ipynb \ No newline at end of file diff --git a/documentation/ExampleJaxPEtab.ipynb b/documentation/ExampleJaxPEtab.ipynb new file mode 100644 index 0000000000..855860e242 --- /dev/null +++ b/documentation/ExampleJaxPEtab.ipynb @@ -0,0 +1,673 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d4d2bc5c", + "metadata": {}, + "source": [ + "# Simulating AMICI models using JAX\n", + "\n", + "## Overview\n", + "\n", + "This guide demonstrates how to use AMICI to export models in a format compatible with the [JAX](https://jax.readthedocs.io/en/latest/) ecosystem, enabling simulations with the [diffrax](https://docs.kidger.site/diffrax/) library. " + ] + }, + { + "cell_type": "markdown", + "id": "fb2fe897", + "metadata": {}, + "source": [ + "## Preparation\n", + "\n", + "To begin, we will import a model using [PEtab](https://petab.readthedocs.io). For this demonstration, we will utilize the [Benchmark Collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which provides a diverse set of models. For more information on importing PEtab models, refer to the corresponding [PEtab notebook](https://amici.readthedocs.io/en/latest/petab.html).\n", + "\n", + "In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem), we will load the PEtab problem. To create a [JAXModel](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n" + ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from amici.petab.petab_import import import_petab_problem\n", + "import petab.v1 as petab\n", + "\n", + "# Define the model name and YAML file location\n", + "model_name = \"Boehm_JProteomeRes2014\"\n", + "yaml_url = (\n", + " f\"https://raw.githubusercontent.com/Benchmarking-Initiative/Benchmark-Models-PEtab/\"\n", + " f\"master/Benchmark-Models/{model_name}/{model_name}.yaml\"\n", + ")\n", + "\n", + "# Load the PEtab problem from the YAML file\n", + "petab_problem = petab.Problem.from_yaml(yaml_url)\n", + "\n", + "# Import the PEtab problem as a JAX-compatible AMICI model\n", + "jax_model = import_petab_problem(\n", + " petab_problem,\n", + " verbose=False, # no text output\n", + " jax=True, # return jax model\n", + ")" + ], + "id": "c71c96da0da3144a" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Simulation\n", + "\n", + "In principle, we can already use this model for simulation using the [simulate_condition](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel.simulate_condition) method. However, this approach can be cumbersome as timepoints, data etc. need to be specified manually. Instead, we process the PEtab problem into a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem), which enables efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)." + ], + "id": "7e0f1c27bd71ee1f" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from amici.jax import JAXProblem, run_simulations\n", + "\n", + "# Create a JAXProblem from the JAX model and PEtab problem\n", + "jax_problem = JAXProblem(jax_model, petab_problem)\n", + "\n", + "# Run simulations and compute the log-likelihood\n", + "llh, results = run_simulations(jax_problem)" + ], + "id": "ccecc9a29acc7b73" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results.", + "id": "415962751301c64a" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Define the simulation condition\n", + "simulation_condition = (\"model1_data1\",)\n", + "\n", + "# Access the results for the specified condition\n", + "results[simulation_condition]" + ], + "id": "596b86e45e18fe3d" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "Unfortunately, the simulation failed! As seen in the output, the simulation broke down after the initial timepoint, indicated by the `inf` values in the state variables `results[simulation_condition][1].x` and the `nan` likelihood value. A closer inspection of this variable provides additional clues about what might have gone wrong.\n", + "\n", + "The issue stems from using single precision, as indicated by the `float32` dtype of state variables. Single precision is generally a [bad idea](https://docs.kidger.site/diffrax/examples/stiff_ode/) for stiff systems like the Böhm model. Let’s retry the simulation with double precision." + ], + "id": "a1b173e013f9210a" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "import jax\n", + "\n", + "# Enable double precision in JAX\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "# Re-run simulations with double precision\n", + "llh, results = run_simulations(jax_problem)\n", + "\n", + "results" + ], + "id": "f4f5ff705a3f7402" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories.", + "id": "fe4d3b40ee3efdf2" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "\n", + "def plot_simulation(results):\n", + " \"\"\"\n", + " Plot the state trajectories from the simulation results.\n", + "\n", + " Parameters:\n", + " results (dict): Simulation results from run_simulations.\n", + " \"\"\"\n", + " # Extract the simulation results for the specific condition\n", + " sim_results = results[simulation_condition][1]\n", + "\n", + " # Create a new figure for the state trajectories\n", + " plt.figure(figsize=(8, 6))\n", + " for idx in range(sim_results[\"x\"].shape[1]):\n", + " time_points = np.array(sim_results[\"ts\"])\n", + " state_values = np.array(sim_results[\"x\"][:, idx])\n", + " plt.plot(time_points, state_values, label=jax_model.state_ids[idx])\n", + "\n", + " # Add labels, legend, and grid\n", + " plt.xlabel(\"Time\")\n", + " plt.ylabel(\"State Values\")\n", + " plt.title(simulation_condition)\n", + " plt.legend()\n", + " plt.grid(True)\n", + " plt.show()\n", + "\n", + "\n", + "# Plot the simulation results\n", + "plot_simulation(results)" + ], + "id": "72f1ed397105e14a" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all.", + "id": "4fa97c33719c2277" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n", + "results" + ], + "id": "7950774a3e989042" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Updating Parameters\n", + "\n", + "As next step, we will update the parameter values used for simulation. However, if we attempt to directly modify the values in `JAXModel.parameters`, we encounter a `FrozenInstanceError`." + ], + "id": "98b8516a75ce4d12" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from dataclasses import FrozenInstanceError\n", + "import jax\n", + "\n", + "# Generate random noise to update the parameters\n", + "noise = (\n", + " jax.random.normal(\n", + " key=jax.random.PRNGKey(0), shape=jax_problem.parameters.shape\n", + " )\n", + " / 10\n", + ")\n", + "\n", + "# Attempt to update the parameters\n", + "try:\n", + " jax_problem.parameters += noise\n", + "except FrozenInstanceError as e:\n", + " print(\"Error:\", e)" + ], + "id": "3d278a3d21e709d" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "The root cause of this error lies in the fact that, to enable autodiff, direct modifications of attributes are not allowed in [equinox](https://docs.kidger.site/equinox/), which AMICI utilizes under the hood. Consequently, attributes of instances like `JAXModel` or `JAXProblem` cannot be updated directly — this is the price we have to pay for autodiff.\n", + "\n", + "However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one." + ], + "id": "4cc3d595de4a4085" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Update the parameters and create a new JAXProblem instance\n", + "jax_problem = jax_problem.update_parameters(jax_problem.parameters + noise)\n", + "\n", + "# Run simulations with the updated parameters\n", + "llh, results = run_simulations(jax_problem)\n", + "\n", + "# Plot the simulation results\n", + "plot_simulation(results)" + ], + "id": "e47748376059628b" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Computing Gradients\n", + "\n", + "Similar to updating attributes, computing gradients in the JAX ecosystem can feel a bit unconventional if you’re not familiar with the JAX ecosysmt. JAX offers [powerful automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) through the `jax.grad` function. However, to use `jax.grad` with `JAXProblem`, we need to specify which parts of the `JAXProblem` should be treated as static." + ], + "id": "660baf605a4e8339" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "try:\n", + " # Attempt to compute the gradient of the run_simulations function\n", + " jax.grad(run_simulations, has_aux=True)(jax_problem)\n", + "except TypeError as e:\n", + " print(\"Error:\", e)" + ], + "id": "7033d09cc81b7f69" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`.", + "id": "dc9bc07cde00a926" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "import equinox as eqx\n", + "\n", + "# Compute the gradient using equinox's filter_grad, preserving auxiliary outputs\n", + "grad, _ = eqx.filter_grad(run_simulations, has_aux=True)(jax_problem)" + ], + "id": "a6704182200e6438" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`.", + "id": "851c3ec94cb5d086" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "grad.parameters", + "id": "c00c1581d7173d7a" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`.", + "id": "375b835fecc5a022" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "grad", + "id": "f7c17f7459d0151f" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out.", + "id": "8eb7cc3db510c826" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "grad._measurements[simulation_condition]", + "id": "3badd4402cf6b8c6" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation.", + "id": "58eb04393a1463d" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "import jax.numpy as jnp\n", + "import diffrax\n", + "\n", + "# Define the simulation condition\n", + "simulation_condition = (\"model1_data1\",)\n", + "\n", + "# Load condition-specific data\n", + "ts_preeq, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n", + " simulation_condition\n", + "]\n", + "\n", + "# Load parameters for the specified condition\n", + "p = jax_problem.load_parameters(simulation_condition[0])\n", + "# Disable preequilibration\n", + "p_preeq = jnp.array([])\n", + "\n", + "\n", + "# Define a function to compute the gradient with respect to dynamic timepoints\n", + "@eqx.filter_jacfwd\n", + "def grad_ts_dyn(tt):\n", + " return jax_problem.model.simulate_condition(\n", + " p=p,\n", + " p_preeq=p_preeq,\n", + " ts_preeq=ts_preeq,\n", + " ts_dyn=tt,\n", + " ts_posteq=ts_posteq,\n", + " my=jnp.array(my),\n", + " iys=jnp.array(iys),\n", + " solver=diffrax.Kvaerno5(),\n", + " controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n", + " max_steps=2**10,\n", + " adjoint=diffrax.DirectAdjoint(),\n", + " ret=\"y\", # Return observables\n", + " )[0]\n", + "\n", + "\n", + "# Compute the gradient with respect to `ts_dyn`\n", + "g = grad_ts_dyn(ts_dyn)\n", + "g" + ], + "id": "1a91aff44b93157" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Compilation & Profiling\n", + "\n", + "To maximize performance with JAX, code should be just-in-time (JIT) compiled. This can be achieved using the `jax.jit` or `equinox.filter_jit` decorators. While JIT compilation introduces some overhead during the first function call, it significantly improves performance for subsequent calls. To demonstrate this, we will first clear the JIT cache and then profile the execution." + ], + "id": "9f870da7754e139c" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from time import time\n", + "\n", + "# Clear JAX caches to ensure a fresh start\n", + "jax.clear_caches()\n", + "\n", + "# Define a JIT-compiled gradient function with auxiliary outputs\n", + "gradfun = eqx.filter_jit(eqx.filter_grad(run_simulations, has_aux=True))" + ], + "id": "58ebdc110ea7457e" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Measure the time taken for the first function call (including compilation)\n", + "start = time()\n", + "run_simulations(jax_problem)\n", + "print(f\"Function compilation time: {time() - start:.2f} seconds\")\n", + "\n", + "# Measure the time taken for the gradient computation (including compilation)\n", + "start = time()\n", + "gradfun(jax_problem)\n", + "print(f\"Gradient compilation time: {time() - start:.2f} seconds\")" + ], + "id": "e1242075f7e0faf" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "%%timeit\n", + "run_simulations(\n", + " jax_problem,\n", + " controller=diffrax.PIDController(\n", + " rtol=1e-8, # same as amici default\n", + " atol=1e-16, # same as amici default\n", + " pcoeff=0.4, # recommended value for stiff systems\n", + " icoeff=0.3, # recommended value for stiff systems\n", + " dcoeff=0.0, # recommended value for stiff systems\n", + " ),\n", + ")" + ], + "id": "27181f367ccb1817" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "%%timeit \n", + "gradfun(\n", + " jax_problem,\n", + " controller=diffrax.PIDController(\n", + " rtol=1e-8, # same as amici default\n", + " atol=1e-16, # same as amici default\n", + " pcoeff=0.4, # recommended value for stiff systems\n", + " icoeff=0.3, # recommended value for stiff systems\n", + " dcoeff=0.0, # recommended value for stiff systems\n", + " ),\n", + ")" + ], + "id": "5b8d3a6162a3ae55" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from amici.petab import simulate_petab\n", + "import amici\n", + "\n", + "# Import the PEtab problem as a standard AMICI model\n", + "amici_model = import_petab_problem(\n", + " petab_problem,\n", + " verbose=False,\n", + " compile_=True,\n", + " jax=False, # load the amici model this time\n", + ")\n", + "\n", + "# Configure the solver with appropriate tolerances\n", + "solver = amici_model.getSolver()\n", + "solver.setAbsoluteTolerance(1e-8)\n", + "solver.setRelativeTolerance(1e-8)\n", + "\n", + "# Prepare the parameters for the simulation\n", + "problem_parameters = dict(\n", + " zip(jax_problem.parameter_ids, jax_problem.parameters)\n", + ")" + ], + "id": "d733a450635a749b" + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "413ed7c60b2cf4be", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:55.259985Z", + "start_time": "2024-11-19T09:51:55.257937Z" + } + }, + "outputs": [], + "source": [ + "# Profile simulation only\n", + "solver.setSensitivityOrder(amici.SensitivityOrder.none)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "768fa60e439ca8b4", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:57.417608Z", + "start_time": "2024-11-19T09:51:55.273367Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "26.1 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit \n", + "simulate_petab(\n", + " petab_problem,\n", + " amici_model,\n", + " solver=solver,\n", + " problem_parameters=problem_parameters,\n", + " scaled_parameters=True,\n", + " scaled_gradients=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "b8382b0b2b68f49e", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:57.497361Z", + "start_time": "2024-11-19T09:51:57.494502Z" + } + }, + "outputs": [], + "source": [ + "# Profile gradient computation using forward sensitivity analysis\n", + "solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", + "solver.setSensitivityMethod(amici.SensitivityMethod.forward)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "3bae1fab8c416122", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:59.897459Z", + "start_time": "2024-11-19T09:51:57.511889Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "29.1 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit \n", + "simulate_petab(\n", + " petab_problem,\n", + " amici_model,\n", + " solver=solver,\n", + " problem_parameters=problem_parameters,\n", + " scaled_parameters=True,\n", + " scaled_gradients=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "71e0358227e1dc74", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:59.972149Z", + "start_time": "2024-11-19T09:51:59.969006Z" + } + }, + "outputs": [], + "source": [ + "# Profile gradient computation using adjoint sensitivity analysis\n", + "solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", + "solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "e3cc7971002b6d06", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:52:03.266074Z", + "start_time": "2024-11-19T09:51:59.992465Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "39.3 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit \n", + "simulate_petab(\n", + " petab_problem,\n", + " amici_model,\n", + " solver=solver,\n", + " problem_parameters=problem_parameters,\n", + " scaled_parameters=True,\n", + " scaled_gradients=True,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From f53d4a02f7953bb749bb5b575823fbf37e86b0f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 11:00:06 +0000 Subject: [PATCH 03/14] run preequilibration once --- python/sdist/amici/jax/petab.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 1bd42fe6c1..e97f7d2e7a 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -406,11 +406,9 @@ def run_simulations( simulation_conditions = problem.get_all_simulation_conditions() preeqs = { - sc[1]: problem.run_preequilibration( - sc[1], solver, controller, max_steps - ) - for sc in simulation_conditions - if len(sc) > 1 + sc: problem.run_preequilibration(sc, solver, controller, max_steps) + # only run preequilibration once per condition + for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1} } results = { From 019efd6cf03e79dd1f53641a47b560f03ab07794 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 11:02:02 +0000 Subject: [PATCH 04/14] fix symlink --- documentation/ExampleJaxPEtab.ipynb | 674 +--------------------------- 1 file changed, 1 insertion(+), 673 deletions(-) mode change 100644 => 120000 documentation/ExampleJaxPEtab.ipynb diff --git a/documentation/ExampleJaxPEtab.ipynb b/documentation/ExampleJaxPEtab.ipynb deleted file mode 100644 index 855860e242..0000000000 --- a/documentation/ExampleJaxPEtab.ipynb +++ /dev/null @@ -1,673 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "d4d2bc5c", - "metadata": {}, - "source": [ - "# Simulating AMICI models using JAX\n", - "\n", - "## Overview\n", - "\n", - "This guide demonstrates how to use AMICI to export models in a format compatible with the [JAX](https://jax.readthedocs.io/en/latest/) ecosystem, enabling simulations with the [diffrax](https://docs.kidger.site/diffrax/) library. " - ] - }, - { - "cell_type": "markdown", - "id": "fb2fe897", - "metadata": {}, - "source": [ - "## Preparation\n", - "\n", - "To begin, we will import a model using [PEtab](https://petab.readthedocs.io). For this demonstration, we will utilize the [Benchmark Collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which provides a diverse set of models. For more information on importing PEtab models, refer to the corresponding [PEtab notebook](https://amici.readthedocs.io/en/latest/petab.html).\n", - "\n", - "In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem), we will load the PEtab problem. To create a [JAXModel](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n" - ] - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "from amici.petab.petab_import import import_petab_problem\n", - "import petab.v1 as petab\n", - "\n", - "# Define the model name and YAML file location\n", - "model_name = \"Boehm_JProteomeRes2014\"\n", - "yaml_url = (\n", - " f\"https://raw.githubusercontent.com/Benchmarking-Initiative/Benchmark-Models-PEtab/\"\n", - " f\"master/Benchmark-Models/{model_name}/{model_name}.yaml\"\n", - ")\n", - "\n", - "# Load the PEtab problem from the YAML file\n", - "petab_problem = petab.Problem.from_yaml(yaml_url)\n", - "\n", - "# Import the PEtab problem as a JAX-compatible AMICI model\n", - "jax_model = import_petab_problem(\n", - " petab_problem,\n", - " verbose=False, # no text output\n", - " jax=True, # return jax model\n", - ")" - ], - "id": "c71c96da0da3144a" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": [ - "## Simulation\n", - "\n", - "In principle, we can already use this model for simulation using the [simulate_condition](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel.simulate_condition) method. However, this approach can be cumbersome as timepoints, data etc. need to be specified manually. Instead, we process the PEtab problem into a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem), which enables efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)." - ], - "id": "7e0f1c27bd71ee1f" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "from amici.jax import JAXProblem, run_simulations\n", - "\n", - "# Create a JAXProblem from the JAX model and PEtab problem\n", - "jax_problem = JAXProblem(jax_model, petab_problem)\n", - "\n", - "# Run simulations and compute the log-likelihood\n", - "llh, results = run_simulations(jax_problem)" - ], - "id": "ccecc9a29acc7b73" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results.", - "id": "415962751301c64a" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "# Define the simulation condition\n", - "simulation_condition = (\"model1_data1\",)\n", - "\n", - "# Access the results for the specified condition\n", - "results[simulation_condition]" - ], - "id": "596b86e45e18fe3d" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": [ - "Unfortunately, the simulation failed! As seen in the output, the simulation broke down after the initial timepoint, indicated by the `inf` values in the state variables `results[simulation_condition][1].x` and the `nan` likelihood value. A closer inspection of this variable provides additional clues about what might have gone wrong.\n", - "\n", - "The issue stems from using single precision, as indicated by the `float32` dtype of state variables. Single precision is generally a [bad idea](https://docs.kidger.site/diffrax/examples/stiff_ode/) for stiff systems like the Böhm model. Let’s retry the simulation with double precision." - ], - "id": "a1b173e013f9210a" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "import jax\n", - "\n", - "# Enable double precision in JAX\n", - "jax.config.update(\"jax_enable_x64\", True)\n", - "\n", - "# Re-run simulations with double precision\n", - "llh, results = run_simulations(jax_problem)\n", - "\n", - "results" - ], - "id": "f4f5ff705a3f7402" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories.", - "id": "fe4d3b40ee3efdf2" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "\n", - "\n", - "def plot_simulation(results):\n", - " \"\"\"\n", - " Plot the state trajectories from the simulation results.\n", - "\n", - " Parameters:\n", - " results (dict): Simulation results from run_simulations.\n", - " \"\"\"\n", - " # Extract the simulation results for the specific condition\n", - " sim_results = results[simulation_condition][1]\n", - "\n", - " # Create a new figure for the state trajectories\n", - " plt.figure(figsize=(8, 6))\n", - " for idx in range(sim_results[\"x\"].shape[1]):\n", - " time_points = np.array(sim_results[\"ts\"])\n", - " state_values = np.array(sim_results[\"x\"][:, idx])\n", - " plt.plot(time_points, state_values, label=jax_model.state_ids[idx])\n", - "\n", - " # Add labels, legend, and grid\n", - " plt.xlabel(\"Time\")\n", - " plt.ylabel(\"State Values\")\n", - " plt.title(simulation_condition)\n", - " plt.legend()\n", - " plt.grid(True)\n", - " plt.show()\n", - "\n", - "\n", - "# Plot the simulation results\n", - "plot_simulation(results)" - ], - "id": "72f1ed397105e14a" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all.", - "id": "4fa97c33719c2277" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n", - "results" - ], - "id": "7950774a3e989042" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": [ - "## Updating Parameters\n", - "\n", - "As next step, we will update the parameter values used for simulation. However, if we attempt to directly modify the values in `JAXModel.parameters`, we encounter a `FrozenInstanceError`." - ], - "id": "98b8516a75ce4d12" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "from dataclasses import FrozenInstanceError\n", - "import jax\n", - "\n", - "# Generate random noise to update the parameters\n", - "noise = (\n", - " jax.random.normal(\n", - " key=jax.random.PRNGKey(0), shape=jax_problem.parameters.shape\n", - " )\n", - " / 10\n", - ")\n", - "\n", - "# Attempt to update the parameters\n", - "try:\n", - " jax_problem.parameters += noise\n", - "except FrozenInstanceError as e:\n", - " print(\"Error:\", e)" - ], - "id": "3d278a3d21e709d" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": [ - "The root cause of this error lies in the fact that, to enable autodiff, direct modifications of attributes are not allowed in [equinox](https://docs.kidger.site/equinox/), which AMICI utilizes under the hood. Consequently, attributes of instances like `JAXModel` or `JAXProblem` cannot be updated directly — this is the price we have to pay for autodiff.\n", - "\n", - "However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one." - ], - "id": "4cc3d595de4a4085" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "# Update the parameters and create a new JAXProblem instance\n", - "jax_problem = jax_problem.update_parameters(jax_problem.parameters + noise)\n", - "\n", - "# Run simulations with the updated parameters\n", - "llh, results = run_simulations(jax_problem)\n", - "\n", - "# Plot the simulation results\n", - "plot_simulation(results)" - ], - "id": "e47748376059628b" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": [ - "## Computing Gradients\n", - "\n", - "Similar to updating attributes, computing gradients in the JAX ecosystem can feel a bit unconventional if you’re not familiar with the JAX ecosysmt. JAX offers [powerful automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) through the `jax.grad` function. However, to use `jax.grad` with `JAXProblem`, we need to specify which parts of the `JAXProblem` should be treated as static." - ], - "id": "660baf605a4e8339" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "try:\n", - " # Attempt to compute the gradient of the run_simulations function\n", - " jax.grad(run_simulations, has_aux=True)(jax_problem)\n", - "except TypeError as e:\n", - " print(\"Error:\", e)" - ], - "id": "7033d09cc81b7f69" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`.", - "id": "dc9bc07cde00a926" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "import equinox as eqx\n", - "\n", - "# Compute the gradient using equinox's filter_grad, preserving auxiliary outputs\n", - "grad, _ = eqx.filter_grad(run_simulations, has_aux=True)(jax_problem)" - ], - "id": "a6704182200e6438" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`.", - "id": "851c3ec94cb5d086" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "grad.parameters", - "id": "c00c1581d7173d7a" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`.", - "id": "375b835fecc5a022" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "grad", - "id": "f7c17f7459d0151f" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out.", - "id": "8eb7cc3db510c826" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "grad._measurements[simulation_condition]", - "id": "3badd4402cf6b8c6" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation.", - "id": "58eb04393a1463d" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "import jax.numpy as jnp\n", - "import diffrax\n", - "\n", - "# Define the simulation condition\n", - "simulation_condition = (\"model1_data1\",)\n", - "\n", - "# Load condition-specific data\n", - "ts_preeq, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n", - " simulation_condition\n", - "]\n", - "\n", - "# Load parameters for the specified condition\n", - "p = jax_problem.load_parameters(simulation_condition[0])\n", - "# Disable preequilibration\n", - "p_preeq = jnp.array([])\n", - "\n", - "\n", - "# Define a function to compute the gradient with respect to dynamic timepoints\n", - "@eqx.filter_jacfwd\n", - "def grad_ts_dyn(tt):\n", - " return jax_problem.model.simulate_condition(\n", - " p=p,\n", - " p_preeq=p_preeq,\n", - " ts_preeq=ts_preeq,\n", - " ts_dyn=tt,\n", - " ts_posteq=ts_posteq,\n", - " my=jnp.array(my),\n", - " iys=jnp.array(iys),\n", - " solver=diffrax.Kvaerno5(),\n", - " controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n", - " max_steps=2**10,\n", - " adjoint=diffrax.DirectAdjoint(),\n", - " ret=\"y\", # Return observables\n", - " )[0]\n", - "\n", - "\n", - "# Compute the gradient with respect to `ts_dyn`\n", - "g = grad_ts_dyn(ts_dyn)\n", - "g" - ], - "id": "1a91aff44b93157" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": [ - "## Compilation & Profiling\n", - "\n", - "To maximize performance with JAX, code should be just-in-time (JIT) compiled. This can be achieved using the `jax.jit` or `equinox.filter_jit` decorators. While JIT compilation introduces some overhead during the first function call, it significantly improves performance for subsequent calls. To demonstrate this, we will first clear the JIT cache and then profile the execution." - ], - "id": "9f870da7754e139c" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "from time import time\n", - "\n", - "# Clear JAX caches to ensure a fresh start\n", - "jax.clear_caches()\n", - "\n", - "# Define a JIT-compiled gradient function with auxiliary outputs\n", - "gradfun = eqx.filter_jit(eqx.filter_grad(run_simulations, has_aux=True))" - ], - "id": "58ebdc110ea7457e" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "# Measure the time taken for the first function call (including compilation)\n", - "start = time()\n", - "run_simulations(jax_problem)\n", - "print(f\"Function compilation time: {time() - start:.2f} seconds\")\n", - "\n", - "# Measure the time taken for the gradient computation (including compilation)\n", - "start = time()\n", - "gradfun(jax_problem)\n", - "print(f\"Gradient compilation time: {time() - start:.2f} seconds\")" - ], - "id": "e1242075f7e0faf" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "%%timeit\n", - "run_simulations(\n", - " jax_problem,\n", - " controller=diffrax.PIDController(\n", - " rtol=1e-8, # same as amici default\n", - " atol=1e-16, # same as amici default\n", - " pcoeff=0.4, # recommended value for stiff systems\n", - " icoeff=0.3, # recommended value for stiff systems\n", - " dcoeff=0.0, # recommended value for stiff systems\n", - " ),\n", - ")" - ], - "id": "27181f367ccb1817" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "%%timeit \n", - "gradfun(\n", - " jax_problem,\n", - " controller=diffrax.PIDController(\n", - " rtol=1e-8, # same as amici default\n", - " atol=1e-16, # same as amici default\n", - " pcoeff=0.4, # recommended value for stiff systems\n", - " icoeff=0.3, # recommended value for stiff systems\n", - " dcoeff=0.0, # recommended value for stiff systems\n", - " ),\n", - ")" - ], - "id": "5b8d3a6162a3ae55" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "from amici.petab import simulate_petab\n", - "import amici\n", - "\n", - "# Import the PEtab problem as a standard AMICI model\n", - "amici_model = import_petab_problem(\n", - " petab_problem,\n", - " verbose=False,\n", - " compile_=True,\n", - " jax=False, # load the amici model this time\n", - ")\n", - "\n", - "# Configure the solver with appropriate tolerances\n", - "solver = amici_model.getSolver()\n", - "solver.setAbsoluteTolerance(1e-8)\n", - "solver.setRelativeTolerance(1e-8)\n", - "\n", - "# Prepare the parameters for the simulation\n", - "problem_parameters = dict(\n", - " zip(jax_problem.parameter_ids, jax_problem.parameters)\n", - ")" - ], - "id": "d733a450635a749b" - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "413ed7c60b2cf4be", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-19T09:51:55.259985Z", - "start_time": "2024-11-19T09:51:55.257937Z" - } - }, - "outputs": [], - "source": [ - "# Profile simulation only\n", - "solver.setSensitivityOrder(amici.SensitivityOrder.none)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "768fa60e439ca8b4", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-19T09:51:57.417608Z", - "start_time": "2024-11-19T09:51:55.273367Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "26.1 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "%%timeit \n", - "simulate_petab(\n", - " petab_problem,\n", - " amici_model,\n", - " solver=solver,\n", - " problem_parameters=problem_parameters,\n", - " scaled_parameters=True,\n", - " scaled_gradients=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "b8382b0b2b68f49e", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-19T09:51:57.497361Z", - "start_time": "2024-11-19T09:51:57.494502Z" - } - }, - "outputs": [], - "source": [ - "# Profile gradient computation using forward sensitivity analysis\n", - "solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", - "solver.setSensitivityMethod(amici.SensitivityMethod.forward)" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "3bae1fab8c416122", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-19T09:51:59.897459Z", - "start_time": "2024-11-19T09:51:57.511889Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "29.1 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "%%timeit \n", - "simulate_petab(\n", - " petab_problem,\n", - " amici_model,\n", - " solver=solver,\n", - " problem_parameters=problem_parameters,\n", - " scaled_parameters=True,\n", - " scaled_gradients=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "71e0358227e1dc74", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-19T09:51:59.972149Z", - "start_time": "2024-11-19T09:51:59.969006Z" - } - }, - "outputs": [], - "source": [ - "# Profile gradient computation using adjoint sensitivity analysis\n", - "solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", - "solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "e3cc7971002b6d06", - "metadata": { - "ExecuteTime": { - "end_time": "2024-11-19T09:52:03.266074Z", - "start_time": "2024-11-19T09:51:59.992465Z" - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "39.3 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" - ] - } - ], - "source": [ - "%%timeit \n", - "simulate_petab(\n", - " petab_problem,\n", - " amici_model,\n", - " solver=solver,\n", - " problem_parameters=problem_parameters,\n", - " scaled_parameters=True,\n", - " scaled_gradients=True,\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/documentation/ExampleJaxPEtab.ipynb b/documentation/ExampleJaxPEtab.ipynb new file mode 120000 index 0000000000..821b14f21f --- /dev/null +++ b/documentation/ExampleJaxPEtab.ipynb @@ -0,0 +1 @@ +../python/examples/example_jax_petab/ExampleJaxPEtab.ipynb \ No newline at end of file From 1613aa245d437a40a21ac9bdb66ead43be6bdf39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 11:11:27 +0000 Subject: [PATCH 05/14] separate default dirs for jax/cpp, honour model dir/name --- python/examples/example_jax_petab/ExampleJaxPEtab.ipynb | 1 - python/sdist/amici/petab/import_helpers.py | 5 +---- python/sdist/amici/petab/petab_import.py | 7 ++----- python/sdist/amici/petab/sbml_import.py | 9 ++++++--- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index fe7b5d8830..a77cf64c69 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -489,7 +489,6 @@ "amici_model = import_petab_problem(\n", " petab_problem,\n", " verbose=False,\n", - " compile_=True,\n", " jax=False, # load the amici model this time\n", ")\n", "\n", diff --git a/python/sdist/amici/petab/import_helpers.py b/python/sdist/amici/petab/import_helpers.py index daa902efb0..57bc551205 100644 --- a/python/sdist/amici/petab/import_helpers.py +++ b/python/sdist/amici/petab/import_helpers.py @@ -138,11 +138,8 @@ def _can_import_model( Check whether a module of that name can already be imported. """ # try to import (in particular checks version) - suffix = "_jax" if jax else "" try: - model_module = amici.import_model_module( - model_name + suffix, model_output_dir - ) + model_module = amici.import_model_module(model_name, model_output_dir) except ModuleNotFoundError: return False diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 63bade9bb8..24cb21a466 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -114,7 +114,7 @@ def import_petab_problem( from .sbml_import import _create_model_output_dir_name model_output_dir = _create_model_output_dir_name( - petab_problem.sbml_model, model_name + petab_problem.sbml_model, model_name, jax=jax ) else: model_output_dir = os.path.abspath(model_output_dir) @@ -160,10 +160,7 @@ def import_petab_problem( ) # import model - suffix = "_jax" if jax else "" - model_module = amici.import_model_module( - model_name + suffix, model_output_dir - ) + model_module = amici.import_model_module(model_name, model_output_dir) if jax: model = model_module.Model() diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index 02a2c4e12c..e4e7efd7fc 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -588,7 +588,9 @@ def _get_fixed_parameters_sbml( def _create_model_output_dir_name( - sbml_model: "libsbml.Model", model_name: str | None = None + sbml_model: "libsbml.Model", + model_name: str | None = None, + jax: bool = False, ) -> Path: """ Find a folder for storing the compiled amici model. @@ -599,12 +601,13 @@ def _create_model_output_dir_name( BASE_DIR = Path("amici_models").absolute() BASE_DIR.mkdir(exist_ok=True) # try model_name + suffix = "_jax" if jax else "" if model_name: - return BASE_DIR / model_name + return BASE_DIR / (model_name + suffix) # try sbml model id if sbml_model_id := sbml_model.getId(): - return BASE_DIR / sbml_model_id + return BASE_DIR / (sbml_model_id + suffix) # create random folder name return Path(tempfile.mkdtemp(dir=BASE_DIR)) From 513ca2a902d9cddeac867893aa264336f59a03ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 11:17:32 +0000 Subject: [PATCH 06/14] fix notebook --- python/examples/example_jax_petab/ExampleJaxPEtab.ipynb | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index a77cf64c69..a937f76d47 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -357,14 +357,12 @@ "simulation_condition = (\"model1_data1\",)\n", "\n", "# Load condition-specific data\n", - "ts_preeq, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n", + "ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n", " simulation_condition\n", "]\n", "\n", "# Load parameters for the specified condition\n", "p = jax_problem.load_parameters(simulation_condition[0])\n", - "# Disable preequilibration\n", - "p_preeq = jnp.array([])\n", "\n", "\n", "# Define a function to compute the gradient with respect to dynamic timepoints\n", @@ -372,8 +370,7 @@ "def grad_ts_dyn(tt):\n", " return jax_problem.model.simulate_condition(\n", " p=p,\n", - " p_preeq=p_preeq,\n", - " ts_init=ts_preeq,\n", + " ts_init=ts_init,\n", " ts_dyn=tt,\n", " ts_posteq=ts_posteq,\n", " my=jnp.array(my),\n", From 69f2fa451239b4b957fc234bface63308f9c1caa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 15:02:47 +0000 Subject: [PATCH 07/14] fix path SNAFU --- python/sdist/amici/jax/ode_export.py | 5 ++--- python/sdist/amici/petab/import_helpers.py | 25 +++++++++++++++++++++- python/sdist/amici/petab/petab_import.py | 13 ++++++++--- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 7ea4a29d8a..cec5104ded 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -234,12 +234,10 @@ def _generate_jax_code(self) -> None: "MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'", }, } - outdir = self.model_path / (self.model_name + "_jax") - outdir.mkdir(parents=True, exist_ok=True) apply_template( Path(amiciModulePath) / "jax" / "jax.template.py", - outdir / "__init__.py", + self.model_path / "__init__.py", tpl_data, ) @@ -258,6 +256,7 @@ def set_paths(self, output_dir: str | Path | None = None) -> None: output_dir = Path(os.getcwd()) / f"amici-{self.model_name}" self.model_path = Path(output_dir).resolve() + self.model_path.mkdir(parents=True, exist_ok=True) def set_name(self, model_name: str) -> None: """ diff --git a/python/sdist/amici/petab/import_helpers.py b/python/sdist/amici/petab/import_helpers.py index 57bc551205..d42e99b1e3 100644 --- a/python/sdist/amici/petab/import_helpers.py +++ b/python/sdist/amici/petab/import_helpers.py @@ -139,7 +139,9 @@ def _can_import_model( """ # try to import (in particular checks version) try: - model_module = amici.import_model_module(model_name, model_output_dir) + model_module = amici.import_model_module( + *_get_package_name_and_path(model_name, model_output_dir, jax) + ) except ModuleNotFoundError: return False @@ -268,3 +270,24 @@ def check_model( "the current model might also resolve this. Parameters: " f"{amici_ids_free_required.difference(amici_ids_free)}" ) + + +def _get_package_name_and_path( + model_name: str, model_output_dir: str | Path, jax: bool = False +) -> tuple[str, Path]: + """ + Get the package name and path for the generated model module. + + :param model_name: + Name of the model + :param model_output_dir: + Target directory for the generated model module + :param jax: + Whether to generate the paths for a JAX or CPP model + :return: + """ + if jax: + outdir = Path(model_output_dir) + return outdir.stem, outdir.parent + else: + return model_name, Path(model_output_dir) diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 24cb21a466..b7fccca241 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -16,7 +16,12 @@ from petab.v1.models import MODEL_TYPE_PYSB, MODEL_TYPE_SBML from ..logging import get_logger -from .import_helpers import _can_import_model, _create_model_name, check_model +from .import_helpers import ( + _can_import_model, + _create_model_name, + check_model, + _get_package_name_and_path, +) from .sbml_import import import_model_sbml try: @@ -136,7 +141,7 @@ def import_petab_problem( ) # remove folder if exists - if os.path.exists(model_output_dir): + if not jax and os.path.exists(model_output_dir): shutil.rmtree(model_output_dir) logger.info(f"Compiling model {model_name} to {model_output_dir}.") @@ -160,7 +165,9 @@ def import_petab_problem( ) # import model - model_module = amici.import_model_module(model_name, model_output_dir) + model_module = amici.import_model_module( + *_get_package_name_and_path(model_name, model_output_dir, jax=jax) + ) if jax: model = model_module.Model() From ba37d0be8f202f2148bde06b418e6b495181952b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 15:30:05 +0000 Subject: [PATCH 08/14] fix models without preequilibration --- python/sdist/amici/jax/petab.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index e97f7d2e7a..172aac7043 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -413,7 +413,11 @@ def run_simulations( results = { sc[0]: problem.run_simulation( - sc, solver, controller, max_steps, preeqs.get(sc[1])[0] + sc, + solver, + controller, + max_steps, + preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]), ) for sc in simulation_conditions } From 1f0af69bc478337db999dcb2a85c863462011f09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 15:45:05 +0000 Subject: [PATCH 09/14] fix tests --- python/tests/test_jax.py | 46 +++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 8f4c68510b..ce7018e078 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -47,7 +47,7 @@ def test_conversion(): module_name=model.name, module_path=outdir ) jax_module = amici.import_model_module( - module_name=model.name + "_jax", module_path=outdir + module_name=Path(outdir).stem, module_path=Path(outdir).parent ) ts = tuple(np.linspace(0, 1, 10)) @@ -108,7 +108,7 @@ def test_dimerization(): module_name=model.name, module_path=outdir ) jax_module = amici.import_model_module( - module_name=model.name + "_jax", module_path=outdir + module_name=Path(outdir).stem, module_path=Path(outdir).parent ) ts = tuple(np.linspace(0, 1, 10)) @@ -178,7 +178,7 @@ def check_fields_jax( ts = ts.flatten() iys = iys.flatten() - ts_preeq = ts[ts == 0] + ts_init = ts[ts == 0] ts_dyn = ts[ts > 0] ts_posteq = np.array([]) @@ -188,31 +188,37 @@ def check_fields_jax( } p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids]) - args = ( - jnp.array([]), # p_preeq - jnp.array(ts_preeq), # ts_preeq - jnp.array(ts_dyn), # ts_dyn - jnp.array(ts_posteq), # ts_posteq - jnp.array(my), # my - jnp.array(iys), # iys - diffrax.Kvaerno5(), # solver - diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), # controller - diffrax.RecursiveCheckpointAdjoint(), # adjoint - 2**8, # max_steps - ) + kwargs = { + "ts_init": jnp.array(ts_init), + "ts_dyn": jnp.array(ts_dyn), + "ts_posteq": jnp.array(ts_posteq), + "my": jnp.array(my), + "iys": jnp.array(iys), + "x_preeq": jnp.array([]), + "solver": diffrax.Kvaerno5(), + "controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), + "adjoint": diffrax.RecursiveCheckpointAdjoint(), + "max_steps": 2**8, # max_steps + } fun = beartype(jax_model.simulate_condition) for output in ["llh", "x0", "x", "y", "res"]: - oargs = (*args[:-2], diffrax.DirectAdjoint(), 2**8, output) + okwargs = kwargs | { + "adjoint": diffrax.DirectAdjoint(), + "max_steps": 2**8, + "ret": output, + } if sensi_order == amici.SensitivityOrder.none: - r_jax[output] = fun(p, *oargs)[0] + r_jax[output] = fun(p, **okwargs)[0] if sensi_order == amici.SensitivityOrder.first: if output == "llh": - r_jax[f"s{output}"] = jax.grad(fun, has_aux=True)(p, *args)[0] - else: - r_jax[f"s{output}"] = jax.jacfwd(fun, has_aux=True)(p, *oargs)[ + r_jax[f"s{output}"] = jax.grad(fun, has_aux=True)(p, **kwargs)[ 0 ] + else: + r_jax[f"s{output}"] = jax.jacfwd(fun, has_aux=True)( + p, **okwargs + )[0] amici_par_idx = np.array( [jax_model.parameter_ids.index(par_id) for par_id in parameter_ids] From 08afd8bd3ba6581cd80ae3ea92942aa23a8308fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 19:57:35 +0000 Subject: [PATCH 10/14] fixup --- python/sdist/amici/jax/petab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 172aac7043..b9f4b18881 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -412,7 +412,7 @@ def run_simulations( } results = { - sc[0]: problem.run_simulation( + sc: problem.run_simulation( sc, solver, controller, From 7a780ad100f97938ed4e09d2972115c9a9097079 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 20:01:49 +0000 Subject: [PATCH 11/14] fix doc typehints --- python/sdist/amici/jax/petab.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index b9f4b18881..bb8749e27c 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -154,7 +154,7 @@ def _get_parameter_mappings( def _get_measurements( self, simulation_conditions: pd.DataFrame ) -> dict[ - str, + tuple[str, ...], tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], ]: """ @@ -320,7 +320,7 @@ def run_simulation( Step size controller to use for simulation :param max_steps: Maximum number of steps to take during simulation - :param preeq: + :param x_preeq: Pre-equilibration state if available :return: Tuple of log-likelihood and simulation statistics @@ -375,7 +375,7 @@ def run_preequilibration( def run_simulations( problem: JAXProblem, - simulation_conditions: Iterable[tuple] | None = None, + simulation_conditions: Iterable[tuple[str, ...]] | None = None, solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), controller: diffrax.AbstractStepSizeController = diffrax.PIDController( rtol=1e-8, From 9d82a6c6a2c3311e1006edfac9e6e9bb56859d70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 20:37:22 +0000 Subject: [PATCH 12/14] fix notebook --- python/examples/example_jax_petab/ExampleJaxPEtab.ipynb | 1 + 1 file changed, 1 insertion(+) diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index a937f76d47..7f541ea7e4 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -375,6 +375,7 @@ " ts_posteq=ts_posteq,\n", " my=jnp.array(my),\n", " iys=jnp.array(iys),\n", + " x_preeq=jnp.array([]),\n", " solver=diffrax.Kvaerno5(),\n", " controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n", " max_steps=2**10,\n", From 9a2542052961bd0029ad851b5b577e33928fb484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 22:22:12 +0000 Subject: [PATCH 13/14] fix output dict construction --- python/sdist/amici/jax/petab.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index bb8749e27c..0411e5e2df 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -421,4 +421,7 @@ def run_simulations( ) for sc in simulation_conditions } - return sum(llh for llh, _ in results.values()), results | preeqs + return sum(llh for llh, _ in results.values()), { + sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1] + for sc, res in results.items() + } From ecaf5a966ba251c5931c2b07fe895f5d9f0fdabe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 22:41:52 +0000 Subject: [PATCH 14/14] fix notebook --- python/examples/example_jax_petab/ExampleJaxPEtab.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 7f541ea7e4..f6a4f10e98 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -151,7 +151,7 @@ " results (dict): Simulation results from run_simulations.\n", " \"\"\"\n", " # Extract the simulation results for the specific condition\n", - " sim_results = results[simulation_condition][1]\n", + " sim_results = results[simulation_condition]\n", "\n", " # Create a new figure for the state trajectories\n", " plt.figure(figsize=(8, 6))\n",