Skip to content

Commit

Permalink
Merge branch 'develop' into jax_nan_objective
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich authored Dec 2, 2024
2 parents f79a96e + 1505d90 commit 7c70e05
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 55 deletions.
1 change: 1 addition & 0 deletions include/amici/model_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ struct ModelStateDerived {
dwdx.set_ctx(sunctx_);
}
sspl_.set_ctx(sunctx_);
x_pos_tmp_.set_ctx(sunctx_);
dwdw_.set_ctx(sunctx_);
dJydy_dense_.set_ctx(sunctx_);
}
Expand Down
5 changes: 5 additions & 0 deletions python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# ruff: noqa: F401, F821, F841
import jax.numpy as jnp
from interpax import interp1d
from pathlib import Path

from amici.jax.model import JAXModel, safe_log, safe_div

Expand All @@ -9,6 +10,7 @@ class JAXModel_TPL_MODEL_NAME(JAXModel):
api_version = TPL_MODEL_API_VERSION

def __init__(self):
self.jax_py_file = Path(__file__).resolve()
super().__init__()

def _xdot(self, t, x, args):
Expand Down Expand Up @@ -100,3 +102,6 @@ def state_ids(self):
@property
def parameter_ids(self):
return TPL_PK_IDS


Model = JAXModel_TPL_MODEL_NAME
4 changes: 3 additions & 1 deletion python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ruff: noqa: F821 F722

from abc import abstractmethod
from pathlib import Path

import diffrax
import equinox as eqx
Expand All @@ -18,8 +19,9 @@ class JAXModel(eqx.Module):
classes inheriting from JAXModel.
"""

MODEL_API_VERSION = "0.0.1"
MODEL_API_VERSION = "0.0.2"
api_version: str
jax_py_file: Path

def __init__(self):
if self.api_version != self.MODEL_API_VERSION:
Expand Down
43 changes: 42 additions & 1 deletion python/sdist/amici/jax/petab.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""PEtab wrappers for JAX models.""" ""

import shutil
from numbers import Number
from collections.abc import Iterable
from pathlib import Path

import diffrax
import equinox as eqx
Expand All @@ -12,6 +13,7 @@
import pandas as pd
import petab.v1 as petab

from amici import _module_from_path
from amici.petab.parameter_mapping import (
ParameterMappingForCondition,
create_parameter_mapping,
Expand Down Expand Up @@ -84,6 +86,45 @@ def __init__(self, model: JAXModel, petab_problem: petab.Problem):
self._measurements = self._get_measurements(scs)
self.parameters = self._get_nominal_parameter_values()

def save(self, directory: Path):
"""
Save the problem to a directory.
:param directory:
Directory to save the problem to.
"""
self._petab_problem.to_files(
prefix_path=directory,
model_file="model",
condition_file="conditions.tsv",
measurement_file="measurements.tsv",
parameter_file="parameters.tsv",
observable_file="observables.tsv",
yaml_file="problem.yaml",
)
shutil.copy(self.model.jax_py_file, directory / "jax_py_file.py")
with open(directory / "parameters.pkl", "wb") as f:
eqx.tree_serialise_leaves(f, self)

@classmethod
def load(cls, directory: Path):
"""
Load a problem from a directory.
:param directory:
Directory to load the problem from.
:return:
Loaded problem instance.
"""
petab_problem = petab.Problem.from_yaml(
directory / "problem.yaml",
)
model = _module_from_path("jax", directory / "jax_py_file.py").Model()
problem = cls(model, petab_problem)
with open(directory / "parameters.pkl", "rb") as f:
return eqx.tree_deserialise_leaves(f, problem)

def _get_parameter_mappings(
self, simulation_conditions: pd.DataFrame
) -> dict[str, ParameterMappingForCondition]:
Expand Down
Loading

0 comments on commit 7c70e05

Please sign in to comment.