From 560023509f47b229d9aaf01ec5cf5804570922c5 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Fri, 19 Apr 2024 09:51:46 +0200 Subject: [PATCH] SwigPtrView: look up unhandled attributes in _swigptr So far, only the explicitly listed attributes are accessible directly via SwigPtrView. Things like ReturnData.ny are only available through ReturnDataView.ptr.ny, which is inconvenient. Let's just look all non-private attributes that are not already explicitly handled by SwigPtrView on _swigptr and return them as is. --- python/sdist/amici/numpy.py | 17 ++++++++++------- python/sdist/amici/plotting.py | 2 +- python/tests/test_swig_interface.py | 3 +++ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index c1aef949c6..f40d0f4c6e 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -55,15 +55,18 @@ def __getitem__(self, item: str) -> Union[np.ndarray, float]: if item in self._cache: return self._cache[item] - if item == "id": - return getattr(self._swigptr, item) + if item in self._field_names: + value = _field_as_numpy( + self._field_dimensions, item, self._swigptr + ) + self._cache[item] = value - if item not in self._field_names: - self.__missing__(item) + return value + + if not item.startswith("_") and hasattr(self._swigptr, item): + return getattr(self._swigptr, item) - value = _field_as_numpy(self._field_dimensions, item, self._swigptr) - self._cache[item] = value - return value + self.__missing__(item) def __missing__(self, key: str) -> None: """ diff --git a/python/sdist/amici/plotting.py b/python/sdist/amici/plotting.py index d27f2994ce..19dbe05f89 100644 --- a/python/sdist/amici/plotting.py +++ b/python/sdist/amici/plotting.py @@ -109,7 +109,7 @@ def plot_observable_trajectories( if not ax: fig, ax = plt.subplots() if not observable_indices: - observable_indices = range(rdata["y"].shape[1]) + observable_indices = range(rdata.ny) if marker is None: # Show marker if only one time point is available, diff --git a/python/tests/test_swig_interface.py b/python/tests/test_swig_interface.py index f214519f26..b5063ca3cc 100644 --- a/python/tests/test_swig_interface.py +++ b/python/tests/test_swig_interface.py @@ -511,6 +511,9 @@ def test_rdataview(sbml_example_presimulation_module): rdata = amici.runAmiciSimulation(model, model.getSolver()) assert isinstance(rdata, amici.ReturnDataView) + # check that non-array attributes are looked up in the wrapped object + assert rdata.ptr.ny == rdata.ny + # fields are accessible via dot notation and [] operator, # __contains__ and __getattr__ are implemented correctly with pytest.raises(AttributeError):