diff --git a/CMakeLists.txt b/CMakeLists.txt index 744847930d..d8a1109d90 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -133,9 +133,16 @@ elseif(AMICI_TRY_ENABLE_HDF5) endif() set(VENDORED_SUNDIALS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/ThirdParty/sundials) -set(VENDORED_SUNDIALS_BUILD_DIR ${VENDORED_SUNDIALS_DIR}/build) -set(VENDORED_SUNDIALS_INSTALL_DIR ${VENDORED_SUNDIALS_BUILD_DIR}) set(SUNDIALS_PRIVATE_INCLUDE_DIRS "${VENDORED_SUNDIALS_DIR}/src") +# Handle different sundials build/install dirs, depending on whether we are +# building the Python extension only or the full C++ interface +if(AMICI_PYTHON_BUILD_EXT_ONLY) + set(VENDORED_SUNDIALS_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}) + set(VENDORED_SUNDIALS_INSTALL_DIR ${VENDORED_SUNDIALS_BUILD_DIR}) +else() + set(VENDORED_SUNDIALS_BUILD_DIR ${VENDORED_SUNDIALS_DIR}/build) + set(VENDORED_SUNDIALS_INSTALL_DIR ${VENDORED_SUNDIALS_BUILD_DIR}) +endif() find_package( SUNDIALS REQUIRED PATHS "${VENDORED_SUNDIALS_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/cmake/sundials/") diff --git a/documentation/conf.py b/documentation/conf.py index 8b2379a299..25c6dab647 100644 --- a/documentation/conf.py +++ b/documentation/conf.py @@ -9,19 +9,33 @@ import subprocess import sys from enum import EnumType - -# need to import before setting typing.TYPE_CHECKING=True, fails otherwise -import amici import exhale.deploy -import exhale_multiproject_monkeypatch from unittest import mock -import pandas as pd import sphinx -import sympy as sp from exhale import configs as exhale_configs from sphinx.transforms.post_transforms import ReferencesResolver -exhale_multiproject_monkeypatch, pd, sp # to avoid removal of unused import +try: + import exhale_multiproject_monkeypatch # noqa: F401 +except ModuleNotFoundError: + # for unclear reasons, the import of exhale_multiproject_monkeypatch + # fails on some systems, because the the location of the editable install + # is not automatically added to sys.path ¯\_(ツ)_/¯ + from importlib.metadata import Distribution + import json + from urllib.parse import unquote_plus, urlparse + + dist = Distribution.from_name("sphinx-contrib-exhale-multiproject") + url = json.loads(dist.read_text("direct_url.json"))["url"] + package_dir = unquote_plus(urlparse(url).path) + sys.path.append(package_dir) + import exhale_multiproject_monkeypatch # noqa: F401 + +# need to import before setting typing.TYPE_CHECKING=True, fails otherwise +import amici +import pandas as pd # noqa: F401 +import sympy as sp # noqa: F401 + # BEGIN Monkeypatch exhale from exhale.deploy import _generate_doxygen as exhale_generate_doxygen diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index c1aef949c6..f40d0f4c6e 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -55,15 +55,18 @@ def __getitem__(self, item: str) -> Union[np.ndarray, float]: if item in self._cache: return self._cache[item] - if item == "id": - return getattr(self._swigptr, item) + if item in self._field_names: + value = _field_as_numpy( + self._field_dimensions, item, self._swigptr + ) + self._cache[item] = value - if item not in self._field_names: - self.__missing__(item) + return value + + if not item.startswith("_") and hasattr(self._swigptr, item): + return getattr(self._swigptr, item) - value = _field_as_numpy(self._field_dimensions, item, self._swigptr) - self._cache[item] = value - return value + self.__missing__(item) def __missing__(self, key: str) -> None: """ diff --git a/python/sdist/amici/plotting.py b/python/sdist/amici/plotting.py index d27f2994ce..19dbe05f89 100644 --- a/python/sdist/amici/plotting.py +++ b/python/sdist/amici/plotting.py @@ -109,7 +109,7 @@ def plot_observable_trajectories( if not ax: fig, ax = plt.subplots() if not observable_indices: - observable_indices = range(rdata["y"].shape[1]) + observable_indices = range(rdata.ny) if marker is None: # Show marker if only one time point is available, diff --git a/python/sdist/amici/swig.py b/python/sdist/amici/swig.py index fbc486c301..5ba8017005 100644 --- a/python/sdist/amici/swig.py +++ b/python/sdist/amici/swig.py @@ -1,11 +1,12 @@ """Functions related to SWIG or SWIG-generated code""" +from __future__ import annotations import ast import contextlib import re class TypeHintFixer(ast.NodeTransformer): - """Replaces SWIG-generated C++ typehints by corresponding Python types""" + """Replaces SWIG-generated C++ typehints by corresponding Python types.""" mapping = { "void": None, @@ -53,9 +54,13 @@ class TypeHintFixer(ast.NodeTransformer): "std::allocator< amici::ParameterScaling > > const &": ast.Constant( "ParameterScalingVector" ), + "H5::H5File": None, } def visit_FunctionDef(self, node): + # convert type/rtype from docstring to annotation, if possible. + # those may be c++ types, not valid in python, that need to be + # converted to python types below. self._annotation_from_docstring(node) # Has a return type annotation? @@ -67,14 +72,17 @@ def visit_FunctionDef(self, node): for arg in node.args.args: if not arg.annotation: continue - if isinstance(arg.annotation, ast.Name): + if not isinstance(arg.annotation, ast.Constant): # there is already proper annotation continue arg.annotation = self._new_annot(arg.annotation.value) return node - def _new_annot(self, old_annot: str): + def _new_annot(self, old_annot: str | ast.Name): + if isinstance(old_annot, ast.Name): + old_annot = old_annot.id + with contextlib.suppress(KeyError): return self.mapping[old_annot] @@ -117,6 +125,8 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): Swig sometimes generates ``:type solver: :py:class:`Solver`` instead of ``:type solver: Solver``. Those need special treatment. + + Overloaded functions are skipped. """ docstring = ast.get_docstring(node, clean=False) if not docstring or "*Overload 1:*" in docstring: @@ -127,22 +137,18 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): lines_to_remove = set() for line_no, line in enumerate(docstring): - if ( - match := re.match( - r"\s*:rtype:\s*(?::py:class:`)?(\w+)`?\s+$", line - ) - ) and not match.group(1).startswith(":"): - node.returns = ast.Constant(match.group(1)) + if type_str := self.extract_rtype(line): + # handle `:rtype:` + node.returns = ast.Constant(type_str) lines_to_remove.add(line_no) + continue - if ( - match := re.match( - r"\s*:type\s*(\w+):\W*(?::py:class:`)?(\w+)`?\s+$", line - ) - ) and not match.group(1).startswith(":"): + arg_name, type_str = self.extract_type(line) + if arg_name is not None: + # handle `:type ...:` for arg in node.args.args: - if arg.arg == match.group(1): - arg.annotation = ast.Constant(match.group(2)) + if arg.arg == arg_name: + arg.annotation = ast.Constant(type_str) lines_to_remove.add(line_no) if lines_to_remove: @@ -155,13 +161,42 @@ def _annotation_from_docstring(self, node: ast.FunctionDef): ) node.body[0].value = ast.Str(new_docstring) + @staticmethod + def extract_type(line: str) -> tuple[str, str] | tuple[None, None]: + """Extract argument name and type string from ``:type:`` docstring + line.""" + match = re.match(r"\s*:type\s+(\w+):\s+(.+?)(?:, optional)?\s*$", line) + if not match: + return None, None + + arg_name = match.group(1) + + # get rid of any :py:class`...` in the type string if necessary + if not match.group(2).startswith(":py:"): + return arg_name, match.group(2) + + match = re.match(r":py:\w+:`(.+)`", match.group(2)) + assert match + return arg_name, match.group(1) + + @staticmethod + def extract_rtype(line: str) -> str | None: + """Extract type string from ``:rtype:`` docstring line.""" + match = re.match(r"\s*:rtype:\s+(.+)\s*$", line) + if not match: + return None + + # get rid of any :py:class`...` in the type string if necessary + if not match.group(1).startswith(":py:"): + return match.group(1) + + match = re.match(r":py:\w+:`(.+)`", match.group(1)) + assert match + return match.group(1) + def fix_typehints(infilename, outfilename): """Change SWIG-generated C++ typehints to Python typehints""" - # Only available from Python3.9 - if not getattr(ast, "unparse", None): - return - # file -> AST with open(infilename) as f: source = f.read() diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index f214519f26..b5063ca3cc 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -511,6 +511,9 @@ def test_rdataview(sbml_example_presimulation_module): rdata = amici.runAmiciSimulation(model, model.getSolver()) assert isinstance(rdata, amici.ReturnDataView) + # check that non-array attributes are looked up in the wrapped object + assert rdata.ptr.ny == rdata.ny + # fields are accessible via dot notation and [] operator, # __contains__ and __getattr__ are implemented correctly with pytest.raises(AttributeError): diff --git a/src/model.cpp b/src/model.cpp index cefdf1ac97..452e484f58 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -1671,7 +1671,6 @@ int Model::checkFinite( && model_quantity != ModelQuantity::ts) { checkFinite(state_.fixedParameters, ModelQuantity::k, t); checkFinite(state_.unscaledParameters, ModelQuantity::p, t); - checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t); if (!always_check_finite_ && model_quantity != ModelQuantity::w) { // don't check twice if always_check_finite_ is true checkFinite(derived_state_.w_, ModelQuantity::w, t); @@ -1789,7 +1788,6 @@ int Model::checkFinite( // check upstream checkFinite(state_.fixedParameters, ModelQuantity::k, t); checkFinite(state_.unscaledParameters, ModelQuantity::p, t); - checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t); checkFinite(derived_state_.w_, ModelQuantity::w, t); return AMICI_RECOVERABLE_ERROR; @@ -1880,7 +1878,6 @@ int Model::checkFinite(SUNMatrix m, ModelQuantity model_quantity, realtype t) // check upstream checkFinite(state_.fixedParameters, ModelQuantity::k, t); checkFinite(state_.unscaledParameters, ModelQuantity::p, t); - checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t); checkFinite(derived_state_.w_, ModelQuantity::w, t); return AMICI_RECOVERABLE_ERROR; diff --git a/src/solver_idas.cpp b/src/solver_idas.cpp index 4f96c95f86..8093b336e5 100644 --- a/src/solver_idas.cpp +++ b/src/solver_idas.cpp @@ -17,7 +17,7 @@ namespace amici { /* - * The following static members are callback function to CVODES. + * The following static members are callback function to IDAS. * Their signatures must not be changes. */ @@ -437,7 +437,7 @@ void IDASolver::reInitPostProcess( auto status = IDASetStopTime(ida_mem, tout); if (status != IDA_SUCCESS) - throw IDAException(status, "CVodeSetStopTime"); + throw IDAException(status, "IDASetStopTime"); status = IDASolve( ami_mem, tout, t, yout->getNVector(), ypout->getNVector(), IDA_ONE_STEP @@ -853,7 +853,7 @@ void IDASolver::setNonLinearSolver() const { solver_memory_.get(), non_linear_solver_->get() ); if (status != IDA_SUCCESS) - throw CvodeException(status, "CVodeSetNonlinearSolver"); + throw IDAException(status, "IDASetNonlinearSolver"); } void IDASolver::setNonLinearSolverSens() const { @@ -883,7 +883,7 @@ void IDASolver::setNonLinearSolverSens() const { } if (status != IDA_SUCCESS) - throw CvodeException(status, "CVodeSolver::setNonLinearSolverSens"); + throw IDAException(status, "IDASolver::setNonLinearSolverSens"); } void IDASolver::setNonLinearSolverB(int which) const { @@ -891,7 +891,7 @@ void IDASolver::setNonLinearSolverB(int which) const { solver_memory_.get(), which, non_linear_solver_B_->get() ); if (status != IDA_SUCCESS) - throw CvodeException(status, "CVodeSetNonlinearSolverB"); + throw IDAException(status, "IDASetNonlinearSolverB"); } /** diff --git a/swig/misc.i b/swig/misc.i index 8015e28bfe..af166b48ed 100644 --- a/swig/misc.i +++ b/swig/misc.i @@ -4,6 +4,8 @@ %ignore amici::regexErrorToString; %ignore amici::writeSlice; %ignore ContextManager; +%ignore amici::scaleParameters; +%ignore amici::unscaleParameters; // Add necessary symbols to generated header %{