Skip to content

Commit

Permalink
add nan safe log&divide
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 30, 2024
1 parent 2366c2e commit f79a96e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 22 deletions.
11 changes: 2 additions & 9 deletions python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# ruff: noqa: F401, F821, F841
import jax.numpy as jnp
from interpax import interp1d

from amici.jax.model import JAXModel
from amici.jax.model import JAXModel, safe_log, safe_div


class JAXModel_TPL_MODEL_NAME(JAXModel):
Expand All @@ -11,7 +12,6 @@ def __init__(self):
super().__init__()

def _xdot(self, t, x, args):

pk, tcl = args

TPL_X_SYMS = x
Expand All @@ -24,7 +24,6 @@ def _xdot(self, t, x, args):
return TPL_XDOT_RET

def _w(self, t, x, pk, tcl):

TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_TCL_SYMS = tcl
Expand All @@ -34,23 +33,20 @@ def _w(self, t, x, pk, tcl):
return TPL_W_RET

def _x0(self, pk):

TPL_PK_SYMS = pk

TPL_X0_EQ

return TPL_X0_RET

def _x_solver(self, x):

TPL_X_RDATA_SYMS = x

TPL_X_SOLVER_EQ

return TPL_X_SOLVER_RET

def _x_rdata(self, x, tcl):

TPL_X_SYMS = x
TPL_TCL_SYMS = tcl

Expand All @@ -59,7 +55,6 @@ def _x_rdata(self, x, tcl):
return TPL_X_RDATA_RET

def _tcl(self, x, pk):

TPL_X_RDATA_SYMS = x
TPL_PK_SYMS = pk

Expand All @@ -68,7 +63,6 @@ def _tcl(self, x, pk):
return TPL_TOTAL_CL_RET

def _y(self, t, x, pk, tcl):

TPL_X_SYMS = x
TPL_PK_SYMS = pk
TPL_W_SYMS = self._w(t, x, pk, tcl)
Expand All @@ -86,7 +80,6 @@ def _sigmay(self, y, pk):

return TPL_SIGMAY_RET


def _nllh(self, t, x, pk, tcl, my, iy):
y = self._y(t, x, pk, tcl)
TPL_Y_SYMS = y
Expand Down
36 changes: 36 additions & 0 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,39 @@ def simulate_condition(
stats_dyn=stats_dyn,
stats_posteq=stats_posteq,
)


def safe_log(x: jnp.float_) -> jnp.float_:
"""
Safe logarithm that returns `jnp.log(jnp.finfo(jnp.float_).eps)` for x <= 0.
:param x:
input
:return:
logarithm of x
"""
# see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard
# against nans in forward & backward passes
safe_x = jnp.where(

Check warning on line 573 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L573

Added line #L573 was not covered by tests
x > jnp.finfo(jnp.float_).eps, x, jnp.finfo(jnp.float_).eps
)
return jnp.where(

Check warning on line 576 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L576

Added line #L576 was not covered by tests
x > 0, jnp.log(safe_x), jnp.log(jnp.finfo(jnp.float_).eps)
)


def safe_div(x: jnp.float_, y: jnp.float_) -> jnp.float_:
"""
Safe division that returns `x/jnp.finfo(jnp.float_).eps` for `y == 0`.
:param x:
numerator
:param y:
denominator
:return:
x / y
"""
# see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard
# against nans in forward & backward passes
safe_y = jnp.where(y != 0, y, jnp.finfo(jnp.float_).eps)
return jnp.where(y != 0, x / safe_y, x / jnp.finfo(jnp.float_).eps)

Check warning on line 595 in python/sdist/amici/jax/model.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax/model.py#L594-L595

Added lines #L594 - L595 were not covered by tests
9 changes: 9 additions & 0 deletions python/sdist/amici/jaxcodeprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ def _print_AmiciSpline(self, expr: sp.Expr) -> str:
# FIXME: untested, where are spline nodes coming from anyways?
return f'interp1d(time, {self.doprint(expr.args[2:])}, kind="cubic")'

def _print_log(self, expr: sp.Expr) -> str:
return f"safe_log({self.doprint(expr.args[0])})"

def _print_Mul(self, expr: sp.Expr) -> str:
numer, denom = expr.as_numer_denom()
if denom == 1:
return super()._print_Mul(expr)
return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})"

def _get_sym_lines(
self,
symbols: sp.Matrix | Iterable[str],
Expand Down
20 changes: 7 additions & 13 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,8 @@ def test_jax_llh(benchmark_problem):

np.random.seed(cur_settings.rng_seed)

problems_for_gradient_check_jax = list(
set(problems_for_gradient_check) - {"Laske_PLOSComputBiol2019"}
# Laske has nan values in gradient due to nan values in observables that are not used in the likelihood
# but are problematic during backpropagation
)

problem_parameters = None
if problem_id in problems_for_gradient_check_jax:
if problem_id in problems_for_gradient_check:
point = petab_problem.x_nominal_free_scaled
for _ in range(20):
amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint)
Expand Down Expand Up @@ -352,12 +346,12 @@ def test_jax_llh(benchmark_problem):
[problem_parameters[pid] for pid in jax_problem.parameter_ids]
),
)
if problem_id in problems_for_gradient_check_jax:
(llh_jax, _), sllh_jax = eqx.filter_jit(
eqx.filter_value_and_grad(run_simulations, has_aux=True)
if problem_id in problems_for_gradient_check:
(llh_jax, _), sllh_jax = eqx.filter_value_and_grad(
run_simulations, has_aux=True
)(jax_problem, simulation_conditions)
else:
llh_jax, _ = beartype(eqx.filter_jit(run_simulations))(
llh_jax, _ = beartype(run_simulations)(
jax_problem, simulation_conditions
)

Expand All @@ -369,14 +363,14 @@ def test_jax_llh(benchmark_problem):
err_msg=f"LLH mismatch for {problem_id}",
)

if problem_id in problems_for_gradient_check_jax:
if problem_id in problems_for_gradient_check:
sllh_amici = r_amici[SLLH]
np.testing.assert_allclose(
sllh_jax.parameters,
np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]),
rtol=1e-2,
atol=1e-2,
err_msg=f"SLLH mismatch for {problem_id}",
err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}",
)


Expand Down

0 comments on commit f79a96e

Please sign in to comment.