Skip to content

Commit

Permalink
Separate pre-equilibration and dynamic simulation in jax (#2617)
Browse files Browse the repository at this point in the history
* disentangle sim & preeq

* disentangle sim & preeq

* run preequilibration once

* fix symlink

* separate default dirs for jax/cpp, honour model dir/name

* fix notebook

* fix path SNAFU

* fix models without preequilibration

* fix tests

* fixup

* fix doc typehints

* fix notebook

* fix output dict construction

* fix notebook
  • Loading branch information
FFroehlich authored Dec 5, 2024
1 parent d1c8250 commit 449041d
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 77 deletions.
11 changes: 4 additions & 7 deletions python/examples/example_jax_petab/ExampleJaxPEtab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
" results (dict): Simulation results from run_simulations.\n",
" \"\"\"\n",
" # Extract the simulation results for the specific condition\n",
" sim_results = results[simulation_condition][1]\n",
" sim_results = results[simulation_condition]\n",
"\n",
" # Create a new figure for the state trajectories\n",
" plt.figure(figsize=(8, 6))\n",
Expand Down Expand Up @@ -357,27 +357,25 @@
"simulation_condition = (\"model1_data1\",)\n",
"\n",
"# Load condition-specific data\n",
"ts_preeq, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n",
"ts_init, ts_dyn, ts_posteq, my, iys = jax_problem._measurements[\n",
" simulation_condition\n",
"]\n",
"\n",
"# Load parameters for the specified condition\n",
"p = jax_problem.load_parameters(simulation_condition[0])\n",
"# Disable preequilibration\n",
"p_preeq = jnp.array([])\n",
"\n",
"\n",
"# Define a function to compute the gradient with respect to dynamic timepoints\n",
"@eqx.filter_jacfwd\n",
"def grad_ts_dyn(tt):\n",
" return jax_problem.model.simulate_condition(\n",
" p=p,\n",
" p_preeq=p_preeq,\n",
" ts_preeq=ts_preeq,\n",
" ts_init=ts_init,\n",
" ts_dyn=tt,\n",
" ts_posteq=ts_posteq,\n",
" my=jnp.array(my),\n",
" iys=jnp.array(iys),\n",
" x_preeq=jnp.array([]),\n",
" solver=diffrax.Kvaerno5(),\n",
" controller=diffrax.PIDController(atol=1e-8, rtol=1e-8),\n",
" max_steps=2**10,\n",
Expand Down Expand Up @@ -489,7 +487,6 @@
"amici_model = import_petab_problem(\n",
" petab_problem,\n",
" verbose=False,\n",
" compile_=True,\n",
" jax=False, # load the amici model this time\n",
")\n",
"\n",
Expand Down
64 changes: 42 additions & 22 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,12 @@ def _sigmays(
def simulate_condition(
self,
p: jt.Float[jt.Array, "np"],
p_preeq: jt.Float[jt.Array, "*np"],
ts_preeq: jt.Float[jt.Array, "nt_preeq"],
ts_init: jt.Float[jt.Array, "nt_preeq"],
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
x_preeq: jt.Float[jt.Array, "nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
Expand All @@ -444,12 +444,9 @@ def simulate_condition(
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param p_preeq:
parameters for pre-equilibration ordered according to ids in :ivar parameter_ids:. May be empty to
disable pre-equilibration.
:param ts_preeq:
time points for pre-equilibration. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated after pre-equilibration.
:param ts_init:
time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to
the number of observables that are evaluated before dynamic simulation.
:param ts_dyn:
time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order.
Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time
Expand Down Expand Up @@ -486,24 +483,16 @@ def simulate_condition(
output according to `ret` and statistics
"""
# Pre-equilibration
if p_preeq.shape[0] > 0:
x0 = self._x0(p_preeq)
tcl = self._tcl(x0, p_preeq)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p_preeq, tcl, current_x, solver, controller, max_steps
)
if x_preeq.shape[0] > 0:
current_x = self._x_solver(x_preeq)
# update tcl with new parameters
tcl = self._tcl(self._x_rdata(current_x, tcl), p)
tcl = self._tcl(x_preeq, p)
else:
x0 = self._x0(p)
current_x = self._x_solver(x0)
stats_preeq = None

tcl = self._tcl(x0, p)
x_preq = jnp.repeat(
current_x.reshape(1, -1), ts_preeq.shape[0], axis=0
)
x_preq = jnp.repeat(current_x.reshape(1, -1), ts_init.shape[0], axis=0)

# Dynamic simulation
if ts_dyn.shape[0] > 0:
Expand Down Expand Up @@ -536,7 +525,7 @@ def simulate_condition(
current_x.reshape(1, -1), ts_posteq.shape[0], axis=0
)

ts = jnp.concatenate((ts_preeq, ts_dyn, ts_posteq), axis=0)
ts = jnp.concatenate((ts_init, ts_dyn, ts_posteq), axis=0)
x = jnp.concatenate((x_preq, x_dyn, x_posteq), axis=0)

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
Expand All @@ -555,11 +544,42 @@ def simulate_condition(
}[ret], dict(
ts=ts,
x=x,
stats_preeq=stats_preeq,
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)

@eqx.filter_jit
def preequilibrate_condition(
self,
p: jt.Float[jt.Array, "np"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: int | jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]:
r"""
Simulate a condition.
:param p:
parameters for simulation ordered according to ids in :ivar parameter_ids:
:param solver:
ODE solver
:param controller:
step size controller
:param max_steps:
maximum number of solver steps
:return:
pre-equilibrated state variables and statistics
"""
# Pre-equilibration
x0 = self._x0(p)
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
p, tcl, current_x, solver, controller, max_steps
)

return self._x_rdata(current_x, tcl), dict(stats_preeq=stats_preeq)


def safe_log(x: jnp.float_) -> jnp.float_:
"""
Expand Down
5 changes: 2 additions & 3 deletions python/sdist/amici/jax/ode_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,10 @@ def _generate_jax_code(self) -> None:
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
},
}
outdir = self.model_path / (self.model_name + "_jax")
outdir.mkdir(parents=True, exist_ok=True)

apply_template(
Path(amiciModulePath) / "jax" / "jax.template.py",
outdir / "__init__.py",
self.model_path / "__init__.py",
tpl_data,
)

Expand All @@ -258,6 +256,7 @@ def set_paths(self, output_dir: str | Path | None = None) -> None:
output_dir = Path(os.getcwd()) / f"amici-{self.model_name}"

self.model_path = Path(output_dir).resolve()
self.model_path.mkdir(parents=True, exist_ok=True)

def set_name(self, model_name: str) -> None:
"""
Expand Down
67 changes: 54 additions & 13 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _get_parameter_mappings(
def _get_measurements(
self, simulation_conditions: pd.DataFrame
) -> dict[
tuple[str],
tuple[str, ...],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
]:
"""
Expand Down Expand Up @@ -307,49 +307,75 @@ def run_simulation(
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
) -> tuple[jnp.float_, dict]:
"""
Run a simulation for a given simulation condition.
:param simulation_condition:
Tuple of simulation conditions to run the simulation for. can be a single string (simulation only) or a
tuple of strings (pre-equilibration followed by simulation).
Simulation condition to run simulation for.
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param max_steps:
Maximum number of steps to take during simulation
:param x_preeq:
Pre-equilibration state if available
:return:
Tuple of log-likelihood and simulation statistics
"""
ts_preeq, ts_dyn, ts_posteq, my, iys = self._measurements[
simulation_condition
]
p = self.load_parameters(simulation_condition[0])
p_preeq = (
self.load_parameters(simulation_condition[1])
if len(simulation_condition) > 1
else jnp.array([])
)
return self.model.simulate_condition(
p=p,
p_preeq=p_preeq,
ts_preeq=jax.lax.stop_gradient(jnp.array(ts_preeq)),
ts_init=jax.lax.stop_gradient(jnp.array(ts_preeq)),
ts_dyn=jax.lax.stop_gradient(jnp.array(ts_dyn)),
ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),
my=jax.lax.stop_gradient(jnp.array(my)),
iys=jax.lax.stop_gradient(jnp.array(iys)),
x_preeq=x_preeq,
solver=solver,
controller=controller,
max_steps=max_steps,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)

def run_preequilibration(
self,
simulation_condition: str,
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
) -> tuple[jt.Float[jt.Array, "nx"], dict]: # noqa: F821
"""
Run a pre-equilibration simulation for a given simulation condition.
:param simulation_condition:
Simulation condition to run simulation for.
:param solver:
ODE solver to use for simulation
:param controller:
Step size controller to use for simulation
:param max_steps:
Maximum number of steps to take during simulation
:return:
Pre-equilibration state
"""
p = self.load_parameters(simulation_condition)
return self.model.preequilibrate_condition(
p=p,
solver=solver,
controller=controller,
max_steps=max_steps,
)


def run_simulations(
problem: JAXProblem,
simulation_conditions: Iterable[tuple] | None = None,
simulation_conditions: Iterable[tuple[str, ...]] | None = None,
solver: diffrax.AbstractSolver = diffrax.Kvaerno5(),
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
rtol=1e-8,
Expand Down Expand Up @@ -379,8 +405,23 @@ def run_simulations(
if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()

preeqs = {
sc: problem.run_preequilibration(sc, solver, controller, max_steps)
# only run preequilibration once per condition
for sc in {sc[1] for sc in simulation_conditions if len(sc) > 1}
}

results = {
sc: problem.run_simulation(sc, solver, controller, max_steps)
sc: problem.run_simulation(
sc,
solver,
controller,
max_steps,
preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]),
)
for sc in simulation_conditions
}
return sum(llh for llh, _ in results.values()), results
return sum(llh for llh, _ in results.values()), {
sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1]
for sc, res in results.items()
}
24 changes: 22 additions & 2 deletions python/sdist/amici/petab/import_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,9 @@ def _can_import_model(
Check whether a module of that name can already be imported.
"""
# try to import (in particular checks version)
suffix = "_jax" if jax else ""
try:
model_module = amici.import_model_module(
model_name + suffix, model_output_dir
*_get_package_name_and_path(model_name, model_output_dir, jax)
)
except ModuleNotFoundError:
return False
Expand Down Expand Up @@ -271,3 +270,24 @@ def check_model(
"the current model might also resolve this. Parameters: "
f"{amici_ids_free_required.difference(amici_ids_free)}"
)


def _get_package_name_and_path(
model_name: str, model_output_dir: str | Path, jax: bool = False
) -> tuple[str, Path]:
"""
Get the package name and path for the generated model module.
:param model_name:
Name of the model
:param model_output_dir:
Target directory for the generated model module
:param jax:
Whether to generate the paths for a JAX or CPP model
:return:
"""
if jax:
outdir = Path(model_output_dir)
return outdir.stem, outdir.parent
else:
return model_name, Path(model_output_dir)
14 changes: 9 additions & 5 deletions python/sdist/amici/petab/petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from petab.v1.models import MODEL_TYPE_PYSB, MODEL_TYPE_SBML

from ..logging import get_logger
from .import_helpers import _can_import_model, _create_model_name, check_model
from .import_helpers import (
_can_import_model,
_create_model_name,
check_model,
_get_package_name_and_path,
)
from .sbml_import import import_model_sbml

try:
Expand Down Expand Up @@ -114,7 +119,7 @@ def import_petab_problem(
from .sbml_import import _create_model_output_dir_name

model_output_dir = _create_model_output_dir_name(
petab_problem.sbml_model, model_name
petab_problem.sbml_model, model_name, jax=jax
)
else:
model_output_dir = os.path.abspath(model_output_dir)
Expand All @@ -136,7 +141,7 @@ def import_petab_problem(
)

# remove folder if exists
if os.path.exists(model_output_dir):
if not jax and os.path.exists(model_output_dir):
shutil.rmtree(model_output_dir)

logger.info(f"Compiling model {model_name} to {model_output_dir}.")
Expand All @@ -160,9 +165,8 @@ def import_petab_problem(
)

# import model
suffix = "_jax" if jax else ""
model_module = amici.import_model_module(
model_name + suffix, model_output_dir
*_get_package_name_and_path(model_name, model_output_dir, jax=jax)
)

if jax:
Expand Down
Loading

0 comments on commit 449041d

Please sign in to comment.