Skip to content

Commit

Permalink
refactor to pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 3, 2024
1 parent ea1c75e commit b9add9d
Showing 1 changed file with 62 additions and 83 deletions.
145 changes: 62 additions & 83 deletions tests/sciml/testsuite.py → tests/sciml/test_sciml.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from yaml import safe_load
import pytest

from pathlib import Path
from petab.v2 import Problem
import petab.v1 as petab
from amici.petab import import_petab_problem
from amici.jax import JAXProblem, generate_equinox, run_simulations
Expand Down Expand Up @@ -37,40 +37,43 @@ def change_directory(destination):

# pip install git+https://github.com/sebapersson/petab_sciml@add_standard#egg=petab_sciml\&subdirectory=src/python

cases_dir = Path(__file__).parent / "testsuite" / "test_cases"

def _test_net(test):
print(f"Running net test: {test.stem}")
with open(test / "solutions.yaml") as f:
solutions = safe_load(f)

if test.stem in (
"net_042",
"net_043",
"net_044",
"net_045", # BatchNorm
"net_009",
"net_018", # MaxPool with dilation
"net_020", # AlphaDropout
"net_019",
"net_021",
"net_022",
"net_024", # inplace Dropout
"net_002", # Bilinear
):
return
@pytest.mark.parametrize(
"test", [d.stem for d in cases_dir.glob("net_[0-9]*")]
)
def test_net(test):
test_dir = cases_dir / test
with open(test_dir / "solutions.yaml") as f:
solutions = safe_load(f)

if test.stem.endswith("_alt"):
net_file = (
test.parent / test.stem.replace("_alt", "") / solutions["net_file"]
)
if test.endswith("_alt"):
net_file = cases_dir / test.replace("_alt", "") / solutions["net_file"]
else:
net_file = test / solutions["net_file"]
net_file = test_dir / solutions["net_file"]
ml_models = PetabScimlStandard.load_data(net_file)

nets = {}
outdir = Path(__file__).parent / "models" / test.stem
outdir = Path(__file__).parent / "models" / test
for ml_model in ml_models.models:
module_dir = outdir / f"{ml_model.mlmodel_id}.py"
if test in (
"net_022",
"net_002",
"net_045",
"net_042",
"net_018",
"net_020",
"net_043",
"net_044",
"net_021",
"net_019",
"net_002",
):
with pytest.raises(NotImplementedError):
generate_equinox(ml_model, module_dir)
return
generate_equinox(ml_model, module_dir)
nets[ml_model.mlmodel_id] = amici._module_from_path(
ml_model.mlmodel_id, module_dir
Expand All @@ -81,7 +84,7 @@ def _test_net(test):
solutions.get("net_ps", solutions["net_input"]),
solutions["net_output"],
):
input_flat = pd.read_csv(test / input_file, sep="\t").sort_values(
input_flat = pd.read_csv(test_dir / input_file, sep="\t").sort_values(
by="ix"
)
input_shape = tuple(
Expand All @@ -94,9 +97,9 @@ def _test_net(test):
)
input = jnp.array(input_flat["value"].values).reshape(input_shape)

output_flat = pd.read_csv(test / output_file, sep="\t").sort_values(
by="ix"
)
output_flat = pd.read_csv(
test_dir / output_file, sep="\t"
).sort_values(by="ix")
output_shape = tuple(
np.stack(
output_flat["ix"].astype(str).str.split(";").apply(np.array)
Expand All @@ -109,7 +112,7 @@ def _test_net(test):

if "net_ps" in solutions:
par = (
pd.read_csv(test / par_file, sep="\t")
pd.read_csv(test_dir / par_file, sep="\t")
.set_index("parameterId")
.sort_index()
)
Expand Down Expand Up @@ -148,22 +151,6 @@ def _test_net(test):
).reshape(net.layers[layer].bias.shape),
)
net = eqx.nn.inference_mode(net)
net.forward(input)
if test.stem in (
"net_046",
"net_047",
"net_048",
"net_050", # Conv layers
"net_021",
"net_022", # Conv layers
"net_004",
"net_004_alt",
"net_005",
"net_006",
"net_007",
"net_008", # Conv layers
):
return

np.testing.assert_allclose(
net.forward(input),
Expand All @@ -173,14 +160,15 @@ def _test_net(test):
)


def _test_ude(test):
print(f"Running ude test: {test.stem}")
with open(test / "petab" / "problem_ude.yaml") as f:
@pytest.mark.parametrize("test", [d.stem for d in cases_dir.glob("[0-9]*")])
def test_ude(test):
test_dir = cases_dir / test
with open(test_dir / "petab" / "problem_ude.yaml") as f:
petab_yaml = safe_load(f)
with open(test / "solutions.yaml") as f:
with open(test_dir / "solutions.yaml") as f:
solutions = safe_load(f)

with change_directory(test / "petab"):
with change_directory(test_dir / "petab"):
petab_yaml["format_version"] = "2.0.0"
for problem in petab_yaml["problems"]:
problem["model_files"] = {
Expand Down Expand Up @@ -229,16 +217,35 @@ def _test_ude(test):
)
df.to_csv(petab_yaml["parameter_file"][1], sep="\t", index=False)

from petab.v2 import Problem

petab_problem = Problem.from_yaml(petab_yaml)
jax_model = import_petab_problem(
petab_problem,
model_output_dir=Path(__file__).parent / "models" / test.stem,
model_output_dir=Path(__file__).parent / "models" / test,
compile_=True,
jax=True,
)
jax_problem = JAXProblem(jax_model, petab_problem)

# llh

if test in (
"012",
"013",
"014",
"001",
"011",
"016",
"010",
"010",
"003",
"004",
"005",
):
with pytest.raises(NotImplementedError):
run_simulations(jax_problem)
return
llh, r = run_simulations(jax_problem)
np.testing.assert_allclose(
llh,
Expand All @@ -248,7 +255,7 @@ def _test_ude(test):
)
simulations = pd.concat(
[
pd.read_csv(test / simulation, sep="\t")
pd.read_csv(test_dir / simulation, sep="\t")
for simulation in solutions["simulation_files"]
]
)
Expand Down Expand Up @@ -292,7 +299,7 @@ def _test_ude(test):
expected = (
pd.concat(
[
pd.read_csv(test / simulation, sep="\t")
pd.read_csv(test_dir / simulation, sep="\t")
for simulation in solutions["grad_llh_files"]
]
)
Expand All @@ -314,37 +321,9 @@ def _test_ude(test):
sllh.model.nns[net].layers[layer], attribute
)[*index].item()
actual = pd.Series(actual_dict).sort_index()
if test.stem in ("015",):
return
np.testing.assert_allclose(
actual.values,
expected["value"].values,
atol=solutions["tol_grad_llh"],
rtol=solutions["tol_grad_llh"],
)


if __name__ == "__main__":
print("Running from testsuite.py")
test_case_dir = Path(__file__).parent / "testsuite" / "test_cases"
test_cases = list(test_case_dir.glob("*"))
for test in test_cases:
if test.stem.startswith("net_"):
_test_net(test)
elif test.stem.startswith("0"):
if test.stem in (
"002", # nn in ode, rhs assignment
"004", # nn input in condition table
"015", # passing, wrong gradient
"016", # files in condition table
"001",
"005",
"010",
"011",
"012",
"013",
"014", # nn in ode
"008", # nn in initial condition
):
continue
_test_ude(test)

0 comments on commit b9add9d

Please sign in to comment.