From c548c935af1a63971f47efa18651511b1ac6acd1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 24 Oct 2024 10:21:45 +0100 Subject: [PATCH] fix for NONCONST_CLS --- python/sdist/amici/jax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index 5537aef2c8..c882658e3e 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -167,7 +167,9 @@ def _run( ts, ps, k, x0, checkpointed=checkpointed ) else: - x = tuple(jnp.array([x0_i] * len(ts)) for x0_i in x0) + x = tuple( + self.x_solver(jnp.array([x0_i] * len(ts)) for x0_i in x0) + ) tcl = self.tcl(x0, ps, k) stats = None obs = self._obs(ts, x, ps, k, tcl)