Skip to content

Commit

Permalink
implement __iter__ and __aiter__ in DocumentCollection (#231)
Browse files Browse the repository at this point in the history
* implement __iter__ and __aiter__ in DocumentCollection + improve pit handling

* remove unused imports
  • Loading branch information
adnejacobsen authored Oct 3, 2023
1 parent 5e9f1cb commit 138955f
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/fmu/sumo/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_case_by_uuid(self, uuid: str) -> Case:
Case: case object
"""
metadata = self._utils.get_object(uuid, _CASE_FIELDS)
return Case(self._sumo, metadata)
return Case(self._sumo, metadata, self._pit)

async def get_case_by_uuid_async(self, uuid: str) -> Case:
"""Get case object by uuid
Expand All @@ -126,7 +126,7 @@ async def get_case_by_uuid_async(self, uuid: str) -> Case:
Case: case object
"""
metadata = await self._utils.get_object_async(uuid, _CASE_FIELDS)
return Case(self._sumo, metadata)
return Case(self._sumo, metadata, self._pit)

def get_surface_by_uuid(self, uuid: str) -> Surface:
"""Get surface object by uuid
Expand Down
71 changes: 50 additions & 21 deletions src/fmu/sumo/explorer/objects/_document_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,40 @@ def __init__(
self._type = doc_type
self._sumo = sumo
self._query = self._init_query(doc_type, query)
self._pit = pit

self._pit = pit
self._new_pit_id = None
self._after = None
self._curr_index = 0
self._len = None
self._items = []
self._field_values = {}
self._query = self._init_query(doc_type, query)
self._select = select

def __iter__(self):
self._curr_index = 0
return self

def __next__(self):
if self._curr_index < self.__len__():
res = self.__getitem__(self._curr_index)
self._curr_index += 1
return res
else:
raise StopIteration

def __aiter__(self):
self._curr_index = 0
return self

async def __anext__(self):
if self._curr_index < await self.length_async():
res = await self.getitem_async(self._curr_index)
self._curr_index += 1
return res
else:
raise StopAsyncIteration

def __len__(self) -> int:
"""Get size of document collection
Expand Down Expand Up @@ -61,17 +85,14 @@ def __getitem__(self, index: int) -> Dict:
Returns:
A document at a given index
"""
if index >= self.__len__():
if index > self.__len__():
raise IndexError

if len(self._items) <= index:
while len(self._items) <= index:
prev_len = len(self._items)
self._next_batch()
curr_len = len(self._items)
while len(self._items) <= index:
hits_size = self._next_batch()

if prev_len == curr_len:
raise IndexError
if hits_size == 0:
raise IndexError

return self._items[index]

Expand All @@ -84,17 +105,14 @@ async def getitem_async(self, index: int) -> Dict:
Returns:
A document at a given index
"""
if index >= await self.length_async():
if index > await self.length_async():
raise IndexError

if len(self._items) <= index:
while len(self._items) <= index:
prev_len = len(self._items)
await self._next_batch_async()
curr_len = len(self._items)
while len(self._items) <= index:
hits_size = await self._next_batch_async()

if prev_len == curr_len:
raise IndexError
if hits_size == 0:
raise IndexError

return self._items[index]

Expand Down Expand Up @@ -162,18 +180,23 @@ def _next_batch(self) -> List[Dict]:
query["search_after"] = self._after

if self._pit is not None:
query["pit"] = self._pit.get_pit_object()
query["pit"] = self._pit.get_pit_object(self._new_pit_id)

res = self._sumo.post("/search", json=query).json()
hits = res["hits"]

if self._pit is not None:
self._new_pit_id = res["pit_id"]

if self._len is None:
self._len = hits["total"]["value"]

if len(hits["hits"]) > 0:
self._after = hits["hits"][-1]["sort"]
self._items.extend(hits["hits"])

return len(hits["hits"])

async def _next_batch_async(self) -> List[Dict]:
"""Get next batch of documents
Expand All @@ -196,10 +219,14 @@ async def _next_batch_async(self) -> List[Dict]:
query["search_after"] = self._after

if self._pit is not None:
query["pit"] = self._pit.get_pit_object()
query["pit"] = self._pit.get_pit_object(self._new_pit_id)

res = await self._sumo.post_async("/search", json=query)
hits = res.json()["hits"]
data = res.json()
hits = data["hits"]

if self._pit is not None:
self._new_pit_id = data["pit_id"]

if self._len is None:
self._len = hits["total"]["value"]
Expand All @@ -208,6 +235,8 @@ async def _next_batch_async(self) -> List[Dict]:
self._after = hits["hits"][-1]["sort"]
self._items.extend(hits["hits"])

return len(hits["hits"])

def _init_query(self, doc_type: str, query: Dict = None) -> Dict:
"""Initialize base filter for document collection
Expand Down
4 changes: 4 additions & 0 deletions src/fmu/sumo/explorer/objects/case_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def __getitem__(self, index: int) -> Case:
doc = super().__getitem__(index)
return Case(self._sumo, doc, self._pit)

async def getitem_async(self, index: int) -> Case:
doc = await super().getitem_async(index)
return Case(self._sumo, doc)

def filter(
self,
uuid: Union[str, List[str]] = None,
Expand Down
8 changes: 6 additions & 2 deletions src/fmu/sumo/explorer/objects/cube_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def __getitem__(self, index) -> Cube:
doc = super().__getitem__(index)
return Cube(self._sumo, doc)

async def getitem_async(self, index: int) -> Cube:
doc = await super().getitem_async(index)
return Cube(self._sumo, doc)

@property
def timestamps(self) -> List[str]:
"""List of unique timestamps in CubeCollection"""
Expand Down Expand Up @@ -129,7 +133,7 @@ def filter(
time: TimeFilter = None,
uuid: Union[str, List[str], bool] = None,
is_observation: bool = None,
is_prediction: bool = None
is_prediction: bool = None,
) -> "CubeCollection":
"""Filter cubes
Expand All @@ -156,7 +160,7 @@ def filter(
time=time,
uuid=uuid,
is_observation=is_observation,
is_prediction=is_prediction
is_prediction=is_prediction,
)

return CubeCollection(self._sumo, self._case_uuid, query, self._pit)
4 changes: 4 additions & 0 deletions src/fmu/sumo/explorer/objects/dictionary_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __getitem__(self, index) -> Dictionary:
doc = super().__getitem__(index)
return Dictionary(self._sumo, doc)

async def getitem_async(self, index: int) -> Dictionary:
doc = await super().getitem_async(index)
return Dictionary(self._sumo, doc)

def filter(
self,
name: Union[str, List[str], bool] = None,
Expand Down
4 changes: 4 additions & 0 deletions src/fmu/sumo/explorer/objects/polygons_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __getitem__(self, index) -> Polygons:
doc = super().__getitem__(index)
return Polygons(self._sumo, doc)

async def getitem_async(self, index: int) -> Polygons:
doc = await super().getitem_async(index)
return Polygons(self._sumo, doc)

def filter(
self,
name: Union[str, List[str], bool] = None,
Expand Down
10 changes: 8 additions & 2 deletions src/fmu/sumo/explorer/objects/surface_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __getitem__(self, index) -> Surface:
doc = super().__getitem__(index)
return Surface(self._sumo, doc)

async def getitem_async(self, index: int) -> Surface:
doc = await super().getitem_async(index)
return Surface(self._sumo, doc)

@property
def timestamps(self) -> List[str]:
"""List of unique timestamps in SurfaceCollection"""
Expand Down Expand Up @@ -141,7 +145,9 @@ def _aggregate(self, operation: str) -> xtgeo.RegularSurface:

async def _aggregate_async(self, operation: str) -> xtgeo.RegularSurface:
if operation not in self._aggregation_cache:
objects = await self._utils.get_objects_async(500, self._query, ["_id"])
objects = await self._utils.get_objects_async(
500, self._query, ["_id"]
)
object_ids = list(map(lambda obj: obj["_id"], objects))

res = await self._sumo.post_async(
Expand Down Expand Up @@ -291,4 +297,4 @@ def p90(self) -> xtgeo.RegularSurface:

async def p90_async(self) -> xtgeo.RegularSurface:
"""Perform a percentile aggregation"""
return await self._aggregate_async("p90")
return await self._aggregate_async("p90")
4 changes: 4 additions & 0 deletions src/fmu/sumo/explorer/objects/table_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __getitem__(self, index) -> Table:
doc = super().__getitem__(index)
return Table(self._sumo, doc)

async def getitem_async(self, index: int) -> Table:
doc = await super().getitem_async(index)
return Table(self._sumo, doc)

@property
def columns(self) -> List[str]:
"""List of unique column names"""
Expand Down
7 changes: 5 additions & 2 deletions src/fmu/sumo/explorer/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ def __get_pit_id(self, keep_alive) -> str:
res = self._sumo.post("/pit", params={"keep-alive": keep_alive})
return res.json()["id"]

def get_pit_object(self) -> Dict:
def get_pit_object(self, pit_id: str = None) -> Dict:
"""Get the pit object
Returns:
Dict: dict with id and info about how long to keep alive
"""
return {"id": self._pit_id, "keep_alive": self._keep_alive}
return {
"id": pit_id if pit_id is not None else self._pit_id,
"keep_alive": self._keep_alive,
}

0 comments on commit 138955f

Please sign in to comment.