diff --git a/darwin/future/core/client.py b/darwin/future/core/client.py index c7e813599..f2c7e9874 100644 --- a/darwin/future/core/client.py +++ b/darwin/future/core/client.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Callable, Dict, Optional, overload +from typing import Callable, Dict, Optional from urllib.parse import urlparse import requests @@ -135,7 +135,11 @@ class ClientCore: team: Team, team to make requests to """ - def __init__(self, config: DarwinConfig, retries: Optional[Retry] = None) -> None: + def __init__( + self, + config: DarwinConfig, + retries: Optional[Retry] = None, + ) -> None: self.config = config self.session = requests.Session() if not retries: @@ -166,21 +170,6 @@ def headers(self) -> Dict[str, str]: http_headers["Authorization"] = f"ApiKey {self.config.api_key}" return http_headers - @overload - def _generic_call( - self, method: Callable[[str], requests.Response], endpoint: str - ) -> dict: - ... - - @overload - def _generic_call( - self, - method: Callable[[str, dict], requests.Response], - endpoint: str, - payload: dict, - ) -> dict: - ... - def _generic_call( self, method: Callable, endpoint: str, payload: Optional[dict] = None ) -> JSONType: @@ -227,7 +216,7 @@ def delete( return self._generic_call( self.session.delete, self._contain_qs_and_endpoint(endpoint, query_string), - data if data is not None else {}, + data, ) def patch(self, endpoint: str, data: dict) -> JSONType: diff --git a/darwin/future/core/items/get.py b/darwin/future/core/items/get.py index b0e74785f..8701d5828 100644 --- a/darwin/future/core/items/get.py +++ b/darwin/future/core/items/get.py @@ -9,7 +9,10 @@ def get_item_ids( - api_client: ClientCore, team_slug: str, dataset_id: Union[str, int] + api_client: ClientCore, + team_slug: str, + dataset_id: Union[str, int], + params: QueryString = QueryString({}), ) -> List[UUID]: """ Returns a list of item ids for the dataset @@ -28,16 +31,14 @@ def get_item_ids( List[UUID] A list of item ids """ - response = api_client.get( - f"/v2/teams/{team_slug}/items/ids", + f"/v2/teams/{team_slug}/items/list_ids", QueryString( { - "not_statuses": "archived,error", - "sort[id]": "desc", "dataset_ids": str(dataset_id), } - ), + ) + + params, ) assert isinstance(response, dict) uuids = [UUID(uuid) for uuid in response["item_ids"]] diff --git a/darwin/future/core/types/common.py b/darwin/future/core/types/common.py index 561ec5eaa..489916509 100644 --- a/darwin/future/core/types/common.py +++ b/darwin/future/core/types/common.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Dict, List, Union from darwin.future.data_objects import validators as darwin_validators @@ -83,3 +85,6 @@ def __init__(self, value: Dict[str, str]) -> None: def __str__(self) -> str: return "?" + "&".join(f"{k}={v}" for k, v in self.value.items()) + + def __add__(self, other: QueryString) -> QueryString: + return QueryString({**self.value, **other.value}) diff --git a/darwin/future/core/types/query.py b/darwin/future/core/types/query.py index 0465c0a61..194f4eb38 100644 --- a/darwin/future/core/types/query.py +++ b/darwin/future/core/types/query.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar from darwin.future.core.client import ClientCore +from darwin.future.data_objects.page import Page from darwin.future.exceptions import ( InvalidQueryFilter, InvalidQueryModifier, @@ -29,8 +30,9 @@ class Modifier(Enum): class QueryFilter(DefaultDarwin): - """Basic query filter with a name and a parameter - + """ + Basic query filter with a name and a parameter + Modifiers are for client side filtering only, and are not passed to the API Attributes ---------- name: str @@ -104,6 +106,12 @@ def _from_kwarg(cls, key: str, value: str) -> QueryFilter: modifier = None return QueryFilter(name=key, param=value, modifier=modifier) + def to_dict(self, ignore_modifier: bool = True) -> Dict[str, str]: + d = {self.name: self.param} + if self.modifier is not None and not ignore_modifier: + d["modifier"] = self.modifier.value + return d + class Query(Generic[T], ABC): """ @@ -154,17 +162,16 @@ def __init__( self.meta_params: dict = meta_params or {} self.client = client self.filters = filters or [] - self.results: Optional[List[T]] = None - self._changed_since_last: bool = True + self.results: dict[int, T] = {} + self._changed_since_last: bool = False def filter(self, filter: QueryFilter) -> Query[T]: return self + filter def __add__(self, filter: QueryFilter) -> Query[T]: self._changed_since_last = True - return self.__class__( - self.client, filters=[*self.filters, filter], meta_params=self.meta_params - ) + self.filters.append(filter) + return self def __sub__(self, filter: QueryFilter) -> Query[T]: self._changed_since_last = True @@ -186,7 +193,7 @@ def __isub__(self, filter: QueryFilter) -> Query[T]: def __len__(self) -> int: if not self.results: - self.results = list(self._collect()) + self.results = {**self.results, **self._collect()} return len(self.results) def __iter__(self) -> Query[T]: @@ -195,7 +202,7 @@ def __iter__(self) -> Query[T]: def __next__(self) -> T: if not self.results: - self.results = list(self._collect()) + self.collect() if self.n < len(self.results): result = self.results[self.n] self.n += 1 @@ -205,12 +212,12 @@ def __next__(self) -> T: def __getitem__(self, index: int) -> T: if not self.results: - self.results = list(self._collect()) + self.results = {**self.results, **self._collect()} return self.results[index] def __setitem__(self, index: int, value: T) -> None: if not self.results: - self.results = list(self._collect()) + self.results = {**self.results, **self._collect()} self.results[index] = value def where(self, *args: object, **kwargs: str) -> Query[T]: @@ -222,18 +229,21 @@ def where(self, *args: object, **kwargs: str) -> Query[T]: def collect(self, force: bool = False) -> List[T]: if force or self._changed_since_last: - self.results = [] - self.results = self._collect() + self.results = {} + self.results = {**self.results, **self._collect()} self._changed_since_last = False - return self.results + return self._unwrap(self.results) + + def _unwrap(self, results: Dict[int, T]) -> List[T]: + return list(results.values()) @abstractmethod - def _collect(self) -> List[T]: + def _collect(self) -> Dict[int, T]: raise NotImplementedError("Not implemented") def collect_one(self) -> T: if not self.results: - self.results = list(self.collect()) + self.results = {**self.results, **self._collect()} if len(self.results) == 0: raise ResultsNotFound("No results found") if len(self.results) > 1: @@ -242,12 +252,71 @@ def collect_one(self) -> T: def first(self) -> T: if not self.results: - self.results = list(self.collect()) + self.results = {**self.results, **self._collect()} if len(self.results) == 0: raise ResultsNotFound("No results found") + return self.results[0] def _generic_execute_filter(self, objects: List[T], filter: QueryFilter) -> List[T]: return [ m for m in objects if filter.filter_attr(getattr(m._element, filter.name)) ] + + +class PaginatedQuery(Query[T]): + def __init__( + self, + client: ClientCore, + filters: List[QueryFilter] | None = None, + meta_params: Param | None = None, + page: Page | None = None, + ): + super().__init__(client, filters, meta_params) + self.page = page or Page() + self.completed = False + + def collect(self, force: bool = False) -> List[T]: + if force or self._changed_since_last: + self.page = Page() + self.completed = False + if self.completed: + return self._unwrap(self.results) + new_results = self._collect() + self.results = {**self.results, **new_results} + if len(new_results) < self.page.size or len(new_results) == 0: + self.completed = True + else: + self.page.increment() + return self._unwrap(self.results) + + def collect_all(self, force: bool = False) -> List[T]: + if force: + self.page = Page() + self.completed = False + self.results = {} + while not self.completed: + self.collect() + return self._unwrap(self.results) + + def __getitem__(self, index: int) -> T: + if index not in self.results: + temp_page = self.page + self.page = self.page.get_required_page(index) + self.collect() + self.page = temp_page + return super().__getitem__(index) + + def __next__(self) -> T: + if not self.completed and self.n not in self.results: + self.collect() + if self.completed and self.n not in self.results: + raise StopIteration + result = self.results[self.n] + self.n += 1 + return result + + def __len__(self) -> int: + if not self.completed: + self.collect_all() + return len(self.results) diff --git a/darwin/future/data_objects/item.py b/darwin/future/data_objects/item.py index e3639be7a..b649c63a2 100644 --- a/darwin/future/data_objects/item.py +++ b/darwin/future/data_objects/item.py @@ -62,7 +62,7 @@ def validate_slot_name(cls, v: UnknownType) -> str: return v @classmethod - def validate_fps(cls, values: dict): + def validate_fps(cls, values: dict) -> dict: value = values.get("fps") if value is None: diff --git a/darwin/future/data_objects/page.py b/darwin/future/data_objects/page.py new file mode 100644 index 000000000..ac8c0c5da --- /dev/null +++ b/darwin/future/data_objects/page.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from math import floor + +from pydantic import NonNegativeInt, PositiveInt + +from darwin.future.core.types.common import QueryString +from darwin.future.data_objects.pydantic_base import DefaultDarwin + + +def must_be_positive(v: int) -> int: + if v is not None and v < 0: + raise ValueError("Value must be positive") + return v + + +class Page(DefaultDarwin): + offset: NonNegativeInt = 0 + size: PositiveInt = 500 + + def get_required_page(self, item_index: int) -> Page: + """ + Get the page that contains the item at the specified index + + Args: + item_index (int): The index of the item + + Returns: + Page: The page that contains the item + """ + assert self.size is not None + required_offset = floor(item_index / self.size) * self.size + return Page(offset=required_offset, size=self.size) + + def to_query_string(self) -> QueryString: + """ + Generate a query string from the page object, some fields are not included if they are None, + and certain fields are renamed. Outgoing and incoming query strings are different and require + dropping certain fields + + Returns: + QueryString: Outgoing query string + """ + qs_dict = {"page[offset]": str(self.offset), "page[size]": str(self.size)} + return QueryString(qs_dict) + + def increment(self) -> None: + """ + Increment the page offset by the page size + """ + self.offset += self.size diff --git a/darwin/future/meta/objects/base.py b/darwin/future/meta/objects/base.py index 3015dbc8b..77eafcd7e 100644 --- a/darwin/future/meta/objects/base.py +++ b/darwin/future/meta/objects/base.py @@ -3,9 +3,8 @@ from typing import Dict, Generic, Optional, TypeVar from darwin.future.core.client import ClientCore -from darwin.future.pydantic_base import DefaultDarwin -R = TypeVar("R", bound=DefaultDarwin) +R = TypeVar("R") Param = Dict[str, object] @@ -38,7 +37,10 @@ class Team(MetaBase[TeamCore]): client: ClientCore def __init__( - self, client: ClientCore, element: R, meta_params: Optional[Param] = None + self, + element: R, + client: ClientCore, + meta_params: Optional[Param] = None, ) -> None: self.client = client self._element = element diff --git a/darwin/future/meta/objects/dataset.py b/darwin/future/meta/objects/dataset.py index 714d8bc2f..e44eff0ae 100644 --- a/darwin/future/meta/objects/dataset.py +++ b/darwin/future/meta/objects/dataset.py @@ -1,17 +1,16 @@ from __future__ import annotations from typing import List, Optional, Sequence, Union -from uuid import UUID from darwin.cli_functions import upload_data from darwin.dataset.upload_manager import LocalFile from darwin.datatypes import PathLike from darwin.future.core.client import ClientCore from darwin.future.core.datasets import create_dataset, remove_dataset -from darwin.future.core.items import get_item_ids from darwin.future.data_objects.dataset import DatasetCore from darwin.future.helpers.assertion import assert_is from darwin.future.meta.objects.base import MetaBase +from darwin.future.meta.queries.item_id import ItemIDQuery class Dataset(MetaBase[DatasetCore]): @@ -62,7 +61,7 @@ def id(self) -> int: return self._element.id @property - def item_ids(self) -> List[UUID]: + def item_ids(self) -> ItemIDQuery: """Returns a list of item ids for the dataset Returns: @@ -72,9 +71,8 @@ def item_ids(self) -> List[UUID]: assert self.meta_params["team_slug"] is not None and isinstance( self.meta_params["team_slug"], str ) - return get_item_ids( - self.client, self.meta_params["team_slug"], str(self._element.id) - ) + meta_params = {"dataset_ids": self.id, **self.meta_params} + return ItemIDQuery(self.client, meta_params=meta_params) @classmethod def create_dataset(cls, client: ClientCore, slug: str) -> DatasetCore: diff --git a/darwin/future/meta/objects/stage.py b/darwin/future/meta/objects/stage.py index 80e1432ea..d0aa53219 100644 --- a/darwin/future/meta/objects/stage.py +++ b/darwin/future/meta/objects/stage.py @@ -3,9 +3,11 @@ from typing import List from uuid import UUID -from darwin.future.core.items import get_item_ids_stage, move_items_to_stage +from darwin.future.core.items import move_items_to_stage +from darwin.future.core.types.query import QueryFilter from darwin.future.data_objects.workflow import WFEdgeCore, WFStageCore from darwin.future.meta.objects.base import MetaBase +from darwin.future.meta.queries.item_id import ItemIDQuery class Stage(MetaBase[WFStageCore]): @@ -42,18 +44,19 @@ class Stage(MetaBase[WFStageCore]): """ @property - def item_ids(self) -> List[UUID]: + def item_ids(self) -> ItemIDQuery: """Item ids attached to the stage Returns: List[UUID]: List of item ids """ assert self._element.id is not None - return get_item_ids_stage( + return ItemIDQuery( self.client, - str(self.meta_params["team_slug"]), - str(self.meta_params["dataset_id"]), - self.id, + meta_params=self.meta_params, + filters=[ + QueryFilter(name="workflow_stage_ids", param=str(self._element.id)) + ], ) def move_attached_files_to_stage(self, new_stage_id: UUID) -> Stage: @@ -71,7 +74,8 @@ def move_attached_files_to_stage(self, new_stage_id: UUID) -> Stage: self.meta_params["workflow_id"], self.meta_params["dataset_id"], ) - move_items_to_stage(self.client, slug, w_id, d_id, new_stage_id, self.item_ids) + ids = [x.id for x in self.item_ids.collect_all()] + move_items_to_stage(self.client, slug, w_id, d_id, new_stage_id, ids) return self @property diff --git a/darwin/future/meta/objects/team.py b/darwin/future/meta/objects/team.py index 275e0c686..078693230 100644 --- a/darwin/future/meta/objects/team.py +++ b/darwin/future/meta/objects/team.py @@ -68,7 +68,7 @@ class Team(MetaBase[TeamCore]): def __init__(self, client: ClientCore, team: Optional[TeamCore] = None) -> None: team = team or get_team(client) - super().__init__(client, team) + super().__init__(client=client, element=team) @property def name(self) -> str: @@ -172,7 +172,9 @@ def _delete_dataset_by_id(client: ClientCore, dataset_id: int) -> int: def create_dataset(self, slug: str) -> Dataset: core = Dataset.create_dataset(self.client, slug) - return Dataset(self.client, core, meta_params={"team_slug": self.slug}) + return Dataset( + client=self.client, element=core, meta_params={"team_slug": self.slug} + ) def __str__(self) -> str: return f"Team\n\ diff --git a/darwin/future/meta/objects/v7_id.py b/darwin/future/meta/objects/v7_id.py new file mode 100644 index 000000000..cbd0c149b --- /dev/null +++ b/darwin/future/meta/objects/v7_id.py @@ -0,0 +1,15 @@ +from uuid import UUID + +from darwin.future.meta.objects.base import MetaBase + + +class V7ID(MetaBase[UUID]): + @property + def id(self) -> UUID: + return self._element + + def __str__(self) -> str: + return str(self._element) + + def __repr__(self) -> str: + return str(self) diff --git a/darwin/future/meta/queries/dataset.py b/darwin/future/meta/queries/dataset.py index 5fbbc4aa7..7768a1def 100644 --- a/darwin/future/meta/queries/dataset.py +++ b/darwin/future/meta/queries/dataset.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List +from typing import Dict, List from darwin.future.core.datasets import list_datasets from darwin.future.core.types.query import Query, QueryFilter @@ -18,13 +18,14 @@ class DatasetQuery(Query[Dataset]): collect: Executes the query and returns the filtered data """ - def _collect(self) -> List[Dataset]: + def _collect(self) -> Dict[int, Dataset]: datasets, exceptions = list_datasets(self.client) if exceptions: # TODO: print and or raise exceptions, tbd how we want to handle this pass datasets_meta = [ - Dataset(self.client, dataset, self.meta_params) for dataset in datasets + Dataset(client=self.client, element=dataset, meta_params=self.meta_params) + for dataset in datasets ] if not self.filters: self.filters = [] @@ -32,7 +33,7 @@ def _collect(self) -> List[Dataset]: for filter in self.filters: datasets_meta = self._execute_filters(datasets_meta, filter) - return datasets_meta + return dict(enumerate(datasets_meta)) def _execute_filters( self, datasets: List[Dataset], filter: QueryFilter diff --git a/darwin/future/meta/queries/item_id.py b/darwin/future/meta/queries/item_id.py new file mode 100644 index 000000000..d792fec26 --- /dev/null +++ b/darwin/future/meta/queries/item_id.py @@ -0,0 +1,41 @@ +from functools import reduce +from typing import Dict + +from darwin.future.core.items.get import get_item_ids +from darwin.future.core.types.common import QueryString +from darwin.future.core.types.query import PaginatedQuery +from darwin.future.meta.objects.v7_id import V7ID + + +class ItemIDQuery(PaginatedQuery[V7ID]): + def _collect(self) -> Dict[int, V7ID]: + if "team_slug" not in self.meta_params: + raise ValueError("Must specify team_slug to query item ids") + if ( + "dataset_ids" not in self.meta_params + and "dataset_id" not in self.meta_params + ): + raise ValueError("Must specify dataset_id to query item ids") + team_slug: str = self.meta_params["team_slug"] + dataset_ids: int = ( + self.meta_params["dataset_ids"] + if "dataset_ids" in self.meta_params + else self.meta_params["dataset_id"] + ) + params: QueryString = reduce( + lambda s1, s2: s1 + s2, + [ + self.page.to_query_string(), + *[QueryString(f.to_dict()) for f in self.filters], + ], + ) + uuids = get_item_ids(self.client, team_slug, dataset_ids, params) + + results = { + i + + self.page.offset: V7ID( + client=self.client, element=uuid, meta_params=self.meta_params + ) + for i, uuid in enumerate(uuids) + } + return results diff --git a/darwin/future/meta/queries/stage.py b/darwin/future/meta/queries/stage.py index 3a18bc0c1..723664e2e 100644 --- a/darwin/future/meta/queries/stage.py +++ b/darwin/future/meta/queries/stage.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List +from typing import Dict, List from uuid import UUID from darwin.future.core.types.query import Query, QueryFilter @@ -9,7 +9,7 @@ class StageQuery(Query[Stage]): - def _collect(self) -> List[Stage]: + def _collect(self) -> Dict[int, Stage]: if "workflow_id" not in self.meta_params: raise ValueError("Must specify workflow_id to query stages") workflow_id: UUID = self.meta_params["workflow_id"] @@ -17,13 +17,14 @@ def _collect(self) -> List[Stage]: workflow = get_workflow(self.client, str(workflow_id)) assert workflow is not None stages = [ - Stage(self.client, s, meta_params=meta_params) for s in workflow.stages + Stage(client=self.client, element=s, meta_params=meta_params) + for s in workflow.stages ] if not self.filters: self.filters = [] for filter in self.filters: stages = self._execute_filter(stages, filter) - return stages + return dict(enumerate(stages)) def _execute_filter(self, stages: List[Stage], filter: QueryFilter) -> List[Stage]: """Executes filtering on the local list of stages diff --git a/darwin/future/meta/queries/team_member.py b/darwin/future/meta/queries/team_member.py index a76e1f7a9..c7ac11b30 100644 --- a/darwin/future/meta/queries/team_member.py +++ b/darwin/future/meta/queries/team_member.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List +from typing import Dict, List from darwin.future.core.team.get_team import get_team_members from darwin.future.core.types.query import Query, QueryFilter @@ -16,9 +16,11 @@ class TeamMemberQuery(Query[TeamMember]): _execute_filter: Executes a filter on a list of objects """ - def _collect(self) -> List[TeamMember]: + def _collect(self) -> Dict[int, TeamMember]: members, exceptions = get_team_members(self.client) - members_meta = [TeamMember(self.client, member) for member in members] + members_meta = [ + TeamMember(client=self.client, element=member) for member in members + ] if exceptions: # TODO: print and or raise exceptions, tbd how we want to handle this pass @@ -27,7 +29,7 @@ def _collect(self) -> List[TeamMember]: for filter in self.filters: members_meta = self._execute_filter(members_meta, filter) - return members_meta + return dict(enumerate(members_meta)) def _execute_filter( self, members: List[TeamMember], filter: QueryFilter diff --git a/darwin/future/meta/queries/workflow.py b/darwin/future/meta/queries/workflow.py index 75c3a6442..80b47792a 100644 --- a/darwin/future/meta/queries/workflow.py +++ b/darwin/future/meta/queries/workflow.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import List +from typing import Dict, List from uuid import UUID from darwin.exceptions import DarwinException @@ -21,22 +21,22 @@ class WorkflowQuery(Query[Workflow]): collect: Executes the query and returns the filtered data """ - def _collect(self) -> List[Workflow]: + def _collect(self) -> Dict[int, Workflow]: workflows_core, exceptions = list_workflows(self.client) if exceptions: handle_exception(exceptions) raise DarwinException from exceptions[0] workflows = [ - Workflow(self.client, workflow, self.meta_params) + Workflow(client=self.client, element=workflow, meta_params=self.meta_params) for workflow in workflows_core ] if not self.filters: - return workflows + self.filters = [] for filter in self.filters: workflows = self._execute_filters(workflows, filter) - return workflows + return dict(enumerate(workflows)) def _execute_filters( self, workflows: List[Workflow], filter: QueryFilter diff --git a/darwin/future/tests/core/items/test_get_items.py b/darwin/future/tests/core/items/test_get_items.py index ed808cd4b..3e7a40386 100644 --- a/darwin/future/tests/core/items/test_get_items.py +++ b/darwin/future/tests/core/items/test_get_items.py @@ -19,8 +19,8 @@ def test_get_item_ids( with responses.RequestsMock() as rsps: rsps.add( rsps.GET, - base_client.config.api_endpoint + "v2/teams/default-team/items/ids" - "?not_statuses=archived,error&sort[id]=desc&dataset_ids=1337", + base_client.config.api_endpoint + "v2/teams/default-team/items/list_ids" + "?dataset_ids=1337", json={"item_ids": UUIDs_str}, status=200, ) diff --git a/darwin/future/tests/core/items/test_set_priority.py b/darwin/future/tests/core/items/test_set_priority.py index 9b056bcf0..b6aec7d77 100644 --- a/darwin/future/tests/core/items/test_set_priority.py +++ b/darwin/future/tests/core/items/test_set_priority.py @@ -11,7 +11,7 @@ @responses.activate -def test_set_item_priority(base_client) -> None: +def test_set_item_priority(base_client: ClientCore) -> None: responses.add( responses.POST, base_client.config.api_endpoint + "v2/teams/test-team/items/priority", diff --git a/darwin/future/tests/core/test_query.py b/darwin/future/tests/core/test_query.py index 8c5692390..57453291f 100644 --- a/darwin/future/tests/core/test_query.py +++ b/darwin/future/tests/core/test_query.py @@ -135,7 +135,7 @@ def test_QF_from_asteriks() -> None: def test_query_first(non_abc_query: Type[Query.Query], base_client: ClientCore) -> None: query = non_abc_query(base_client) - query.results = [1, 2, 3] + query.results = {0: 1, 1: 2, 2: 3} first = query.first() assert first == 1 @@ -144,9 +144,9 @@ def test_query_collect_one( non_abc_query: Type[Query.Query], base_client: ClientCore ) -> None: query = non_abc_query(base_client) - query.results = [1, 2, 3] + query.results = {0: 1, 1: 2, 2: 3} with pytest.raises(MoreThanOneResultFound): query.collect_one() - query.results = [1] + query.results = {0: 1} assert query.collect_one() == 1 diff --git a/darwin/future/tests/data_objects/test_page.py b/darwin/future/tests/data_objects/test_page.py new file mode 100644 index 000000000..9097722cf --- /dev/null +++ b/darwin/future/tests/data_objects/test_page.py @@ -0,0 +1,32 @@ +import pytest + +from darwin.future.data_objects.page import Page + + +def test_default_page() -> None: + page = Page() + assert page.offset == 0 + assert page.size == 500 + + +def test_to_query_string() -> None: + page = Page(offset=0, size=10) + qs = page.to_query_string() + assert qs.value == {"page[offset]": "0", "page[size]": "10"} + + +def test_increment() -> None: + page = Page(offset=0, size=10) + page.increment() + assert page.offset == 10 + assert page.size == 10 + + +@pytest.mark.parametrize( + "size, index, expected_offset", [(10, 0, 0), (10, 9, 0), (10, 10, 10), (10, 11, 10)] +) +def test_get_required_page(size: int, index: int, expected_offset: int) -> None: + page = Page(size=size, offset=0) + required_page = page.get_required_page(index) + assert required_page.offset == expected_offset + assert required_page.size == size diff --git a/darwin/future/tests/meta/objects/fixtures.py b/darwin/future/tests/meta/objects/fixtures.py index a4a75c14f..6aebae03d 100644 --- a/darwin/future/tests/meta/objects/fixtures.py +++ b/darwin/future/tests/meta/objects/fixtures.py @@ -21,21 +21,21 @@ def base_UUID() -> UUID: @fixture def base_meta_team(base_client: ClientCore, base_team: TeamCore) -> Team: - return Team(base_client, base_team) + return Team(client=base_client, team=base_team) @fixture def base_meta_workflow( base_client: ClientCore, base_workflow: WorkflowCore ) -> Workflow: - return Workflow(base_client, base_workflow) + return Workflow(client=base_client, element=base_workflow) @fixture def base_meta_stage( base_client: ClientCore, base_stage: WFStageCore, base_UUID: UUID ) -> Stage: - return Stage(base_client, base_stage) + return Stage(client=base_client, element=base_stage) @fixture @@ -45,4 +45,6 @@ def base_meta_stage_list(base_meta_stage: Stage, base_UUID: UUID) -> List[Stage] @fixture def base_meta_dataset(base_client: ClientCore, base_dataset: DatasetCore) -> Dataset: - return Dataset(base_client, base_dataset, meta_params={"team_slug": "test_team"}) + return Dataset( + client=base_client, element=base_dataset, meta_params={"team_slug": "test_team"} + ) diff --git a/darwin/future/tests/meta/objects/test_stagemeta.py b/darwin/future/tests/meta/objects/test_stagemeta.py index 06eeff96d..f83b81ede 100644 --- a/darwin/future/tests/meta/objects/test_stagemeta.py +++ b/darwin/future/tests/meta/objects/test_stagemeta.py @@ -3,6 +3,7 @@ import responses from pytest import fixture +from responses.matchers import query_param_matcher from darwin.future.data_objects.workflow import WFEdgeCore, WFStageCore, WFTypeCore from darwin.future.meta.client import Client @@ -33,9 +34,13 @@ def stage_meta( base_meta_client: Client, base_WFStage: WFStageCore, workflow_id: UUID ) -> Stage: return Stage( - base_meta_client, - base_WFStage, - {"team_slug": "default-team", "dataset_id": 1337, "workflow_id": workflow_id}, + client=base_meta_client, + element=base_WFStage, + meta_params={ + "team_slug": "default-team", + "dataset_id": 1337, + "workflow_id": workflow_id, + }, ) @@ -46,12 +51,21 @@ def test_item_ids( rsps.add( rsps.GET, base_meta_client.config.api_endpoint - + f"v2/teams/default-team/items/ids?workflow_stage_ids={str(stage_meta.id)}" - "&dataset_ids=1337", + + "v2/teams/default-team/items/list_ids", + match=[ + query_param_matcher( + { + "page[offset]": "0", + "page[size]": "500", + "workflow_stage_ids": str(stage_meta.id), + "dataset_ids": "1337", + } + ) + ], json={"item_ids": UUIDs_str}, status=200, ) - item_ids = stage_meta.item_ids + item_ids = [x.id for x in stage_meta.item_ids.collect_all()] assert item_ids == UUIDs @@ -62,9 +76,18 @@ def test_move_attached_files_to_stage( rsps.add( rsps.GET, base_meta_client.config.api_endpoint - + f"v2/teams/default-team/items/ids?workflow_stage_ids={str(stage_meta.id)}" - "&dataset_ids=1337", + + "v2/teams/default-team/items/list_ids", json={"item_ids": UUIDs_str}, + match=[ + query_param_matcher( + { + "page[offset]": "0", + "page[size]": "500", + "workflow_stage_ids": str(stage_meta.id), + "dataset_ids": "1337", + } + ) + ], status=200, ) rsps.add( @@ -74,16 +97,6 @@ def test_move_attached_files_to_stage( status=200, ) stage_meta.move_attached_files_to_stage(stage_meta.id) - assert rsps.assert_call_count( - base_meta_client.config.api_endpoint + "v2/teams/default-team/items/stage", - 1, - ) - assert rsps.assert_call_count( - base_meta_client.config.api_endpoint - + f"v2/teams/default-team/items/ids?workflow_stage_ids={str(stage_meta.id)}" - "&dataset_ids=1337", - 1, - ) def test_get_stage_id(stage_meta: Stage) -> None: @@ -114,15 +127,15 @@ def test_get_stage_edges(stage_meta: Stage) -> None: ), ] test_stage = Stage( - stage_meta.client, - WFStageCore( + client=stage_meta.client, + element=WFStageCore( id=UUID("00000000-0000-0000-0000-000000000000"), name="test-stage", type=WFTypeCore.ANNOTATE, assignable_users=[], edges=edges, ), - { + meta_params={ "team_slug": "default-team", "dataset_id": 000000, "workflow_id": UUID("00000000-0000-0000-0000-000000000000"), diff --git a/darwin/future/tests/meta/objects/test_v7_id.py b/darwin/future/tests/meta/objects/test_v7_id.py new file mode 100644 index 000000000..62171cc2d --- /dev/null +++ b/darwin/future/tests/meta/objects/test_v7_id.py @@ -0,0 +1,18 @@ +from uuid import UUID + +from darwin.future.meta.client import Client +from darwin.future.meta.objects.v7_id import V7ID +from darwin.future.tests.meta.fixtures import * + + +def test_v7_id(base_meta_client: Client) -> None: + # Test creating a V7ID object + uuid = UUID("123e4567-e89b-12d3-a456-426655440000") + v7_id = V7ID(uuid, base_meta_client) + assert v7_id.id == uuid + + # Test __str__ method + assert str(v7_id) == str(uuid) + + # Test __repr__ method + assert repr(v7_id) == str(uuid) diff --git a/darwin/future/tests/meta/queries/test_dataset.py b/darwin/future/tests/meta/queries/test_dataset.py index d06e123c6..33399c9d2 100644 --- a/darwin/future/tests/meta/queries/test_dataset.py +++ b/darwin/future/tests/meta/queries/test_dataset.py @@ -13,7 +13,7 @@ def test_dataset_collects_basic( with responses.RequestsMock() as rsps: endpoint = base_client.config.api_endpoint + "datasets" rsps.add(responses.GET, endpoint, json=base_datasets_json) - datasets = query._collect() + datasets = query._collect().values() assert len(datasets) == 2 assert all(isinstance(dataset, Dataset) for dataset in datasets) diff --git a/darwin/future/tests/meta/queries/test_stage.py b/darwin/future/tests/meta/queries/test_stage.py index 1d7ad0721..0d7546733 100644 --- a/darwin/future/tests/meta/queries/test_stage.py +++ b/darwin/future/tests/meta/queries/test_stage.py @@ -20,7 +20,9 @@ def filled_query(base_client: ClientCore, base_workflow_meta: Workflow) -> Stage def base_workflow_meta( base_client: ClientCore, base_single_workflow_object: dict ) -> Workflow: - return Workflow(base_client, WorkflowCore.parse_obj(base_single_workflow_object)) + return Workflow( + client=base_client, element=WorkflowCore.parse_obj(base_single_workflow_object) + ) @pytest.fixture @@ -94,5 +96,5 @@ def test_stage_filters_WFType( stages = filled_query.where({"name": "type", "param": wf_type.value})._collect() assert len(stages) == 3 assert isinstance(stages[0], Stage) - for stage in stages: + for key, stage in stages.items(): assert stage._element.type == wf_type diff --git a/darwin/future/tests/meta/queries/test_team_id.py b/darwin/future/tests/meta/queries/test_team_id.py new file mode 100644 index 000000000..1c4b4c4f1 --- /dev/null +++ b/darwin/future/tests/meta/queries/test_team_id.py @@ -0,0 +1,215 @@ +from typing import List +from uuid import UUID, uuid4 + +import pytest +import responses +from responses.matchers import query_param_matcher + +from darwin.future.core.client import ClientCore +from darwin.future.data_objects.page import Page +from darwin.future.meta.queries.item_id import ItemIDQuery +from darwin.future.tests.core.fixtures import * + + +@pytest.fixture +def base_ItemIDQuery(base_client: ClientCore) -> ItemIDQuery: + return ItemIDQuery( + base_client, meta_params={"dataset_id": 0000, "team_slug": "test_team"} + ) + + +@pytest.fixture +def list_of_uuids() -> List[UUID]: + return [uuid4() for _ in range(10)] + + +def test_pagination_collects_all( + base_client: ClientCore, base_ItemIDQuery: ItemIDQuery, list_of_uuids: List[UUID] +) -> None: + base_ItemIDQuery.page = Page(size=5) + team_slug = base_ItemIDQuery.meta_params["team_slug"] + dataset_id = base_ItemIDQuery.meta_params["dataset_id"] + str_ids = [str(uuid) for uuid in list_of_uuids] + with responses.RequestsMock() as rsps: + endpoint = ( + base_client.config.api_endpoint + f"v2/teams/{team_slug}/items/list_ids" + ) + rsps.add( + responses.GET, + endpoint, + match=[ + query_param_matcher( + { + "page[offset]": "0", + "page[size]": "5", + "dataset_ids": str(dataset_id), + } + ) + ], + json={"item_ids": [str(uuid) for uuid in str_ids[:5]]}, + ) + rsps.add( + responses.GET, + endpoint, + match=[ + query_param_matcher( + { + "page[offset]": "5", + "page[size]": "5", + "dataset_ids": str(dataset_id), + } + ) + ], + json={"item_ids": [str(uuid) for uuid in str_ids[5:]]}, + ) + rsps.add( + responses.GET, + endpoint, + match=[ + query_param_matcher( + { + "page[offset]": "10", + "page[size]": "5", + "dataset_ids": str(dataset_id), + } + ) + ], + json={"item_ids": []}, + ) + + ids = base_ItemIDQuery.collect_all() + raw_ids = [x.id for x in ids] + assert len(rsps.calls) == 3 + assert len(ids) == 10 + assert raw_ids == list_of_uuids + assert base_ItemIDQuery.page.offset == 10 + assert base_ItemIDQuery.completed is True + + +def test_iterable_collects_all( + base_client: ClientCore, base_ItemIDQuery: ItemIDQuery, list_of_uuids: List[UUID] +) -> None: + base_ItemIDQuery.page = Page(size=5) + team_slug = base_ItemIDQuery.meta_params["team_slug"] + dataset_id = base_ItemIDQuery.meta_params["dataset_id"] + str_ids = [str(uuid) for uuid in list_of_uuids] + with responses.RequestsMock() as rsps: + endpoint = ( + base_client.config.api_endpoint + f"v2/teams/{team_slug}/items/list_ids" + ) + rsps.add( + responses.GET, + endpoint, + match=[ + query_param_matcher( + { + "page[offset]": "0", + "page[size]": "5", + "dataset_ids": str(dataset_id), + } + ) + ], + json={"item_ids": [str(uuid) for uuid in str_ids[:5]]}, + ) + rsps.add( + responses.GET, + endpoint, + match=[ + query_param_matcher( + { + "page[offset]": "5", + "page[size]": "5", + "dataset_ids": str(dataset_id), + } + ) + ], + json={"item_ids": [str(uuid) for uuid in str_ids[5:]]}, + ) + rsps.add( + responses.GET, + endpoint, + match=[ + query_param_matcher( + { + "page[offset]": "10", + "page[size]": "5", + "dataset_ids": str(dataset_id), + } + ) + ], + json={"item_ids": []}, + ) + + ids = base_ItemIDQuery + for i, item in enumerate(ids): + if i < 5: + assert item.id in list_of_uuids[:5] + assert len(rsps.calls) == 1 + elif i < 10: + assert item.id in list_of_uuids[:10] + assert len(rsps.calls) == 2 + + assert len(rsps.calls) == 3 + assert base_ItemIDQuery.page.offset == 10 + assert base_ItemIDQuery.completed is True + assert len(ids) == 10 + + +def test_can_become_iterable( + base_client: ClientCore, base_ItemIDQuery: ItemIDQuery, list_of_uuids: List[UUID] +) -> None: + base_ItemIDQuery.page = Page(size=20) + team_slug = base_ItemIDQuery.meta_params["team_slug"] + dataset_id = base_ItemIDQuery.meta_params["dataset_id"] + str_ids = [str(uuid) for uuid in list_of_uuids] + with responses.RequestsMock() as rsps: + endpoint = ( + base_client.config.api_endpoint + f"v2/teams/{team_slug}/items/list_ids" + ) + rsps.add( + responses.GET, + endpoint, + match=[ + query_param_matcher( + { + "page[offset]": "0", + "page[size]": "20", + "dataset_ids": str(dataset_id), + } + ) + ], + json={"item_ids": [str(uuid) for uuid in str_ids]}, + ) + + ids = list(base_ItemIDQuery) + ids_raw = [x.id for x in ids] + assert len(rsps.calls) == 1 + assert ids_raw == list_of_uuids + + +def test_get_specific_index_collects_correct_page( + base_client: ClientCore, base_ItemIDQuery: ItemIDQuery, list_of_uuids: List[UUID] +) -> None: + base_ItemIDQuery.page = Page(size=5) + team_slug = base_ItemIDQuery.meta_params["team_slug"] + dataset_id = base_ItemIDQuery.meta_params["dataset_id"] + str_ids = [str(uuid) for uuid in list_of_uuids] + with responses.RequestsMock() as rsps: + endpoint = ( + base_client.config.api_endpoint + f"v2/teams/{team_slug}/items/list_ids" + ) + rsps.add( + responses.GET, + endpoint, + match=[ + query_param_matcher( + { + "page[offset]": "5", + "page[size]": "5", + "dataset_ids": str(dataset_id), + } + ) + ], + json={"item_ids": [str(uuid) for uuid in str_ids[5:]]}, + ) + base_ItemIDQuery[7] diff --git a/darwin/future/tests/meta/queries/test_team_member.py b/darwin/future/tests/meta/queries/test_team_member.py index 6fe314e58..cf480eff8 100644 --- a/darwin/future/tests/meta/queries/test_team_member.py +++ b/darwin/future/tests/meta/queries/test_team_member.py @@ -57,7 +57,7 @@ def test_team_member_filters_role( rsps.add(responses.GET, endpoint, json=base_team_members_json) members = query._collect() assert len(members) == len(TeamMemberRole) - 1 - for member in members: + for member in members.values(): assert member._element.role != role diff --git a/darwin/future/tests/meta/queries/test_workflow.py b/darwin/future/tests/meta/queries/test_workflow.py index c5630c761..3cbafde37 100644 --- a/darwin/future/tests/meta/queries/test_workflow.py +++ b/darwin/future/tests/meta/queries/test_workflow.py @@ -30,7 +30,7 @@ def test_workflowquery_collects_basic( workflows = query._collect() assert len(workflows) == 3 - assert all(isinstance(workflow, Workflow) for workflow in workflows) + assert all(isinstance(workflow, Workflow) for workflow in workflows.values()) @responses.activate @@ -84,7 +84,7 @@ def test_workflowquery_filters_inserted_at( workflows = query._collect() assert len(workflows) == 2 - ids = [str(workflow.id) for workflow in workflows] + ids = [str(workflow.id) for workflow in workflows.values()] assert WORKFLOW_1 in ids assert WORKFLOW_2 in ids @@ -119,7 +119,7 @@ def test_workflowquery_filters_updated_at( workflows = query._collect() assert len(workflows) == 2 - ids = [str(workflow.id) for workflow in workflows] + ids = [str(workflow.id) for workflow in workflows.values()] assert WORKFLOW_1 in ids assert WORKFLOW_2 in ids @@ -249,7 +249,7 @@ def test_workflowquery_filters_stages_multiple( workflows = query._collect() assert len(workflows) == 2 - workflow_names = [workflow.name for workflow in workflows] + workflow_names = [workflow.name for workflow in workflows.values()] assert "test-workflow-3" in workflow_names assert "test-workflow-1" in workflow_names