Skip to content

Commit

Permalink
fix jax tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Apr 13, 2024
1 parent 1ec591c commit aebe07c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 80 deletions.
46 changes: 29 additions & 17 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class JAXModel(eqx.Module):
dcoeff: float
maxsteps: int
term: diffrax.ODETerm
sensi_order: amici.SensitivityOrder

def __init__(self):
self.solver = diffrax.Kvaerno5()
Expand All @@ -38,7 +37,7 @@ def __init__(self):
self.pcoeff: float = 0.4
self.icoeff: float = 0.3
self.dcoeff: float = 0.0
self.maxsteps: int = 2**10
self.maxsteps: int = 2**14
self.controller = diffrax.PIDController(

Check warning on line 41 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L34-L41

Added lines #L34 - L41 were not covered by tests
rtol=self.rtol,
atol=self.atol,
Expand All @@ -47,7 +46,6 @@ def __init__(self):
dcoeff=self.dcoeff,
)
self.term = diffrax.ODETerm(self.xdot)

Check warning on line 48 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L48

Added line #L48 was not covered by tests
self.sensi_order = amici.SensitivityOrder.none

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -120,7 +118,7 @@ def _preeq(self, p, k):
)
return sol.ys

Check warning on line 119 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L119

Added line #L119 was not covered by tests

def _solve(self, ts, p, k, x0):
def _solve(self, ts, p, k, x0, checkpointed):
tcl = self.tcl(x0, p, k)
sol = diffrax.diffeqsolve(

Check warning on line 123 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L122-L123

Added lines #L122 - L123 were not covered by tests
self.term,
Expand All @@ -132,6 +130,9 @@ def _solve(self, ts, p, k, x0):
y0=self.x_solver(x0),
stepsize_controller=self.controller,
max_steps=self.maxsteps,
adjoint=diffrax.RecursiveCheckpointAdjoint()
if checkpointed
else diffrax.DirectAdjoint(),
saveat=diffrax.SaveAt(ts=ts),
)
return sol.ys, tcl, sol.stats

Check warning on line 138 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L138

Added line #L138 was not covered by tests
Expand Down Expand Up @@ -159,13 +160,14 @@ def _run(
k_preeq: jnp.ndarray,
my: jnp.ndarray,
pscale: np.ndarray,
checkpointed=True,
):
ps = self.unscale_p(p, pscale)
if k_preeq.shape[0] > 0:
x0 = self._preeq(ps, k_preeq)

Check warning on line 167 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L165-L167

Added lines #L165 - L167 were not covered by tests
else:
x0 = self.x0(p, k)
x, tcl, stats = self._solve(ts, ps, k, x0)
x, tcl, stats = self._solve(ts, ps, k, x0, checkpointed=checkpointed)
obs = self._obs(ts, x, ps, k, tcl)
my_r = my.reshape((len(ts), -1))
sigmay = self._sigmay(obs, ps, k)
Expand All @@ -191,12 +193,13 @@ def srun(
ts: np.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
):
(llh, (x, obs, stats)), sllh = (

Check warning on line 200 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L200

Added line #L200 was not covered by tests
jax.value_and_grad(self._run, 1, True)
)(ts, p, k, my, pscale)
)(ts, p, k, k_preeq, my, pscale)
return llh, sllh, (x, obs, stats)

Check warning on line 203 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L203

Added line #L203 was not covered by tests

@eqx.filter_jit
Expand All @@ -205,18 +208,23 @@ def s2run(
ts: np.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
):
(llh, (_, _, _)), sllh = (jax.value_and_grad(self._run, 1, True))(
ts, p, k, my, pscale
(llh, (x, obs, stats)), sllh = (

Check warning on line 215 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L215

Added line #L215 was not covered by tests
jax.value_and_grad(self._run, 1, True)
)(ts, p, k, k_preeq, my, pscale)

s2llh = jax.hessian(self._run, 1, True)(

Check warning on line 219 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L219

Added line #L219 was not covered by tests
ts, p, k, k_preeq, my, pscale, False
)
s2llh, (x, obs, stats) = jax.jacfwd(
jax.grad(self._run, 1, True), 1, True
)(ts, p, k, my, pscale)

return llh, sllh, s2llh, (x, obs, stats)

Check warning on line 223 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L223

Added line #L223 was not covered by tests

def run_simulation(self, edata: amici.ExpData):
def run_simulation(
self, edata: amici.ExpData, sensitivity_order: amici.SensitivityOrder
):
ts = np.asarray(edata.getTimepoints())
p = jnp.asarray(edata.parameters)
k = np.asarray(edata.fixedParameters)
Expand All @@ -226,18 +234,18 @@ def run_simulation(self, edata: amici.ExpData):

rdata_kwargs = dict()

Check warning on line 235 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L235

Added line #L235 was not covered by tests

if self.sensi_order == amici.SensitivityOrder.none:
if sensitivity_order == amici.SensitivityOrder.none:
(

Check warning on line 238 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L237-L238

Added lines #L237 - L238 were not covered by tests
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.run(ts, p, k, k_preeq, my, pscale)
elif self.sensi_order == amici.SensitivityOrder.first:
elif sensitivity_order == amici.SensitivityOrder.first:
(

Check warning on line 243 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L242-L243

Added lines #L242 - L243 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.srun(ts, p, k, k_preeq, my, pscale)
elif self.sensi_order == amici.SensitivityOrder.second:
elif sensitivity_order == amici.SensitivityOrder.second:
(

Check warning on line 249 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L248-L249

Added lines #L248 - L249 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
Expand All @@ -260,13 +268,17 @@ def run_simulation(self, edata: amici.ExpData):
def run_simulations(
self,
edatas: Iterable[amici.ExpData],
sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none,
num_threads: int = 1,
):
fun = eqx.Partial(

Check warning on line 274 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L274

Added line #L274 was not covered by tests
self.run_simulation, sensitivity_order=sensitivity_order
)
if num_threads > 1:
with ThreadPoolExecutor(max_workers=num_threads) as pool:
results = pool.map(self.run_simulation, edatas)
results = pool.map(fun, edatas)

Check warning on line 279 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L277-L279

Added lines #L277 - L279 were not covered by tests
else:
results = map(self.run_simulation, edatas)
results = map(fun, edatas)
return list(results)

Check warning on line 282 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L281-L282

Added lines #L281 - L282 were not covered by tests


Expand Down
136 changes: 73 additions & 63 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import amici

pytest.importorskip("jax")
import amici.jax

Expand All @@ -16,21 +17,18 @@ def test_conversion():
pysb.SelfExporter.cleanup() # reset pysb
pysb.SelfExporter.do_export = True

model = pysb.Model('conversion')
a = pysb.Monomer('A', sites=['s'], site_states={'s': ['a', 'b']})
pysb.Initial(a(s='a'), pysb.Parameter('aa0', 1.2))
pysb.Rule(
'conv',
a(s='a') >> a(s='b'), pysb.Parameter('kcat', 0.05)
)
pysb.Observable('ab', a(s='b'))
model = pysb.Model("conversion")
a = pysb.Monomer("A", sites=["s"], site_states={"s": ["a", "b"]})
pysb.Initial(a(s="a"), pysb.Parameter("aa0", 1.2))
pysb.Rule("conv", a(s="a") >> a(s="b"), pysb.Parameter("kcat", 0.05))
pysb.Observable("ab", a(s="b"))

outdir = model.name
pysb2amici(model, outdir, verbose=True,
observables=['ab'])
pysb2amici(model, outdir, verbose=True, observables=["ab"])

model_module = amici.import_model_module(module_name=model.name,
module_path=outdir)
model_module = amici.import_model_module(
module_name=model.name, module_path=outdir
)

ts = tuple(np.linspace(0, 1, 10))
p = jnp.stack((1.0, 0.1), axis=-1)
Expand All @@ -42,33 +40,44 @@ def test_dimerization():
pysb.SelfExporter.cleanup() # reset pysb
pysb.SelfExporter.do_export = True

model = pysb.Model('dimerization')
a = pysb.Monomer('A', sites=['b'])
b = pysb.Monomer('B', sites=['a'])

pysb.Rule('turnover_a',
a(b=None) | None,
pysb.Parameter('kdeg_a', 10),
pysb.Parameter('ksyn_a', 0.1))
pysb.Rule('turnover_b',
b(a=None) | None,
pysb.Parameter('kdeg_b', 0.1),
pysb.Parameter('ksyn_b', 10))
pysb.Rule('dimer',
a(b=None) + b(a=None) | a(b=1) % b(a=1),
pysb.Parameter('kon', 1.0),
pysb.Parameter('koff', 0.1))

pysb.Observable('a_obs', a())
pysb.Observable('b_obs', b())
model = pysb.Model("dimerization")
a = pysb.Monomer("A", sites=["b"])
b = pysb.Monomer("B", sites=["a"])

pysb.Rule(
"turnover_a",
a(b=None) | None,
pysb.Parameter("kdeg_a", 10),
pysb.Parameter("ksyn_a", 0.1),
)
pysb.Rule(
"turnover_b",
b(a=None) | None,
pysb.Parameter("kdeg_b", 0.1),
pysb.Parameter("ksyn_b", 10),
)
pysb.Rule(
"dimer",
a(b=None) + b(a=None) | a(b=1) % b(a=1),
pysb.Parameter("kon", 1.0),
pysb.Parameter("koff", 0.1),
)

pysb.Observable("a_obs", a())
pysb.Observable("b_obs", b())

outdir = model.name
pysb2amici(model, outdir, verbose=True,
observables=['a_obs', 'b_obs'],
constant_parameters=['ksyn_a', 'ksyn_b'])
pysb2amici(
model,
outdir,
verbose=True,
observables=["a_obs", "b_obs"],
constant_parameters=["ksyn_a", "ksyn_b"],
)

model_module = amici.import_model_module(module_name=model.name,
module_path=outdir)
model_module = amici.import_model_module(
module_name=model.name, module_path=outdir
)

ts = tuple(np.linspace(0, 1, 10))
p = jnp.stack((5, 0.5, 0.5, 0.5), axis=-1)
Expand All @@ -80,11 +89,11 @@ def _test_model(model_module, ts, p, k):
amici_model = model_module.getModel()

amici_model.setTimepoints(np.asarray(ts, dtype=np.float64))
sol_amici_ref = amici.runAmiciSimulation(amici_model,
amici_model.getSolver())
sol_amici_ref = amici.runAmiciSimulation(
amici_model, amici_model.getSolver()
)

jax_model = model_module.get_jax_model()
jax_solver = jax_model.get_solver()

amici_model.setParameters(np.asarray(p, dtype=np.float64))
amici_model.setFixedParameters(np.asarray(k, dtype=np.float64))
Expand All @@ -99,39 +108,40 @@ def _test_model(model_module, ts, p, k):
amici_solver = amici_model.getSolver()
amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward)
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)
rs_amici = amici.runAmiciSimulations(
amici_model,
amici_solver,
edatas
)

check_fields_jax(rs_amici, jax_model, jax_solver, edatas,
['x', 'y', 'llh'])
rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, edatas)

jax_solver.sensi_order = amici.SensitivityOrder.first
check_fields_jax(rs_amici, jax_model, jax_solver, edatas,
['x', 'y', 'llh', 'sllh'])

jax_solver.sensi_order = amici.SensitivityOrder.second
check_fields_jax(rs_amici, jax_model, jax_solver, edatas,
['x', 'y', 'llh', 'sllh'])
check_fields_jax(rs_amici, jax_model, edatas, ["x", "y", "llh"])

check_fields_jax(
rs_amici,
jax_model,
edatas,
["x", "y", "llh", "sllh"],
sensi_order=amici.SensitivityOrder.first,
)

def check_fields_jax(rs_amici,
jax_model,
jax_solver,
edatas,
fields):
rs_jax = amici.jax.run_simulations(
check_fields_jax(
rs_amici,
jax_model,
jax_solver,
edatas
edatas,
["x", "y", "llh", "sllh"],
sensi_order=amici.SensitivityOrder.second,
)


def check_fields_jax(
rs_amici,
jax_model,
edatas,
fields,
sensi_order=amici.SensitivityOrder.none,
):
rs_jax = jax_model.run_simulations(edatas, sensitivity_order=sensi_order)
for field in fields:
for r_amici, r_jax in zip(rs_amici, rs_jax):
assert_allclose(
actual=r_amici[field],
desired=r_jax[field],
atol=1e-6,
rtol=1e-6
rtol=1e-6,
)

0 comments on commit aebe07c

Please sign in to comment.