From 82a01bacb8970f29ee614babbbfc7778c6a131c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Tue, 22 Oct 2024 15:58:12 +0100 Subject: [PATCH] simply tests + add support for non-dynamic simulation in jax --- .../test_benchmark_collection_models.yml | 4 +- python/sdist/amici/jax.py | 53 ++++++++++--------- python/sdist/amici/petab/petab_import.py | 16 +++++- .../test_benchmark_collection.sh | 12 +---- tests/benchmark-models/test_petab_model.py | 34 ++++++------ 5 files changed, 62 insertions(+), 57 deletions(-) diff --git a/.github/workflows/test_benchmark_collection_models.yml b/.github/workflows/test_benchmark_collection_models.yml index 39eef6f9be..81c971be15 100644 --- a/.github/workflows/test_benchmark_collection_models.yml +++ b/.github/workflows/test_benchmark_collection_models.yml @@ -59,9 +59,7 @@ jobs: # retrieve test models - name: Download and test benchmark collection run: | - git clone --depth 1 https://github.com/benchmarking-initiative/Benchmark-Models-PEtab.git \ - && export BENCHMARK_COLLECTION="$(pwd)/Benchmark-Models-PEtab/Benchmark-Models/" \ - && pip3 install -e $BENCHMARK_COLLECTION/../src/python \ + pip install git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python \ && AMICI_PARALLEL_COMPILE="" tests/benchmark-models/test_benchmark_collection.sh # run gradient checks diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index 75e7810a49..5537aef2c8 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -49,48 +49,39 @@ def __init__(self): @staticmethod @abstractmethod - def xdot(t, x, args): - ... + def xdot(t, x, args): ... @staticmethod @abstractmethod - def _w(t, x, p, k, tcl): - ... + def _w(t, x, p, k, tcl): ... @staticmethod @abstractmethod - def x0(p, k): - ... + def x0(p, k): ... @staticmethod @abstractmethod - def x_solver(x): - ... + def x_solver(x): ... @staticmethod @abstractmethod - def x_rdata(x, tcl): - ... + def x_rdata(x, tcl): ... @staticmethod @abstractmethod - def tcl(x, p, k): - ... + def tcl(x, p, k): ... @staticmethod @abstractmethod - def y(t, x, p, k, tcl): - ... + def y(t, x, p, k, tcl): ... @staticmethod @abstractmethod - def sigmay(y, p, k): - ... + def sigmay(y, p, k): ... @staticmethod @abstractmethod - def Jy(y, my, sigmay): - ... + def Jy(y, my, sigmay): ... def unscale_p(self, p, pscale): return jax.vmap( @@ -136,6 +127,7 @@ def _solve(self, ts, p, k, x0, checkpointed): saveat=diffrax.SaveAt(ts=ts), throw=False, ) + return sol.ys, tcl, sol.stats def _obs(self, ts, x, p, k, tcl): @@ -162,13 +154,22 @@ def _run( my: jnp.ndarray, pscale: np.ndarray, checkpointed=True, + dynamic=True, ): ps = self.unscale_p(p, pscale) if k_preeq.shape[0] > 0: x0 = self._preeq(ps, k_preeq) else: x0 = self.x0(ps, k) - x, tcl, stats = self._solve(ts, ps, k, x0, checkpointed=checkpointed) + + if dynamic: + x, tcl, stats = self._solve( + ts, ps, k, x0, checkpointed=checkpointed + ) + else: + x = tuple(jnp.array([x0_i] * len(ts)) for x0_i in x0) + tcl = self.tcl(x0, ps, k) + stats = None obs = self._obs(ts, x, ps, k, tcl) my_r = my.reshape((len(ts), -1)) sigmay = self._sigmay(obs, ps, k) @@ -176,7 +177,7 @@ def _run( x_rdata = self._x_rdata(x, tcl) return llh, (x_rdata, obs, stats) - @eqx.filter_jit + # @eqx.filter_jit def run( self, ts: np.ndarray, @@ -185,8 +186,9 @@ def run( k_preeq: np.ndarray, my: np.ndarray, pscale: np.ndarray, + dynamic=True, ): - return self._run(ts, p, k, k_preeq, my, pscale) + return self._run(ts, p, k, k_preeq, my, pscale, dynamic=dynamic) @eqx.filter_jit def srun( @@ -197,6 +199,7 @@ def srun( k_preeq: np.ndarray, my: np.ndarray, pscale: np.ndarray, + dynamic=True, ): (llh, (x, obs, stats)), sllh = ( jax.value_and_grad(self._run, 1, True) @@ -212,6 +215,7 @@ def s2run( k_preeq: np.ndarray, my: np.ndarray, pscale: np.ndarray, + dynamic=True, ): (llh, (x, obs, stats)), sllh = ( jax.value_and_grad(self._run, 1, True) @@ -232,6 +236,7 @@ def run_simulation( k_preeq = np.asarray(edata.fixedParametersPreequilibration) my = np.asarray(edata.getObservedData()) pscale = np.asarray(edata.pscale) + dynamic = np.max(ts) > 0 rdata_kwargs = dict() @@ -239,20 +244,20 @@ def run_simulation( ( rdata_kwargs["llh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), - ) = self.run(ts, p, k, k_preeq, my, pscale) + ) = self.run(ts, p, k, k_preeq, my, pscale, dynamic) elif sensitivity_order == amici.SensitivityOrder.first: ( rdata_kwargs["llh"], rdata_kwargs["sllh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), - ) = self.srun(ts, p, k, k_preeq, my, pscale) + ) = self.srun(ts, p, k, k_preeq, my, pscale, dynamic) elif sensitivity_order == amici.SensitivityOrder.second: ( rdata_kwargs["llh"], rdata_kwargs["sllh"], rdata_kwargs["s2llh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), - ) = self.s2run(ts, p, k, k_preeq, my, pscale) + ) = self.s2run(ts, p, k, k_preeq, my, pscale, dynamic) for field in rdata_kwargs.keys(): if field == "llh": diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 52b08cfd47..42a4d85dc4 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -37,8 +37,9 @@ def import_petab_problem( model_name: str = None, compile_: bool = None, non_estimated_parameters_as_constants=True, + jax=False, **kwargs, -) -> "amici.Model": +) -> "amici.Model | amici.JAXModel": """ Create an AMICI model for a PEtab problem. @@ -64,6 +65,9 @@ def import_petab_problem( model size and simulation times. If sensitivities with respect to those parameters are required, this should be set to ``False``. + :param jax: + Whether to load the jax version of the model. + :param kwargs: Additional keyword arguments to be passed to :meth:`amici.sbml_import.SbmlImporter.sbml2amici` or @@ -154,6 +158,16 @@ def import_petab_problem( # import model model_module = amici.import_model_module(model_name, model_output_dir) + + if jax: + model = model_module.get_jax_model() + + logger.info( + f"Successfully loaded jax model {model_name} " + f"from {model_output_dir}." + ) + return model + model = model_module.getModel() check_model(amici_model=model, petab_problem=petab_problem) diff --git a/tests/benchmark-models/test_benchmark_collection.sh b/tests/benchmark-models/test_benchmark_collection.sh index 581b8db028..4efd1c55bb 100755 --- a/tests/benchmark-models/test_benchmark_collection.sh +++ b/tests/benchmark-models/test_benchmark_collection.sh @@ -86,17 +86,9 @@ script_path=$(dirname "$BASH_SOURCE") script_path=$(cd "$script_path" && pwd) for model in $models; do - yaml="${model_dir}"/"${model}"/"${model}".yaml - - # different naming scheme - if [[ "$model" == "Bertozzi_PNAS2020" ]]; then - yaml="${model_dir}"/"${model}"/problem.yaml - fi - - amici_model_dir=test_bmc/"${model}" + amici_model_dir=test_bmc mkdir -p "$amici_model_dir" - cmd_import="amici_import_petab ${yaml} -o ${amici_model_dir} -n ${model} --flatten" - cmd_run="$script_path/test_petab_model.py -y ${yaml} -d ${amici_model_dir} -m ${model} -c" + cmd_run="$script_path/test_petab_model.py -d ${amici_model_dir} -m ${model} -c" printf '=%.0s' {1..40} printf " %s " "${model}" diff --git a/tests/benchmark-models/test_petab_model.py b/tests/benchmark-models/test_petab_model.py index 89a482cd7a..d38c1b5f9e 100755 --- a/tests/benchmark-models/test_petab_model.py +++ b/tests/benchmark-models/test_petab_model.py @@ -6,7 +6,6 @@ import argparse import contextlib -import importlib import logging import os import sys @@ -29,6 +28,7 @@ ) from timeit import default_timer as timer from petab.v1.visualize import plot_problem +import benchmark_models_petab logger = get_logger(f"amici.{__name__}", logging.WARNING) @@ -67,15 +67,6 @@ def parse_cli_args(): help="Plot measurement and simulation results", ) - # PEtab problem - parser.add_argument( - "-y", - "--yaml", - dest="yaml_file_name", - required=True, - help="PEtab YAML problem filename", - ) - # Corresponding AMICI model parser.add_argument( "-m", @@ -88,7 +79,7 @@ def parse_cli_args(): "-d", "--model-dir", dest="model_directory", - help="Directory containing the AMICI module of the " + help="Parent directory containing the AMICI module of the " "model to simulate. Required if model is not " "in python path.", ) @@ -113,19 +104,20 @@ def main(): logger.info( f"Simulating '{args.model_name}' " - f"({args.model_directory}) using PEtab data from " - f"{args.yaml_file_name}" + f"({args.model_directory}) with AMICI" ) # load PEtab files - problem = petab.Problem.from_yaml(args.yaml_file_name) + problem = benchmark_models_petab.get_problem(args.model_name) petab.flatten_timepoint_specific_output_overrides(problem) # load model - if args.model_directory: - sys.path.insert(0, args.model_directory) - model_module = importlib.import_module(args.model_name) - amici_model = model_module.getModel() + from amici.petab.petab_import import import_petab_problem + + amici_model = import_petab_problem( + problem, + model_output_dir=Path(args.model_directory) / args.model_name, + ) amici_solver = amici_model.getSolver() amici_solver.setAbsoluteTolerance(1e-8) @@ -145,7 +137,11 @@ def main(): rdatas = res[RDATAS] llh = res[LLH] - jax_model = model_module.get_jax_model() + jax_model = import_petab_problem( + problem, + model_output_dir=Path(args.model_directory) / args.model_name, + jax=True, + ) simulation_conditions = ( problem.get_simulation_conditions_from_measurement_df() )