Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Weindl <[email protected]>
  • Loading branch information
FFroehlich and dweindl authored Oct 19, 2024
1 parent c9f7bcb commit a5d6280
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/examples/example_jax/ExampleJax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@
"source": [
"## JAX implementation\n",
"\n",
"For full jax support, we would have to implement a new [primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), which would require quite a bit of engineering, and in the end wouldn't add much benefit since AMICI can't run on GPUs. Instead, will interface AMICI using the jax method [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html)."
"For full jax support, we would have to implement a new [primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), which would require quite a bit of engineering, and in the end wouldn't add much benefit since AMICI can't run on GPUs. Instead, we will interface AMICI using the jax method [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html)."
]
},
{
Expand Down Expand Up @@ -343,7 +343,7 @@
"id": "6f6201e8",
"metadata": {},
"source": [
"Now we can use this base function to create two separate functions that return the log-likelihood (`llh`) and a tuple with log-likelihood and it's gradient (`sllh`). Both functions use [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html) such that they can be called by other jax functions. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important."
"Now we can use this base function to create two separate functions that return the log-likelihood (`llh`) and a tuple with log-likelihood and its gradient (`sllh`). Both functions use [pure_callback](https://jax.readthedocs.io/en/latest/_autosummary/jax.pure_callback.html) such that they can be called by other jax functions. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important."
]
},
{
Expand Down Expand Up @@ -408,7 +408,7 @@
"id": "379485ca",
"metadata": {},
"source": [
"As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the previously objective function we previously defined. We add the `jax.value_and_grad` decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the `jax.jit` decorator such that the function is [just in time compiled](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) upon the first function call."
"As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the objective function we previously defined. We add the `jax.value_and_grad` decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the `jax.jit` decorator such that the function is [just-in-time compiled](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) upon the first function call."
]
},
{
Expand Down

0 comments on commit a5d6280

Please sign in to comment.