Skip to content

Commit

Permalink
optimize & fix bachmann
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Apr 11, 2024
1 parent 85b8173 commit b213adb
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 145 deletions.
288 changes: 159 additions & 129 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,219 +3,248 @@
from concurrent.futures import ThreadPoolExecutor

import diffrax
import equinox as eqx
import jax.numpy as jnp
import numpy as np
import jax
from functools import partial
from collections.abc import Iterable

import amici

jax.config.update("jax_enable_x64", True)

Check warning on line 14 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L14

Added line #L14 was not covered by tests


class JAXModel:
class JAXModel(eqx.Module):
_unscale_funs = {

Check warning on line 18 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L17-L18

Added lines #L17 - L18 were not covered by tests
amici.ParameterScaling.none: lambda x: x,
amici.ParameterScaling.ln: lambda x: jnp.exp(x),
amici.ParameterScaling.log10: lambda x: jnp.power(10, x),
}
solver: diffrax.AbstractSolver
controller: diffrax.AbstractStepSizeController
atol: float
rtol: float
pcoeff: float
icoeff: float
dcoeff: float
maxsteps: int
term: diffrax.ODETerm
sensi_order: amici.SensitivityOrder

Check warning on line 32 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L23-L32

Added lines #L23 - L32 were not covered by tests

def __init__(self):
self.solver = diffrax.Kvaerno5()
self.atol: float = 1e-8
self.rtol: float = 1e-8
self.pcoeff: float = 0.4
self.icoeff: float = 0.3
self.dcoeff: float = 0.0
self.maxsteps: int = 2**10
self.controller = diffrax.PIDController(

Check warning on line 42 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L34-L42

Added lines #L34 - L42 were not covered by tests
rtol=self.rtol,
atol=self.atol,
pcoeff=self.pcoeff,
icoeff=self.icoeff,
dcoeff=self.dcoeff,
)
self.term = diffrax.ODETerm(self.xdot)
self.sensi_order = amici.SensitivityOrder.none

Check warning on line 50 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L49-L50

Added lines #L49 - L50 were not covered by tests

@staticmethod
@abstractmethod
def xdot(self, t, x, args):
def xdot(t, x, args):
...

Check warning on line 55 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L52-L55

Added lines #L52 - L55 were not covered by tests

@staticmethod
@abstractmethod
def _w(self, t, x, p, k, tcl):
def _w(t, x, p, k, tcl):
...

Check warning on line 60 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L57-L60

Added lines #L57 - L60 were not covered by tests

@staticmethod
@abstractmethod
def x0(self, p, k):
def x0(p, k):
...

Check warning on line 65 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L62-L65

Added lines #L62 - L65 were not covered by tests

@staticmethod
@abstractmethod
def x_solver(self, x):
def x_solver(x):
...

Check warning on line 70 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L67-L70

Added lines #L67 - L70 were not covered by tests

@staticmethod
@abstractmethod
def x_rdata(self, x, tcl):
def x_rdata(x, tcl):
...

Check warning on line 75 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L72-L75

Added lines #L72 - L75 were not covered by tests

@staticmethod
@abstractmethod
def tcl(self, x, p, k):
def tcl(x, p, k):
...

Check warning on line 80 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L77-L80

Added lines #L77 - L80 were not covered by tests

@staticmethod
@abstractmethod
def y(self, t, x, p, k, tcl):
def y(t, x, p, k, tcl):
...

Check warning on line 85 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L82-L85

Added lines #L82 - L85 were not covered by tests

@staticmethod
@abstractmethod
def sigmay(self, y, p, k):
def sigmay(y, p, k):
...

Check warning on line 90 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L87-L90

Added lines #L87 - L90 were not covered by tests

@staticmethod
@abstractmethod
def Jy(self, y, my, sigmay):
def Jy(y, my, sigmay):
...

Check warning on line 95 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L92-L95

Added lines #L92 - L95 were not covered by tests

def unscale_p(self, p, pscale):
return jnp.stack(
[
self._unscale_funs[pscale_i](p_i)
for p_i, pscale_i in zip(p, pscale)
]
)

def get_solver(self):
return JAXSolver(model=self)


class JAXSolver:
def __init__(self, model: JAXModel):
self.model: JAXModel = model
self.solver: diffrax.AbstractSolver = diffrax.Kvaerno5()
self.atol: float = 1e-8
self.rtol: float = 1e-8
self.pcoeff: float = 0.4
self.icoeff: float = 0.3
self.dcoeff: float = 0.0
self.maxsteps: int = int(1e6)
self.sensi_mode: amici.SensitivityMethod = (
amici.SensitivityMethod.adjoint
)
self.sensi_order: amici.SensitivityOrder = amici.SensitivityOrder.none
return jax.vmap(

Check warning on line 98 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L97-L98

Added lines #L97 - L98 were not covered by tests
lambda p_i, pscale_i: jnp.stack(
(p_i, jnp.exp(p_i), jnp.power(10, p_i))
)
.at[pscale_i]
.get()
)(p, pscale)

def _solve(self, ts, p, k):
x0 = self.model.x0(p, k)
tcl = self.model.tcl(x0, p, k)
x0 = self.x0(p, k)
tcl = self.tcl(x0, p, k)
sol = diffrax.diffeqsolve(

Check warning on line 109 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L106-L109

Added lines #L106 - L109 were not covered by tests
diffrax.ODETerm(self.model.xdot),
self.term,
self.solver,
args=(p, k, tcl),
t0=0.0,
t1=ts[-1],
dt0=None,
y0=self.model.x_solver(x0),
stepsize_controller=diffrax.PIDController(
rtol=self.rtol,
atol=self.atol,
pcoeff=self.pcoeff,
icoeff=self.icoeff,
dcoeff=self.dcoeff,
),
y0=self.x_solver(x0),
stepsize_controller=self.controller,
max_steps=self.maxsteps,
saveat=diffrax.SaveAt(ts=ts),
)
return sol.ys, tcl
return sol.ys, tcl, sol.stats

Check warning on line 121 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L121

Added line #L121 was not covered by tests

def _obs(self, ts, x, p, k, tcl):
return jax.vmap(self.model.y, in_axes=(0, 0, None, None, None))(
np.asarray(ts), x, p, k, tcl
return jax.vmap(self.y, in_axes=(0, 0, None, None, None))(

Check warning on line 124 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L123-L124

Added lines #L123 - L124 were not covered by tests
ts, x, p, k, tcl
)

def _sigmay(self, obs, p, k):
return jax.vmap(self.model.sigmay, in_axes=(0, None, None))(obs, p, k)
return jax.vmap(self.sigmay, in_axes=(0, None, None))(obs, p, k)

Check warning on line 129 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L128-L129

Added lines #L128 - L129 were not covered by tests

def _x_rdata(self, x, tcl):
return jax.vmap(self.model.x_rdata, in_axes=(0, None))(x, tcl)
return jax.vmap(self.x_rdata, in_axes=(0, None))(x, tcl)

Check warning on line 132 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L131-L132

Added lines #L131 - L132 were not covered by tests

def _loss(self, obs: jnp.ndarray, sigmay: jnp.ndarray, my: np.ndarray):
loss_fun = jax.vmap(self.model.Jy, in_axes=(0, 0, 0))
loss_fun = jax.vmap(self.Jy, in_axes=(0, 0, 0))
return -jnp.sum(loss_fun(obs, my, sigmay))

Check warning on line 136 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L134-L136

Added lines #L134 - L136 were not covered by tests

def _run(

Check warning on line 138 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L138

Added line #L138 was not covered by tests
self, ts: tuple, p: jnp.ndarray, k: tuple, my: tuple, pscale: tuple
self,
ts: np.ndarray,
p: np.ndarray,
k: jnp.ndarray,
my: jnp.ndarray,
pscale: np.ndarray,
):
ps = self.model.unscale_p(p, pscale)
x, tcl = self._solve(ts, ps, k)
ps = self.unscale_p(p, pscale)
x, tcl, stats = self._solve(ts, ps, k)
obs = self._obs(ts, x, ps, k, tcl)
my_r = np.asarray(my).reshape((len(ts), -1))
my_r = my.reshape((len(ts), -1))
sigmay = self._sigmay(obs, ps, k)
llh = self._loss(obs, sigmay, my_r)
x_rdata = self._x_rdata(x, tcl)
return llh, (x_rdata, obs)
return llh, (x_rdata, obs, stats)

Check warning on line 153 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L146-L153

Added lines #L146 - L153 were not covered by tests

@partial(jax.jit, static_argnames=("self", "ts", "k", "my", "pscale"))
@eqx.filter_jit
def run(

Check warning on line 156 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L155-L156

Added lines #L155 - L156 were not covered by tests
self, ts: tuple, p: jnp.ndarray, k: tuple, my: tuple, pscale: tuple
self,
ts: np.ndarray,
p: jnp.ndarray,
k: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
):
return self._run(ts, p, k, my, pscale)

Check warning on line 164 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L164

Added line #L164 was not covered by tests

@partial(jax.jit, static_argnames=("self", "ts", "k", "my", "pscale"))
@eqx.filter_jit
def srun(

Check warning on line 167 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L166-L167

Added lines #L166 - L167 were not covered by tests
self, ts: tuple, p: jnp.ndarray, k: tuple, my: tuple, pscale: tuple
self,
ts: np.ndarray,
p: jnp.ndarray,
k: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
):
(llh, (x, obs)), sllh = (jax.value_and_grad(self._run, 1, True))(
ts, p, k, my, pscale
)
return llh, sllh, (x, obs)
(llh, (x, obs, stats)), sllh = (

Check warning on line 175 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L175

Added line #L175 was not covered by tests
jax.value_and_grad(self._run, 1, True)
)(ts, p, k, my, pscale)
return llh, sllh, (x, obs, stats)

Check warning on line 178 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L178

Added line #L178 was not covered by tests

@partial(jax.jit, static_argnames=("self", "ts", "k", "my", "pscale"))
@eqx.filter_jit
def s2run(

Check warning on line 181 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L180-L181

Added lines #L180 - L181 were not covered by tests
self, ts: tuple, p: jnp.ndarray, k: tuple, my: tuple, pscale: tuple
self,
ts: np.ndarray,
p: jnp.ndarray,
k: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
):
(llh, (x, obs)), sllh = (jax.value_and_grad(self._run, 1, True))(
(llh, (_, _, _)), sllh = (jax.value_and_grad(self._run, 1, True))(

Check warning on line 189 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L189

Added line #L189 was not covered by tests
ts, p, k, my, pscale
)
s2llh, (x, obs) = jax.jacfwd(jax.grad(self._run, 1, True), 1, True)(
ts, p, k, my, pscale
)
return llh, sllh, s2llh, (x, obs)


def run_simulations(
model: JAXModel,
solver: JAXSolver,
edatas: Iterable[amici.ExpData],
num_threads: int = 1,
):
def run(edata):
return run_simulation(model, solver, edata)

if num_threads > 1:
with ThreadPoolExecutor(max_workers=num_threads) as pool:
results = pool.map(run, edatas)
else:
results = map(run, edatas)
return list(results)


def run_simulation(model: JAXModel, solver: JAXSolver, edata: amici.ExpData):
ts = tuple(edata.getTimepoints())
p = jnp.asarray(edata.parameters)
k = tuple(edata.fixedParameters)
my = tuple(edata.getObservedData())
pscale = tuple(edata.pscale)

rdata_kwargs = dict()

if solver.sensi_order == amici.SensitivityOrder.none:
(
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"]),
) = solver.run(ts, p, k, my, pscale)
elif solver.sensi_order == amici.SensitivityOrder.first:
(
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"]),
) = solver.srun(ts, p, k, my, pscale)
elif solver.sensi_order == amici.SensitivityOrder.second:
(
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
rdata_kwargs["s2llh"],
(rdata_kwargs["x"], rdata_kwargs["y"]),
) = solver.s2run(ts, p, k, my, pscale)

for field in rdata_kwargs.keys():
if field == "llh":
rdata_kwargs[field] = np.float64(rdata_kwargs[field])
elif field not in ["sllh", "s2llh"]:
rdata_kwargs[field] = np.asarray(rdata_kwargs[field]).T
if rdata_kwargs[field].ndim == 1:
rdata_kwargs[field] = np.expand_dims(rdata_kwargs[field], 1)

return ReturnDataJAX(**rdata_kwargs)
s2llh, (x, obs, stats) = jax.jacfwd(

Check warning on line 192 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L192

Added line #L192 was not covered by tests
jax.grad(self._run, 1, True), 1, True
)(ts, p, k, my, pscale)
return llh, sllh, s2llh, (x, obs, stats)

Check warning on line 195 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L195

Added line #L195 was not covered by tests

def run_simulation(self, edata: amici.ExpData):
ts = np.asarray(edata.getTimepoints())
p = jnp.asarray(edata.parameters)
k = np.asarray(edata.fixedParameters)
my = np.asarray(edata.getObservedData())
pscale = np.asarray(edata.pscale)

Check warning on line 202 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L197-L202

Added lines #L197 - L202 were not covered by tests

rdata_kwargs = dict()

Check warning on line 204 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L204

Added line #L204 was not covered by tests

if self.sensi_order == amici.SensitivityOrder.none:
(

Check warning on line 207 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L206-L207

Added lines #L206 - L207 were not covered by tests
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.run(ts, p, k, my, pscale)
elif self.sensi_order == amici.SensitivityOrder.first:
(

Check warning on line 212 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L211-L212

Added lines #L211 - L212 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.srun(ts, p, k, my, pscale)
elif self.sensi_order == amici.SensitivityOrder.second:
(

Check warning on line 218 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L217-L218

Added lines #L217 - L218 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
rdata_kwargs["s2llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.s2run(ts, p, k, my, pscale)

for field in rdata_kwargs.keys():
if field == "llh":
rdata_kwargs[field] = np.float64(rdata_kwargs[field])
elif field not in ["sllh", "s2llh"]:
rdata_kwargs[field] = np.asarray(rdata_kwargs[field]).T
if rdata_kwargs[field].ndim == 1:
rdata_kwargs[field] = np.expand_dims(

Check warning on line 231 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L225-L231

Added lines #L225 - L231 were not covered by tests
rdata_kwargs[field], 1
)

return ReturnDataJAX(**rdata_kwargs)

Check warning on line 235 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L235

Added line #L235 was not covered by tests

def run_simulations(

Check warning on line 237 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L237

Added line #L237 was not covered by tests
self,
edatas: Iterable[amici.ExpData],
num_threads: int = 1,
):
if num_threads > 1:
with ThreadPoolExecutor(max_workers=num_threads) as pool:
results = pool.map(self.run_simulation, edatas)

Check warning on line 244 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L242-L244

Added lines #L242 - L244 were not covered by tests
else:
results = map(self.run_simulation, edatas)
return list(results)

Check warning on line 247 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L246-L247

Added lines #L246 - L247 were not covered by tests


@dataclass
Expand All @@ -228,6 +257,7 @@ class ReturnDataJAX(dict):
ssigmay: np.array = None
llh: np.array = None
sllh: np.array = None
stats: dict = None

Check warning on line 260 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L250-L260

Added lines #L250 - L260 were not covered by tests

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
Loading

0 comments on commit b213adb

Please sign in to comment.