diff --git a/CMakeLists.txt b/CMakeLists.txt index 744847930d..d8a1109d90 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -133,9 +133,16 @@ elseif(AMICI_TRY_ENABLE_HDF5) endif() set(VENDORED_SUNDIALS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/ThirdParty/sundials) -set(VENDORED_SUNDIALS_BUILD_DIR ${VENDORED_SUNDIALS_DIR}/build) -set(VENDORED_SUNDIALS_INSTALL_DIR ${VENDORED_SUNDIALS_BUILD_DIR}) set(SUNDIALS_PRIVATE_INCLUDE_DIRS "${VENDORED_SUNDIALS_DIR}/src") +# Handle different sundials build/install dirs, depending on whether we are +# building the Python extension only or the full C++ interface +if(AMICI_PYTHON_BUILD_EXT_ONLY) + set(VENDORED_SUNDIALS_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}) + set(VENDORED_SUNDIALS_INSTALL_DIR ${VENDORED_SUNDIALS_BUILD_DIR}) +else() + set(VENDORED_SUNDIALS_BUILD_DIR ${VENDORED_SUNDIALS_DIR}/build) + set(VENDORED_SUNDIALS_INSTALL_DIR ${VENDORED_SUNDIALS_BUILD_DIR}) +endif() find_package( SUNDIALS REQUIRED PATHS "${VENDORED_SUNDIALS_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/cmake/sundials/") diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index c1aef949c6..f40d0f4c6e 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -55,15 +55,18 @@ def __getitem__(self, item: str) -> Union[np.ndarray, float]: if item in self._cache: return self._cache[item] - if item == "id": - return getattr(self._swigptr, item) + if item in self._field_names: + value = _field_as_numpy( + self._field_dimensions, item, self._swigptr + ) + self._cache[item] = value - if item not in self._field_names: - self.__missing__(item) + return value + + if not item.startswith("_") and hasattr(self._swigptr, item): + return getattr(self._swigptr, item) - value = _field_as_numpy(self._field_dimensions, item, self._swigptr) - self._cache[item] = value - return value + self.__missing__(item) def __missing__(self, key: str) -> None: """ diff --git a/python/sdist/amici/plotting.py b/python/sdist/amici/plotting.py index d27f2994ce..19dbe05f89 100644 --- a/python/sdist/amici/plotting.py +++ b/python/sdist/amici/plotting.py @@ -109,7 +109,7 @@ def plot_observable_trajectories( if not ax: fig, ax = plt.subplots() if not observable_indices: - observable_indices = range(rdata["y"].shape[1]) + observable_indices = range(rdata.ny) if marker is None: # Show marker if only one time point is available, diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index 722fb57931..2dfb46e0a7 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -517,6 +517,9 @@ def test_rdataview(sbml_example_presimulation_module): rdata = amici.runAmiciSimulation(model, model.getSolver()) assert isinstance(rdata, amici.ReturnDataView) + # check that non-array attributes are looked up in the wrapped object + assert rdata.ptr.ny == rdata.ny + # fields are accessible via dot notation and [] operator, # __contains__ and __getattr__ are implemented correctly with pytest.raises(AttributeError):