From ea1c75e392cc7adaf370b5cbfc3c670f885efa37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Mon, 2 Dec 2024 20:00:16 +0000 Subject: [PATCH] update net_004_alt test --- tests/sciml/testsuite | 2 +- tests/sciml/testsuite.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/sciml/testsuite b/tests/sciml/testsuite index da2bd1bb23..4518f54dd6 160000 --- a/tests/sciml/testsuite +++ b/tests/sciml/testsuite @@ -1 +1 @@ -Subproject commit da2bd1bb2370468389a99933d48e12f89030e1f4 +Subproject commit 4518f54dd62c1256fb1803b9f5e9817f4f78c26d diff --git a/tests/sciml/testsuite.py b/tests/sciml/testsuite.py index 364298efad..c4297efc01 100644 --- a/tests/sciml/testsuite.py +++ b/tests/sciml/testsuite.py @@ -59,7 +59,13 @@ def _test_net(test): ): return - ml_models = PetabScimlStandard.load_data(test / solutions["net_file"]) + if test.stem.endswith("_alt"): + net_file = ( + test.parent / test.stem.replace("_alt", "") / solutions["net_file"] + ) + else: + net_file = test / solutions["net_file"] + ml_models = PetabScimlStandard.load_data(net_file) nets = {} outdir = Path(__file__).parent / "models" / test.stem @@ -151,6 +157,7 @@ def _test_net(test): "net_021", "net_022", # Conv layers "net_004", + "net_004_alt", "net_005", "net_006", "net_007", @@ -277,7 +284,6 @@ def _test_ude(test): ) # gradient - sllh, _ = eqx.filter_grad(run_simulations, has_aux=True)( jax_problem, solver=diffrax.Tsit5(),