diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index c92b35e570..8f4c68510b 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -98,7 +98,11 @@ def test_dimerization(): observables=["a_obs", "b_obs"], constant_parameters=["ksyn_a", "ksyn_b"], ) - pysb2jax(model, outdir, verbose=True, observables=["ab"]) + pysb2jax( + model, + outdir, + observables=["a_obs", "b_obs"], + ) amici_module = amici.import_model_module( module_name=model.name, module_path=outdir @@ -137,12 +141,19 @@ def _test_model(amici_module, jax_module, ts, p, k): rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata]) check_fields_jax( - rs_amici, jax_model, edata, ["x", "y", "llh", "res", "x0"] + rs_amici, + jax_model, + amici_model.getParameterIds(), + amici_model.getFixedParameterIds(), + edata, + ["x", "y", "llh", "res", "x0"], ) check_fields_jax( rs_amici, jax_model, + amici_model.getParameterIds(), + amici_model.getFixedParameterIds(), edata, ["sllh", "sx0", "sx", "sres", "sy"], sensi_order=amici.SensitivityOrder.first, @@ -152,6 +163,8 @@ def _test_model(amici_module, jax_module, ts, p, k): def check_fields_jax( rs_amici, jax_model, + parameter_ids, + fixed_parameter_ids, edata, fields, sensi_order=amici.SensitivityOrder.none, @@ -168,7 +181,13 @@ def check_fields_jax( ts_preeq = ts[ts == 0] ts_dyn = ts[ts > 0] ts_posteq = np.array([]) - p = jnp.array(list(edata.parameters) + list(edata.fixedParameters)) + + par_dict = { + **dict(zip(parameter_ids, edata.parameters)), + **dict(zip(fixed_parameter_ids, edata.fixedParameters)), + } + + p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids]) args = ( jnp.array([]), # p_preeq jnp.array(ts_preeq), # ts_preeq @@ -195,6 +214,10 @@ def check_fields_jax( 0 ] + amici_par_idx = np.array( + [jax_model.parameter_ids.index(par_id) for par_id in parameter_ids] + ) + for field in fields: for r_amici, r_jax in zip(rs_amici, [r_jax]): actual = r_jax[field] @@ -207,16 +230,16 @@ def check_fields_jax( axis=1, ) elif field == "sllh": - actual = actual[: len(edata.parameters)] + actual = actual[amici_par_idx] elif field == "sx": - actual = np.permute_dims( - actual[iys == 0, :, : len(edata.parameters)], (0, 2, 1) - ) + actual = actual[:, :, amici_par_idx] + actual = np.permute_dims(actual[iys == 0, :, :], (0, 2, 1)) elif field == "sy": + actual = actual[:, amici_par_idx] actual = np.permute_dims( np.stack( [ - actual[iys == iy, : len(edata.parameters)] + actual[iys == iy, :] for iy in sorted(np.unique(iys)) ], axis=1, @@ -224,9 +247,9 @@ def check_fields_jax( (0, 2, 1), ) elif field == "sx0": - actual = actual[:, : len(edata.parameters)].T + actual = actual[:, amici_par_idx].T elif field == "sres": - actual = actual[:, : len(edata.parameters)] + actual = actual[:, amici_par_idx] assert_allclose( actual=actual,