Skip to content

Commit

Permalink
add jax runner to petab testsuite & fix
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 6, 2024
1 parent 8fcfc72 commit 2562371
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 61 deletions.
4 changes: 2 additions & 2 deletions python/sdist/amici/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from warnings import warn

from amici.jax.petab import JAXProblem, run_simulations
from amici.jax.petab import JAXProblem, run_simulations, petab_simulate
from amici.jax.model import JAXModel

warn(
Expand All @@ -18,4 +18,4 @@
stacklevel=2,
)

__all__ = ["JAXModel", "JAXProblem", "run_simulations"]
__all__ = ["JAXModel", "JAXProblem", "run_simulations", "petab_simulate"]
19 changes: 19 additions & 0 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def simulate_condition(
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
iy_trafos: jt.Int[jt.Array, "nt"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
Expand Down Expand Up @@ -488,6 +489,7 @@ def simulate_condition(
- `sigmay`: standard deviations of the observables
- `tcl`: total values for conservation laws (at final timepoint)
- `res`: residuals (observed - simulated)
- 'chi2': sum((observed - simulated) ** 2 / sigma ** 2)
:return:
output according to `ret` and statistics
"""
Expand Down Expand Up @@ -540,6 +542,15 @@ def simulate_condition(

nllhs = self._nllhs(ts, x, p, tcl, my, iys)
llh = -jnp.sum(nllhs)
obs_trafo = jax.vmap(
lambda y, iy_trafo: jnp.array(
[y, safe_log(y), safe_log(y) / jnp.log(10)]
)
.at[iy_trafo]
.get(),
)
ys_obj = obs_trafo(self._ys(ts, x, p, tcl, iys), iy_trafos)
m_obj = obs_trafo(my, iy_trafos)
return {
"llh": llh,
"nllhs": nllhs,
Expand All @@ -551,6 +562,10 @@ def simulate_condition(
"x0_solver": x[0, :],
"tcl": tcl,
"res": self._ys(ts, x, p, tcl, iys) - my,
"chi2": jnp.sum(
jnp.square(ys_obj - m_obj)
/ jnp.square(self._sigmays(ts, x, p, tcl, iys))
),
}[ret], dict(
ts=ts,
x=x,
Expand All @@ -562,6 +577,8 @@ def simulate_condition(
def preequilibrate_condition(
self,
p: jt.Float[jt.Array, "np"],
x_reinit: jt.Float[jt.Array, "*nx"],
mask_reinit: jt.Bool[jt.Array, "*nx"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
max_steps: int | jnp.int_,
Expand All @@ -582,6 +599,8 @@ def preequilibrate_condition(
"""
# Pre-equilibration
x0 = self._x0(p)
if x_reinit.shape[0]:
x0 = jnp.where(mask_reinit, x_reinit, x0)
tcl = self._tcl(x0, p)
current_x = self._x_solver(x0)
current_x, stats_preeq = self._eq(
Expand Down
192 changes: 177 additions & 15 deletions python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
)
from amici.jax.model import JAXModel

DEFAULT_CONTROLLER_SETTINGS = {
"atol": 1e-8,
"rtol": 1e-8,
"pcoeff": 1.0,
"icoeff": 1.0,
"dcoeff": 0.0,
}


def jax_unscale(
parameter: jnp.float_,
Expand Down Expand Up @@ -66,8 +74,16 @@ class JAXProblem(eqx.Module):
_parameter_mappings: dict[str, ParameterMappingForCondition]
_measurements: dict[
tuple[str, ...],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
],
]
_petab_measurement_indices: dict[tuple[str, ...], pd.Index]
_petab_problem: petab.Problem

def __init__(self, model: JAXModel, petab_problem: petab.Problem):
Expand All @@ -83,7 +99,9 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem):
scs = petab_problem.get_simulation_conditions_from_measurement_df()
self._petab_problem = petab_problem
self._parameter_mappings = self._get_parameter_mappings(scs)
self._measurements = self._get_measurements(scs)
self._measurements, self._petab_measurement_indices = (
self._get_measurements(scs)
)
self.parameters = self._get_nominal_parameter_values()

def save(self, directory: Path):
Expand Down Expand Up @@ -155,7 +173,14 @@ def _get_measurements(
self, simulation_conditions: pd.DataFrame
) -> dict[
tuple[str, ...],
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray],
tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
],
]:
"""
Get measurements for the model based on the provided simulation conditions.
Expand All @@ -168,6 +193,7 @@ def _get_measurements(
post-equilibrium time points; measurements and observable indices).
"""
measurements = dict()
indices = dict()
for _, simulation_condition in simulation_conditions.iterrows():
query = " & ".join(
[f"{k} == '{v}'" for k, v in simulation_condition.items()]
Expand All @@ -176,26 +202,53 @@ def _get_measurements(
by=petab.TIME
)

ts = m[petab.TIME].values
ts = m[petab.TIME]
ts_preeq = ts[np.isfinite(ts) & (ts == 0)]
ts_dyn = ts[np.isfinite(ts) & (ts > 0)]
ts_posteq = ts[np.logical_not(np.isfinite(ts))]
index = pd.concat([ts_preeq, ts_dyn, ts_posteq]).index
ts_preeq = ts_preeq.values
ts_dyn = ts_dyn.values
ts_posteq = ts_posteq.values
my = m[petab.MEASUREMENT].values
iys = np.array(
[
self.model.observable_ids.index(oid)
for oid in m[petab.OBSERVABLE_ID].values
]
)
if (
petab.OBSERVABLE_TRANSFORMATION
in self._petab_problem.observable_df
):
trafo_map = {
petab.LIN: 0,
petab.LOG: 1,
petab.LOG10: 2,
}
iy_trafos = np.array(
[
trafo_map[
self._petab_problem.observable_df.loc[
oid, petab.OBSERVABLE_TRANSFORMATION
]
]
for oid in m[petab.OBSERVABLE_ID].values
]
)
else:
iy_trafos = np.zeros_like(iys)

measurements[tuple(simulation_condition)] = (
ts_preeq,
ts_dyn,
ts_posteq,
my,
iys,
iy_trafos,
)
return measurements
indices[tuple(simulation_condition)] = index
return measurements, indices

def get_all_simulation_conditions(self) -> tuple[tuple[str, ...], ...]:
simulation_conditions = (
Expand Down Expand Up @@ -372,6 +425,7 @@ def run_simulation(
controller: diffrax.AbstractStepSizeController,
max_steps: jnp.int_,
x_preeq: jt.Float[jt.Array, "*nx"] = jnp.array([]), # noqa: F821, F722
ret: str = "llh",
) -> tuple[jnp.float_, dict]:
"""
Run a simulation for a given simulation condition.
Expand All @@ -386,10 +440,23 @@ def run_simulation(
Maximum number of steps to take during simulation
:param x_preeq:
Pre-equilibration state if available
:param ret:
which output to return. Valid values are
- `llh`: log-likelihood (default)
- `nllhs`: negative log-likelihood at each time point
- `x0`: full initial state vector (after pre-equilibration)
- `x0_solver`: reduced initial state vector (after pre-equilibration)
- `x`: full state vector
- `x_solver`: reduced state vector
- `y`: observables
- `sigmay`: standard deviations of the observables
- `tcl`: total values for conservation laws (at final timepoint)
- `res`: residuals (observed - simulated)
- 'chi2': sum((observed - simulated) ** 2 / sigma ** 2)
:return:
Tuple of log-likelihood and simulation statistics
Tuple of output value and simulation statistics
"""
ts_preeq, ts_dyn, ts_posteq, my, iys = self._measurements[
ts_preeq, ts_dyn, ts_posteq, my, iys, iy_trafos = self._measurements[
simulation_condition
]
p = self.load_parameters(simulation_condition[0])
Expand All @@ -403,13 +470,17 @@ def run_simulation(
ts_posteq=jax.lax.stop_gradient(jnp.array(ts_posteq)),
my=jax.lax.stop_gradient(jnp.array(my)),
iys=jax.lax.stop_gradient(jnp.array(iys)),
iy_trafos=jax.lax.stop_gradient(jnp.array(iy_trafos)),
x_preeq=x_preeq,
mask_reinit=mask_reinit,
x_reinit=x_reinit,
solver=solver,
controller=controller,
max_steps=max_steps,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
adjoint=diffrax.RecursiveCheckpointAdjoint()
if ret in ("llh", "chi2")
else diffrax.DirectAdjoint(),
ret=ret,
)

def run_preequilibration(
Expand All @@ -434,8 +505,13 @@ def run_preequilibration(
Pre-equilibration state
"""
p = self.load_parameters(simulation_condition)
mask_reinit, x_reinit = self.load_reinitialisation(
simulation_condition, p
)
return self.model.preequilibrate_condition(
p=p,
mask_reinit=mask_reinit,
x_reinit=x_reinit,
solver=solver,
controller=controller,
max_steps=max_steps,
Expand All @@ -447,13 +523,10 @@ def run_simulations(
simulation_conditions: Iterable[tuple[str, ...]] | None = None,
solver: diffrax.AbstractSolver = diffrax.Kvaerno5(),
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=0.3,
dcoeff=0.0,
**DEFAULT_CONTROLLER_SETTINGS
),
max_steps: int = 2**10,
ret: str = "llh",
):
"""
Run simulations for a problem.
Expand All @@ -468,8 +541,21 @@ def run_simulations(
Step size controller to use for simulation.
:param max_steps:
Maximum number of steps to take during simulation.
:param ret:
which output to return. Valid values are
- `llh`: log-likelihood (default)
- `nllhs`: negative log-likelihood at each time point
- `x0`: full initial state vector (after pre-equilibration)
- `x0_solver`: reduced initial state vector (after pre-equilibration)
- `x`: full state vector
- `x_solver`: reduced state vector
- `y`: observables
- `sigmay`: standard deviations of the observables
- `tcl`: total values for conservation laws (at final timepoint)
- `res`: residuals (observed - simulated)
- 'chi2': sum((observed - simulated) ** 2 / sigma ** 2)
:return:
Overall negative log-likelihood and condition specific results and statistics.
Overall output value and condition specific results and statistics.
"""
if simulation_conditions is None:
simulation_conditions = problem.get_all_simulation_conditions()
Expand All @@ -487,10 +573,86 @@ def run_simulations(
controller,
max_steps,
preeqs.get(sc[1])[0] if len(sc) > 1 else jnp.array([]),
ret=ret,
)
for sc in simulation_conditions
}
return sum(llh for llh, _ in results.values()), {
stats = {
sc: res[1] | preeqs[sc[1]][1] if len(sc) > 1 else res[1]
for sc, res in results.items()
}
if ret in ("llh", "chi2"):
output = sum(r for r, _ in results.values())
else:
output = {sc: res[0] for sc, res in results.items()}

return output, stats


def petab_simulate(
problem: JAXProblem,
solver: diffrax.AbstractSolver = diffrax.Kvaerno5(),
controller: diffrax.AbstractStepSizeController = diffrax.PIDController(
**DEFAULT_CONTROLLER_SETTINGS
),
max_steps: int = 2**10,
):
"""
Run simulations for a problem and return the results as a petab simulation dataframe.
:param problem:
Problem to run simulations for.
:param solver:
ODE solver to use for simulation.
:param controller:
Step size controller to use for simulation.
:param max_steps:
Maximum number of steps to take during simulation.
:return:
petab simulation dataframe.
"""
y, r = run_simulations(
problem,
solver=solver,
controller=controller,
max_steps=max_steps,
ret="y",
)
dfs = []
for sc, ys in y.items():
obs = [
problem.model.observable_ids[io]
for io in problem._measurements[sc][4]
]
t = jnp.concat(problem._measurements[sc][:2])
df_sc = pd.DataFrame(
{
petab.SIMULATION: ys,
petab.TIME: t,
petab.OBSERVABLE_ID: obs,
petab.SIMULATION_CONDITION_ID: [sc[0]] * len(t),
},
index=problem._petab_measurement_indices[sc],
)
if (
petab.OBSERVABLE_PARAMETERS
in problem._petab_problem.measurement_df
):
df_sc[petab.OBSERVABLE_PARAMETERS] = (
problem._petab_problem.measurement_df.query(
f"{petab.SIMULATION_CONDITION_ID} == '{sc[0]}'"
)[petab.OBSERVABLE_PARAMETERS]
)
if petab.NOISE_PARAMETERS in problem._petab_problem.measurement_df:
df_sc[petab.NOISE_PARAMETERS] = (
problem._petab_problem.measurement_df.query(
f"{petab.SIMULATION_CONDITION_ID} == '{sc[0]}'"
)[petab.NOISE_PARAMETERS]
)
if (
petab.PREEQUILIBRATION_CONDITION_ID
in problem._petab_problem.measurement_df
):
df_sc[petab.PREEQUILIBRATION_CONDITION_ID] = sc[1]
dfs.append(df_sc)
return pd.concat(dfs).sort_index()
Loading

0 comments on commit 2562371

Please sign in to comment.