diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a1a1dfadc0..f16458b29a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,22 @@ repos: args: [--allow-multiple-documents] - id: end-of-file-fixer - id: trailing-whitespace +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.6.7 + hooks: + # Run the linter. + - id: ruff + args: + - --fix + - --config + - python/sdist/pyproject.toml + # Run the formatter. + - id: ruff-format + args: + - --config + - python/sdist/pyproject.toml - repo: https://github.com/asottile/pyupgrade rev: v3.17.0 diff --git a/python/sdist/amici/jax/jax.template.py b/python/sdist/amici/jax/jax.template.py index de10a67ff8..b76c86b021 100644 --- a/python/sdist/amici/jax/jax.template.py +++ b/python/sdist/amici/jax/jax.template.py @@ -9,6 +9,7 @@ TPL_NET_IMPORTS + class JAXModel_TPL_MODEL_NAME(JAXModel): api_version = TPL_MODEL_API_VERSION diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index b2e02b4aae..75e346bfe6 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -126,45 +126,6 @@ def load(cls, directory: Path): with open(directory / "parameters.pkl", "rb") as f: return eqx.tree_deserialise_leaves(f, problem) - def save(self, directory: Path): - """ - Save the problem to a directory. - - :param directory: - Directory to save the problem to. - """ - self._petab_problem.to_files( - prefix_path=directory, - model_file="model", - condition_file="conditions.tsv", - measurement_file="measurements.tsv", - parameter_file="parameters.tsv", - observable_file="observables.tsv", - yaml_file="problem.yaml", - ) - shutil.copy(self.model.jax_py_file, directory / "jax_py_file.py") - with open(directory / "parameters.pkl", "wb") as f: - eqx.tree_serialise_leaves(f, self) - - @classmethod - def load(cls, directory: Path): - """ - Load a problem from a directory. - - :param directory: - Directory to load the problem from. - - :return: - Loaded problem instance. - """ - petab_problem = petab.Problem.from_yaml( - directory / "problem.yaml", - ) - model = _module_from_path("jax", directory / "jax_py_file.py").Model() - problem = cls(model, petab_problem) - with open(directory / "parameters.pkl", "rb") as f: - return eqx.tree_deserialise_leaves(f, problem) - def _get_parameter_mappings( self, simulation_conditions: pd.DataFrame ) -> dict[str, ParameterMappingForCondition]: