From f79a96ee159b5be84bc89843b60c4700f659636f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Sat, 30 Nov 2024 22:26:04 +0000 Subject: [PATCH] add nan safe log÷ --- python/sdist/amici/jax.template.py | 11 ++---- python/sdist/amici/jax/model.py | 36 +++++++++++++++++++ python/sdist/amici/jaxcodeprinter.py | 9 +++++ .../benchmark-models/test_petab_benchmark.py | 20 ++++------- 4 files changed, 54 insertions(+), 22 deletions(-) diff --git a/python/sdist/amici/jax.template.py b/python/sdist/amici/jax.template.py index 367ba9e500..683ebe3c02 100644 --- a/python/sdist/amici/jax.template.py +++ b/python/sdist/amici/jax.template.py @@ -1,7 +1,8 @@ +# ruff: noqa: F401, F821, F841 import jax.numpy as jnp from interpax import interp1d -from amici.jax.model import JAXModel +from amici.jax.model import JAXModel, safe_log, safe_div class JAXModel_TPL_MODEL_NAME(JAXModel): @@ -11,7 +12,6 @@ def __init__(self): super().__init__() def _xdot(self, t, x, args): - pk, tcl = args TPL_X_SYMS = x @@ -24,7 +24,6 @@ def _xdot(self, t, x, args): return TPL_XDOT_RET def _w(self, t, x, pk, tcl): - TPL_X_SYMS = x TPL_PK_SYMS = pk TPL_TCL_SYMS = tcl @@ -34,7 +33,6 @@ def _w(self, t, x, pk, tcl): return TPL_W_RET def _x0(self, pk): - TPL_PK_SYMS = pk TPL_X0_EQ @@ -42,7 +40,6 @@ def _x0(self, pk): return TPL_X0_RET def _x_solver(self, x): - TPL_X_RDATA_SYMS = x TPL_X_SOLVER_EQ @@ -50,7 +47,6 @@ def _x_solver(self, x): return TPL_X_SOLVER_RET def _x_rdata(self, x, tcl): - TPL_X_SYMS = x TPL_TCL_SYMS = tcl @@ -59,7 +55,6 @@ def _x_rdata(self, x, tcl): return TPL_X_RDATA_RET def _tcl(self, x, pk): - TPL_X_RDATA_SYMS = x TPL_PK_SYMS = pk @@ -68,7 +63,6 @@ def _tcl(self, x, pk): return TPL_TOTAL_CL_RET def _y(self, t, x, pk, tcl): - TPL_X_SYMS = x TPL_PK_SYMS = pk TPL_W_SYMS = self._w(t, x, pk, tcl) @@ -86,7 +80,6 @@ def _sigmay(self, y, pk): return TPL_SIGMAY_RET - def _nllh(self, t, x, pk, tcl, my, iy): y = self._y(t, x, pk, tcl) TPL_Y_SYMS = y diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index a7b274027a..3425f5c015 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -557,3 +557,39 @@ def simulate_condition( stats_dyn=stats_dyn, stats_posteq=stats_posteq, ) + + +def safe_log(x: jnp.float_) -> jnp.float_: + """ + Safe logarithm that returns `jnp.log(jnp.finfo(jnp.float_).eps)` for x <= 0. + + :param x: + input + :return: + logarithm of x + """ + # see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard + # against nans in forward & backward passes + safe_x = jnp.where( + x > jnp.finfo(jnp.float_).eps, x, jnp.finfo(jnp.float_).eps + ) + return jnp.where( + x > 0, jnp.log(safe_x), jnp.log(jnp.finfo(jnp.float_).eps) + ) + + +def safe_div(x: jnp.float_, y: jnp.float_) -> jnp.float_: + """ + Safe division that returns `x/jnp.finfo(jnp.float_).eps` for `y == 0`. + + :param x: + numerator + :param y: + denominator + :return: + x / y + """ + # see https://docs.kidger.site/equinox/api/debug/, need double jnp.where to guard + # against nans in forward & backward passes + safe_y = jnp.where(y != 0, y, jnp.finfo(jnp.float_).eps) + return jnp.where(y != 0, x / safe_y, x / jnp.finfo(jnp.float_).eps) diff --git a/python/sdist/amici/jaxcodeprinter.py b/python/sdist/amici/jaxcodeprinter.py index ed9181cc09..6cfce97b35 100644 --- a/python/sdist/amici/jaxcodeprinter.py +++ b/python/sdist/amici/jaxcodeprinter.py @@ -27,6 +27,15 @@ def _print_AmiciSpline(self, expr: sp.Expr) -> str: # FIXME: untested, where are spline nodes coming from anyways? return f'interp1d(time, {self.doprint(expr.args[2:])}, kind="cubic")' + def _print_log(self, expr: sp.Expr) -> str: + return f"safe_log({self.doprint(expr.args[0])})" + + def _print_Mul(self, expr: sp.Expr) -> str: + numer, denom = expr.as_numer_denom() + if denom == 1: + return super()._print_Mul(expr) + return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})" + def _get_sym_lines( self, symbols: sp.Matrix | Iterable[str], diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index 7a0afc6832..dc20fab2d3 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -299,14 +299,8 @@ def test_jax_llh(benchmark_problem): np.random.seed(cur_settings.rng_seed) - problems_for_gradient_check_jax = list( - set(problems_for_gradient_check) - {"Laske_PLOSComputBiol2019"} - # Laske has nan values in gradient due to nan values in observables that are not used in the likelihood - # but are problematic during backpropagation - ) - problem_parameters = None - if problem_id in problems_for_gradient_check_jax: + if problem_id in problems_for_gradient_check: point = petab_problem.x_nominal_free_scaled for _ in range(20): amici_solver.setSensitivityMethod(amici.SensitivityMethod.adjoint) @@ -352,12 +346,12 @@ def test_jax_llh(benchmark_problem): [problem_parameters[pid] for pid in jax_problem.parameter_ids] ), ) - if problem_id in problems_for_gradient_check_jax: - (llh_jax, _), sllh_jax = eqx.filter_jit( - eqx.filter_value_and_grad(run_simulations, has_aux=True) + if problem_id in problems_for_gradient_check: + (llh_jax, _), sllh_jax = eqx.filter_value_and_grad( + run_simulations, has_aux=True )(jax_problem, simulation_conditions) else: - llh_jax, _ = beartype(eqx.filter_jit(run_simulations))( + llh_jax, _ = beartype(run_simulations)( jax_problem, simulation_conditions ) @@ -369,14 +363,14 @@ def test_jax_llh(benchmark_problem): err_msg=f"LLH mismatch for {problem_id}", ) - if problem_id in problems_for_gradient_check_jax: + if problem_id in problems_for_gradient_check: sllh_amici = r_amici[SLLH] np.testing.assert_allclose( sllh_jax.parameters, np.array([sllh_amici[pid] for pid in jax_problem.parameter_ids]), rtol=1e-2, atol=1e-2, - err_msg=f"SLLH mismatch for {problem_id}", + err_msg=f"SLLH mismatch for {problem_id}, {dict(zip(jax_problem.parameter_ids, sllh_jax.parameters))}", )