From 25623714975c37e6e9417bf11840974c0eda033a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 6 Dec 2024 13:25:38 +0000 Subject: [PATCH] add jax runner to petab testsuite & fix --- python/sdist/amici/jax/__init__.py | 4 +- python/sdist/amici/jax/model.py | 19 ++ python/sdist/amici/jax/petab.py | 192 +++++++++++++++++++-- tests/petab_test_suite/conftest.py | 16 +- tests/petab_test_suite/test_petab_suite.py | 94 ++++++---- 5 files changed, 264 insertions(+), 61 deletions(-) diff --git a/python/sdist/amici/jax/__init__.py b/python/sdist/amici/jax/__init__.py index 8b67abda27..34642e3d49 100644 --- a/python/sdist/amici/jax/__init__.py +++ b/python/sdist/amici/jax/__init__.py @@ -9,7 +9,7 @@ from warnings import warn -from amici.jax.petab import JAXProblem, run_simulations +from amici.jax.petab import JAXProblem, run_simulations, petab_simulate from amici.jax.model import JAXModel warn( @@ -18,4 +18,4 @@ stacklevel=2, ) -__all__ = ["JAXModel", "JAXProblem", "run_simulations"] +__all__ = ["JAXModel", "JAXProblem", "run_simulations", "petab_simulate"] diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 71000587d4..6c41b8b179 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -432,6 +432,7 @@ def simulate_condition( ts_posteq: jt.Float[jt.Array, "nt_posteq"], my: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], + iy_trafos: jt.Int[jt.Array, "nt"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, @@ -488,6 +489,7 @@ def simulate_condition( - `sigmay`: standard deviations of the observables - `tcl`: total values for conservation laws (at final timepoint) - `res`: residuals (observed - simulated) + - 'chi2': sum((observed - simulated) ** 2 / sigma ** 2) :return: output according to `ret` and statistics """ @@ -540,6 +542,15 @@ def simulate_condition( nllhs = self._nllhs(ts, x, p, tcl, my, iys) llh = -jnp.sum(nllhs) + obs_trafo = jax.vmap( + lambda y, iy_trafo: jnp.array( + [y, safe_log(y), safe_log(y) / jnp.log(10)] + ) + .at[iy_trafo] + .get(), + ) + ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos) + m_obj = obs_trafo(my, iy_trafos) return { "llh": llh, "nllhs": nllhs, @@ -551,6 +562,10 @@ def simulate_condition( "x0_solver": x[0, :], "tcl": tcl, "res": self._ys(ts, x, p, tcl, iys) - my, + "chi2": jnp.sum( + jnp.square(ys_obj - m_obj) + / jnp.square(self._sigmays(ts, x, p, tcl, iys)) + ), }[ret], dict( ts=ts, x=x, @@ -562,6 +577,8 @@ def simulate_condition( def preequilibrate_condition( self, p: jt.Float[jt.Array, "np"], + x_reinit: jt.Float[jt.Array, "*nx"], + mask_reinit: jt.Bool[jt.Array, "*nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, max_steps: int | jnp.int_, @@ -582,6 +599,8 @@ def preequilibrate_condition( """ # Pre-equilibration x0 = self._x0(p) + if x_reinit.shape[0]: + x0 = jnp.where(mask_reinit, x_reinit, x0) tcl = self._tcl(x0, p) current_x = self._x_solver(x0) current_x, stats_preeq = self._eq( diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index b39051390b..750cadc56e 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -20,6 +20,14 @@ ) from amici.jax.model import JAXModel +DEFAULT_CONTROLLER_SETTINGS = { + "atol": 1e-8, + "rtol": 1e-8, + "pcoeff": 1.0, + "icoeff": 1.0, + "dcoeff": 0.0, +} + def jax_unscale( parameter: jnp.float_, @@ -66,8 +74,16 @@ class JAXProblem(eqx.Module): _parameter_mappings: dict[str, ParameterMappingForCondition] _measurements: dict[ tuple[str, ...], - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], + tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + ], ] + _petab_measurement_indices: dict[tuple[str, ...], pd.Index] _petab_problem: petab.Problem def __init__(self, model: JAXModel, petab_problem: petab.Problem): @@ -83,7 +99,9 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem): scs = petab_problem.get_simulation_conditions_from_measurement_df() self._petab_problem = petab_problem self._parameter_mappings = self._get_parameter_mappings(scs) - self._measurements = self._get_measurements(scs) + self._measurements, self._petab_measurement_indices = ( + self._get_measurements(scs) + ) self.parameters = self._get_nominal_parameter_values() def save(self, directory: Path): @@ -155,7 +173,14 @@ def _get_measurements( self, simulation_conditions: pd.DataFrame ) -> dict[ tuple[str, ...], - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], + tuple[ + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + np.ndarray, + ], ]: """ Get measurements for the model based on the provided simulation conditions. @@ -168,6 +193,7 @@ def _get_measurements( post-equilibrium time points; measurements and observable indices). """ measurements = dict() + indices = dict() for _, simulation_condition in simulation_conditions.iterrows(): query = " & ".join( [f"{k} == '{v}'" for k, v in simulation_condition.items()] @@ -176,10 +202,14 @@ def _get_measurements( by=petab.TIME ) - ts = m[petab.TIME].values + ts = m[petab.TIME] ts_preeq = ts[np.isfinite(ts) & (ts == 0)] ts_dyn = ts[np.isfinite(ts) & (ts > 0)] ts_posteq = ts[np.logical_not(np.isfinite(ts))] + index = pd.concat([ts_preeq, ts_dyn, ts_posteq]).index + ts_preeq = ts_preeq.values + ts_dyn = ts_dyn.values + ts_posteq = ts_posteq.values my = m[petab.MEASUREMENT].values iys = np.array( [ @@ -187,6 +217,27 @@ def _get_measurements( for oid in m[petab.OBSERVABLE_ID].values ] ) + if ( + petab.OBSERVABLE_TRANSFORMATION + in self._petab_problem.observable_df + ): + trafo_map = { + petab.LIN: 0, + petab.LOG: 1, + petab.LOG10: 2, + } + iy_trafos = np.array( + [ + trafo_map[ + self._petab_problem.observable_df.loc[ + oid, petab.OBSERVABLE_TRANSFORMATION + ] + ] + for oid in m[petab.OBSERVABLE_ID].values + ] + ) + else: + iy_trafos = np.zeros_like(iys) measurements[tuple(simulation_condition)] = ( ts_preeq, @@ -194,8 +245,10 @@ def _get_measurements( ts_posteq, my, iys, + iy_trafos, ) - return measurements + indices[tuple(simulation_condition)] = index + return measurements, indices def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]: simulation_conditions = ( @@ -372,6 +425,7 @@ def run_simulation( controller: diffrax.AbstractStepSizeController, max_steps: jnp.int_, x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 + ret: str = "llh", ) -> tuple[jnp.float_, dict]: """ Run a simulation for a given simulation condition. @@ -386,10 +440,23 @@ def run_simulation( Maximum number of steps to take during simulation :param x_preeq: Pre-equilibration state if available + :param ret: + which output to return. Valid values are + - `llh`: log-likelihood (default) + - `nllhs`: negative log-likelihood at each time point + - `x0`: full initial state vector (after pre-equilibration) + - `x0_solver`: reduced initial state vector (after pre-equilibration) + - `x`: full state vector + - `x_solver`: reduced state vector + - `y`: observables + - `sigmay`: standard deviations of the observables + - `tcl`: total values for conservation laws (at final timepoint) + - `res`: residuals (observed - simulated) + - 'chi2': sum((observed - simulated) ** 2 / sigma ** 2) :return: - Tuple of log-likelihood and simulation statistics + Tuple of output value and simulation statistics """ - ts_preeq, ts_dyn, ts_posteq, my, iys = self._measurements[ + ts_preeq, ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[ simulation_condition ] p = self.load_parameters(simulation_condition[0]) @@ -403,13 +470,17 @@ def run_simulation( ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), my=jax.lax.stop_gradient(jnp.array(my)), iys=jax.lax.stop_gradient(jnp.array(iys)), + iy_trafos=jax.lax.stop_gradient(jnp.array(iy_trafos)), x_preeq=x_preeq, mask_reinit=mask_reinit, x_reinit=x_reinit, solver=solver, controller=controller, max_steps=max_steps, - adjoint=diffrax.RecursiveCheckpointAdjoint(), + adjoint=diffrax.RecursiveCheckpointAdjoint() + if ret in ("llh", "chi2") + else diffrax.DirectAdjoint(), + ret=ret, ) def run_preequilibration( @@ -434,8 +505,13 @@ def run_preequilibration( Pre-equilibration state """ p = self.load_parameters(simulation_condition) + mask_reinit, x_reinit = self.load_reinitialisation( + simulation_condition, p + ) return self.model.preequilibrate_condition( p=p, + mask_reinit=mask_reinit, + x_reinit=x_reinit, solver=solver, controller=controller, max_steps=max_steps, @@ -447,13 +523,10 @@ def run_simulations( simulation_conditions: Iterable[tuple[str, ...]] | None = None, solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), controller: diffrax.AbstractStepSizeController = diffrax.PIDController( - rtol=1e-8, - atol=1e-8, - pcoeff=0.4, - icoeff=0.3, - dcoeff=0.0, + **DEFAULT_CONTROLLER_SETTINGS ), max_steps: int = 2**10, + ret: str = "llh", ): """ Run simulations for a problem. @@ -468,8 +541,21 @@ def run_simulations( Step size controller to use for simulation. :param max_steps: Maximum number of steps to take during simulation. + :param ret: + which output to return. Valid values are + - `llh`: log-likelihood (default) + - `nllhs`: negative log-likelihood at each time point + - `x0`: full initial state vector (after pre-equilibration) + - `x0_solver`: reduced initial state vector (after pre-equilibration) + - `x`: full state vector + - `x_solver`: reduced state vector + - `y`: observables + - `sigmay`: standard deviations of the observables + - `tcl`: total values for conservation laws (at final timepoint) + - `res`: residuals (observed - simulated) + - 'chi2': sum((observed - simulated) ** 2 / sigma ** 2) :return: - Overall negative log-likelihood and condition specific results and statistics. + Overall output value and condition specific results and statistics. """ if simulation_conditions is None: simulation_conditions = problem.get_all_simulation_conditions() @@ -487,10 +573,86 @@ def run_simulations( controller, max_steps, preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]), + ret=ret, ) for sc in simulation_conditions } - return sum(llh for llh, _ in results.values()), { + stats = { sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1] for sc, res in results.items() } + if ret in ("llh", "chi2"): + output = sum(r for r, _ in results.values()) + else: + output = {sc: res[0] for sc, res in results.items()} + + return output, stats + + +def petab_simulate( + problem: JAXProblem, + solver: diffrax.AbstractSolver = diffrax.Kvaerno5(), + controller: diffrax.AbstractStepSizeController = diffrax.PIDController( + **DEFAULT_CONTROLLER_SETTINGS + ), + max_steps: int = 2**10, +): + """ + Run simulations for a problem and return the results as a petab simulation dataframe. + + :param problem: + Problem to run simulations for. + :param solver: + ODE solver to use for simulation. + :param controller: + Step size controller to use for simulation. + :param max_steps: + Maximum number of steps to take during simulation. + :return: + petab simulation dataframe. + """ + y, r = run_simulations( + problem, + solver=solver, + controller=controller, + max_steps=max_steps, + ret="y", + ) + dfs = [] + for sc, ys in y.items(): + obs = [ + problem.model.observable_ids[io] + for io in problem._measurements[sc][4] + ] + t = jnp.concat(problem._measurements[sc][:2]) + df_sc = pd.DataFrame( + { + petab.SIMULATION: ys, + petab.TIME: t, + petab.OBSERVABLE_ID: obs, + petab.SIMULATION_CONDITION_ID: [sc[0]] * len(t), + }, + index=problem._petab_measurement_indices[sc], + ) + if ( + petab.OBSERVABLE_PARAMETERS + in problem._petab_problem.measurement_df + ): + df_sc[petab.OBSERVABLE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petab.SIMULATION_CONDITION_ID} == '{sc[0]}'" + )[petab.OBSERVABLE_PARAMETERS] + ) + if petab.NOISE_PARAMETERS in problem._petab_problem.measurement_df: + df_sc[petab.NOISE_PARAMETERS] = ( + problem._petab_problem.measurement_df.query( + f"{petab.SIMULATION_CONDITION_ID} == '{sc[0]}'" + )[petab.NOISE_PARAMETERS] + ) + if ( + petab.PREEQUILIBRATION_CONDITION_ID + in problem._petab_problem.measurement_df + ): + df_sc[petab.PREEQUILIBRATION_CONDITION_ID] = sc[1] + dfs.append(df_sc) + return pd.concat(dfs).sort_index() diff --git a/tests/petab_test_suite/conftest.py b/tests/petab_test_suite/conftest.py index 2e1c6d3cea..b51f240ffd 100644 --- a/tests/petab_test_suite/conftest.py +++ b/tests/petab_test_suite/conftest.py @@ -60,7 +60,7 @@ def pytest_generate_tests(metafunc): if metafunc.config.getoption("--only-sbml"): argvalues = [ - (case, "sbml", version) + (case, "sbml", version, False) for version in ("v1.0.0", "v2.0.0") for case in ( test_numbers @@ -70,7 +70,7 @@ def pytest_generate_tests(metafunc): ] elif metafunc.config.getoption("--only-pysb"): argvalues = [ - (case, "pysb", "v2.0.0") + (case, "pysb", "v2.0.0", False) for case in ( test_numbers if test_numbers @@ -81,8 +81,10 @@ def pytest_generate_tests(metafunc): argvalues = [] for version in ("v1.0.0", "v2.0.0"): for format in ("sbml", "pysb"): - argvalues.extend( - (case, format, version) - for case in test_numbers or get_cases(format, version) - ) - metafunc.parametrize("case,model_type,version", argvalues) + for jax in (True, False): + argvalues.extend( + (case, format, version, jax) + for case in test_numbers + or get_cases(format, version) + ) + metafunc.parametrize("case,model_type,version,jax", argvalues) diff --git a/tests/petab_test_suite/test_petab_suite.py b/tests/petab_test_suite/test_petab_suite.py index f5bf354cd3..bb13c25d84 100755 --- a/tests/petab_test_suite/test_petab_suite.py +++ b/tests/petab_test_suite/test_petab_suite.py @@ -23,10 +23,10 @@ logger.addHandler(stream_handler) -def test_case(case, model_type, version): +def test_case(case, model_type, version, jax): """Wrapper for _test_case for handling test outcomes""" try: - _test_case(case, model_type, version) + _test_case(case, model_type, version, jax) except Exception as e: if isinstance( e, NotImplementedError @@ -41,10 +41,10 @@ def test_case(case, model_type, version): raise e -def _test_case(case, model_type, version): +def _test_case(case, model_type, version, jax): """Run a single PEtab test suite case""" case = petabtests.test_id_str(case) - logger.debug(f"Case {case} [{model_type}] [{version}]") + logger.debug(f"Case {case} [{model_type}] [{version}] [{jax}]") # load case_dir = petabtests.get_case_dir(case, model_type, version) @@ -57,34 +57,46 @@ def _test_case(case, model_type, version): model_name = ( f"petab_{model_type}_test_case_{case}" f"_{version.replace('.', '_')}" ) - model_output_dir = f"amici_models/{model_name}" + model_output_dir = f"amici_models/{model_name}" + ("_jax" if jax else "") model = import_petab_problem( petab_problem=problem, model_output_dir=model_output_dir, model_name=model_name, compile_=True, + jax=jax, ) - solver = model.getSolver() - solver.setSteadyStateToleranceFactor(1.0) - problem_parameters = dict( - zip(problem.x_free_ids, problem.x_nominal_free, strict=True) - ) + if jax: + from amici.jax import JAXProblem, run_simulations, petab_simulate + + jax_problem = JAXProblem(model, problem) + llh, ret = run_simulations(jax_problem) + chi2, _ = run_simulations(jax_problem, ret="chi2") + simulation_df = petab_simulate(jax_problem) + simulation_df.rename( + columns={petab.SIMULATION: petab.MEASUREMENT}, inplace=True + ) + else: + solver = model.getSolver() + solver.setSteadyStateToleranceFactor(1.0) + problem_parameters = dict( + zip(problem.x_free_ids, problem.x_nominal_free, strict=True) + ) - # simulate - ret = simulate_petab( - problem, - model, - problem_parameters=problem_parameters, - solver=solver, - log_level=logging.DEBUG, - ) + # simulate + ret = simulate_petab( + problem, + model, + problem_parameters=problem_parameters, + solver=solver, + log_level=logging.DEBUG, + ) - rdatas = ret["rdatas"] - chi2 = sum(rdata["chi2"] for rdata in rdatas) - llh = ret["llh"] - simulation_df = rdatas_to_measurement_df( - rdatas, model, problem.measurement_df - ) + rdatas = ret["rdatas"] + chi2 = sum(rdata["chi2"] for rdata in rdatas) + llh = ret["llh"] + simulation_df = rdatas_to_measurement_df( + rdatas, model, problem.measurement_df + ) petab.check_measurement_df(simulation_df, problem.observable_df) simulation_df = simulation_df.rename( columns={petab.MEASUREMENT: petab.SIMULATION} @@ -142,7 +154,10 @@ def _test_case(case, model_type, version): f"LLH: simulated: {llh}, expected: {gt_llh}, " f"match = {llhs_match}", ) - check_derivatives(problem, model, solver, problem_parameters) + if jax: + pass # skip derivative checks for now + else: + check_derivatives(problem, model, solver, problem_parameters) if not all([llhs_match, simulations_match]) or not chi2s_match: logger.error(f"Case {case} failed.") @@ -196,18 +211,23 @@ def run(): n_skipped = 0 n_total = 0 for version in ("v1.0.0", "v2.0.0"): - cases = petabtests.get_cases("sbml", version=version) - n_total += len(cases) - for case in cases: - try: - test_case(case, "sbml", version=version) - n_success += 1 - except Skipped: - n_skipped += 1 - except Exception as e: - # run all despite failures - logger.error(f"Case {case} failed.") - logger.error(e) + for model_lang in ("cpp", "jax"): + cases = petabtests.get_cases( + "sbml", version=version, jax=model_lang == "jax" + ) + n_total += len(cases) + for case in cases: + try: + test_case( + case, "sbml", version=version, jax=model_lang == "jax" + ) + n_success += 1 + except Skipped: + n_skipped += 1 + except Exception as e: + # run all despite failures + logger.error(f"Case {case} failed.") + logger.error(e) logger.info(f"{n_success} / {n_total} successful, " f"{n_skipped} skipped") if n_success != len(cases):