Skip to content

Commit

Permalink
disentangle sim & preeq
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 4, 2024
1 parent 0d49041 commit 14c9d22
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 37 deletions.
2 changes: 1 addition & 1 deletion python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@
" return jax_problem.model.simulate_condition(\n",
" p=p,\n",
" p_preeq=p_preeq,\n",
" ts_preeq=ts_preeq,\n",
" ts_init=ts_preeq,\n",
" ts_dyn=tt,\n",
" ts_posteq=ts_posteq,\n",
" my=jnp.array(my),\n",
Expand Down
64 changes: 42 additions & 22 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,12 @@ def _sigmays(
def simulate_condition(
self,
p: jt.Float[jt.Array, "np"],
p_preeq: jt.Float[jt.Array, "*np"],
ts_preeq: jt.Float[jt.Array, "nt_preeq"],
ts_init: jt.Float[jt.Array, "nt_preeq"],
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
Expand All @@ -444,12 +444,9 @@ def simulate_condition(
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param p_preeq:
parameters for pre-equilibration ordered according to ids in :ivar parameter_ids:. May be empty to
disable pre-equilibration.
:param ts_preeq:
time points for pre-equilibration. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated after pre-equilibration.
:param ts_init:
time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated before dynamic simulation.
:param ts_dyn:
time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order.
Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time
Expand Down Expand Up @@ -486,24 +483,16 @@ def simulate_condition(
output according to `ret` and statistics
"""
# Pre-equilibration
if p_preeq.shape[0] > 0:
x0 = self._x0(p_preeq)
tcl = self._tcl(x0, p_preeq)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p_preeq, tcl, current_x, solver, controller, max_steps
)
if x_preeq.shape[0] > 0:
current_x = self._x_solver(x_preeq)
# update tcl with new parameters
tcl = self._tcl(self._x_rdata(current_x, tcl), p)
tcl = self._tcl(x_preeq, p)
else:
x0 = self._x0(p)
current_x = self._x_solver(x0)
stats_preeq = None

tcl = self._tcl(x0, p)
x_preq = jnp.repeat(
current_x.reshape(1, -1), ts_preeq.shape[0], axis=0
)
x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0)

# Dynamic simulation
if ts_dyn.shape[0] > 0:
Expand Down Expand Up @@ -536,7 +525,7 @@ def simulate_condition(
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0)
ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
Expand All @@ -555,11 +544,42 @@ def simulate_condition(
}[ret], dict(
ts=ts,
x=x,
stats_preeq=stats_preeq,
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)

@eqx.filter_jit
def preequilibrate_condition(
self,
p: jt.Float[jt.Array, "np"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: int | jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]:
r"""
Simulate a condition.
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param solver:
ODE solver
:param controller:
step size controller
:param max_steps:
maximum number of solver steps
:return:
pre-equilibrated state variables and statistics
"""
# Pre-equilibration
x0 = self._x0(p)
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p, tcl, current_x, solver, controller, max_steps
)

return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)


def safe_log(x: jnp.float_) -> jnp.float_:
"""
Expand Down
60 changes: 48 additions & 12 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _get_parameter_mappings(
def _get_measurements(
self, simulation_conditions: pd.DataFrame
) -> dict[
tuple[str],
str,
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]:
"""
Expand Down Expand Up @@ -307,45 +307,71 @@ def run_simulation(
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
) -> tuple[jnp.float_, dict]:
"""
Run a simulation for a given simulation condition.
:param simulation_condition:
Tuple of simulation conditions to run the simulation for. can be a single string (simulation only) or a
tuple of strings (pre-equilibration followed by simulation).
Simulation condition to run simulation for.
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param max_steps:
Maximum number of steps to take during simulation
:param preeq:
Pre-equilibration state if available
:return:
Tuple of log-likelihood and simulation statistics
"""
ts_preeq, ts_dyn, ts_posteq, my, iys = self._measurements[
simulation_condition
]
p = self.load_parameters(simulation_condition[0])
p_preeq = (
self.load_parameters(simulation_condition[1])
if len(simulation_condition) > 1
else jnp.array([])
)
return self.model.simulate_condition(
p=p,
p_preeq=p_preeq,
ts_preeq=jax.lax.stop_gradient(jnp.array(ts_preeq)),
ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)),
ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)),
ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),
my=jax.lax.stop_gradient(jnp.array(my)),
iys=jax.lax.stop_gradient(jnp.array(iys)),
x_preeq=x_preeq,
solver=solver,
controller=controller,
max_steps=max_steps,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)

def run_preequilibration(
self,
simulation_condition: str,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821
"""
Run a pre-equilibration simulation for a given simulation condition.
:param simulation_condition:
Simulation condition to run simulation for.
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param max_steps:
Maximum number of steps to take during simulation
:return:
Pre-equilibration state
"""
p = self.load_parameters(simulation_condition)
return self.model.preequilibrate_condition(
p=p,
solver=solver,
controller=controller,
max_steps=max_steps,
)


def run_simulations(
problem: JAXProblem,
Expand Down Expand Up @@ -379,8 +405,18 @@ def run_simulations(
if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()

preeqs = {
sc[1]: problem.run_preequilibration(
sc[1], solver, controller, max_steps
)
for sc in simulation_conditions
if len(sc) > 1
}

results = {
sc: problem.run_simulation(sc, solver, controller, max_steps)
sc[0]: problem.run_simulation(
sc, solver, controller, max_steps, preeqs.get(sc[1])[0]
)
for sc in simulation_conditions
}
return sum(llh for llh, _ in results.values()), results
return sum(llh for llh, _ in results.values()), results | preeqs
4 changes: 2 additions & 2 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_jax_llh(benchmark_problem):

jax_model = import_petab_problem(
petab_problem,
model_output_dir=benchmark_outdir / problem_id,
model_output_dir=benchmark_outdir / (problem_id + "_jax"),
jax=True,
)
jax_problem = JAXProblem(jax_model, petab_problem)
Expand All @@ -340,7 +340,7 @@ def test_jax_llh(benchmark_problem):
[problem_parameters[pid] for pid in jax_problem.parameter_ids]
),
)
if problem_id in problems_for_gradient_check_jax:
if problem_id in problems_for_gradient_check:
(llh_jax, _), sllh_jax = eqx.filter_jit(
eqx.filter_value_and_grad(run_simulations, has_aux=True)
)(jax_problem)
Expand Down

0 comments on commit 14c9d22

Please sign in to comment.