Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax export #1861

Merged
merged 94 commits into from
Nov 19, 2024
Merged
Changes from 1 commit
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
328d462
basic prototype
FFroehlich Aug 25, 2022
ffa5afb
Merge branch 'develop' into jax_export
FFroehlich Aug 26, 2022
d4f8552
add dimerization example, add second order code, refactor jit
FFroehlich Aug 26, 2022
d37a850
remove equinox dependency, list dependencies
FFroehlich Aug 26, 2022
ff37c7e
make jax optional
FFroehlich Aug 26, 2022
c3a77f7
Merge branch 'develop' into jax_export
FFroehlich Aug 26, 2022
7cd8553
support conservation laws
FFroehlich Aug 26, 2022
5177ad7
fixup
FFroehlich Aug 26, 2022
5612cfc
fix jit nesting
FFroehlich Aug 26, 2022
2dd0377
use vmap for vectorization
FFroehlich Aug 26, 2022
e9bd14f
fixups
FFroehlich Aug 26, 2022
bbb5246
add multithreaded simulation runner
FFroehlich Aug 26, 2022
9bd1004
fix my
FFroehlich Aug 26, 2022
1b06c24
Merge branch 'develop' into jax_export
FFroehlich Sep 9, 2022
599aa71
fixes
FFroehlich Sep 13, 2022
51812d6
merge
FFroehlich Apr 10, 2024
3fbd17a
fixup merge
FFroehlich Apr 10, 2024
5974d47
fix install
FFroehlich Apr 10, 2024
37cdc81
actually generate code
FFroehlich Apr 10, 2024
9e6a0ff
fix
FFroehlich Apr 10, 2024
22b2b38
fix
FFroehlich Apr 10, 2024
48a2e49
add better default coefficients, fix jax
FFroehlich Apr 10, 2024
481216d
ignore fujita in jax
FFroehlich Apr 10, 2024
85b8173
ignore smith
FFroehlich Apr 10, 2024
b213adb
optimize & fix bachmann
FFroehlich Apr 11, 2024
a1f37b7
fix import/wokflow
FFroehlich Apr 11, 2024
e09bb2f
Update __init__.template.py
FFroehlich Apr 12, 2024
d8d1900
fix jax imports
FFroehlich Apr 12, 2024
c24fe6b
Update setup.cfg
FFroehlich Apr 12, 2024
1ec591c
add preequilibration support
FFroehlich Apr 12, 2024
aebe07c
fix jax tests
FFroehlich Apr 13, 2024
4125c51
add filterwarning
FFroehlich Apr 14, 2024
8143cc2
fix parameter transformation
FFroehlich Apr 14, 2024
781bb3b
Merge branch 'develop' into jax_export
FFroehlich Oct 19, 2024
81e2aeb
reenable ruff format
FFroehlich Oct 19, 2024
c01f707
post merge cleanup
FFroehlich Oct 19, 2024
a5d356a
"fix" splines
FFroehlich Oct 19, 2024
9a021cf
Update .pre-commit-config.yaml
FFroehlich Oct 19, 2024
a02d215
Merge branch 'develop' into jax_export
FFroehlich Oct 19, 2024
50193d8
force optimistix 0.0.9
FFroehlich Oct 21, 2024
d6c5bcd
Merge branch 'jax_export' of https://github.com/AMICI-dev/AMICI into …
FFroehlich Oct 21, 2024
7faae32
add support for heavyside functions
FFroehlich Oct 21, 2024
907acb7
cleanup & actually run tests
FFroehlich Oct 21, 2024
82a01ba
simply tests + add support for non-dynamic simulation in jax
FFroehlich Oct 22, 2024
7c3aef9
Merge branch 'develop' into jax_export
FFroehlich Oct 23, 2024
c548c93
fix for NONCONST_CLS
FFroehlich Oct 24, 2024
7c27a21
fix petab path
FFroehlich Oct 24, 2024
b84dbdb
Merge branch 'develop' into jax_export
FFroehlich Oct 24, 2024
37b9329
Merge branch 'develop' into jax_export
FFroehlich Oct 24, 2024
956b0a6
fixup merge
FFroehlich Oct 24, 2024
2f3834d
support postequilibration
FFroehlich Oct 25, 2024
5366632
fixup
FFroehlich Oct 25, 2024
5a86f4c
fix
FFroehlich Oct 25, 2024
480b75a
fix gradients
FFroehlich Oct 25, 2024
8b9c10a
fix hessian
FFroehlich Oct 25, 2024
7dc81ac
Update test_petab_benchmark.py
FFroehlich Oct 25, 2024
866c811
Merge branch 'develop' into jax_export
FFroehlich Oct 27, 2024
02a1272
skip smith in jax
FFroehlich Oct 27, 2024
51bd18c
exclude more models
FFroehlich Oct 27, 2024
c7c5d4b
refactor: remove use of edatas
FFroehlich Nov 9, 2024
a514deb
update template
FFroehlich Nov 9, 2024
498681a
Update .pre-commit-config.yaml
FFroehlich Nov 9, 2024
4a5e7d2
Merge branch 'develop' into jax_export
FFroehlich Nov 11, 2024
f745be0
fix python jax tests
FFroehlich Nov 12, 2024
a64f89b
simplify petab interface
FFroehlich Nov 12, 2024
7292451
add parameter values to model class
FFroehlich Nov 12, 2024
da02106
refactor parameter mapping
FFroehlich Nov 12, 2024
a46e65d
refactor & simplify
FFroehlich Nov 12, 2024
404d82e
refsctor
FFroehlich Nov 16, 2024
e399f4c
update template
FFroehlich Nov 16, 2024
eaae778
Update .pre-commit-config.yaml
FFroehlich Nov 16, 2024
d79cfc1
refactor fix test
FFroehlich Nov 16, 2024
94aa679
Update petab.py
FFroehlich Nov 16, 2024
b129c86
fixups
FFroehlich Nov 17, 2024
9b6a62b
fixup
FFroehlich Nov 17, 2024
74cd498
add documentation and typing
FFroehlich Nov 17, 2024
d94714b
add runtime typechecks to jax tests
FFroehlich Nov 17, 2024
0a9fcdf
add coverage from benchmark tests
FFroehlich Nov 17, 2024
186805c
add api versioning and reenable jit compilation
FFroehlich Nov 17, 2024
250f9dd
review comments
FFroehlich Nov 18, 2024
dc4992e
use temporary directories
FFroehlich Nov 18, 2024
d547509
fix doc
FFroehlich Nov 18, 2024
82bfe31
Update test_jax.py
FFroehlich Nov 18, 2024
a010803
don't generate code if jax/diffrax not available
FFroehlich Nov 18, 2024
d9ae05e
Merge branch 'develop' into jax_export
FFroehlich Nov 18, 2024
f7c2c10
add example
FFroehlich Nov 19, 2024
5dc8735
fix doc
FFroehlich Nov 19, 2024
784ab2c
fix notebook symlink
FFroehlich Nov 19, 2024
d528168
update notebook
FFroehlich Nov 19, 2024
24d8c09
Update ExampleJaxPEtab.ipynb
FFroehlich Nov 19, 2024
5393e6c
Update ExampleJaxPEtab.ipynb
FFroehlich Nov 19, 2024
a22f099
fix compilation issue
FFroehlich Nov 19, 2024
a585414
Merge branch 'develop' into jax_export
FFroehlich Nov 19, 2024
c242b15
fix
FFroehlich Nov 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor & simplify
FFroehlich committed Nov 12, 2024
commit a46e65d270d6a7beb952cac26c50ccb93e7121c3
207 changes: 107 additions & 100 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import jax.numpy as jnp
import numpy as np
import jax
import pandas as pd
import petab.v1 as petab

import amici
@@ -24,102 +25,120 @@


class JAXModel(eqx.Module):
_unscale_funs = {
amici.ParameterScaling.none: lambda x: x,
amici.ParameterScaling.ln: lambda x: jnp.exp(x),
amici.ParameterScaling.log10: lambda x: jnp.power(10, x),
}
solver: diffrax.AbstractSolver
controller: diffrax.AbstractStepSizeController
atol: float
rtol: float
pcoeff: float
icoeff: float
dcoeff: float
maxsteps: int
parameters: jnp.ndarray
parameter_mappings: dict[tuple[str], ParameterMappingForCondition]
term: diffrax.ODETerm
parameter_mappings: dict[tuple[str], ParameterMappingForCondition] | None
measurements: dict[tuple[str], pd.DataFrame] | None
petab_problem: petab.Problem | None

def __init__(self):
self.solver = diffrax.Kvaerno5()
self.atol: float = 1e-8
self.rtol: float = 1e-8
self.pcoeff: float = 0.4
self.icoeff: float = 0.3
self.dcoeff: float = 0.0
self.maxsteps: int = 2**14
self.controller = diffrax.PIDController(
rtol=self.rtol,
atol=self.atol,
pcoeff=self.pcoeff,
icoeff=self.icoeff,
dcoeff=self.dcoeff,
rtol=1e-8,
atol=1e-8,
pcoeff=0.4,
icoeff=0.3,
dcoeff=0.0,
)
self.term = diffrax.ODETerm(self.xdot)
self.petab_problem = None
self.parameter_mappings = None
self.measurements = None
self.parameters = jnp.array([])

def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
"""
Set the PEtab problem for the model and updates parameters to the nominal values.
:param petab_problem:
Petab problem to set.
:return: JAXModel instance
"""

is_leaf = lambda x: x is None if self.petab_problem is None else None # noqa: E731
model = eqx.tree_at(
lambda x: x.petab_problem,
self,
petab_problem,
is_leaf=is_leaf,
)

simulation_conditions = (
petab_problem.get_simulation_conditions_from_measurement_df()
)

def _set_parameter_mappings(
self, simulation_conditions: pd.DataFrame
) -> "JAXModel":
mappings = create_parameter_mapping(

Check warning on line 54 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L54

Added line #L54 was not covered by tests
petab_problem=petab_problem,
petab_problem=self.petab_problem,
simulation_conditions=simulation_conditions,
scaled_parameters=False,
amici_model=self,
)

parameter_mappings = {

Check warning on line 61 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L61

Added line #L61 was not covered by tests
tuple(simulation_condition.values): mapping
for (_, simulation_condition), mapping in zip(
simulation_conditions.iterrows(), mappings
)
}

is_leaf = ( # noqa: E731

Check warning on line 68 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L68

Added line #L68 was not covered by tests
lambda x: x is None if self.parameter_mappings is None else None
)
model = eqx.tree_at(
return eqx.tree_at(

Check warning on line 71 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L71

Added line #L71 was not covered by tests
lambda x: x.parameter_mappings,
model,
self,
parameter_mappings,
is_leaf=is_leaf,
)

def _set_measurements(
self, simulation_conditions: pd.DataFrame
) -> "JAXModel":
measurements = dict()
for _, simulation_condition in simulation_conditions.iterrows():
measurements_df = self.petab_problem.measurement_df
for k, v in simulation_condition.items():
measurements_df = measurements_df.query(f"{k} == '{v}'")

Check warning on line 85 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L81-L85

Added lines #L81 - L85 were not covered by tests

ts = _get_timepoints_with_replicates(measurements_df)
my = _get_measurements_and_sigmas(

Check warning on line 88 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L87-L88

Added lines #L87 - L88 were not covered by tests
measurements_df, ts, self.observable_ids
)[0].flatten()
measurements[tuple(simulation_condition)] = np.array(ts), my
is_leaf = ( # noqa: E731

Check warning on line 92 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L91-L92

Added lines #L91 - L92 were not covered by tests
lambda x: x is None if self.measurements is None else None
)
return eqx.tree_at(

Check warning on line 95 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L95

Added line #L95 was not covered by tests
lambda x: x.measurements,
self,
measurements,
is_leaf=is_leaf,
)

def _set_nominal_parameter_values(self) -> "JAXModel":
nominal_values = jnp.array(

Check warning on line 103 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L103

Added line #L103 was not covered by tests
[
petab.scale(
model.petab_problem.parameter_df.loc[
self.petab_problem.parameter_df.loc[
pval, petab.NOMINAL_VALUE
],
model.petab_problem.parameter_df.loc[
self.petab_problem.parameter_df.loc[
pval, petab.PARAMETER_SCALE
],
)
for pval in model.petab_parameter_ids()
for pval in self.petab_parameter_ids()
]
)
return eqx.tree_at(lambda x: x.parameters, self, nominal_values)

Check warning on line 116 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L116

Added line #L116 was not covered by tests

return eqx.tree_at(lambda x: x.parameters, model, nominal_values)
def _set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
is_leaf = lambda x: x is None if self.petab_problem is None else None # noqa: E731
return eqx.tree_at(

Check warning on line 120 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L119-L120

Added lines #L119 - L120 were not covered by tests
lambda x: x.petab_problem,
self,
petab_problem,
is_leaf=is_leaf,
)

def set_petab_problem(self, petab_problem: petab.Problem) -> "JAXModel":
"""
Set the PEtab problem for the model and updates parameters to the nominal values.
:param petab_problem:
Petab problem to set.
:return: JAXModel instance
"""

model = self._set_petab_problem(petab_problem)
simulation_conditions = (

Check warning on line 136 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L135-L136

Added lines #L135 - L136 were not covered by tests
petab_problem.get_simulation_conditions_from_measurement_df()
)
model = model._set_parameter_mappings(simulation_conditions)
model = model._set_measurements(simulation_conditions)
return model._set_nominal_parameter_values()

Check warning on line 141 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L139-L141

Added lines #L139 - L141 were not covered by tests

@staticmethod
@abstractmethod
@@ -179,7 +198,7 @@
parameter mappings via :func:`amici.petab.create_parameter_mapping`.
:return:
"""
return self.parameter_ids

Check warning on line 201 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L201

Added line #L201 was not covered by tests

def getFixedParameterIds(self) -> list[str]: # noqa: N802
"""
@@ -187,15 +206,15 @@
parameter mappings via :func:`amici.petab.create_parameter_mapping`.
:return:
"""
return self.fixed_parameter_ids

Check warning on line 209 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L209

Added line #L209 was not covered by tests

def petab_parameter_ids(self) -> list[str]:
return self.petab_problem.parameter_df[

Check warning on line 212 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L212

Added line #L212 was not covered by tests
self.petab_problem.parameter_df[petab.ESTIMATE] == 1
].index.tolist()

def get_petab_parameter_by_name(self, name: str) -> jnp.float_:
return self.parameters[self.petab_parameter_ids().index(name)]

Check warning on line 217 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L217

Added line #L217 was not covered by tests

def _unscale_p(self, p, pscale):
return jax.vmap(
@@ -207,16 +226,16 @@
)(p, pscale)

def _preeq(self, p, k):
x0 = self.x_solver(self.x0(p, k))
tcl = self.tcl(x0, p, k)
return self._eq(p, k, tcl, x0)

Check warning on line 231 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L229-L231

Added lines #L229 - L231 were not covered by tests

def _posteq(self, p, k, x, tcl):
return self._eq(p, k, tcl, x)

Check warning on line 234 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L234

Added line #L234 was not covered by tests

def _eq(self, p, k, tcl, x0):
sol = diffrax.diffeqsolve(

Check warning on line 237 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L237

Added line #L237 was not covered by tests
self.term,
diffrax.ODETerm(self.xdot),
self.solver,
args=(p, k, tcl),
t0=0.0,
@@ -227,12 +246,12 @@
max_steps=self.maxsteps,
event=diffrax.Event(cond_fn=diffrax.steady_state_event()),
)
return sol.ys

Check warning on line 249 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L249

Added line #L249 was not covered by tests

def _solve(self, ts, p, k, x0, checkpointed):
tcl = self.tcl(x0, p, k)
sol = diffrax.diffeqsolve(
self.term,
diffrax.ODETerm(self.xdot),
self.solver,
args=(p, k, tcl),
t0=0.0,
@@ -264,15 +283,15 @@
loss_fun = jax.vmap(self.Jy, in_axes=(0, 0, 0))
return -jnp.sum(loss_fun(obs, my, sigmay))

def _run(
def run_condition(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
ts: jnp.ndarray,
ts_dyn: jnp.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
k: jnp.ndarray,
k_preeq: jnp.ndarray,
my: jnp.ndarray,
pscale: np.ndarray,
pscale: jnp.ndarray,
checkpointed=True,
dynamic="true",
):
@@ -280,7 +299,7 @@

# Pre-equilibration
if k_preeq.shape[0] > 0:
x0 = self._preeq(ps, k_preeq)

Check warning on line 302 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L302

Added line #L302 was not covered by tests
else:
x0 = self.x0(ps, k)

@@ -290,30 +309,30 @@
ts_dyn, ps, k, x0, checkpointed=checkpointed
)
else:
x = tuple(

Check warning on line 312 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L312

Added line #L312 was not covered by tests
jnp.array([x0_i] * len(ts_dyn)) for x0_i in self.x_solver(x0)
)
tcl = self.tcl(x0, ps, k)
stats = None

Check warning on line 316 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L315-L316

Added lines #L315 - L316 were not covered by tests

# Post-equilibration
if len(ts) > len(ts_dyn):
if len(ts_dyn) > 0:
x_final = tuple(x_i[-1] for x_i in x)

Check warning on line 321 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L320-L321

Added lines #L320 - L321 were not covered by tests
else:
x_final = self.x_solver(x0)
x_posteq = self._posteq(ps, k, x_final, tcl)
x_posteq = tuple(

Check warning on line 325 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L323-L325

Added lines #L323 - L325 were not covered by tests
jnp.array([x0_i] * (len(ts) - len(ts_dyn)))
for x0_i in x_posteq
)
if len(ts_dyn) > 0:
x = tuple(

Check warning on line 330 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L329-L330

Added lines #L329 - L330 were not covered by tests
jnp.concatenate((x_i, x_posteq_i), axis=0)
for x_i, x_posteq_i in zip(x, x_posteq)
)
else:
x = x_posteq

Check warning on line 335 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L335

Added line #L335 was not covered by tests

obs = jnp.stack(self._obs(ts, x, ps, k, tcl), axis=1)
my_r = my.reshape((len(ts), -1))
@@ -323,55 +342,55 @@
return llh, (x_rdata, obs, stats)

@eqx.filter_jit
def run(
def _fun(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
ts: jnp.ndarray,
ts_dyn: jnp.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
k: jnp.ndarray,
k_preeq: jnp.ndarray,
my: jnp.ndarray,
pscale: jnp.ndarray,
dynamic="true",
):
return self._run(
return self.run_condition(
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
)

@eqx.filter_jit
def srun(
def _grad(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
ts: jnp.ndarray,
ts_dyn: jnp.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
k: jnp.ndarray,
k_preeq: jnp.ndarray,
my: jnp.ndarray,
pscale: jnp.ndarray,
dynamic="true",
):
(llh, (x, obs, stats)), sllh = (
jax.value_and_grad(self._run, 2, True)
jax.value_and_grad(self.run_condition, 2, True)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic)
return llh, sllh, (x, obs, stats)

@eqx.filter_jit
def s2run(
def _hessian(
self,
ts: np.ndarray,
ts_dyn: np.ndarray,
ts: jnp.ndarray,
ts_dyn: jnp.ndarray,
p: jnp.ndarray,
k: np.ndarray,
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
k: jnp.ndarray,
k_preeq: jnp.ndarray,
my: jnp.ndarray,
pscale: jnp.ndarray,
dynamic="true",
):
(llh, (x, obs, stats)), sllh = (
jax.value_and_grad(self._run, 2, True)
jax.value_and_grad(self.run_condition, 2, True)
)(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic)

s2llh = jax.hessian(self._run, 2, True)(
s2llh = jax.hessian(self.run_condition, 2, True)(
ts,
ts_dyn,
p,
@@ -390,18 +409,9 @@
simulation_condition: tuple[str],
sensitivity_order: amici.SensitivityOrder = amici.SensitivityOrder.none,
):
parameter_mapping = self.parameter_mappings[simulation_condition]
measurements_df = self.petab_problem.measurement_df
for v, k in zip(
simulation_condition,
(
petab.SIMULATION_CONDITION_ID,
petab.PREEQUILIBRATION_CONDITION_ID,
),
):
measurements_df = measurements_df.query(f"{k} == '{v}'")
ts = _get_timepoints_with_replicates(measurements_df)
ts, my = self.measurements[simulation_condition]
p = jnp.array(

Check warning on line 414 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L412-L414

Added lines #L412 - L414 were not covered by tests
[
pval
if isinstance(
@@ -411,72 +421,69 @@
for par in self.parameter_ids
]
)
pscale = jnp.array(

Check warning on line 424 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L424

Added line #L424 was not covered by tests
[
0 if s == petab.LIN else 1 if s == petab.LOG else 2
for s in parameter_mapping.scale_map_sim_var.values()
]
)
k_sim = np.array(

Check warning on line 430 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L430

Added line #L430 was not covered by tests
[
parameter_mapping.map_sim_fix[k]
for k in self.fixed_parameter_ids
]
)
k_preeq = np.array(

Check warning on line 436 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L436

Added line #L436 was not covered by tests
[
parameter_mapping.map_preeq_fix[k]
for k in self.fixed_parameter_ids
if k in parameter_mapping.map_preeq_fix
]
)
my = _get_measurements_and_sigmas(
measurements_df, ts, self.observable_ids
)[0].flatten()
ts = np.array(ts)

ts_dyn = ts[np.isfinite(ts)]
dynamic = "true" if len(ts_dyn) and np.max(ts_dyn) > 0 else "false"

Check warning on line 445 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L444-L445

Added lines #L444 - L445 were not covered by tests

rdata_kwargs = dict(

Check warning on line 447 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L447

Added line #L447 was not covered by tests
simulation_condition=simulation_condition,
)

if sensitivity_order == amici.SensitivityOrder.none:
(

Check warning on line 452 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L451-L452

Added lines #L451 - L452 were not covered by tests
rdata_kwargs["llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.run(
) = self._fun(
ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic
)
elif sensitivity_order == amici.SensitivityOrder.first:
(

Check warning on line 459 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L458-L459

Added lines #L458 - L459 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.srun(
) = self._grad(
ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic
)
elif sensitivity_order == amici.SensitivityOrder.second:
(

Check warning on line 467 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L466-L467

Added lines #L466 - L467 were not covered by tests
rdata_kwargs["llh"],
rdata_kwargs["sllh"],
rdata_kwargs["s2llh"],
(rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]),
) = self.s2run(
) = self._hessian(
ts, ts_dyn, p, k_sim, k_preeq, my, pscale, dynamic=dynamic
)

for field in rdata_kwargs.keys():
if field == "llh":
rdata_kwargs[field] = np.float64(rdata_kwargs[field])
elif field not in ["sllh", "s2llh"]:
rdata_kwargs[field] = np.asarray(rdata_kwargs[field]).T
if rdata_kwargs[field].ndim == 1:
rdata_kwargs[field] = np.expand_dims(

Check warning on line 482 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L476-L482

Added lines #L476 - L482 were not covered by tests
rdata_kwargs[field], 1
)

return ReturnDataJAX(**rdata_kwargs)

Check warning on line 486 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L486

Added line #L486 was not covered by tests

def run_simulations(
self,
@@ -484,17 +491,17 @@
num_threads: int = 1,
simulation_conditions: tuple[tuple[str]] = None,
):
fun = eqx.Partial(

Check warning on line 494 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L494

Added line #L494 was not covered by tests
self.run_simulation,
sensitivity_order=sensitivity_order,
)

if num_threads > 1:
with ThreadPoolExecutor(max_workers=num_threads) as pool:
results = pool.map(fun, simulation_conditions)

Check warning on line 501 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L499-L501

Added lines #L499 - L501 were not covered by tests
else:
results = map(fun, simulation_conditions)
return list(results)

Check warning on line 504 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L503-L504

Added lines #L503 - L504 were not covered by tests


@dataclass
@@ -509,5 +516,5 @@
stats: dict = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self

Check warning on line 520 in python/sdist/amici/jax.py

Codecov / codecov/patch

python/sdist/amici/jax.py#L519-L520

Added lines #L519 - L520 were not covered by tests
6 changes: 3 additions & 3 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -151,20 +151,20 @@ def check_fields_jax(
(
r_jax["llh"],
(r_jax["x"], r_jax["y"], r_jax["stats"]),
) = jax_model.run(**kwargs)
) = jax_model._fun(**kwargs)
elif sensi_order == amici.SensitivityOrder.first:
(
r_jax["llh"],
r_jax["sllh"],
(r_jax["x"], r_jax["y"], r_jax["stats"]),
) = jax_model.srun(**kwargs)
) = jax_model._grad(**kwargs)
elif sensi_order == amici.SensitivityOrder.second:
(
r_jax["llh"],
r_jax["sllh"],
r_jax["s2llh"],
(r_jax["x"], r_jax["y"], r_jax["stats"]),
) = jax_model.s2run(**kwargs)
) = jax_model._hessian(**kwargs)

for field in fields:
for r_amici, r_jax in zip(rs_amici, [r_jax]):

Unchanged files with check annotations Beta

code = re.sub(r"numpy\.", r"jnp.", code)
return code
except TypeError as e:

Check warning on line 20 in python/sdist/amici/jaxcodeprinter.py

Codecov / codecov/patch

python/sdist/amici/jaxcodeprinter.py#L20

Added line #L20 was not covered by tests
raise ValueError(
f'Encountered unsupported function in expression "{expr}"'
) from e