Skip to content

Commit

Permalink
change integration tolerances JAX test
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 26, 2024
1 parent 2ebc149 commit 25aa00e
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
jax.config.update("jax_enable_x64", True)


ATOL_SIM = 1e-12
RTOL_SIM = 1e-12


def test_conversion():
pysb.SelfExporter.cleanup() # reset pysb
pysb.SelfExporter.do_export = True
Expand Down Expand Up @@ -113,6 +117,8 @@ def _test_model(model_module, ts, p, k):
amici_solver = amici_model.getSolver()
amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward)
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)
amici_solver.setAbsoluteTolerance(ATOL_SIM)
amici_solver.setRelativeTolerance(RTOL_SIM)
rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata])

check_fields_jax(
Expand Down Expand Up @@ -156,7 +162,7 @@ def check_fields_jax(
jnp.array(my), # my
jnp.array(iys), # iys
diffrax.Kvaerno5(), # solver
diffrax.PIDController(atol=1e-8, rtol=1e-8), # controller
diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), # controller
diffrax.RecursiveCheckpointAdjoint(), # adjoint
2**8, # max_steps
)
Expand Down

0 comments on commit 25aa00e

Please sign in to comment.