From b8632f167a15256bd3eb16b7c1cd22015f405ced Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Tue, 3 Dec 2024 12:24:19 +0000 Subject: [PATCH] fixup merge --- python/sdist/amici/jax/jax.template.py | 3 +-- python/sdist/amici/jax/ode_export.py | 22 ++++++++++++++++++++++ python/sdist/amici/sbml_import.py | 2 ++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index 4eca618143..f9de581b1e 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -20,7 +20,7 @@ def __init__(self): super().__init__() def _xdot(self, t, x, args): - pk, tcl = args + p, tcl = args TPL_X_SYMS = x TPL_P_SYMS = p @@ -31,7 +31,6 @@ def _xdot(self, t, x, args): return TPL_XDOT_RET - def _w(self, t, x, p, tcl): TPL_X_SYMS = x TPL_P_SYMS = p diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 7ea4a29d8a..385bc65e07 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -24,6 +24,7 @@ from amici._codegen.template import apply_template from amici.jax.jaxcodeprinter import AmiciJaxCodePrinter from amici.jax.model import JAXModel +from amici.jax.nn import generate_equinox from amici.de_model import DEModel from amici.de_export import is_valid_identifier from amici.import_utils import ( @@ -129,6 +130,7 @@ def __init__( outdir: Path | str | None = None, verbose: bool | int | None = False, model_name: str | None = "model", + hybridisation: dict[str, str] = {}, ): """ Generate AMICI jax files for the ODE provided to the constructor. @@ -157,6 +159,8 @@ def __init__( self.model: DEModel = ode_model + self.hybridisation = hybridisation + self._code_printer = AmiciJaxCodePrinter() @log_execution_time("generating jax code", logger) @@ -169,6 +173,7 @@ def generate_model_code(self) -> None: ): self._prepare_model_folder() self._generate_jax_code() + self._generate_nn_code() def _prepare_model_folder(self) -> None: """ @@ -233,6 +238,14 @@ def _generate_jax_code(self) -> None: # can flag conflicts in the future "MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'", }, + "NET_IMPORTS": "\n".join( + f"{net} = _module_from_path('{net}', Path(__file__).parent / '{net}.py')" + for net in self.hybridisation.keys() + ), + "NETS": ",\n".join( + f'"{net}": {net}.net(jr.PRNGKey(0))' + for net in self.hybridisation.keys() + ), } outdir = self.model_path / (self.model_name + "_jax") outdir.mkdir(parents=True, exist_ok=True) @@ -243,6 +256,15 @@ def _generate_jax_code(self) -> None: tpl_data, ) + def _generate_nn_code(self) -> None: + for net_name, net in self.hybridisation.items(): + generate_equinox( + net["model"], + os.path.join( + self.model_path, self.model_name + "_jax", f"{net_name}.py" + ), + ) + def set_paths(self, output_dir: str | Path | None = None) -> None: """ Set output paths for the model and create if necessary diff --git a/python/sdist/amici/sbml_import.py b/python/sdist/amici/sbml_import.py index cb5c80ea88..9e66a5d924 100644 --- a/python/sdist/amici/sbml_import.py +++ b/python/sdist/amici/sbml_import.py @@ -460,6 +460,7 @@ def sbml2jax( simplify: Callable | None = _default_simplify, cache_simplify: bool = False, log_as_log10: bool = True, + hybridisation: dict = None, ) -> None: """ Generate and compile AMICI jax files for the model provided to the @@ -549,6 +550,7 @@ def sbml2jax( model_name=model_name, outdir=output_dir, verbose=verbose, + hybridisation=hybridisation, ) exporter.generate_model_code()