diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py
index b5834223fb..b2b42f5c2a 100644
--- a/python/sdist/amici/jax/petab.py
+++ b/python/sdist/amici/jax/petab.py
@@ -500,7 +500,7 @@ def run_simulation(
             simulation_condition[0], p
         )
         return self.model.simulate_condition(
-            p=eqx.debug.backward_nan(p),
+            p=p,
             ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)),
             ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)),
             ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),