diff --git a/python/examples/example_jax/ExampleJax.ipynb b/python/examples/example_jax/ExampleJax.ipynb index 62391ce9be..931dfb7e28 100644 --- a/python/examples/example_jax/ExampleJax.ipynb +++ b/python/examples/example_jax/ExampleJax.ipynb @@ -5,7 +5,10 @@ "id": "d4d2bc5c", "metadata": {}, "source": [ - "# Overview\n", + "# AMICI & JAX\n", + "\n", + "## Overview\n", + "\n", "The purpose of this guide is to showcase how AMICI can be combined with differentiable programming in [JAX](https://jax.readthedocs.io/en/latest/index.html). We will do so by reimplementing the parameter transformations available in AMICI in JAX and comparing it to the native implementation." ] }, @@ -25,9 +28,9 @@ "id": "fb2fe897", "metadata": {}, "source": [ - "# Preparation\n", + "## Preparation\n", "\n", - "To get started we will import a model using the [petab](https://petab.readthedocs.io). To this end, we will use the [benchmark collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which features a variety of different models. For more details about petab import, see the respective notebook petab [notebook](https://amici.readthedocs.io/en/latest/petab.html)." + "To get started, we will import a model using the [petab](https://petab.readthedocs.io). To this end, we will use the [benchmark collection](https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab), which features a variety of different models. For more details about petab import, see the respective notebook petab [notebook](https://amici.readthedocs.io/en/latest/petab.html)." ] }, { @@ -274,7 +277,7 @@ "id": "e2ef051a", "metadata": {}, "source": [ - "# JAX implementation\n", + "## JAX implementation\n", "\n", "For full jax support, we would have to implement a new [primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), which would require quite a bit of engineering, and in the end wouldn't add much benefit since AMICI can't run on GPUs. Instead, will interface AMICI using the experimental jax module [host_callback](https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html)." ] @@ -439,7 +442,7 @@ "id": "293e29fb", "metadata": {}, "source": [ - "# Testing\n", + "## Testing\n", "\n", "We can now run the function to compute the log-likelihood and the gradient." ] @@ -473,7 +476,7 @@ "id": "6aa4a5f7", "metadata": {}, "source": [ - "As a sanity check, we compare the computed value to native parameter transformation in amici. " + "As a sanity check, we compare the computed value to native parameter transformation in amici." ] }, {