From 138955f196bc53695df403a478b28543a8861f48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=85dne=20Jacobsen?= <88316787+adnejacobsen@users.noreply.github.com> Date: Tue, 3 Oct 2023 14:52:50 +0200 Subject: [PATCH] implement __iter__ and __aiter__ in DocumentCollection (#231) * implement __iter__ and __aiter__ in DocumentCollection + improve pit handling * remove unused imports --- src/fmu/sumo/explorer/explorer.py | 4 +- .../explorer/objects/_document_collection.py | 71 +++++++++++++------ .../sumo/explorer/objects/case_collection.py | 4 ++ .../sumo/explorer/objects/cube_collection.py | 8 ++- .../explorer/objects/dictionary_collection.py | 4 ++ .../explorer/objects/polygons_collection.py | 4 ++ .../explorer/objects/surface_collection.py | 10 ++- .../sumo/explorer/objects/table_collection.py | 4 ++ src/fmu/sumo/explorer/pit.py | 7 +- 9 files changed, 87 insertions(+), 29 deletions(-) diff --git a/src/fmu/sumo/explorer/explorer.py b/src/fmu/sumo/explorer/explorer.py index a41fae67..92757591 100644 --- a/src/fmu/sumo/explorer/explorer.py +++ b/src/fmu/sumo/explorer/explorer.py @@ -114,7 +114,7 @@ def get_case_by_uuid(self, uuid: str) -> Case: Case: case object """ metadata = self._utils.get_object(uuid, _CASE_FIELDS) - return Case(self._sumo, metadata) + return Case(self._sumo, metadata, self._pit) async def get_case_by_uuid_async(self, uuid: str) -> Case: """Get case object by uuid @@ -126,7 +126,7 @@ async def get_case_by_uuid_async(self, uuid: str) -> Case: Case: case object """ metadata = await self._utils.get_object_async(uuid, _CASE_FIELDS) - return Case(self._sumo, metadata) + return Case(self._sumo, metadata, self._pit) def get_surface_by_uuid(self, uuid: str) -> Surface: """Get surface object by uuid diff --git a/src/fmu/sumo/explorer/objects/_document_collection.py b/src/fmu/sumo/explorer/objects/_document_collection.py index 2a9ddcd3..447edf72 100644 --- a/src/fmu/sumo/explorer/objects/_document_collection.py +++ b/src/fmu/sumo/explorer/objects/_document_collection.py @@ -20,16 +20,40 @@ def __init__( self._type = doc_type self._sumo = sumo self._query = self._init_query(doc_type, query) - self._pit = pit + self._pit = pit + self._new_pit_id = None self._after = None self._curr_index = 0 self._len = None self._items = [] self._field_values = {} - self._query = self._init_query(doc_type, query) self._select = select + def __iter__(self): + self._curr_index = 0 + return self + + def __next__(self): + if self._curr_index < self.__len__(): + res = self.__getitem__(self._curr_index) + self._curr_index += 1 + return res + else: + raise StopIteration + + def __aiter__(self): + self._curr_index = 0 + return self + + async def __anext__(self): + if self._curr_index < await self.length_async(): + res = await self.getitem_async(self._curr_index) + self._curr_index += 1 + return res + else: + raise StopAsyncIteration + def __len__(self) -> int: """Get size of document collection @@ -61,17 +85,14 @@ def __getitem__(self, index: int) -> Dict: Returns: A document at a given index """ - if index >= self.__len__(): + if index > self.__len__(): raise IndexError - if len(self._items) <= index: - while len(self._items) <= index: - prev_len = len(self._items) - self._next_batch() - curr_len = len(self._items) + while len(self._items) <= index: + hits_size = self._next_batch() - if prev_len == curr_len: - raise IndexError + if hits_size == 0: + raise IndexError return self._items[index] @@ -84,17 +105,14 @@ async def getitem_async(self, index: int) -> Dict: Returns: A document at a given index """ - if index >= await self.length_async(): + if index > await self.length_async(): raise IndexError - if len(self._items) <= index: - while len(self._items) <= index: - prev_len = len(self._items) - await self._next_batch_async() - curr_len = len(self._items) + while len(self._items) <= index: + hits_size = await self._next_batch_async() - if prev_len == curr_len: - raise IndexError + if hits_size == 0: + raise IndexError return self._items[index] @@ -162,11 +180,14 @@ def _next_batch(self) -> List[Dict]: query["search_after"] = self._after if self._pit is not None: - query["pit"] = self._pit.get_pit_object() + query["pit"] = self._pit.get_pit_object(self._new_pit_id) res = self._sumo.post("/search", json=query).json() hits = res["hits"] + if self._pit is not None: + self._new_pit_id = res["pit_id"] + if self._len is None: self._len = hits["total"]["value"] @@ -174,6 +195,8 @@ def _next_batch(self) -> List[Dict]: self._after = hits["hits"][-1]["sort"] self._items.extend(hits["hits"]) + return len(hits["hits"]) + async def _next_batch_async(self) -> List[Dict]: """Get next batch of documents @@ -196,10 +219,14 @@ async def _next_batch_async(self) -> List[Dict]: query["search_after"] = self._after if self._pit is not None: - query["pit"] = self._pit.get_pit_object() + query["pit"] = self._pit.get_pit_object(self._new_pit_id) res = await self._sumo.post_async("/search", json=query) - hits = res.json()["hits"] + data = res.json() + hits = data["hits"] + + if self._pit is not None: + self._new_pit_id = data["pit_id"] if self._len is None: self._len = hits["total"]["value"] @@ -208,6 +235,8 @@ async def _next_batch_async(self) -> List[Dict]: self._after = hits["hits"][-1]["sort"] self._items.extend(hits["hits"]) + return len(hits["hits"]) + def _init_query(self, doc_type: str, query: Dict = None) -> Dict: """Initialize base filter for document collection diff --git a/src/fmu/sumo/explorer/objects/case_collection.py b/src/fmu/sumo/explorer/objects/case_collection.py index 08c4238e..eb1f8e92 100644 --- a/src/fmu/sumo/explorer/objects/case_collection.py +++ b/src/fmu/sumo/explorer/objects/case_collection.py @@ -85,6 +85,10 @@ def __getitem__(self, index: int) -> Case: doc = super().__getitem__(index) return Case(self._sumo, doc, self._pit) + async def getitem_async(self, index: int) -> Case: + doc = await super().getitem_async(index) + return Case(self._sumo, doc) + def filter( self, uuid: Union[str, List[str]] = None, diff --git a/src/fmu/sumo/explorer/objects/cube_collection.py b/src/fmu/sumo/explorer/objects/cube_collection.py index e4795fd0..2f6d208a 100644 --- a/src/fmu/sumo/explorer/objects/cube_collection.py +++ b/src/fmu/sumo/explorer/objects/cube_collection.py @@ -37,6 +37,10 @@ def __getitem__(self, index) -> Cube: doc = super().__getitem__(index) return Cube(self._sumo, doc) + async def getitem_async(self, index: int) -> Cube: + doc = await super().getitem_async(index) + return Cube(self._sumo, doc) + @property def timestamps(self) -> List[str]: """List of unique timestamps in CubeCollection""" @@ -129,7 +133,7 @@ def filter( time: TimeFilter = None, uuid: Union[str, List[str], bool] = None, is_observation: bool = None, - is_prediction: bool = None + is_prediction: bool = None, ) -> "CubeCollection": """Filter cubes @@ -156,7 +160,7 @@ def filter( time=time, uuid=uuid, is_observation=is_observation, - is_prediction=is_prediction + is_prediction=is_prediction, ) return CubeCollection(self._sumo, self._case_uuid, query, self._pit) diff --git a/src/fmu/sumo/explorer/objects/dictionary_collection.py b/src/fmu/sumo/explorer/objects/dictionary_collection.py index 517abcd1..96966011 100644 --- a/src/fmu/sumo/explorer/objects/dictionary_collection.py +++ b/src/fmu/sumo/explorer/objects/dictionary_collection.py @@ -29,6 +29,10 @@ def __getitem__(self, index) -> Dictionary: doc = super().__getitem__(index) return Dictionary(self._sumo, doc) + async def getitem_async(self, index: int) -> Dictionary: + doc = await super().getitem_async(index) + return Dictionary(self._sumo, doc) + def filter( self, name: Union[str, List[str], bool] = None, diff --git a/src/fmu/sumo/explorer/objects/polygons_collection.py b/src/fmu/sumo/explorer/objects/polygons_collection.py index f021884a..cb2c583a 100644 --- a/src/fmu/sumo/explorer/objects/polygons_collection.py +++ b/src/fmu/sumo/explorer/objects/polygons_collection.py @@ -29,6 +29,10 @@ def __getitem__(self, index) -> Polygons: doc = super().__getitem__(index) return Polygons(self._sumo, doc) + async def getitem_async(self, index: int) -> Polygons: + doc = await super().getitem_async(index) + return Polygons(self._sumo, doc) + def filter( self, name: Union[str, List[str], bool] = None, diff --git a/src/fmu/sumo/explorer/objects/surface_collection.py b/src/fmu/sumo/explorer/objects/surface_collection.py index 6b9b2335..9966a4b8 100644 --- a/src/fmu/sumo/explorer/objects/surface_collection.py +++ b/src/fmu/sumo/explorer/objects/surface_collection.py @@ -41,6 +41,10 @@ def __getitem__(self, index) -> Surface: doc = super().__getitem__(index) return Surface(self._sumo, doc) + async def getitem_async(self, index: int) -> Surface: + doc = await super().getitem_async(index) + return Surface(self._sumo, doc) + @property def timestamps(self) -> List[str]: """List of unique timestamps in SurfaceCollection""" @@ -141,7 +145,9 @@ def _aggregate(self, operation: str) -> xtgeo.RegularSurface: async def _aggregate_async(self, operation: str) -> xtgeo.RegularSurface: if operation not in self._aggregation_cache: - objects = await self._utils.get_objects_async(500, self._query, ["_id"]) + objects = await self._utils.get_objects_async( + 500, self._query, ["_id"] + ) object_ids = list(map(lambda obj: obj["_id"], objects)) res = await self._sumo.post_async( @@ -291,4 +297,4 @@ def p90(self) -> xtgeo.RegularSurface: async def p90_async(self) -> xtgeo.RegularSurface: """Perform a percentile aggregation""" - return await self._aggregate_async("p90") \ No newline at end of file + return await self._aggregate_async("p90") diff --git a/src/fmu/sumo/explorer/objects/table_collection.py b/src/fmu/sumo/explorer/objects/table_collection.py index a072c518..cf79e682 100644 --- a/src/fmu/sumo/explorer/objects/table_collection.py +++ b/src/fmu/sumo/explorer/objects/table_collection.py @@ -29,6 +29,10 @@ def __getitem__(self, index) -> Table: doc = super().__getitem__(index) return Table(self._sumo, doc) + async def getitem_async(self, index: int) -> Table: + doc = await super().getitem_async(index) + return Table(self._sumo, doc) + @property def columns(self) -> List[str]: """List of unique column names""" diff --git a/src/fmu/sumo/explorer/pit.py b/src/fmu/sumo/explorer/pit.py index 9f3723d9..a4194a66 100644 --- a/src/fmu/sumo/explorer/pit.py +++ b/src/fmu/sumo/explorer/pit.py @@ -21,10 +21,13 @@ def __get_pit_id(self, keep_alive) -> str: res = self._sumo.post("/pit", params={"keep-alive": keep_alive}) return res.json()["id"] - def get_pit_object(self) -> Dict: + def get_pit_object(self, pit_id: str = None) -> Dict: """Get the pit object Returns: Dict: dict with id and info about how long to keep alive """ - return {"id": self._pit_id, "keep_alive": self._keep_alive} + return { + "id": pit_id if pit_id is not None else self._pit_id, + "keep_alive": self._keep_alive, + }