From a0fa1088adc4e42bd050c5fce466e023ceec4a7d Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Sun, 17 Dec 2023 17:52:36 +0100 Subject: [PATCH 1/4] Add SwigPtrView 's fields as properties This way, `dir(SwigPtrView(...))` will show the available fields and sufficiently smart IDEs will show them for code completion. --- python/sdist/amici/numpy.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index b84e52cc2b..97cb624f4a 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -91,6 +91,15 @@ def __init__(self, swigptr): self._cache = {} super(SwigPtrView, self).__init__() + # create properties for all fields + for field in self._field_names: + if not hasattr(self, field): + setattr( + self, + field, + property(lambda self_: self_.__getitem__(field)), + ) + def __len__(self) -> int: """ Returns the number of available keys/fields @@ -237,7 +246,7 @@ def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]): if not isinstance(rdata, (ReturnDataPtr, ReturnData)): raise TypeError( f"Unsupported pointer {type(rdata)}, must be" - f"amici.ExpDataPtr!" + f"amici.ReturnDataPtr or amici.ReturnData!" ) self._field_dimensions = { "ts": [rdata.nt], From 8d4ba4dceaa5fff8406e3f4ccae123c8ff03c340 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Sun, 17 Dec 2023 19:38:53 +0100 Subject: [PATCH 2/4] .. --- python/sdist/amici/numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index 97cb624f4a..487e5d7b98 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -93,7 +93,7 @@ def __init__(self, swigptr): # create properties for all fields for field in self._field_names: - if not hasattr(self, field): + if field not in dir(self): setattr( self, field, From 50cebd44f96cd48b6e7bb8b782ac1471eb88b3ba Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Sun, 17 Dec 2023 19:46:04 +0100 Subject: [PATCH 3/4] .. --- python/sdist/amici/numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index 487e5d7b98..8a1c687388 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -95,7 +95,7 @@ def __init__(self, swigptr): for field in self._field_names: if field not in dir(self): setattr( - self, + self.__class__, field, property(lambda self_: self_.__getitem__(field)), ) From cd18a225e7764d309315b870567e5b01ac6ccb7e Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Mon, 18 Dec 2023 10:50:53 +0100 Subject: [PATCH 4/4] dir --- python/sdist/amici/numpy.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/sdist/amici/numpy.py b/python/sdist/amici/numpy.py index 8a1c687388..93b04603be 100644 --- a/python/sdist/amici/numpy.py +++ b/python/sdist/amici/numpy.py @@ -6,6 +6,7 @@ import collections import copy +import itertools from typing import Dict, Iterator, List, Literal, Union import amici @@ -91,15 +92,6 @@ def __init__(self, swigptr): self._cache = {} super(SwigPtrView, self).__init__() - # create properties for all fields - for field in self._field_names: - if field not in dir(self): - setattr( - self.__class__, - field, - property(lambda self_: self_.__getitem__(field)), - ) - def __len__(self) -> int: """ Returns the number of available keys/fields @@ -173,6 +165,13 @@ def __eq__(self, other): return False return self._swigptr == other._swigptr + def __dir__(self): + return sorted( + set( + itertools.chain(dir(super()), self.__dict__, self._field_names) + ) + ) + class ReturnDataView(SwigPtrView): """ @@ -297,7 +296,7 @@ def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]): "numerrtestfailsB": [rdata.nt], "numnonlinsolvconvfailsB": [rdata.nt], } - super(ReturnDataView, self).__init__(rdata) + super().__init__(rdata) def __getitem__( self, item: str @@ -415,7 +414,7 @@ def __init__(self, edata: Union[ExpDataPtr, ExpData]): edata.observedDataStdDev = edata.getObservedDataStdDev() edata.observedEvents = edata.getObservedEvents() edata.observedEventsStdDev = edata.getObservedEventsStdDev() - super(ExpDataView, self).__init__(edata) + super().__init__(edata) def _field_as_numpy(