diff --git a/.github/workflows/test_petab_sciml.yml b/.github/workflows/test_petab_sciml.yml new file mode 100644 index 0000000000..eb8ca39394 --- /dev/null +++ b/.github/workflows/test_petab_sciml.yml @@ -0,0 +1,87 @@ +name: PEtab +on: + push: + branches: + - develop + - master + pull_request: + branches: + - master + - develop + merge_group: + workflow_dispatch: + +jobs: + build: + name: PEtab SciML Testsuite + + runs-on: ubuntu-latest + + env: + ENABLE_GCOV_COVERAGE: TRUE + + strategy: + matrix: + python-version: ["3.11"] + + steps: + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - uses: actions/checkout@v4 + with: + fetch-depth: 20 + + - name: Install apt dependencies + uses: ./.github/actions/install-apt-dependencies + + # install dependencies + - name: apt + run: | + sudo apt-get update \ + && sudo apt-get install -y python3-venv + + - run: | + echo "${HOME}/.local/bin/" >> $GITHUB_PATH + + # install AMICI + - name: Install python package + run: scripts/installAmiciSource.sh + + - name: Install petab + run: | + source ./venv/bin/activate \ + && pip3 install wheel pytest shyaml pytest-cov + + # retrieve test models + - name: Download and install PEtab SciML test suite + run: | + git clone --depth 1 --branch main \ + https://github.com/sebapersson/petab_sciml.git \ + && export SCIML_TESTSUITE="$(pwd)/petab_sciml" \ + && source venv/bin/activate \ + && python -m pip install -e $SCIML_TESTSUITE/src/python + + + - name: Install petab + run: | + source ./venv/bin/activate \ + && python3 -m pip uninstall -y petab \ + && python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@develop \ + + - name: Run PEtab SciML testsuite + run: | + source ./venv/bin/activate \ + && pytest --cov-report=xml:coverage.xml \ + --cov=./ tests/sciml/test_sciml.py + + - name: Codecov + if: github.event_name == 'pull_request' || github.repository_owner == 'AMICI-dev' + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: coverage.xml + flags: petab + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index e68c2e4f72..95cd525760 100644 --- a/.gitignore +++ b/.gitignore @@ -207,3 +207,4 @@ debug/* tests/benchmark-models/cache_fiddy/* venv/* .coverage +tests/sciml/models/* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..327b90c6ad --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "tests/sciml/testsuite"] + path = tests/sciml/testsuite + url = https://github.com/sebapersson/petab_sciml diff --git a/documentation/ExampleJaxPEtab.ipynb b/documentation/ExampleJaxPEtab.ipynb deleted file mode 120000 index 821b14f21f..0000000000 --- a/documentation/ExampleJaxPEtab.ipynb +++ /dev/null @@ -1 +0,0 @@ -../python/examples/example_jax_petab/ExampleJaxPEtab.ipynb \ No newline at end of file diff --git a/documentation/ExampleJaxPEtab.ipynb b/documentation/ExampleJaxPEtab.ipynb new file mode 100644 index 0000000000..1310091f4c --- /dev/null +++ b/documentation/ExampleJaxPEtab.ipynb @@ -0,0 +1,671 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d4d2bc5c", + "metadata": {}, + "source": [ + "# Simulating AMICI models using JAX\n", + "\n", + "## Overview\n", + "\n", + "This guide demonstrates how to use AMICI to export models in a format compatible with the [JAX](https://jax.readthedocs.io/en/latest/) ecosystem, enabling simulations with the [diffrax](https://docs.kidger.site/diffrax/) library. " + ] + }, + { + "cell_type": "markdown", + "id": "fb2fe897", + "metadata": {}, + "source": [ + "## Preparation\n", + "\n", + "To begin, we will import a model using [PEtab](https://petab.readthedocs.io). For this demonstration, we will utilize the [Benchmark Collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which provides a diverse set of models. For more information on importing PEtab models, refer to the corresponding [PEtab notebook](https://amici.readthedocs.io/en/latest/petab.html).\n", + "\n", + "In this tutorial, we will import the Böhm model from the Benchmark Collection. Using [amici.petab_import](https://amici.readthedocs.io/en/latest/generated/amici.petab_import.html#amici.petab_import.import_petab_problem), we will load the PEtab problem. To create a [JAXModel](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel) instead of a standard AMICI model, we set the `jax` parameter to `True`.\n" + ] + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from amici.petab.petab_import import import_petab_problem\n", + "import petab.v1 as petab\n", + "\n", + "# Define the model name and YAML file location\n", + "model_name = \"Boehm_JProteomeRes2014\"\n", + "yaml_url = (\n", + " f\"https://raw.githubusercontent.com/Benchmarking-Initiative/Benchmark-Models-PEtab/\"\n", + " f\"master/Benchmark-Models/{model_name}/{model_name}.yaml\"\n", + ")\n", + "\n", + "# Load the PEtab problem from the YAML file\n", + "petab_problem = petab.Problem.from_yaml(yaml_url)\n", + "\n", + "# Import the PEtab problem as a JAX-compatible AMICI model\n", + "jax_model = import_petab_problem(\n", + " petab_problem,\n", + " verbose=False, # no text output\n", + " jax=True, # return jax model\n", + ")" + ], + "id": "c71c96da0da3144a" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Simulation\n", + "\n", + "In principle, we can already use this model for simulation using the [simulate_condition](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel.simulate_condition) method. However, this approach can be cumbersome as timepoints, data etc. need to be specified manually. Instead, we process the PEtab problem into a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem), which enables efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)." + ], + "id": "7e0f1c27bd71ee1f" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from amici.jax import JAXProblem, run_simulations\n", + "\n", + "# Create a JAXProblem from the JAX model and PEtab problem\n", + "jax_problem = JAXProblem(jax_model, petab_problem)\n", + "\n", + "# Run simulations and compute the log-likelihood\n", + "llh, results = run_simulations(jax_problem)" + ], + "id": "ccecc9a29acc7b73" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results.", + "id": "415962751301c64a" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Define the simulation condition\n", + "simulation_condition = (\"model1_data1\",)\n", + "\n", + "# Access the results for the specified condition\n", + "results[simulation_condition]" + ], + "id": "596b86e45e18fe3d" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "Unfortunately, the simulation failed! As seen in the output, the simulation broke down after the initial timepoint, indicated by the `inf` values in the state variables `results[simulation_condition][1].x` and the `nan` likelihood value. A closer inspection of this variable provides additional clues about what might have gone wrong.\n", + "\n", + "The issue stems from using single precision, as indicated by the `float32` dtype of state variables. Single precision is generally a [bad idea](https://docs.kidger.site/diffrax/examples/stiff_ode/) for stiff systems like the Böhm model. Let’s retry the simulation with double precision." + ], + "id": "a1b173e013f9210a" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "import jax\n", + "\n", + "# Enable double precision in JAX\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "# Re-run simulations with double precision\n", + "llh, results = run_simulations(jax_problem)\n", + "\n", + "results" + ], + "id": "f4f5ff705a3f7402" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories.", + "id": "fe4d3b40ee3efdf2" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "\n", + "def plot_simulation(results):\n", + " \"\"\"\n", + " Plot the state trajectories from the simulation results.\n", + "\n", + " Parameters:\n", + " results (dict): Simulation results from run_simulations.\n", + " \"\"\"\n", + " # Extract the simulation results for the specific condition\n", + " sim_results = results[simulation_condition]\n", + "\n", + " # Create a new figure for the state trajectories\n", + " plt.figure(figsize=(8, 6))\n", + " for idx in range(sim_results[\"x\"].shape[1]):\n", + " time_points = np.array(sim_results[\"ts\"])\n", + " state_values = np.array(sim_results[\"x\"][:, idx])\n", + " plt.plot(time_points, state_values, label=jax_model.state_ids[idx])\n", + "\n", + " # Add labels, legend, and grid\n", + " plt.xlabel(\"Time\")\n", + " plt.ylabel(\"State Values\")\n", + " plt.title(simulation_condition)\n", + " plt.legend()\n", + " plt.grid(True)\n", + " plt.show()\n", + "\n", + "\n", + "# Plot the simulation results\n", + "plot_simulation(results)" + ], + "id": "72f1ed397105e14a" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all.", + "id": "4fa97c33719c2277" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n", + "results" + ], + "id": "7950774a3e989042" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Updating Parameters\n", + "\n", + "As next step, we will update the parameter values used for simulation. However, if we attempt to directly modify the values in `JAXModel.parameters`, we encounter a `FrozenInstanceError`." + ], + "id": "98b8516a75ce4d12" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from dataclasses import FrozenInstanceError\n", + "import jax\n", + "\n", + "# Generate random noise to update the parameters\n", + "noise = (\n", + " jax.random.normal(\n", + " key=jax.random.PRNGKey(0), shape=jax_problem.parameters.shape\n", + " )\n", + " / 10\n", + ")\n", + "\n", + "# Attempt to update the parameters\n", + "try:\n", + " jax_problem.parameters += noise\n", + "except FrozenInstanceError as e:\n", + " print(\"Error:\", e)" + ], + "id": "3d278a3d21e709d" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "The root cause of this error lies in the fact that, to enable autodiff, direct modifications of attributes are not allowed in [equinox](https://docs.kidger.site/equinox/), which AMICI utilizes under the hood. Consequently, attributes of instances like `JAXModel` or `JAXProblem` cannot be updated directly — this is the price we have to pay for autodiff.\n", + "\n", + "However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one." + ], + "id": "4cc3d595de4a4085" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Update the parameters and create a new JAXProblem instance\n", + "jax_problem = jax_problem.update_parameters(jax_problem.parameters + noise)\n", + "\n", + "# Run simulations with the updated parameters\n", + "llh, results = run_simulations(jax_problem)\n", + "\n", + "# Plot the simulation results\n", + "plot_simulation(results)" + ], + "id": "e47748376059628b" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Computing Gradients\n", + "\n", + "Similar to updating attributes, computing gradients in the JAX ecosystem can feel a bit unconventional if you’re not familiar with the JAX ecosysmt. JAX offers [powerful automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) through the `jax.grad` function. However, to use `jax.grad` with `JAXProblem`, we need to specify which parts of the `JAXProblem` should be treated as static." + ], + "id": "660baf605a4e8339" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "try:\n", + " # Attempt to compute the gradient of the run_simulations function\n", + " jax.grad(run_simulations, has_aux=True)(jax_problem)\n", + "except TypeError as e:\n", + " print(\"Error:\", e)" + ], + "id": "7033d09cc81b7f69" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`.", + "id": "dc9bc07cde00a926" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "import equinox as eqx\n", + "\n", + "# Compute the gradient using equinox's filter_grad, preserving auxiliary outputs\n", + "grad, _ = eqx.filter_grad(run_simulations, has_aux=True)(jax_problem)" + ], + "id": "a6704182200e6438" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`.", + "id": "851c3ec94cb5d086" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "grad.parameters", + "id": "c00c1581d7173d7a" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`.", + "id": "375b835fecc5a022" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "grad", + "id": "f7c17f7459d0151f" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out.", + "id": "8eb7cc3db510c826" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "grad._measurements[simulation_condition]", + "id": "3badd4402cf6b8c6" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation.", + "id": "58eb04393a1463d" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "import jax.numpy as jnp\n", + "import diffrax\n", + "from amici.jax import ReturnValue\n", + "\n", + "# Define the simulation condition\n", + "simulation_condition = (\"model1_data1\",)\n", + "\n", + "# Load condition-specific data\n", + "ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos = jax_problem._measurements[\n", + " simulation_condition\n", + "]\n", + "\n", + "# Load parameters for the specified condition\n", + "p = jax_problem.load_parameters(simulation_condition[0])\n", + "\n", + "\n", + "# Define a function to compute the gradient with respect to dynamic timepoints\n", + "@eqx.filter_jacfwd\n", + "def grad_ts_dyn(tt):\n", + " return jax_problem.model.simulate_condition(\n", + " p=p,\n", + " ts_init=ts_init,\n", + " ts_dyn=tt,\n", + " ts_posteq=ts_posteq,\n", + " my=jnp.array(my),\n", + " iys=jnp.array(iys),\n", + " iy_trafos=jnp.array(iy_trafos),\n", + " solver=diffrax.Kvaerno5(),\n", + " controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n", + " max_steps=2**10,\n", + " adjoint=diffrax.DirectAdjoint(),\n", + " ret=ReturnValue.y, # Return observables\n", + " )[0]\n", + "\n", + "\n", + "# Compute the gradient with respect to `ts_dyn`\n", + "g = grad_ts_dyn(ts_dyn)\n", + "g" + ], + "id": "1a91aff44b93157" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Compilation & Profiling\n", + "\n", + "To maximize performance with JAX, code should be just-in-time (JIT) compiled. This can be achieved using the `jax.jit` or `equinox.filter_jit` decorators. While JIT compilation introduces some overhead during the first function call, it significantly improves performance for subsequent calls. To demonstrate this, we will first clear the JIT cache and then profile the execution." + ], + "id": "9f870da7754e139c" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from time import time\n", + "\n", + "# Clear JAX caches to ensure a fresh start\n", + "jax.clear_caches()\n", + "\n", + "# Define a JIT-compiled gradient function with auxiliary outputs\n", + "gradfun = eqx.filter_jit(eqx.filter_grad(run_simulations, has_aux=True))" + ], + "id": "58ebdc110ea7457e" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Measure the time taken for the first function call (including compilation)\n", + "start = time()\n", + "run_simulations(jax_problem)\n", + "print(f\"Function compilation time: {time() - start:.2f} seconds\")\n", + "\n", + "# Measure the time taken for the gradient computation (including compilation)\n", + "start = time()\n", + "gradfun(jax_problem)\n", + "print(f\"Gradient compilation time: {time() - start:.2f} seconds\")" + ], + "id": "e1242075f7e0faf" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "%%timeit\n", + "run_simulations(\n", + " jax_problem,\n", + " controller=diffrax.PIDController(\n", + " rtol=1e-8, # same as amici default\n", + " atol=1e-16, # same as amici default\n", + " pcoeff=0.4, # recommended value for stiff systems\n", + " icoeff=0.3, # recommended value for stiff systems\n", + " dcoeff=0.0, # recommended value for stiff systems\n", + " ),\n", + ")" + ], + "id": "27181f367ccb1817" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "%%timeit \n", + "gradfun(\n", + " jax_problem,\n", + " controller=diffrax.PIDController(\n", + " rtol=1e-8, # same as amici default\n", + " atol=1e-16, # same as amici default\n", + " pcoeff=0.4, # recommended value for stiff systems\n", + " icoeff=0.3, # recommended value for stiff systems\n", + " dcoeff=0.0, # recommended value for stiff systems\n", + " ),\n", + ")" + ], + "id": "5b8d3a6162a3ae55" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from amici.petab import simulate_petab\n", + "import amici\n", + "\n", + "# Import the PEtab problem as a standard AMICI model\n", + "amici_model = import_petab_problem(\n", + " petab_problem,\n", + " verbose=False,\n", + " jax=False, # load the amici model this time\n", + ")\n", + "\n", + "# Configure the solver with appropriate tolerances\n", + "solver = amici_model.getSolver()\n", + "solver.setAbsoluteTolerance(1e-8)\n", + "solver.setRelativeTolerance(1e-8)\n", + "\n", + "# Prepare the parameters for the simulation\n", + "problem_parameters = dict(\n", + " zip(jax_problem.parameter_ids, jax_problem.parameters)\n", + ")" + ], + "id": "d733a450635a749b" + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "413ed7c60b2cf4be", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:55.259985Z", + "start_time": "2024-11-19T09:51:55.257937Z" + } + }, + "outputs": [], + "source": [ + "# Profile simulation only\n", + "solver.setSensitivityOrder(amici.SensitivityOrder.none)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "768fa60e439ca8b4", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:57.417608Z", + "start_time": "2024-11-19T09:51:55.273367Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "26.1 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit \n", + "simulate_petab(\n", + " petab_problem,\n", + " amici_model,\n", + " solver=solver,\n", + " problem_parameters=problem_parameters,\n", + " scaled_parameters=True,\n", + " scaled_gradients=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "b8382b0b2b68f49e", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:57.497361Z", + "start_time": "2024-11-19T09:51:57.494502Z" + } + }, + "outputs": [], + "source": [ + "# Profile gradient computation using forward sensitivity analysis\n", + "solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", + "solver.setSensitivityMethod(amici.SensitivityMethod.forward)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "3bae1fab8c416122", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:59.897459Z", + "start_time": "2024-11-19T09:51:57.511889Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "29.1 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit \n", + "simulate_petab(\n", + " petab_problem,\n", + " amici_model,\n", + " solver=solver,\n", + " problem_parameters=problem_parameters,\n", + " scaled_parameters=True,\n", + " scaled_gradients=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "71e0358227e1dc74", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:51:59.972149Z", + "start_time": "2024-11-19T09:51:59.969006Z" + } + }, + "outputs": [], + "source": [ + "# Profile gradient computation using adjoint sensitivity analysis\n", + "solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", + "solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "e3cc7971002b6d06", + "metadata": { + "ExecuteTime": { + "end_time": "2024-11-19T09:52:03.266074Z", + "start_time": "2024-11-19T09:51:59.992465Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "39.3 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "%%timeit \n", + "simulate_petab(\n", + " petab_problem,\n", + " amici_model,\n", + " solver=solver,\n", + " problem_parameters=problem_parameters,\n", + " scaled_parameters=True,\n", + " scaled_gradients=True,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pytest.ini b/pytest.ini index 03d50d80e1..b24e565354 100644 --- a/pytest.ini +++ b/pytest.ini @@ -12,6 +12,7 @@ filterwarnings = ignore:Conservation laws for non-constant species in models with Species-AssignmentRules are currently not supported and will be turned off.:UserWarning ignore:Conservation laws for non-constant species in combination with parameterized stoichiometric coefficients are not currently supported and will be turned off.:UserWarning ignore:Support for PEtab2.0 is experimental!:UserWarning + ignore:PEtab v2.0.0 mapping tables are only partially supported:UserWarning ignore:The JAX module is experimental and the API may change in the future.:ImportWarning # hundreds of SBML <=5.17 warnings ignore:.*inspect.getargspec\(\) is deprecated.*:DeprecationWarning diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index e3677af346..2acd82bdbe 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -25,10 +25,11 @@ ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "c71c96da0da3144a", + "metadata": {}, + "outputs": [], "source": [ "from amici.petab.petab_import import import_petab_problem\n", "import petab.v1 as petab\n", @@ -49,24 +50,24 @@ " verbose=False, # no text output\n", " jax=True, # return jax model\n", ")" - ], - "id": "c71c96da0da3144a" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "7e0f1c27bd71ee1f", + "metadata": {}, "source": [ "## Simulation\n", "\n", "In principle, we can already use this model for simulation using the [simulate_condition](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXModel.simulate_condition) method. However, this approach can be cumbersome as timepoints, data etc. need to be specified manually. Instead, we process the PEtab problem into a [JAXProblem](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem), which enables efficient simulation using [amici.jax.run_simulations]((https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.run_simulations)." - ], - "id": "7e0f1c27bd71ee1f" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "ccecc9a29acc7b73", + "metadata": {}, + "outputs": [], "source": [ "from amici.jax import JAXProblem, run_simulations\n", "\n", @@ -75,44 +76,44 @@ "\n", "# Run simulations and compute the log-likelihood\n", "llh, results = run_simulations(jax_problem)" - ], - "id": "ccecc9a29acc7b73" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results.", - "id": "415962751301c64a" + "id": "415962751301c64a", + "metadata": {}, + "source": "This simulates the model for all conditions using the nominal parameter values. Simple, right? Now, let’s take a look at the simulation results." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "596b86e45e18fe3d", + "metadata": {}, + "outputs": [], "source": [ "# Define the simulation condition\n", "simulation_condition = (\"model1_data1\",)\n", "\n", "# Access the results for the specified condition\n", "results[simulation_condition]" - ], - "id": "596b86e45e18fe3d" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "a1b173e013f9210a", + "metadata": {}, "source": [ "Unfortunately, the simulation failed! As seen in the output, the simulation broke down after the initial timepoint, indicated by the `inf` values in the state variables `results[simulation_condition][1].x` and the `nan` likelihood value. A closer inspection of this variable provides additional clues about what might have gone wrong.\n", "\n", "The issue stems from using single precision, as indicated by the `float32` dtype of state variables. Single precision is generally a [bad idea](https://docs.kidger.site/diffrax/examples/stiff_ode/) for stiff systems like the Böhm model. Let’s retry the simulation with double precision." - ], - "id": "a1b173e013f9210a" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "f4f5ff705a3f7402", + "metadata": {}, + "outputs": [], "source": [ "import jax\n", "\n", @@ -123,20 +124,20 @@ "llh, results = run_simulations(jax_problem)\n", "\n", "results" - ], - "id": "f4f5ff705a3f7402" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories.", - "id": "fe4d3b40ee3efdf2" + "id": "fe4d3b40ee3efdf2", + "metadata": {}, + "source": "Success! The simulation completed successfully, and we can now plot the resulting state trajectories." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "72f1ed397105e14a", + "metadata": {}, + "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", @@ -171,41 +172,41 @@ "\n", "# Plot the simulation results\n", "plot_simulation(results)" - ], - "id": "72f1ed397105e14a" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all.", - "id": "4fa97c33719c2277" + "id": "4fa97c33719c2277", + "metadata": {}, + "source": "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "7950774a3e989042", + "metadata": {}, + "outputs": [], "source": [ "llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n", "results" - ], - "id": "7950774a3e989042" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "98b8516a75ce4d12", + "metadata": {}, "source": [ "## Updating Parameters\n", "\n", "As next step, we will update the parameter values used for simulation. However, if we attempt to directly modify the values in `JAXModel.parameters`, we encounter a `FrozenInstanceError`." - ], - "id": "98b8516a75ce4d12" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "3d278a3d21e709d", + "metadata": {}, + "outputs": [], "source": [ "from dataclasses import FrozenInstanceError\n", "import jax\n", @@ -223,24 +224,24 @@ " jax_problem.parameters += noise\n", "except FrozenInstanceError as e:\n", " print(\"Error:\", e)" - ], - "id": "3d278a3d21e709d" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "4cc3d595de4a4085", + "metadata": {}, "source": [ "The root cause of this error lies in the fact that, to enable autodiff, direct modifications of attributes are not allowed in [equinox](https://docs.kidger.site/equinox/), which AMICI utilizes under the hood. Consequently, attributes of instances like `JAXModel` or `JAXProblem` cannot be updated directly — this is the price we have to pay for autodiff.\n", "\n", "However, `JAXProblem` provides a convenient method called [update_parameters](https://amici.readthedocs.io/en/latest/generated/amici.jax.html#amici.jax.JAXProblem.update_parameters). The caveat is that this method creates a new JAXProblem instance instead of modifying the existing one." - ], - "id": "4cc3d595de4a4085" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "e47748376059628b", + "metadata": {}, + "outputs": [], "source": [ "# Update the parameters and create a new JAXProblem instance\n", "jax_problem = jax_problem.update_parameters(jax_problem.parameters + noise)\n", @@ -250,105 +251,111 @@ "\n", "# Plot the simulation results\n", "plot_simulation(results)" - ], - "id": "e47748376059628b" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "660baf605a4e8339", + "metadata": {}, "source": [ "## Computing Gradients\n", "\n", "Similar to updating attributes, computing gradients in the JAX ecosystem can feel a bit unconventional if you’re not familiar with the JAX ecosysmt. JAX offers [powerful automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) through the `jax.grad` function. However, to use `jax.grad` with `JAXProblem`, we need to specify which parts of the `JAXProblem` should be treated as static." - ], - "id": "660baf605a4e8339" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "7033d09cc81b7f69", + "metadata": {}, + "outputs": [], "source": [ "try:\n", " # Attempt to compute the gradient of the run_simulations function\n", " jax.grad(run_simulations, has_aux=True)(jax_problem)\n", "except TypeError as e:\n", " print(\"Error:\", e)" - ], - "id": "7033d09cc81b7f69" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`.", - "id": "dc9bc07cde00a926" + "id": "dc9bc07cde00a926", + "metadata": {}, + "source": "Fortunately, `equinox` simplifies this process by offering [filter_grad](https://docs.kidger.site/equinox/api/transformations/#equinox.filter_grad), which enables autodiff functionality that is compatible with `JAXProblem` and, in theory, also with `JAXModel`." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "a6704182200e6438", + "metadata": {}, + "outputs": [], "source": [ "import equinox as eqx\n", "\n", "# Compute the gradient using equinox's filter_grad, preserving auxiliary outputs\n", "grad, _ = eqx.filter_grad(run_simulations, has_aux=True)(jax_problem)" - ], - "id": "a6704182200e6438" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`.", - "id": "851c3ec94cb5d086" + "id": "851c3ec94cb5d086", + "metadata": {}, + "source": "Functions transformed by `filter_grad` return gradients that share the same structure as the first argument (unless specified otherwise). This allows us to access the gradient with respect to the parameters attribute directly `via grad.parameters`." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "grad.parameters", - "id": "c00c1581d7173d7a" + "id": "c00c1581d7173d7a", + "metadata": {}, + "outputs": [], + "source": [ + "grad.parameters" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`.", - "id": "375b835fecc5a022" + "id": "375b835fecc5a022", + "metadata": {}, + "source": "Attributes for which derivatives cannot be computed (typically anything that is not a [jax.numpy.array](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html)) are automatically set to `None`." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "grad", - "id": "f7c17f7459d0151f" + "id": "f7c17f7459d0151f", + "metadata": {}, + "outputs": [], + "source": [ + "grad" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out.", - "id": "8eb7cc3db510c826" + "id": "8eb7cc3db510c826", + "metadata": {}, + "source": "Observant readers may notice that the gradient above appears to include numeric values for derivatives with respect to some measurements. However, `simulation_conditions` internally disables gradient computations using `jax.lax.stop_gradient`, resulting in these values being zeroed out." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "grad._measurements[simulation_condition]", - "id": "3badd4402cf6b8c6" + "id": "3badd4402cf6b8c6", + "metadata": {}, + "outputs": [], + "source": [ + "grad._measurements[simulation_condition]" + ] }, { - "metadata": {}, "cell_type": "markdown", - "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation.", - "id": "58eb04393a1463d" + "id": "58eb04393a1463d", + "metadata": {}, + "source": "However, we can compute derivatives with respect to data elements using `JAXModel.simulate_condition`. In the example below, we differentiate the observables `y` (specified by passing `y` to the `ret` argument) with respect to the timepoints at which the model outputs are computed after the solving the differential equation. While this might not be particularly practical, it serves as an nice illustration of the power of automatic differentiation." }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "1a91aff44b93157", + "metadata": {}, + "outputs": [], "source": [ "import jax.numpy as jnp\n", "import diffrax\n", @@ -363,7 +370,7 @@ "]\n", "\n", "# Load parameters for the specified condition\n", - "p = jax_problem.load_parameters(simulation_condition[0])\n", + "p = jax_problem.load_model_parameters(simulation_condition[0])\n", "\n", "\n", "# Define a function to compute the gradient with respect to dynamic timepoints\n", @@ -388,24 +395,24 @@ "# Compute the gradient with respect to `ts_dyn`\n", "g = grad_ts_dyn(ts_dyn)\n", "g" - ], - "id": "1a91aff44b93157" + ] }, { - "metadata": {}, "cell_type": "markdown", + "id": "9f870da7754e139c", + "metadata": {}, "source": [ "## Compilation & Profiling\n", "\n", "To maximize performance with JAX, code should be just-in-time (JIT) compiled. This can be achieved using the `jax.jit` or `equinox.filter_jit` decorators. While JIT compilation introduces some overhead during the first function call, it significantly improves performance for subsequent calls. To demonstrate this, we will first clear the JIT cache and then profile the execution." - ], - "id": "9f870da7754e139c" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "58ebdc110ea7457e", + "metadata": {}, + "outputs": [], "source": [ "from time import time\n", "\n", @@ -414,14 +421,14 @@ "\n", "# Define a JIT-compiled gradient function with auxiliary outputs\n", "gradfun = eqx.filter_jit(eqx.filter_grad(run_simulations, has_aux=True))" - ], - "id": "58ebdc110ea7457e" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "e1242075f7e0faf", + "metadata": {}, + "outputs": [], "source": [ "# Measure the time taken for the first function call (including compilation)\n", "start = time()\n", @@ -432,14 +439,14 @@ "start = time()\n", "gradfun(jax_problem)\n", "print(f\"Gradient compilation time: {time() - start:.2f} seconds\")" - ], - "id": "e1242075f7e0faf" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "27181f367ccb1817", + "metadata": {}, + "outputs": [], "source": [ "%%timeit\n", "run_simulations(\n", @@ -452,14 +459,14 @@ " dcoeff=0.0, # recommended value for stiff systems\n", " ),\n", ")" - ], - "id": "27181f367ccb1817" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "5b8d3a6162a3ae55", + "metadata": {}, + "outputs": [], "source": [ "%%timeit \n", "gradfun(\n", @@ -472,14 +479,14 @@ " dcoeff=0.0, # recommended value for stiff systems\n", " ),\n", ")" - ], - "id": "5b8d3a6162a3ae55" + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, + "id": "d733a450635a749b", + "metadata": {}, + "outputs": [], "source": [ "from amici.petab import simulate_petab\n", "import amici\n", @@ -500,8 +507,7 @@ "problem_parameters = dict(\n", " zip(jax_problem.parameter_ids, jax_problem.parameters)\n", ")" - ], - "id": "d733a450635a749b" + ] }, { "cell_type": "code", diff --git a/python/sdist/amici/de_export.py b/python/sdist/amici/de_export.py index f0ec08133f..ea078a42bd 100644 --- a/python/sdist/amici/de_export.py +++ b/python/sdist/amici/de_export.py @@ -169,6 +169,7 @@ def __init__( allow_reinit_fixpar_initcond: bool | None = True, generate_sensitivity_code: bool | None = True, model_name: str | None = "model", + hybridisation: dict | None = None, ): """ Generate AMICI C++ files for the DE provided to the constructor. @@ -232,6 +233,7 @@ def __init__( self.allow_reinit_fixpar_initcond: bool = allow_reinit_fixpar_initcond self._build_hints = set() self.generate_sensitivity_code: bool = generate_sensitivity_code + self.hybridisation = hybridisation @log_execution_time("generating cpp code", logger) def generate_model_code(self) -> None: diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 8ad2e7a998..4f4f9f466b 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -28,10 +28,12 @@ AlgebraicEquation, Observable, EventObservable, + Sigma, SigmaY, SigmaZ, Parameter, Constant, + LogLikelihood, LogLikelihoodY, LogLikelihoodZ, LogLikelihoodRZ, @@ -199,6 +201,7 @@ def __init__( verbose: bool | int | None = False, simplify: Callable | None = _default_simplify, cache_simplify: bool = False, + hybridisation: bool = False, ): """ Create a new DEModel instance. @@ -2305,3 +2308,137 @@ def _process_heavisides( dxdt = dxdt.subs(heaviside_sympy, heaviside_amici) return dxdt + + @property + def _components(self) -> list[ModelQuantity]: + """ + Returns the components of the model + + :return: + components of the model + """ + return ( + self._algebraic_states + + self._algebraic_equations + + self._conservation_laws + + self._constants + + self._differential_states + + self._event_observables + + self._events + + self._expressions + + self._log_likelihood_ys + + self._log_likelihood_zs + + self._log_likelihood_rzs + + self._observables + + self._parameters + + self._sigma_ys + + self._sigma_zs + + self._splines + ) + + def _process_hybridization(self, hybridization: dict) -> None: + """ + Parses the hybridisation information and updates the model accordingly + + :param hybridization: + hybridization information + """ + added_expressions = False + for net_id, net in hybridization.items(): + if not ( + net["hybridization"]["output"] == "ode" + or net["hybridization"]["input"] == "ode" + ): + continue # do not integrate into ODEs, handle in amici.jax.petab + inputs = [ + comp + for comp in self._components + if str(comp.get_id()) in net["input_vars"] + ] + # sort inputs by order in input_vars + inputs = sorted( + inputs, + key=lambda comp: net["input_vars"].index(str(comp.get_id())), + ) + if len(inputs) != len(net["input_vars"]): + raise ValueError( + f"Could not find all input variables for neural network {net_id}" + ) + for inp in inputs: + if isinstance( + inp, + Sigma + | LogLikelihood + | Event + | ConservationLaw + | Observable, + ): + raise NotImplementedError( + f"{inp.get_name()} ({type(inp)}) is not supported as neural network input." + ) + + outputs = { + out_var: comp + for comp in self._components + if (out_var := str(comp.get_id())) in net["output_vars"] + # TODO: SYNTAX NEEDS to CHANGE + or (out_var := str(comp.get_id()) + "_dot") + in net["output_vars"] + } + if len(outputs.keys()) != len(net["output_vars"]): + raise ValueError( + f"Could not find all output variables for neural network {net_id}" + ) + for iout, (out_var, comp) in enumerate(outputs.items()): + # remove output from model components + if isinstance(comp, Parameter): + self._parameters.remove(comp) + elif isinstance(comp, Expression): + self._expressions.remove(comp) + elif isinstance(comp, DifferentialState): + pass + else: + raise NotImplementedError( + f"{comp.get_name()} ({type(comp)}) is not supported as neural network output." + ) + + # generate dummy Function + out_val = sp.Function(net_id)( + *[input.get_id() for input in inputs], iout + ) + + # add to the model + if isinstance(comp, DifferentialState): + ix = self._differential_states.index(comp) + # TODO: SYNTAX NEEDS to CHANGE + if out_var.endswith("_dot"): + self._differential_states[ix].set_dt(out_val) + else: + self._differential_states[ix].set_val(out_val) + else: + self.add_component( + Expression( + identifier=comp.get_id(), + name=net_id, + value=out_val, + ) + ) + added_expressions = True + + if added_expressions: + # toposort expressions + w_sorted = toposort_symbols( + dict( + zip( + self.sym("w"), + self.eq("w"), + strict=True, + ) + ) + ) + old_syms = tuple(self._syms["w"]) + topo_expr_syms = tuple(w_sorted.keys()) + new_order = [old_syms.index(s) for s in topo_expr_syms] + self._expressions = [self._expressions[i] for i in new_order] + self._syms["w"] = sp.Matrix(topo_expr_syms) + self._eqs["w"] = sp.Matrix(list(w_sorted.values())) diff --git a/python/sdist/amici/jax/__init__.py b/python/sdist/amici/jax/__init__.py index a5b5dc1cae..945e054178 100644 --- a/python/sdist/amici/jax/__init__.py +++ b/python/sdist/amici/jax/__init__.py @@ -16,6 +16,7 @@ ReturnValue, ) from amici.jax.model import JAXModel +from amici.jax.nn import generate_equinox warn( "The JAX module is experimental and the API may change in the future.", @@ -26,6 +27,7 @@ __all__ = [ "JAXModel", "JAXProblem", + "generate_equinox", "run_simulations", "petab_simulate", "ReturnValue", diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index 5d5521d222..41f38cc4d0 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -1,9 +1,13 @@ # ruff: noqa: F401, F821, F841 import jax.numpy as jnp +import jax.random as jr from interpax import interp1d from pathlib import Path from amici.jax.model import JAXModel, safe_log, safe_div +from amici import _module_from_path + +TPL_NET_IMPORTS class JAXModel_TPL_MODEL_NAME(JAXModel): @@ -11,6 +15,8 @@ class JAXModel_TPL_MODEL_NAME(JAXModel): def __init__(self): self.jax_py_file = Path(__file__).resolve() + self.nns = {TPL_NETS} + super().__init__() def _xdot(self, t, x, args): diff --git a/python/sdist/amici/jax/jaxcodeprinter.py b/python/sdist/amici/jax/jaxcodeprinter.py index 6cfce97b35..2b7d149c81 100644 --- a/python/sdist/amici/jax/jaxcodeprinter.py +++ b/python/sdist/amici/jax/jaxcodeprinter.py @@ -36,6 +36,9 @@ def _print_Mul(self, expr: sp.Expr) -> str: return super()._print_Mul(expr) return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})" + def _print_Function(self, expr): + return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]" + def _get_sym_lines( self, symbols: sp.Matrix | Iterable[str], diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 8b2c09fcc6..8b820498d6 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -46,6 +46,7 @@ class JAXModel(eqx.Module): MODEL_API_VERSION = "0.0.2" api_version: str jax_py_file: Path + nns: dict def __init__(self): if self.api_version != self.MODEL_API_VERSION: diff --git a/python/sdist/amici/jax/nn.py b/python/sdist/amici/jax/nn.py new file mode 100644 index 0000000000..d503df2393 --- /dev/null +++ b/python/sdist/amici/jax/nn.py @@ -0,0 +1,201 @@ +from pathlib import Path + + +import equinox as eqx +import jax.numpy as jnp + +from amici._codegen.template import apply_template +from amici import amiciModulePath + + +class Flatten(eqx.Module): + start_dim: int + end_dim: int + + def __init__(self, start_dim: int, end_dim: int): + super().__init__() + self.start_dim = start_dim + self.end_dim = end_dim + + def __call__(self, x): + if self.end_dim == -1: + return jnp.reshape(x, x.shape[: self.start_dim] + (-1,)) + else: + return jnp.reshape( + x, x.shape[: self.start_dim] + (-1,) + x.shape[self.end_dim :] + ) + + +def tanhshrink(x: jnp.ndarray) -> jnp.ndarray: + return x - jnp.tanh(x) + + +def generate_equinox(ml_model: "MLModel", filename: Path | str): # noqa: F821 + # TODO: move to top level import and replace forward type definitions + from petab_sciml import Layer + + filename = Path(filename) + layer_indent = 12 + node_indent = 8 + + layers = {layer.layer_id: layer for layer in ml_model.layers} + + tpl_data = { + "MODEL_ID": ml_model.mlmodel_id, + "LAYERS": ",\n".join( + [ + _generate_layer(layer, layer_indent, ilayer) + for ilayer, layer in enumerate(ml_model.layers) + ] + )[layer_indent:], + "FORWARD": "\n".join( + [ + _generate_forward( + node, + node_indent, + layers.get( + node.target, + Layer(layer_id="dummy", layer_type="Linear"), + ).layer_type, + ) + for node in ml_model.forward + ] + )[node_indent:], + "INPUT": ", ".join([f"'{inp.input_id}'" for inp in ml_model.inputs]), + "OUTPUT": ", ".join( + [ + f"'{arg}'" + for arg in next( + node for node in ml_model.forward if node.op == "output" + ).args + ] + ), + "N_LAYERS": len(ml_model.layers), + } + + filename.parent.mkdir(parents=True, exist_ok=True) + + apply_template( + Path(amiciModulePath) / "jax" / "nn.template.py", + filename, + tpl_data, + ) + + +def _process_argval(v): + if isinstance(v, str): + return f"'{v}'" + if isinstance(v, bool): + return str(v) + return str(v) + + +def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F821 + layer_map = { + "Dropout1d": "eqx.nn.Dropout", + "Dropout2d": "eqx.nn.Dropout", + "Flatten": "amici.jax.nn.Flatten", + } + if layer.layer_type.startswith( + ("BatchNorm", "AlphaDropout", "InstanceNorm") + ): + raise NotImplementedError( + f"{layer.layer_type} layers currently not supported" + ) + if layer.layer_type.startswith("MaxPool") and "dilation" in layer.args: + raise NotImplementedError("MaxPool layers with dilation not supported") + if layer.layer_type.startswith("Dropout") and "inplace" in layer.args: + raise NotImplementedError("Dropout layers with inplace not supported") + if layer.layer_type == "Bilinear": + raise NotImplementedError("Bilinear layers not supported") + + kwarg_map = { + "Linear": { + "bias": "use_bias", + }, + "Conv1d": { + "bias": "use_bias", + }, + "Conv2d": { + "bias": "use_bias", + }, + "LayerNorm": { + "affine": "elementwise_affine", + "normalized_shape": "shape", + }, + } + kwarg_ignore = { + "Dropout1d": ("inplace",), + "Dropout2d": ("inplace",), + } + kwargs = [ + f"{kwarg_map.get(layer.layer_type, {}).get(k, k)}={_process_argval(v)}" + for k, v in layer.args.items() + if k not in kwarg_ignore.get(layer.layer_type, ()) + ] + # add key for initialization + if layer.layer_type in ( + "Linear", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + ): + kwargs += [f"key=keys[{ilayer}]"] + type_str = layer_map.get(layer.layer_type, f"eqx.nn.{layer.layer_type}") + layer_str = f"{type_str}({', '.join(kwargs)})" + return f"{' ' * indent}'{layer.layer_id}': {layer_str}" + + +def _generate_forward(node: "Node", indent, layer_type=str) -> str: # noqa: F821 + if node.op == "placeholder": + # TODO: inconsistent target vs name + return f"{' ' * indent}{node.name} = input" + + if node.op == "call_module": + fun_str = f"self.layers['{node.target}']" + if layer_type.startswith(("Conv", "Linear", "LayerNorm")): + if layer_type in ("LayerNorm",): + dims = f"len({fun_str}.shape)+1" + if layer_type == "Linear": + dims = 2 + if layer_type.endswith(("1d",)): + dims = 3 + elif layer_type.endswith(("2d",)): + dims = 4 + elif layer_type.endswith("3d"): + dims = 5 + fun_str = f"(jax.vmap({fun_str}) if len({node.args[0]}.shape) == {dims} else {fun_str})" + + if node.op in ("call_function", "call_method"): + map_fun = { + "hardtanh": "jax.nn.hard_tanh", + "hardsigmoid": "jax.nn.hard_sigmoid", + "hardswish": "jax.nn.hard_swish", + "tanhshrink": "amici.jax.nn.tanhshrink", + "softsign": "jax.nn.soft_sign", + } + if node.target == "hardtanh": + if node.kwargs.pop("min_val", -1.0) != -1.0: + raise NotImplementedError( + "min_val != -1.0 not supported for hardtanh" + ) + if node.kwargs.pop("max_val", 1.0) != 1.0: + raise NotImplementedError( + "max_val != 1.0 not supported for hardtanh" + ) + fun_str = map_fun.get(node.target, f"jax.nn.{node.target}") + + args = ", ".join([f"{arg}" for arg in node.args]) + kwargs = [ + f"{k}={v}" for k, v in node.kwargs.items() if k not in ("inplace",) + ] + if layer_type.startswith(("Dropout",)): + kwargs += ["key=key"] + kwargs_str = ", ".join(kwargs) + if node.op in ("call_module", "call_function", "call_method"): + return f"{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})" + if node.op == "output": + return f"{' ' * indent}{node.target} = {args}" diff --git a/python/sdist/amici/jax/nn.template.py b/python/sdist/amici/jax/nn.template.py new file mode 100644 index 0000000000..b07a251e64 --- /dev/null +++ b/python/sdist/amici/jax/nn.template.py @@ -0,0 +1,26 @@ +# ruff: noqa: F401, F821, F841 +import equinox as eqx +import jax.nn +import jax.random as jr +import jax +import amici.jax.nn + + +class TPL_MODEL_ID(eqx.Module): + layers: dict + inputs: list[str] + outputs: list[str] + + def __init__(self, key): + super().__init__() + keys = jr.split(key, TPL_N_LAYERS) + self.layers = {TPL_LAYERS} + self.inputs = [TPL_INPUT] + self.outputs = [TPL_OUTPUT] + + def forward(self, input, key=None): + TPL_FORWARD + return output + + +net = TPL_MODEL_ID diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 4329195441..a374042f4a 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -24,6 +24,7 @@ from amici._codegen.template import apply_template from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter from amici.jax.model import JAXModel +from amici.jax.nn import generate_equinox from amici.de_model import DEModel from amici.de_export import is_valid_identifier from amici.import_utils import ( @@ -129,6 +130,7 @@ def __init__( outdir: Path | str | None = None, verbose: bool | int | None = False, model_name: str | None = "model", + hybridisation: dict[str, dict] = None, ): """ Generate AMICI jax files for the ODE provided to the constructor. @@ -157,6 +159,8 @@ def __init__( self.model: DEModel = ode_model + self.hybridisation = hybridisation if hybridisation is not None else {} + self._code_printer = AmiciJaxCodePrinter() @log_execution_time("generating jax code", logger) @@ -169,6 +173,7 @@ def generate_model_code(self) -> None: ): self._prepare_model_folder() self._generate_jax_code() + self._generate_nn_code() def _prepare_model_folder(self) -> None: """ @@ -233,6 +238,14 @@ def _generate_jax_code(self) -> None: # can flag conflicts in the future "MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'", }, + "NET_IMPORTS": "\n".join( + f"{net} = _module_from_path('{net}', Path(__file__).parent / '{net}.py')" + for net in self.hybridisation.keys() + ), + "NETS": ",\n".join( + f'"{net}": {net}.net(jr.PRNGKey(0))' + for net in self.hybridisation.keys() + ), } apply_template( @@ -241,6 +254,14 @@ def _generate_jax_code(self) -> None: tpl_data, ) + def _generate_nn_code(self) -> None: + for net_name, net in self.hybridisation.items(): + for model in net["model"]: + generate_equinox( + model, + self.model_path / f"{net_name}.py", + ) + def set_paths(self, output_dir: str | Path | None = None) -> None: """ Set output paths for the model and create if necessary diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 43498ce536..9afe3ba622 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -91,6 +91,7 @@ class JAXProblem(eqx.Module): np.ndarray, ], ] + _inputs: dict[str, dict[str, np.ndarray]] _petab_measurement_indices: dict[tuple[str, ...], tuple[int, ...]] _petab_problem: petab.Problem @@ -103,14 +104,14 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): :param petab_problem: PEtab problem to simulate. """ - self.model = model scs = petab_problem.get_simulation_conditions_from_measurement_df() self._petab_problem = petab_problem + self.parameters, self.model = self._get_nominal_parameter_values(model) self._parameter_mappings = self._get_parameter_mappings(scs) + self._inputs = self._get_inputs() self._measurements, self._petab_measurement_indices = ( self._get_measurements(scs) ) - self.parameters = self._get_nominal_parameter_values() def save(self, directory: Path): """ @@ -258,13 +259,71 @@ def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: ) return tuple(tuple(row) for _, row in simulation_conditions.iterrows()) - def _get_nominal_parameter_values(self) -> jt.Float[jt.Array, "np"]: + def _get_nominal_parameter_values( + self, model: JAXModel + ) -> tuple[jt.Float[jt.Array, "np"], JAXModel]: """ Get the nominal parameter values for the model based on the nominal values in the PEtab problem. + Also set nominal values in the model (where applicable). :return: - jax array with nominal parameter values - """ + jax array with nominal parameter values and model with nominal parameter values set. + """ + # initialize everything with zeros + model_pars = { + net_id: { + layer_id: { + attribute: jnp.zeros_like(getattr(layer, attribute)) + for attribute in ["weight", "bias"] + if hasattr(layer, attribute) + } + for layer_id, layer in nn.layers.items() + } + for net_id, nn in model.nns.items() + } + # extract nominal values from petab problem + for pname, row in self._petab_problem.parameter_df.iterrows(): + if (net := pname.split(".")[0]) in model.nns: + to_set = [] + nn = model_pars[net] + if len(pname.split(".")) > 1: + layer = nn[pname.split(".")[1]] + if len(pname.split(".")) > 2: + to_set.append( + (pname.split(".")[1], pname.split(".")[2]) + ) + else: + to_set.extend( + [ + (pname.split(".")[1], attribute) + for attribute in layer.keys() + ] + ) + else: + to_set.extend( + [ + (layer_name, attribute) + for layer_name, layer in nn.items() + for attribute in layer.keys() + ] + ) + + for layer, attribute in to_set: + nn[layer][attribute] = row[ + petab.NOMINAL_VALUE + ] * jnp.ones_like(nn[layer][attribute]) + + # set values in model + for net_id in model_pars: + for layer_id in model_pars[net_id]: + for attribute in model_pars[net_id][layer_id]: + model = eqx.tree_at( + lambda model: getattr( + model.nns[net_id].layers[layer_id], attribute + ), + model, + model_pars[net_id][layer_id][attribute], + ) return jnp.array( [ petab.scale( @@ -277,7 +336,32 @@ def _get_nominal_parameter_values(self) -> jt.Float[jt.Array, "np"]: ) for pval in self.parameter_ids ] - ) + ), model + + def _get_inputs(self): + if self._petab_problem.mapping_df is None: + return {} + inputs = {net: {} for net in self.model.nns.keys()} + for petab_id, row in self._petab_problem.mapping_df.iterrows(): + if (filepath := Path(petab_id)).is_file(): + data_flat = pd.read_csv(filepath, sep="\t").sort_values( + by="ix" + ) + shape = tuple( + np.stack( + data_flat["ix"] + .astype(str) + .str.split(";") + .apply(np.array) + ) + .astype(int) + .max(axis=0) + + 1 + ) + inputs[row["netId"]][row[petab.MODEL_ENTITY_ID]] = data_flat[ + "value" + ].values.reshape(shape) + return inputs @property def parameter_ids(self) -> list[str]: @@ -291,6 +375,23 @@ def parameter_ids(self) -> list[str]: self._petab_problem.parameter_df[petab.ESTIMATE] == 1 ].index.tolist() + @property + def nn_output_ids(self) -> list[str]: + """ + Parameter ids that are estimated in the PEtab problem. Same ordering as values in :attr:`parameters`. + + :return: + PEtab parameter ids + """ + if self._petab_problem.mapping_df is None: + return [] + return self._petab_problem.mapping_df[ + self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[1] + .str.startswith("output") + ].index.tolist() + def get_petab_parameter_by_id(self, name: str) -> jnp.float_: """ Get the value of a PEtab parameter by name. @@ -319,7 +420,52 @@ def _unscale( [jax_unscale(pval, scale) for pval, scale in zip(p, scales)] ) - def load_parameters( + def _eval_nn(self, output_par: str): + net_id = self._petab_problem.mapping_df.loc[ + output_par, petab.MODEL_ENTITY_ID + ].split(".")[0] + nn = self.model.nns[net_id] + + model_id_map = ( + self._petab_problem.mapping_df[ + self._petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id + ] + .reset_index() + .set_index(petab.MODEL_ENTITY_ID)[petab.PETAB_ENTITY_ID] + .to_dict() + ) + + net_input = jnp.array( + [ + jax.lax.stop_gradient(self._inputs[net_id][model_id]) + if model_id in self._inputs[net_id] + else self.get_petab_parameter_by_id(petab_id) + if petab_id in self.parameter_ids + else self._petab_problem.parameter_df.loc[ + petab_id, petab.NOMINAL_VALUE + ] + for model_id, petab_id in model_id_map.items() + if model_id.split(".")[1].startswith("input") + ] + ) + return nn.forward(net_input).squeeze() + + def _map_model_parameter_value( + self, + mapping: ParameterMappingForCondition, + pname: str, + ) -> jt.Float[jt.Scalar, ""] | float: # noqa: F722 + if pname in self.nn_output_ids: + return self._eval_nn(pname) + pval = mapping.map_sim_var[pname] + if isinstance(pval, Number): + return pval + return self.get_petab_parameter_by_id(pval) + + def load_model_parameters( self, simulation_condition: str ) -> jt.Float[jt.Array, "np"]: """ @@ -331,17 +477,19 @@ def load_parameters( Parameters for the simulation condition. """ mapping = self._parameter_mappings[simulation_condition] + p = jnp.array( [ - pval - if isinstance(pval := mapping.map_sim_var[pname], Number) - else self.get_petab_parameter_by_id(pval) + self._map_model_parameter_value(mapping, pname) for pname in self.model.parameter_ids ] ) pscale = tuple( [ - mapping.scale_map_sim_var[pname] + petab.LIN + if self._petab_problem.mapping_df is not None + and pname in self._petab_problem.mapping_df.index + else mapping.scale_map_sim_var[pname] for pname in self.model.parameter_ids ] ) @@ -362,6 +510,9 @@ def _state_needs_reinitialisation( :return: True if state needs reinitialisation, False otherwise """ + if state_id in self.nn_output_ids: + return True + if state_id not in self._petab_problem.condition_df: return False xval = self._petab_problem.condition_df.loc[ @@ -389,6 +540,9 @@ def _state_reinitialisation_value( :return: reinitialisation value for the state """ + if state_id in self.nn_output_ids: + return self._eval_nn(state_id) + if state_id not in self._petab_problem.condition_df: # no reinitialisation, return dummy value return 0.0 @@ -433,6 +587,7 @@ def load_reinitialisation( """ if not any( x_id in self._petab_problem.condition_df + or x_id in self.nn_output_ids for x_id in self.model.state_ids ): return jnp.array([]), jnp.array([]) @@ -495,7 +650,7 @@ def run_simulation( ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[ simulation_condition ] - p = self.load_parameters(simulation_condition[0]) + p = self.load_model_parameters(simulation_condition[0]) mask_reinit, x_reinit = self.load_reinitialisation( simulation_condition[0], p ) @@ -543,7 +698,7 @@ def run_preequilibration( :return: Pre-equilibration state """ - p = self.load_parameters(simulation_condition) + p = self.load_model_parameters(simulation_condition) mask_reinit, x_reinit = self.load_reinitialisation( simulation_condition, p ) diff --git a/python/sdist/amici/petab/parameter_mapping.py b/python/sdist/amici/petab/parameter_mapping.py index 3bd0e69ac2..77af598f63 100644 --- a/python/sdist/amici/petab/parameter_mapping.py +++ b/python/sdist/amici/petab/parameter_mapping.py @@ -352,7 +352,11 @@ def create_parameter_mapping( if petab_problem.model.type_id == MODEL_TYPE_SBML: import libsbml - if petab_problem.sbml_document: + # v1 guard + if ( + isinstance(petab_problem, petab.Problem) + and petab_problem.sbml_document + ): converter_config = ( libsbml.SBMLLocalParameterConverter().getDefaultProperties() ) diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index b7fccca241..a13ab5c4d9 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -99,9 +99,9 @@ def import_petab_problem( ) if petab_problem.mapping_df is not None: - # It's partially supported. Remove at your own risk... - raise NotImplementedError( - "PEtab v2.0.0 mapping tables are not yet supported." + warn( + "PEtab v2.0.0 mapping tables are only partially supported, use at your own risk.", + stacklevel=2, ) model_name = model_name or petab_problem.model.model_id @@ -145,6 +145,53 @@ def import_petab_problem( shutil.rmtree(model_output_dir) logger.info(f"Compiling model {model_name} to {model_output_dir}.") + + if "petab_sciml" in petab_problem.extensions_config: + from petab_sciml import PetabScimlStandard + + config = petab_problem.extensions_config["petab_sciml"] + hybridization = { + net_id: { + "model": PetabScimlStandard.load_data( + Path() / net_config["file"] + ).models, + "input_vars": [ + petab_id + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("input") + ], + "output_vars": [ + petab_id + for petab_id, model_id in petab_problem.mapping_df.loc[ + petab_problem.mapping_df[petab.MODEL_ENTITY_ID] + .str.split(".") + .str[0] + == net_id, + petab.MODEL_ENTITY_ID, + ] + .to_dict() + .items() + if model_id.split(".")[1].startswith("output") + ], + **net_config, + } + for net_id, net_config in config.items() + } + if not jax or petab_problem.model.type_id == MODEL_TYPE_PYSB: + raise NotImplementedError( + "petab_sciml extension is currently only supported for sbml models" + ) + else: + hybridization = None + # compile the model if petab_problem.model.type_id == MODEL_TYPE_PYSB: import_model_pysb( @@ -160,6 +207,7 @@ def import_petab_problem( model_name=model_name, model_output_dir=model_output_dir, non_estimated_parameters_as_constants=non_estimated_parameters_as_constants, + hybridization=hybridization, jax=jax, **kwargs, ) diff --git a/python/sdist/amici/petab/sbml_import.py b/python/sdist/amici/petab/sbml_import.py index e605a9cc80..d9659e5fd6 100644 --- a/python/sdist/amici/petab/sbml_import.py +++ b/python/sdist/amici/petab/sbml_import.py @@ -205,6 +205,7 @@ def import_model_sbml( non_estimated_parameters_as_constants=True, output_parameter_defaults: dict[str, float] | None = None, discard_sbml_annotations: bool = False, + hybridization: dict = None, jax: bool = False, **kwargs, ) -> amici.SbmlImporter: @@ -380,10 +381,15 @@ def import_model_sbml( sigmas=sigmas, noise_distributions=noise_distrs, verbose=verbose, + hybridization=hybridization, **kwargs, ) return sbml_importer else: + if hybridization: + raise NotImplementedError( + "Hybridization is currently only supported for JAX models." + ) sbml_importer.sbml2amici( model_name=model_name, output_dir=model_output_dir, diff --git a/python/sdist/amici/petab/util.py b/python/sdist/amici/petab/util.py index 48e6ed7786..ebee360953 100644 --- a/python/sdist/amici/petab/util.py +++ b/python/sdist/amici/petab/util.py @@ -28,7 +28,13 @@ def get_states_in_condition_table( species_check_funs = { MODEL_TYPE_SBML: lambda x: _element_is_sbml_state( - petab_problem.sbml_model, x + petab_problem.sbml_model, + x, # v1 + ) + if isinstance(petab_problem, petab.Problem) + else lambda x: _element_is_sbml_state( + petab_problem.model.sbml_model, + x, # v2 ), MODEL_TYPE_PYSB: lambda x: _element_is_pysb_pattern( petab_problem.model.model, x diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index 557ad02d0f..df540ba1da 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -204,7 +204,7 @@ def _process_document(self) -> None: log_execution_time("validating SBML", logger)( self.sbml_doc.validateSBML )() - _check_lib_sbml_errors(self.sbml_doc, self.show_sbml_warnings) + # _check_lib_sbml_errors(self.sbml_doc, self.show_sbml_warnings) # Flatten "comp" model? Do that before any other converters are run if any( @@ -253,7 +253,7 @@ def _process_document(self) -> None: # If any of the above calls produces an error, this will be added to # the SBMLError log in the sbml document. Thus, it is sufficient to # check the error log just once after all conversion/validation calls. - _check_lib_sbml_errors(self.sbml_doc, self.show_sbml_warnings) + # _check_lib_sbml_errors(self.sbml_doc, self.show_sbml_warnings) # need to reload the converted model self.sbml = self.sbml_doc.getModel() @@ -458,6 +458,7 @@ def sbml2jax( simplify: Callable | None = _default_simplify, cache_simplify: bool = False, log_as_log10: bool = True, + hybridization: dict = None, ) -> None: """ Generate and compile AMICI jax files for the model provided to the @@ -538,6 +539,7 @@ def sbml2jax( simplify=simplify, cache_simplify=cache_simplify, log_as_log10=log_as_log10, + hybridization=hybridization, ) from amici.jax.ode_export import ODEExporter @@ -547,6 +549,7 @@ def sbml2jax( model_name=model_name, outdir=output_dir, verbose=verbose, + hybridisation=hybridization, ) exporter.generate_model_code() @@ -565,6 +568,7 @@ def _build_ode_model( cache_simplify: bool = False, log_as_log10: bool = True, hardcode_symbols: Sequence[str] = None, + hybridization: dict = None, ) -> DEModel: """Generate an ODEModel from this SBML model. @@ -727,6 +731,9 @@ def _build_ode_model( if compute_conservation_laws: self._process_conservation_laws(ode_model) + if hybridization: + ode_model._process_hybridization(hybridization) + # fill in 'self._sym' based on prototypes and components in ode_model ode_model.generate_basic_variables() diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py new file mode 100644 index 0000000000..e99b0a88cc --- /dev/null +++ b/tests/sciml/test_sciml.py @@ -0,0 +1,356 @@ +from yaml import safe_load +import pytest + +from pathlib import Path +import petab.v1 as petab +from amici.petab import import_petab_problem +from amici.jax import ( + JAXProblem, + generate_equinox, + run_simulations, + petab_simulate, +) +import amici +import diffrax +import pandas as pd +import jax.numpy as jnp +import jax.random as jr +import jax +import numpy as np +import equinox as eqx +import os +import h5py +from contextlib import contextmanager + +from petab_sciml import PetabScimlStandard + + +@contextmanager +def change_directory(destination): + # Save the current working directory + original_directory = os.getcwd() + try: + # Change to the new directory + os.chdir(destination) + yield + finally: + # Change back to the original directory + os.chdir(original_directory) + + +jax.config.update("jax_enable_x64", True) + + +# pip install git+https://github.com/sebapersson/petab_sciml@add_standard#egg=petab_sciml\&subdirectory=src/python + +cases_dir = Path(__file__).parent / "testsuite" / "test_cases" + + +def _reshape_flat_array(array_flat): + array_flat["ix"] = array_flat["ix"].astype(str) + ix_cols = [ + f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";"))) + ] + if len(ix_cols) == 1: + array_flat[ix_cols[0]] = array_flat["ix"].apply(int) + else: + array_flat[ix_cols] = pd.DataFrame( + array_flat["ix"].str.split(";").apply(np.array).to_list(), + index=array_flat.index, + ).astype(int) + array_flat.sort_values(by=ix_cols, inplace=True) + array_shape = tuple(array_flat[ix_cols].max().astype(int) + 1) + array = np.array(array_flat["value"].values).reshape(array_shape) + return array + + +@pytest.mark.parametrize( + "test", sorted([d.stem for d in cases_dir.glob("net_[0-9]*")]) +) +def test_net(test): + test_dir = cases_dir / test + with open(test_dir / "solutions.yaml") as f: + solutions = safe_load(f) + + if test.endswith("_alt"): + net_file = cases_dir / test.replace("_alt", "") / solutions["net_file"] + else: + net_file = test_dir / solutions["net_file"] + ml_models = PetabScimlStandard.load_data(net_file) + + nets = {} + outdir = Path(__file__).parent / "models" / test + for ml_model in ml_models.models: + module_dir = outdir / f"{ml_model.mlmodel_id}.py" + if test in ( + "net_002", + "net_009", + "net_018", + "net_019", + "net_020", + "net_021", + "net_022", + "net_042", + "net_043", + "net_044", + "net_045", + "net_046", + "net_047", + "net_048", + ): + with pytest.raises(NotImplementedError): + generate_equinox(ml_model, module_dir) + return + generate_equinox(ml_model, module_dir) + nets[ml_model.mlmodel_id] = amici._module_from_path( + ml_model.mlmodel_id, module_dir + ).net + + for input_file, par_file, output_file in zip( + solutions["net_input"], + solutions.get("net_ps", solutions["net_input"]), + solutions["net_output"], + ): + input_flat = pd.read_csv(test_dir / input_file, sep="\t") + input = _reshape_flat_array(input_flat) + + output_flat = pd.read_csv(test_dir / output_file, sep="\t") + output = _reshape_flat_array(output_flat) + + if "net_ps" in solutions: + par = pd.read_csv(test_dir / par_file, sep="\t") + for ml_model in ml_models.models: + net = nets[ml_model.mlmodel_id](jr.PRNGKey(0)) + for layer in net.layers.keys(): + layer_prefix = f"net_{layer}" + if ( + isinstance(net.layers[layer], eqx.Module) + and hasattr(net.layers[layer], "weight") + and net.layers[layer].weight is not None + ): + prefix = layer_prefix + "_weight" + df = par[ + par[petab.PARAMETER_ID].str.startswith(prefix) + ] + df["ix"] = ( + df[petab.PARAMETER_ID] + .str.split("_") + .str[3:] + .apply(lambda x: ";".join(x)) + ) + w = _reshape_flat_array(df) + if isinstance(net.layers[layer], eqx.nn.ConvTranspose): + # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose + w = np.flip( + w, axis=tuple(range(2, w.ndim)) + ).swapaxes(0, 1) + assert w.shape == net.layers[layer].weight.shape + net = eqx.tree_at( + lambda x: x.layers[layer].weight, + net, + jnp.array(w), + ) + if ( + isinstance(net.layers[layer], eqx.Module) + and hasattr(net.layers[layer], "bias") + and net.layers[layer].bias is not None + ): + prefix = layer_prefix + "_bias" + df = par[ + par[petab.PARAMETER_ID].str.startswith(prefix) + ] + df["ix"] = ( + df[petab.PARAMETER_ID] + .str.split("_") + .str[3:] + .apply(lambda x: ";".join(x)) + ) + b = _reshape_flat_array(df) + if isinstance( + net.layers[layer], + eqx.nn.Conv | eqx.nn.ConvTranspose, + ): + b = np.expand_dims( + b, + tuple( + range( + 1, + net.layers[layer].num_spatial_dims + 1, + ) + ), + ) + assert b.shape == net.layers[layer].bias.shape + net = eqx.tree_at( + lambda x: x.layers[layer].bias, + net, + jnp.array(b), + ) + net = eqx.nn.inference_mode(net) + + if test == "net_004_alt": + return # skipping, no support for non-cross-correlation in equinox + + np.testing.assert_allclose( + net.forward(input), + output, + atol=1e-3, + rtol=1e-3, + ) + + +@pytest.mark.parametrize( + "test", sorted([d.stem for d in cases_dir.glob("[0-9]*")]) +) +def test_ude(test): + test_dir = cases_dir / test + with open(test_dir / "petab" / "problem_ude.yaml") as f: + petab_yaml = safe_load(f) + with open(test_dir / "solutions.yaml") as f: + solutions = safe_load(f) + + with change_directory(test_dir / "petab"): + from petab.v2 import Problem + + petab_yaml["format_version"] = "2.0.0" + for problem in petab_yaml["problems"]: + problem["model_files"] = { + problem["model_files"]["location"].split(".")[0]: problem[ + "model_files" + ] + } + for mapping_file in problem["mapping_files"]: + df = pd.read_csv( + mapping_file, + sep="\t", + ) + if df[petab.PETAB_ENTITY_ID].str.startswith("net").any(): + df.rename( + columns={ + petab.PETAB_ENTITY_ID: petab.MODEL_ENTITY_ID, + petab.MODEL_ENTITY_ID: petab.PETAB_ENTITY_ID, + } + ).to_csv(mapping_file, sep="\t", index=False) + + petab_problem = Problem.from_yaml(petab_yaml) + jax_model = import_petab_problem( + petab_problem, + model_output_dir=Path(__file__).parent / "models" / test, + compile_=True, + jax=True, + ) + jax_problem = JAXProblem(jax_model, petab_problem) + for net, net_config in petab_problem.extensions_config[ + "petab_sciml" + ].items(): + pars = h5py.File( + net_config["parameters"].replace(".h5", ".hf5"), "r" + ) + for layer_name, layer in jax_problem.model.nns[net].layers.items(): + for attribute in dir(layer): + if not isinstance( + getattr(layer, attribute), jax.numpy.ndarray + ): + continue + value = jnp.array(pars[layer_name][attribute]) + + if ( + isinstance(layer, eqx.nn.ConvTranspose) + and attribute == "weight" + ): + # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose + value = jnp.flip( + value, axis=tuple(range(2, value.ndim)) + ).swapaxes(0, 1) + jax_problem = eqx.tree_at( + lambda x: getattr( + x.model.nns[net].layers[layer_name], attribute + ), + jax_problem, + value, + ) + + # llh + if test in ( + "004", + "016", + ): + with pytest.raises(NotImplementedError): + run_simulations(jax_problem) + return + llh, r = run_simulations(jax_problem) + np.testing.assert_allclose( + llh, + solutions["llh"], + atol=solutions["tol_llh"], + rtol=solutions["tol_llh"], + ) + simulations = pd.concat( + [ + pd.read_csv(test_dir / simulation, sep="\t") + for simulation in solutions["simulation_files"] + ] + ) + + # simulations + sort_by = [petab.OBSERVABLE_ID, petab.TIME, petab.SIMULATION_CONDITION_ID] + actual = petab_simulate(jax_problem).sort_values(by=sort_by) + expected = simulations.sort_values(by=sort_by) + np.testing.assert_allclose( + actual[petab.SIMULATION].values, + expected[petab.SIMULATION].values, + atol=solutions["tol_simulations"], + rtol=solutions["tol_simulations"], + ) + + # gradient + sllh, _ = eqx.filter_grad(run_simulations, has_aux=True)( + jax_problem, + solver=diffrax.Kvaerno5(), + controller=diffrax.PIDController(atol=1e-14, rtol=1e-14), + max_steps=2**16, + ) + for component, file in solutions["grad_llh_files"].items(): + actual_dict = {} + if component == "mech": + expected = pd.read_csv(test_dir / file, sep="\t").set_index( + petab.PARAMETER_ID + ) + for ip in expected.index: + if ip in jax_problem.parameter_ids: + actual_dict[ip] = sllh.parameters[ + jax_problem.parameter_ids.index(ip) + ].item() + actual = pd.Series(actual_dict).loc[expected.index].values + np.testing.assert_allclose( + actual, + expected["value"].values, + atol=solutions["tol_grad_llh"], + rtol=solutions["tol_grad_llh"], + ) + else: + expected = h5py.File(test_dir / file, "r") + for layer_name, layer in jax_problem.model.nns[ + component + ].layers.items(): + for attribute in dir(layer): + if not isinstance( + getattr(layer, attribute), jax.numpy.ndarray + ): + continue + actual = getattr( + sllh.model.nns[component].layers[layer_name], attribute + ) + if ( + isinstance(layer, eqx.nn.ConvTranspose) + and attribute == "weight" + ): + actual = np.flip( + actual.swapaxes(0, 1), + axis=tuple(range(2, actual.ndim)), + ) + np.testing.assert_allclose( + actual, + expected[layer_name][attribute][:], + atol=solutions["tol_grad_llh"], + rtol=solutions["tol_grad_llh"], + ) diff --git a/tests/sciml/testsuite b/tests/sciml/testsuite new file mode 160000 index 0000000000..f836be3852 --- /dev/null +++ b/tests/sciml/testsuite @@ -0,0 +1 @@ +Subproject commit f836be38526da0850f0e540010accc94217bdf53