From 14c9d22df110db1a03828eb482cdf3230bcf31b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Wed, 4 Dec 2024 17:31:50 +0000 Subject: [PATCH] disentangle sim & preeq --- .../example_jax_petab/ExampleJaxPEtab.ipynb | 2 +- python/sdist/amici/jax/model.py | 64 ++++++++++++------- python/sdist/amici/jax/petab.py | 60 +++++++++++++---- .../benchmark-models/test_petab_benchmark.py | 4 +- 4 files changed, 93 insertions(+), 37 deletions(-) diff --git a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb index 855860e242..fe7b5d8830 100644 --- a/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb +++ b/python/examples/example_jax_petab/ExampleJaxPEtab.ipynb @@ -373,7 +373,7 @@ " return jax_problem.model.simulate_condition(\n", " p=p,\n", " p_preeq=p_preeq,\n", - " ts_preeq=ts_preeq,\n", + " ts_init=ts_preeq,\n", " ts_dyn=tt,\n", " ts_posteq=ts_posteq,\n", " my=jnp.array(my),\n", diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index ac86b547a6..8f9650ef0f 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -427,12 +427,12 @@ def _sigmays( def simulate_condition( self, p: jt.Float[jt.Array, "np"], - p_preeq: jt.Float[jt.Array, "*np"], - ts_preeq: jt.Float[jt.Array, "nt_preeq"], + ts_init: jt.Float[jt.Array, "nt_preeq"], ts_dyn: jt.Float[jt.Array, "nt_dyn"], ts_posteq: jt.Float[jt.Array, "nt_posteq"], my: jt.Float[jt.Array, "nt"], iys: jt.Int[jt.Array, "nt"], + x_preeq: jt.Float[jt.Array, "nx"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, @@ -444,12 +444,9 @@ def simulate_condition( :param p: parameters for simulation ordered according to ids in :ivar parameter_ids: - :param p_preeq: - parameters for pre-equilibration ordered according to ids in :ivar parameter_ids:. May be empty to - disable pre-equilibration. - :param ts_preeq: - time points for pre-equilibration. Usually valued 0.0, but needs to be shaped according to - the number of observables that are evaluated after pre-equilibration. + :param ts_init: + time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to + the number of observables that are evaluated before dynamic simulation. :param ts_dyn: time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order. Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time @@ -486,24 +483,16 @@ def simulate_condition( output according to `ret` and statistics """ # Pre-equilibration - if p_preeq.shape[0] > 0: - x0 = self._x0(p_preeq) - tcl = self._tcl(x0, p_preeq) - current_x = self._x_solver(x0) - current_x, stats_preeq = self._eq( - p_preeq, tcl, current_x, solver, controller, max_steps - ) + if x_preeq.shape[0] > 0: + current_x = self._x_solver(x_preeq) # update tcl with new parameters - tcl = self._tcl(self._x_rdata(current_x, tcl), p) + tcl = self._tcl(x_preeq, p) else: x0 = self._x0(p) current_x = self._x_solver(x0) - stats_preeq = None tcl = self._tcl(x0, p) - x_preq = jnp.repeat( - current_x.reshape(1, -1), ts_preeq.shape[0], axis=0 - ) + x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0) # Dynamic simulation if ts_dyn.shape[0] > 0: @@ -536,7 +525,7 @@ def simulate_condition( current_x.reshape(1, -1), ts_posteq.shape[0], axis=0 ) - ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0) + ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0) x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0) nllhs = self._nllhs(ts, x, p, tcl, my, iys) @@ -555,11 +544,42 @@ def simulate_condition( }[ret], dict( ts=ts, x=x, - stats_preeq=stats_preeq, stats_dyn=stats_dyn, stats_posteq=stats_posteq, ) + @eqx.filter_jit + def preequilibrate_condition( + self, + p: jt.Float[jt.Array, "np"], + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + max_steps: int | jnp.int_, + ) -> tuple[jt.Float[jt.Array, "nx"], dict]: + r""" + Simulate a condition. + + :param p: + parameters for simulation ordered according to ids in :ivar parameter_ids: + :param solver: + ODE solver + :param controller: + step size controller + :param max_steps: + maximum number of solver steps + :return: + pre-equilibrated state variables and statistics + """ + # Pre-equilibration + x0 = self._x0(p) + tcl = self._tcl(x0, p) + current_x = self._x_solver(x0) + current_x, stats_preeq = self._eq( + p, tcl, current_x, solver, controller, max_steps + ) + + return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq) + def safe_log(x: jnp.float_) -> jnp.float_: """ diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index 2c823259fe..1bd42fe6c1 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -154,7 +154,7 @@ def _get_parameter_mappings( def _get_measurements( self, simulation_conditions: pd.DataFrame ) -> dict[ - tuple[str], + str, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], ]: """ @@ -307,19 +307,21 @@ def run_simulation( solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, max_steps: jnp.int_, + x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722 ) -> tuple[jnp.float_, dict]: """ Run a simulation for a given simulation condition. :param simulation_condition: - Tuple of simulation conditions to run the simulation for. can be a single string (simulation only) or a - tuple of strings (pre-equilibration followed by simulation). + Simulation condition to run simulation for. :param solver: ODE solver to use for simulation :param controller: Step size controller to use for simulation :param max_steps: Maximum number of steps to take during simulation + :param preeq: + Pre-equilibration state if available :return: Tuple of log-likelihood and simulation statistics """ @@ -327,25 +329,49 @@ def run_simulation( simulation_condition ] p = self.load_parameters(simulation_condition[0]) - p_preeq = ( - self.load_parameters(simulation_condition[1]) - if len(simulation_condition) > 1 - else jnp.array([]) - ) return self.model.simulate_condition( p=p, - p_preeq=p_preeq, - ts_preeq=jax.lax.stop_gradient(jnp.array(ts_preeq)), + ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)), ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)), ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)), my=jax.lax.stop_gradient(jnp.array(my)), iys=jax.lax.stop_gradient(jnp.array(iys)), + x_preeq=x_preeq, solver=solver, controller=controller, max_steps=max_steps, adjoint=diffrax.RecursiveCheckpointAdjoint(), ) + def run_preequilibration( + self, + simulation_condition: str, + solver: diffrax.AbstractSolver, + controller: diffrax.AbstractStepSizeController, + max_steps: jnp.int_, + ) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821 + """ + Run a pre-equilibration simulation for a given simulation condition. + + :param simulation_condition: + Simulation condition to run simulation for. + :param solver: + ODE solver to use for simulation + :param controller: + Step size controller to use for simulation + :param max_steps: + Maximum number of steps to take during simulation + :return: + Pre-equilibration state + """ + p = self.load_parameters(simulation_condition) + return self.model.preequilibrate_condition( + p=p, + solver=solver, + controller=controller, + max_steps=max_steps, + ) + def run_simulations( problem: JAXProblem, @@ -379,8 +405,18 @@ def run_simulations( if simulation_conditions is None: simulation_conditions = problem.get_all_simulation_conditions() + preeqs = { + sc[1]: problem.run_preequilibration( + sc[1], solver, controller, max_steps + ) + for sc in simulation_conditions + if len(sc) > 1 + } + results = { - sc: problem.run_simulation(sc, solver, controller, max_steps) + sc[0]: problem.run_simulation( + sc, solver, controller, max_steps, preeqs.get(sc[1])[0] + ) for sc in simulation_conditions } - return sum(llh for llh, _ in results.values()), results + return sum(llh for llh, _ in results.values()), results | preeqs diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 74c84d37a9..6a388f7493 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -328,7 +328,7 @@ def test_jax_llh(benchmark_problem): jax_model = import_petab_problem( petab_problem, - model_output_dir=benchmark_outdir / problem_id, + model_output_dir=benchmark_outdir / (problem_id + "_jax"), jax=True, ) jax_problem = JAXProblem(jax_model, petab_problem) @@ -340,7 +340,7 @@ def test_jax_llh(benchmark_problem): [problem_parameters[pid] for pid in jax_problem.parameter_ids] ), ) - if problem_id in problems_for_gradient_check_jax: + if problem_id in problems_for_gradient_check: (llh_jax, _), sllh_jax = eqx.filter_jit( eqx.filter_value_and_grad(run_simulations, has_aux=True) )(jax_problem)