diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 7e68c5c34d..6a7da4b42f 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -495,7 +495,7 @@ def run_simulation( simulation_condition[0], p ) return self.model.simulate_condition( - p=eqx.debug.backward_nan(p), + p=p, ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), my=jax.lax.stop_gradient(jnp.array(my)),