From 480b75a64a48eaf9dd4cb6573e9c334992ae025a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 25 Oct 2024 15:42:24 +0100 Subject: [PATCH] fix gradients --- python/sdist/amici/jax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index c1f083a799..e798a0138f 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -237,7 +237,7 @@ def srun( dynamic=True, ): (llh, (x, obs, stats)), sllh = ( - jax.value_and_grad(self._run, 1, True) + jax.value_and_grad(self._run, 2, True) )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic) return llh, sllh, (x, obs, stats) @@ -254,10 +254,10 @@ def s2run( dynamic=True, ): (llh, (x, obs, stats)), sllh = ( - jax.value_and_grad(self._run, 1, True) + jax.value_and_grad(self._run, 2, True) )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic) - s2llh = jax.hessian(self._run, 1, True)( + s2llh = jax.hessian(self._run, 2, True)( ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic )