Skip to content

Commit

Permalink
some net cases and first ude testcase passing
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 2, 2024
1 parent aec0712 commit 2672be2
Show file tree
Hide file tree
Showing 15 changed files with 657 additions and 45 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "tests/sciml/testsuite"]
path = tests/sciml/testsuite
url = https://github.com/sebapersson/petab_sciml
35 changes: 28 additions & 7 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
AmiciCxxCodePrinter,
get_switch_statement,
)
from .jaxcodeprinter import AmiciJaxCodePrinter
from amici.jaxcodeprinter import AmiciJaxCodePrinter
from .de_model import DEModel
from .de_model_components import *
from .import_utils import (
Expand Down Expand Up @@ -174,6 +174,7 @@ def __init__(
allow_reinit_fixpar_initcond: bool | None = True,
generate_sensitivity_code: bool | None = True,
model_name: str | None = "model",
hybridisation: dict | None = None,
):
"""
Generate AMICI C++ files for the DE provided to the constructor.
Expand Down Expand Up @@ -238,6 +239,7 @@ def __init__(
self.allow_reinit_fixpar_initcond: bool = allow_reinit_fixpar_initcond
self._build_hints = set()
self.generate_sensitivity_code: bool = generate_sensitivity_code
self.hybridisation = hybridisation

@log_execution_time("generating cpp code", logger)
def generate_model_code(self) -> None:
Expand Down Expand Up @@ -380,15 +382,35 @@ def jnp_array_str(array) -> str:
# keep track of the API version that the model was generated with so we
# can flag conflicts in the future
"MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'",
"NET_IMPORTS": "\n".join(
f"{net} = _module_from_path('{net}', Path(__file__).parent / '{net}.py')"
for net in self.hybridisation.keys()
),
"NETS": ",\n".join(
f'"{net}": {net}.net(jr.PRNGKey(0))'
for net in self.hybridisation.keys()
),
},
}
os.makedirs(
os.path.join(self.model_path, self.model_name), exist_ok=True
os.path.join(self.model_path, self.model_name + "_jax"),
exist_ok=True,
)
from amici.jax.nn import generate_equinox

for net_name, net in self.hybridisation.items():
generate_equinox(
net["model"],
os.path.join(
self.model_path, self.model_name + "_jax", f"{net_name}.py"
),
)

apply_template(
os.path.join(amiciModulePath, "jax.template.py"),
os.path.join(self.model_path, self.model_name, "jax.py"),
os.path.join(amiciModulePath, "jax", "jax.template.py"),
os.path.join(
self.model_path, self.model_name + "_jax", "__init__.py"
),
tpl_data,
)

Expand Down Expand Up @@ -795,7 +817,7 @@ def _get_function_body(
lines = []

if len(equations) == 0 or (
isinstance(equations, (sp.Matrix, sp.ImmutableDenseMatrix))
isinstance(equations, sp.Matrix | sp.ImmutableDenseMatrix)
and min(equations.shape) == 0
):
# dJydy is a list
Expand Down Expand Up @@ -1136,8 +1158,7 @@ def _write_model_header_cpp(self) -> None:
)
),
"NDXDOTDX_EXPLICIT": len(self.model.sparsesym("dxdotdx_explicit")),
"NDJYDY": "std::vector<int>{%s}"
% ",".join(str(len(x)) for x in self.model.sparsesym("dJydy")),
"NDJYDY": f"std::vector<int>{{{','.join(str(len(x)) for x in self.model.sparsesym('dJydy'))}}}",
"NDXRDATADXSOLVER": len(self.model.sparsesym("dx_rdatadx_solver")),
"NDXRDATADTCL": len(self.model.sparsesym("dx_rdatadtcl")),
"NDTOTALCLDXRDATA": len(self.model.sparsesym("dtotal_cldx_rdata")),
Expand Down
3 changes: 2 additions & 1 deletion python/sdist/amici/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from amici.jax.petab import JAXProblem, run_simulations
from amici.jax.model import JAXModel
from amici.jax.nn import generate_equinox

__all__ = ["JAXModel", "JAXProblem", "run_simulations"]
__all__ = ["JAXModel", "JAXProblem", "run_simulations", "generate_equinox"]
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# ruff: noqa: F401, F821, F841
import jax.numpy as jnp
import jax.random as jr
from interpax import interp1d
from pathlib import Path

from amici.jax.model import JAXModel, safe_log, safe_div
from amici import _module_from_path

TPL_NET_IMPORTS


class JAXModel_TPL_MODEL_NAME(JAXModel):
api_version = TPL_MODEL_API_VERSION

def __init__(self):
self.jax_py_file = Path(__file__).resolve()
self.nns = {TPL_NETS}
super().__init__()

def _xdot(self, t, x, args):
Expand Down
1 change: 1 addition & 0 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class JAXModel(eqx.Module):
MODEL_API_VERSION = "0.0.2"
api_version: str
jax_py_file: Path
nns: dict

def __init__(self):
if self.api_version != self.MODEL_API_VERSION:
Expand Down
180 changes: 180 additions & 0 deletions python/sdist/amici/jax/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from pathlib import Path

from petab_sciml import MLModel, Layer, Node
import equinox as eqx
import jax.numpy as jnp

from amici._codegen.template import apply_template
from amici import amiciModulePath


class Flatten(eqx.Module):
start_dim: int
end_dim: int

def __init__(self, start_dim: int, end_dim: int):
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim

def __call__(self, x):
if self.end_dim == -1:
return jnp.reshape(x, x.shape[: self.start_dim] + (-1,))
else:
return jnp.reshape(
x, x.shape[: self.start_dim] + (-1,) + x.shape[self.end_dim :]
)


def generate_equinox(ml_model: MLModel, filename: Path | str):
filename = Path(filename)
layer_indent = 12
node_indent = 8

layers = {layer.layer_id: layer for layer in ml_model.layers}

tpl_data = {
"MODEL_ID": ml_model.mlmodel_id,
"LAYERS": ",\n".join(
[
_generate_layer(layer, layer_indent, ilayer)
for ilayer, layer in enumerate(ml_model.layers)
]
)[layer_indent:],
"FORWARD": "\n".join(
[
_generate_forward(
node,
node_indent,
layers.get(
node.target,
Layer(layer_id="dummy", layer_type="Linear"),
).layer_type,
)
for node in ml_model.forward
]
)[node_indent:],
"INPUT": ", ".join([f"'{inp.input_id}'" for inp in ml_model.inputs]),
"N_LAYERS": len(ml_model.layers),
}

filename.parent.mkdir(parents=True, exist_ok=True)

apply_template(
Path(amiciModulePath) / "jax" / "nn.template.py",
filename,
tpl_data,
)


def _process_argval(v):
if isinstance(v, str):
return f"'{v}'"
if isinstance(v, bool):
return str(v)
return str(v)


def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str:
layer_map = {
"InstanceNorm1d": "eqx.nn.LayerNorm",
"InstanceNorm2d": "eqx.nn.LayerNorm",
"InstanceNorm3d": "eqx.nn.LayerNorm",
"Dropout1d": "eqx.nn.Dropout",
"Dropout2d": "eqx.nn.Dropout",
"Flatten": "Flatten",
}
kwarg_map = {
"Linear": {
"bias": "use_bias",
},
"Conv1d": {
"bias": "use_bias",
},
"Conv2d": {
"bias": "use_bias",
},
"InstanceNorm1d": {
"affine": "elementwise_affine",
"num_features": "shape",
},
"InstanceNorm2d": {
"affine": "elementwise_affine",
"num_features": "shape",
},
"InstanceNorm3d": {
"affine": "elementwise_affine",
"num_features": "shape",
},
}
kwarg_ignore = {
"InstanceNorm1d": ("track_running_stats", "momentum"),
"InstanceNorm2d": ("track_running_stats", "momentum"),
"InstanceNorm3d": ("track_running_stats", "momentum"),
"Dropout1d": ("inplace",),
"Dropout2d": ("inplace",),
}
kwargs = [
f"{kwarg_map.get(layer.layer_type, {}).get(k, k)}={_process_argval(v)}"
for k, v in layer.args.items()
if k not in kwarg_ignore.get(layer.layer_type, ())
]
# add key for initialization
if layer.layer_type in ("Linear", "Conv1d", "Conv2d", "Conv3d"):
kwargs += [f"key=keys[{ilayer}]"]
type_str = layer_map.get(layer.layer_type, f"eqx.nn.{layer.layer_type}")
layer_str = f"{type_str}({', '.join(kwargs)})"
if layer.layer_type.startswith(("InstanceNorm",)):
if layer.layer_type.endswith(("1d", "2d", "3d")):
layer_str = f"jax.vmap({layer_str}, in_axes=1, out_axes=1)"
if layer.layer_type.endswith(("2d", "3d")):
layer_str = f"jax.vmap({layer_str}, in_axes=2, out_axes=2)"
if layer.layer_type.endswith("3d"):
layer_str = f"jax.vmap({layer_str}, in_axes=3, out_axes=3)"
return f"{' ' * indent}'{layer.layer_id}': {layer_str}"


def _generate_forward(node: Node, indent, layer_type=str) -> str:
if node.op == "placeholder":
# TODO: inconsistent target vs name
return f"{' ' * indent}{node.name} = input"

if node.op == "call_module":
fun_str = f"self.layers['{node.target}']"
if layer_type.startswith(("InstanceNorm", "Conv", "Linear")):
if layer_type == "Linear":
dims = 1
if layer_type.endswith(("1d",)):
dims = 2
elif layer_type.endswith(("2d",)):
dims = 3
elif layer_type.endswith("3d"):
dims = 4
fun_str = f"(jax.vmap({fun_str}) if len({node.args[0]}.shape) == {dims + 1} else {fun_str})"

if node.op in ("call_function", "call_method"):
map_fun = {
"hardtanh": "jax.nn.hard_tanh",
}
if node.target == "hardtanh":
if node.kwargs.pop("min_val", -1.0) != -1.0:
raise NotImplementedError(
"min_val != -1.0 not supported for hardtanh"
)
if node.kwargs.pop("max_val", 1.0) != 1.0:
raise NotImplementedError(
"max_val != 1.0 not supported for hardtanh"
)
fun_str = map_fun.get(node.target, f"jax.nn.{node.target}")

args = ", ".join([f"{arg}" for arg in node.args])
kwargs = [
f"{k}={v}" for k, v in node.kwargs.items() if k not in ("inplace",)
]
if layer_type.startswith(("Dropout",)):
kwargs += ["inference=inference", "key=key"]
kwargs_str = ", ".join(kwargs)
if node.op in ("call_module", "call_function", "call_method"):
return f"{' ' * indent}{node.name} = {fun_str}({args + ', ' + kwargs_str})"
if node.op == "output":
return f"{' ' * indent}{node.target} = {args}"
24 changes: 24 additions & 0 deletions python/sdist/amici/jax/nn.template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# ruff: noqa: F401, F821, F841
import equinox as eqx
import jax.nn
import jax.random as jr
import jax
from amici.jax.nn import Flatten


class TPL_MODEL_ID(eqx.Module):
layers: dict
inputs: list[str]

def __init__(self, key):
super().__init__()
keys = jr.split(key, TPL_N_LAYERS)
self.layers = {TPL_LAYERS}
self.inputs = [TPL_INPUT]

def forward(self, input, inference=False, key=None):
TPL_FORWARD
return output


net = TPL_MODEL_ID
Loading

0 comments on commit 2672be2

Please sign in to comment.