From a5d6280fe57df9027e1418aef5b9bdf08095be07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= <fabian.frohlich@crick.ac.uk> Date: Sat, 19 Oct 2024 10:31:17 +0100 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> --- python/examples/example_jax/ExampleJax.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/examples/example_jax/ExampleJax.ipynb b/python/examples/example_jax/ExampleJax.ipynb index 200ecec1b6..1d7d0967e1 100644 --- a/python/examples/example_jax/ExampleJax.ipynb +++ b/python/examples/example_jax/ExampleJax.ipynb @@ -299,7 +299,7 @@ "source": [ "## JAX implementation\n", "\n", - "For full jax support, we would have to implement a new [primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), which would require quite a bit of engineering, and in the end wouldn't add much benefit since AMICI can't run on GPUs. Instead, will interface AMICI using the jax method [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html)." + "For full jax support, we would have to implement a new [primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), which would require quite a bit of engineering, and in the end wouldn't add much benefit since AMICI can't run on GPUs. Instead, we will interface AMICI using the jax method [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html)." ] }, { @@ -343,7 +343,7 @@ "id": "6f6201e8", "metadata": {}, "source": [ - "Now we can use this base function to create two separate functions that return the log-likelihood (`llh`) and a tuple with log-likelihood and it's gradient (`sllh`). Both functions use [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html) such that they can be called by other jax functions. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important." + "Now we can use this base function to create two separate functions that return the log-likelihood (`llh`) and a tuple with log-likelihood and its gradient (`sllh`). Both functions use [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html) such that they can be called by other jax functions. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important." ] }, { @@ -408,7 +408,7 @@ "id": "379485ca", "metadata": {}, "source": [ - "As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the previously objective function we previously defined. We add the `jax.value_and_grad` decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the `jax.jit` decorator such that the function is [just in time compiled](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) upon the first function call." + "As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the objective function we previously defined. We add the `jax.value_and_grad` decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the `jax.jit` decorator such that the function is [just-in-time compiled](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) upon the first function call." ] }, {