Skip to content

Commit

Permalink
[IO-1746] Introduction of Pagination queries and changes to item_ids …
Browse files Browse the repository at this point in the history
…endpoint (#707)

* basic pagination

* changes for meta pagination

* paginated id query

* WIP changes for pagination

* pagination object [untested]

* pagination objects completed

* test fixes

* sensible defaults + test changes

* base pagination collects all test

* meta pagination tests

* tweaks to useage

* removal of no longer needed exception

* linting changes

* reverting 'sensible' defaults

* len changes
  • Loading branch information
Nathanjp91 authored Nov 6, 2023
1 parent 7845410 commit d102e24
Show file tree
Hide file tree
Showing 28 changed files with 579 additions and 116 deletions.
25 changes: 7 additions & 18 deletions darwin/future/core/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Callable, Dict, Optional, overload
from typing import Callable, Dict, Optional
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -135,7 +135,11 @@ class ClientCore:
team: Team, team to make requests to
"""

def __init__(self, config: DarwinConfig, retries: Optional[Retry] = None) -> None:
def __init__(
self,
config: DarwinConfig,
retries: Optional[Retry] = None,
) -> None:
self.config = config
self.session = requests.Session()
if not retries:
Expand Down Expand Up @@ -166,21 +170,6 @@ def headers(self) -> Dict[str, str]:
http_headers["Authorization"] = f"ApiKey {self.config.api_key}"
return http_headers

@overload
def _generic_call(
self, method: Callable[[str], requests.Response], endpoint: str
) -> dict:
...

@overload
def _generic_call(
self,
method: Callable[[str, dict], requests.Response],
endpoint: str,
payload: dict,
) -> dict:
...

def _generic_call(
self, method: Callable, endpoint: str, payload: Optional[dict] = None
) -> JSONType:
Expand Down Expand Up @@ -227,7 +216,7 @@ def delete(
return self._generic_call(
self.session.delete,
self._contain_qs_and_endpoint(endpoint, query_string),
data if data is not None else {},
data,
)

def patch(self, endpoint: str, data: dict) -> JSONType:
Expand Down
13 changes: 7 additions & 6 deletions darwin/future/core/items/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@


def get_item_ids(
api_client: ClientCore, team_slug: str, dataset_id: Union[str, int]
api_client: ClientCore,
team_slug: str,
dataset_id: Union[str, int],
params: QueryString = QueryString({}),
) -> List[UUID]:
"""
Returns a list of item ids for the dataset
Expand All @@ -28,16 +31,14 @@ def get_item_ids(
List[UUID]
A list of item ids
"""

response = api_client.get(
f"/v2/teams/{team_slug}/items/ids",
f"/v2/teams/{team_slug}/items/list_ids",
QueryString(
{
"not_statuses": "archived,error",
"sort[id]": "desc",
"dataset_ids": str(dataset_id),
}
),
)
+ params,
)
assert isinstance(response, dict)
uuids = [UUID(uuid) for uuid in response["item_ids"]]
Expand Down
5 changes: 5 additions & 0 deletions darwin/future/core/types/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any, Dict, List, Union

from darwin.future.data_objects import validators as darwin_validators
Expand Down Expand Up @@ -83,3 +85,6 @@ def __init__(self, value: Dict[str, str]) -> None:

def __str__(self) -> str:
return "?" + "&".join(f"{k}={v}" for k, v in self.value.items())

def __add__(self, other: QueryString) -> QueryString:
return QueryString({**self.value, **other.value})
103 changes: 86 additions & 17 deletions darwin/future/core/types/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar

from darwin.future.core.client import ClientCore
from darwin.future.data_objects.page import Page
from darwin.future.exceptions import (
InvalidQueryFilter,
InvalidQueryModifier,
Expand All @@ -29,8 +30,9 @@ class Modifier(Enum):


class QueryFilter(DefaultDarwin):
"""Basic query filter with a name and a parameter
"""
Basic query filter with a name and a parameter
Modifiers are for client side filtering only, and are not passed to the API
Attributes
----------
name: str
Expand Down Expand Up @@ -104,6 +106,12 @@ def _from_kwarg(cls, key: str, value: str) -> QueryFilter:
modifier = None
return QueryFilter(name=key, param=value, modifier=modifier)

def to_dict(self, ignore_modifier: bool = True) -> Dict[str, str]:
d = {self.name: self.param}
if self.modifier is not None and not ignore_modifier:
d["modifier"] = self.modifier.value
return d


class Query(Generic[T], ABC):
"""
Expand Down Expand Up @@ -154,17 +162,16 @@ def __init__(
self.meta_params: dict = meta_params or {}
self.client = client
self.filters = filters or []
self.results: Optional[List[T]] = None
self._changed_since_last: bool = True
self.results: dict[int, T] = {}
self._changed_since_last: bool = False

def filter(self, filter: QueryFilter) -> Query[T]:
return self + filter

def __add__(self, filter: QueryFilter) -> Query[T]:
self._changed_since_last = True
return self.__class__(
self.client, filters=[*self.filters, filter], meta_params=self.meta_params
)
self.filters.append(filter)
return self

def __sub__(self, filter: QueryFilter) -> Query[T]:
self._changed_since_last = True
Expand All @@ -186,7 +193,7 @@ def __isub__(self, filter: QueryFilter) -> Query[T]:

def __len__(self) -> int:
if not self.results:
self.results = list(self._collect())
self.results = {**self.results, **self._collect()}
return len(self.results)

def __iter__(self) -> Query[T]:
Expand All @@ -195,7 +202,7 @@ def __iter__(self) -> Query[T]:

def __next__(self) -> T:
if not self.results:
self.results = list(self._collect())
self.collect()
if self.n < len(self.results):
result = self.results[self.n]
self.n += 1
Expand All @@ -205,12 +212,12 @@ def __next__(self) -> T:

def __getitem__(self, index: int) -> T:
if not self.results:
self.results = list(self._collect())
self.results = {**self.results, **self._collect()}
return self.results[index]

def __setitem__(self, index: int, value: T) -> None:
if not self.results:
self.results = list(self._collect())
self.results = {**self.results, **self._collect()}
self.results[index] = value

def where(self, *args: object, **kwargs: str) -> Query[T]:
Expand All @@ -222,18 +229,21 @@ def where(self, *args: object, **kwargs: str) -> Query[T]:

def collect(self, force: bool = False) -> List[T]:
if force or self._changed_since_last:
self.results = []
self.results = self._collect()
self.results = {}
self.results = {**self.results, **self._collect()}
self._changed_since_last = False
return self.results
return self._unwrap(self.results)

def _unwrap(self, results: Dict[int, T]) -> List[T]:
return list(results.values())

@abstractmethod
def _collect(self) -> List[T]:
def _collect(self) -> Dict[int, T]:
raise NotImplementedError("Not implemented")

def collect_one(self) -> T:
if not self.results:
self.results = list(self.collect())
self.results = {**self.results, **self._collect()}
if len(self.results) == 0:
raise ResultsNotFound("No results found")
if len(self.results) > 1:
Expand All @@ -242,12 +252,71 @@ def collect_one(self) -> T:

def first(self) -> T:
if not self.results:
self.results = list(self.collect())
self.results = {**self.results, **self._collect()}
if len(self.results) == 0:
raise ResultsNotFound("No results found")

return self.results[0]

def _generic_execute_filter(self, objects: List[T], filter: QueryFilter) -> List[T]:
return [
m for m in objects if filter.filter_attr(getattr(m._element, filter.name))
]


class PaginatedQuery(Query[T]):
def __init__(
self,
client: ClientCore,
filters: List[QueryFilter] | None = None,
meta_params: Param | None = None,
page: Page | None = None,
):
super().__init__(client, filters, meta_params)
self.page = page or Page()
self.completed = False

def collect(self, force: bool = False) -> List[T]:
if force or self._changed_since_last:
self.page = Page()
self.completed = False
if self.completed:
return self._unwrap(self.results)
new_results = self._collect()
self.results = {**self.results, **new_results}
if len(new_results) < self.page.size or len(new_results) == 0:
self.completed = True
else:
self.page.increment()
return self._unwrap(self.results)

def collect_all(self, force: bool = False) -> List[T]:
if force:
self.page = Page()
self.completed = False
self.results = {}
while not self.completed:
self.collect()
return self._unwrap(self.results)

def __getitem__(self, index: int) -> T:
if index not in self.results:
temp_page = self.page
self.page = self.page.get_required_page(index)
self.collect()
self.page = temp_page
return super().__getitem__(index)

def __next__(self) -> T:
if not self.completed and self.n not in self.results:
self.collect()
if self.completed and self.n not in self.results:
raise StopIteration
result = self.results[self.n]
self.n += 1
return result

def __len__(self) -> int:
if not self.completed:
self.collect_all()
return len(self.results)
2 changes: 1 addition & 1 deletion darwin/future/data_objects/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def validate_slot_name(cls, v: UnknownType) -> str:
return v

@classmethod
def validate_fps(cls, values: dict):
def validate_fps(cls, values: dict) -> dict:
value = values.get("fps")

if value is None:
Expand Down
51 changes: 51 additions & 0 deletions darwin/future/data_objects/page.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from math import floor

from pydantic import NonNegativeInt, PositiveInt

from darwin.future.core.types.common import QueryString
from darwin.future.data_objects.pydantic_base import DefaultDarwin


def must_be_positive(v: int) -> int:
if v is not None and v < 0:
raise ValueError("Value must be positive")
return v


class Page(DefaultDarwin):
offset: NonNegativeInt = 0
size: PositiveInt = 500

def get_required_page(self, item_index: int) -> Page:
"""
Get the page that contains the item at the specified index
Args:
item_index (int): The index of the item
Returns:
Page: The page that contains the item
"""
assert self.size is not None
required_offset = floor(item_index / self.size) * self.size
return Page(offset=required_offset, size=self.size)

def to_query_string(self) -> QueryString:
"""
Generate a query string from the page object, some fields are not included if they are None,
and certain fields are renamed. Outgoing and incoming query strings are different and require
dropping certain fields
Returns:
QueryString: Outgoing query string
"""
qs_dict = {"page[offset]": str(self.offset), "page[size]": str(self.size)}
return QueryString(qs_dict)

def increment(self) -> None:
"""
Increment the page offset by the page size
"""
self.offset += self.size
8 changes: 5 additions & 3 deletions darwin/future/meta/objects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from typing import Dict, Generic, Optional, TypeVar

from darwin.future.core.client import ClientCore
from darwin.future.pydantic_base import DefaultDarwin

R = TypeVar("R", bound=DefaultDarwin)
R = TypeVar("R")
Param = Dict[str, object]


Expand Down Expand Up @@ -38,7 +37,10 @@ class Team(MetaBase[TeamCore]):
client: ClientCore

def __init__(
self, client: ClientCore, element: R, meta_params: Optional[Param] = None
self,
element: R,
client: ClientCore,
meta_params: Optional[Param] = None,
) -> None:
self.client = client
self._element = element
Expand Down
Loading

0 comments on commit d102e24

Please sign in to comment.