diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index c1f083a799..e798a0138f 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -237,7 +237,7 @@ def srun( dynamic=True, ): (llh, (x, obs, stats)), sllh = ( - jax.value_and_grad(self._run, 1, True) + jax.value_and_grad(self._run, 2, True) )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic) return llh, sllh, (x, obs, stats) @@ -254,10 +254,10 @@ def s2run( dynamic=True, ): (llh, (x, obs, stats)), sllh = ( - jax.value_and_grad(self._run, 1, True) + jax.value_and_grad(self._run, 2, True) )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic) - s2llh = jax.hessian(self._run, 1, True)( + s2llh = jax.hessian(self._run, 2, True)( ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic )