diff --git a/python/sdist/amici/jax/nn.py b/python/sdist/amici/jax/nn.py index 1238625f10..343a749ea6 100644 --- a/python/sdist/amici/jax/nn.py +++ b/python/sdist/amici/jax/nn.py @@ -89,14 +89,13 @@ def _process_argval(v): def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str: layer_map = { - "InstanceNorm1d": "eqx.nn.LayerNorm", - "InstanceNorm2d": "eqx.nn.LayerNorm", - "InstanceNorm3d": "eqx.nn.LayerNorm", "Dropout1d": "eqx.nn.Dropout", "Dropout2d": "eqx.nn.Dropout", "Flatten": "amici.jax.nn.Flatten", } - if layer.layer_type.startswith(("BatchNorm", "AlphaDropout")): + if layer.layer_type.startswith( + ("BatchNorm", "AlphaDropout", "InstanceNorm") + ): raise NotImplementedError( f"{layer.layer_type} layers currently not supported" ) @@ -117,30 +116,12 @@ def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str: "Conv2d": { "bias": "use_bias", }, - "InstanceNorm1d": { - "affine": "elementwise_affine", - "num_features": "shape", - }, - "InstanceNorm2d": { - "affine": "elementwise_affine", - "num_features": "shape", - }, - "InstanceNorm3d": { - "affine": "elementwise_affine", - "num_features": "shape", - }, "LayerNorm": { "affine": "elementwise_affine", "normalized_shape": "shape", }, } kwarg_ignore = { - "InstanceNorm1d": ("track_running_stats", "momentum"), - "InstanceNorm2d": ("track_running_stats", "momentum"), - "InstanceNorm3d": ("track_running_stats", "momentum"), - "BatchNorm1d": ("track_running_stats", "momentum"), - "BatchNorm2d": ("track_running_stats", "momentum"), - "BatchNorm3d": ("track_running_stats", "momentum"), "Dropout1d": ("inplace",), "Dropout2d": ("inplace",), } @@ -162,13 +143,6 @@ def _generate_layer(layer: Layer, indent: int, ilayer: int) -> str: kwargs += [f"key=keys[{ilayer}]"] type_str = layer_map.get(layer.layer_type, f"eqx.nn.{layer.layer_type}") layer_str = f"{type_str}({', '.join(kwargs)})" - if layer.layer_type.startswith(("InstanceNorm",)): - if layer.layer_type.endswith(("1d", "2d", "3d")): - layer_str = f"jax.vmap({layer_str}, in_axes=1, out_axes=1)" - if layer.layer_type.endswith(("2d", "3d")): - layer_str = f"jax.vmap({layer_str}, in_axes=2, out_axes=2)" - if layer.layer_type.endswith("3d"): - layer_str = f"jax.vmap({layer_str}, in_axes=3, out_axes=3)" return f"{' ' * indent}'{layer.layer_id}': {layer_str}" @@ -179,10 +153,8 @@ def _generate_forward(node: Node, indent, layer_type=str) -> str: if node.op == "call_module": fun_str = f"self.layers['{node.target}']" - if layer_type.startswith( - ("InstanceNorm", "Conv", "Linear", "LayerNorm") - ): - if layer_type in ("LayerNorm", "InstanceNorm"): + if layer_type.startswith(("Conv", "Linear", "LayerNorm")): + if layer_type in ("LayerNorm",): dims = f"len({fun_str}.shape)+1" if layer_type == "Linear": dims = 2 diff --git a/tests/sciml/test_sciml.py b/tests/sciml/test_sciml.py index 75205b0093..4986899fd1 100644 --- a/tests/sciml/test_sciml.py +++ b/tests/sciml/test_sciml.py @@ -40,8 +40,26 @@ def change_directory(destination): cases_dir = Path(__file__).parent / "testsuite" / "test_cases" +def _reshape_flat_array(array_flat): + array_flat["ix"] = array_flat["ix"].astype(str) + ix_cols = [ + f"ix_{i}" for i in range(len(array_flat["ix"].values[0].split(";"))) + ] + if len(ix_cols) == 1: + array_flat[ix_cols[0]] = array_flat["ix"].apply(int) + else: + array_flat[ix_cols] = pd.DataFrame( + array_flat["ix"].str.split(";").apply(np.array).to_list(), + index=array_flat.index, + ).astype(int) + array_flat.sort_values(by=ix_cols, inplace=True) + array_shape = tuple(array_flat[ix_cols].max().astype(int) + 1) + array = np.array(array_flat["value"].values).reshape(array_shape) + return array + + @pytest.mark.parametrize( - "test", [d.stem for d in cases_dir.glob("net_[0-9]*")] + "test", sorted([d.stem for d in cases_dir.glob("net_[0-9]*")]) ) def test_net(test): test_dir = cases_dir / test @@ -59,17 +77,20 @@ def test_net(test): for ml_model in ml_models.models: module_dir = outdir / f"{ml_model.mlmodel_id}.py" if test in ( - "net_022", "net_002", - "net_045", - "net_042", + "net_009", "net_018", + "net_019", "net_020", + "net_021", + "net_022", + "net_042", "net_043", "net_044", - "net_021", - "net_019", - "net_002", + "net_045", + "net_046", + "net_047", + "net_048", ): with pytest.raises(NotImplementedError): generate_equinox(ml_model, module_dir) @@ -84,38 +105,14 @@ def test_net(test): solutions.get("net_ps", solutions["net_input"]), solutions["net_output"], ): - input_flat = pd.read_csv(test_dir / input_file, sep="\t").sort_values( - by="ix" - ) - input_shape = tuple( - np.stack( - input_flat["ix"].astype(str).str.split(";").apply(np.array) - ) - .astype(int) - .max(axis=0) - + 1 - ) - input = jnp.array(input_flat["value"].values).reshape(input_shape) - - output_flat = pd.read_csv( - test_dir / output_file, sep="\t" - ).sort_values(by="ix") - output_shape = tuple( - np.stack( - output_flat["ix"].astype(str).str.split(";").apply(np.array) - ) - .astype(int) - .max(axis=0) - + 1 - ) - output = jnp.array(output_flat["value"].values).reshape(output_shape) + input_flat = pd.read_csv(test_dir / input_file, sep="\t") + input = _reshape_flat_array(input_flat) + + output_flat = pd.read_csv(test_dir / output_file, sep="\t") + output = _reshape_flat_array(output_flat) if "net_ps" in solutions: - par = ( - pd.read_csv(test_dir / par_file, sep="\t") - .set_index("parameterId") - .sort_index() - ) + par = pd.read_csv(test_dir / par_file, sep="\t") for ml_model in ml_models.models: net = nets[ml_model.mlmodel_id](jr.PRNGKey(0)) for layer in net.layers.keys(): @@ -126,14 +123,26 @@ def test_net(test): and net.layers[layer].weight is not None ): prefix = layer_prefix + "_weight" + df = par[ + par[petab.PARAMETER_ID].str.startswith(prefix) + ] + df["ix"] = ( + df[petab.PARAMETER_ID] + .str.split("_") + .str[3:] + .apply(lambda x: ";".join(x)) + ) + w = _reshape_flat_array(df) + if isinstance(net.layers[layer], eqx.nn.ConvTranspose): + # see FAQ in https://docs.kidger.site/equinox/api/nn/conv/#equinox.nn.ConvTranspose + w = np.flip( + w, axis=tuple(range(2, w.ndim)) + ).swapaxes(0, 1) + assert w.shape == net.layers[layer].weight.shape net = eqx.tree_at( lambda x: x.layers[layer].weight, net, - jnp.array( - par[par.index.str.startswith(prefix)][ - "value" - ].values - ).reshape(net.layers[layer].weight.shape), + jnp.array(w), ) if ( isinstance(net.layers[layer], eqx.Module) @@ -141,17 +150,40 @@ def test_net(test): and net.layers[layer].bias is not None ): prefix = layer_prefix + "_bias" + df = par[ + par[petab.PARAMETER_ID].str.startswith(prefix) + ] + df["ix"] = ( + df[petab.PARAMETER_ID] + .str.split("_") + .str[3:] + .apply(lambda x: ";".join(x)) + ) + b = _reshape_flat_array(df) + if isinstance( + net.layers[layer], + eqx.nn.Conv | eqx.nn.ConvTranspose, + ): + b = np.expand_dims( + b, + tuple( + range( + 1, + net.layers[layer].num_spatial_dims + 1, + ) + ), + ) + assert b.shape == net.layers[layer].bias.shape net = eqx.tree_at( lambda x: x.layers[layer].bias, net, - jnp.array( - par[par.index.str.startswith(prefix)][ - "value" - ].values - ).reshape(net.layers[layer].bias.shape), + jnp.array(b), ) net = eqx.nn.inference_mode(net) + if test == "net_004_alt": + return # skipping, no support for non-cross-correlation in equinox + np.testing.assert_allclose( net.forward(input), output, @@ -160,7 +192,9 @@ def test_net(test): ) -@pytest.mark.parametrize("test", [d.stem for d in cases_dir.glob("[0-9]*")]) +@pytest.mark.parametrize( + "test", sorted([d.stem for d in cases_dir.glob("[0-9]*")]) +) def test_ude(test): test_dir = cases_dir / test with open(test_dir / "petab" / "problem_ude.yaml") as f: