Skip to content

Commit

Permalink
Support loading non-items in API responses (#1843)
Browse files Browse the repository at this point in the history
  • Loading branch information
haakonvt authored Jul 12, 2024
1 parent 8850a8b commit 4ad52b8
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 72 deletions.
242 changes: 170 additions & 72 deletions cognite/client/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import gzip
import itertools
import logging
import re
import warnings
Expand Down Expand Up @@ -326,6 +327,7 @@ def _retrieve_multiple(
params: dict[str, Any] | None = None,
executor: ThreadPoolExecutor | None = None,
api_subversion: str | None = None,
settings_forcing_raw_response_loading: list[str] | None = None,
) -> T_CogniteResource | None: ...

@overload
Expand All @@ -341,6 +343,7 @@ def _retrieve_multiple(
params: dict[str, Any] | None = None,
executor: ThreadPoolExecutor | None = None,
api_subversion: str | None = None,
settings_forcing_raw_response_loading: list[str] | None = None,
) -> T_CogniteResourceList: ...

def _retrieve_multiple(
Expand All @@ -355,6 +358,7 @@ def _retrieve_multiple(
params: dict[str, Any] | None = None,
executor: ThreadPoolExecutor | None = None,
api_subversion: str | None = None,
settings_forcing_raw_response_loading: list[str] | None = None,
) -> T_CogniteResourceList | T_CogniteResource | None:
resource_path = resource_path or self._RESOURCE_PATH

Expand Down Expand Up @@ -389,6 +393,13 @@ def _retrieve_multiple(
return None
raise

if settings_forcing_raw_response_loading:
# The API response include one or more top-level keys than items we care about:
loaded = list_cls._load_raw_api_response(
tasks_summary.raw_api_responses, cognite_client=self._cognite_client
)
return (loaded[0] if loaded else None) if identifiers.is_singleton() else loaded

retrieved_items = tasks_summary.joined_results(lambda res: res.json()["items"])

if identifiers.is_singleton():
Expand All @@ -402,7 +413,7 @@ def _retrieve_multiple(

def _list_generator(
self,
method: str,
method: Literal["GET", "POST"],
list_cls: type[T_CogniteResourceList],
resource_cls: type[T_CogniteResource],
resource_path: str | None = None,
Expand All @@ -418,79 +429,147 @@ def _list_generator(
advanced_filter: dict | Filter | None = None,
api_subversion: str | None = None,
) -> Iterator[T_CogniteResourceList] | Iterator[T_CogniteResource]:
verify_limit(limit)
if is_unlimited(limit):
limit = None
if partitions:
warnings.warn("passing `partitions` to a generator method is not supported, so it's being ignored")
# set chunk_size to None in order to not break the existing API.
# TODO: Remove this and support for partitions (in combo with generator) in the next major version
chunk_size = None

resource_path = resource_path or self._RESOURCE_PATH
total_items_retrieved = 0
current_limit = self._LIST_LIMIT
next_cursor = initial_cursor
filter = filter or {}
unprocessed_items = []
limit, url_path, params = self._prepare_params_for_list_generator(
limit, method, filter, url_path, resource_path, sort, other_params, advanced_filter
)
unprocessed_items: list[dict[str, Any]] = []
total_retrieved, current_limit, next_cursor = 0, self._LIST_LIMIT, initial_cursor
while True:
if limit:
num_of_remaining_items = limit - total_items_retrieved
if num_of_remaining_items < current_limit:
current_limit = num_of_remaining_items
if limit and (n_remaining := limit - total_retrieved) < current_limit:
current_limit = n_remaining

params.update(limit=current_limit, cursor=next_cursor)
if method == "GET":
params = filter.copy()
params["limit"] = current_limit
params["cursor"] = next_cursor
if sort is not None:
params["sort"] = sort
params.update(other_params or {})
res = self._get(url_path=url_path or resource_path, params=params, headers=headers)

elif method == "POST":
body: dict[str, Any] = {"limit": current_limit, "cursor": next_cursor, **(other_params or {})}
if filter:
body["filter"] = filter
if advanced_filter:
body["advancedFilter"] = (
advanced_filter.dump(camel_case_property=True)
if isinstance(advanced_filter, Filter)
else advanced_filter
)
if sort is not None:
body["sort"] = sort
res = self._post(
url_path=url_path or resource_path + "/list",
json=body,
headers=headers,
api_subversion=api_subversion,
)
res = self._get(url_path=url_path, params=params, headers=headers)
else:
raise ValueError(f"_list_generator parameter `method` must be GET or POST, not {method}")
last_received_items = res.json()["items"]
total_items_retrieved += len(last_received_items)
res = self._post(url_path=url_path, json=params, headers=headers, api_subversion=api_subversion)

if not chunk_size:
for item in last_received_items:
yield resource_cls._load(item, cognite_client=self._cognite_client)
else:
unprocessed_items.extend(last_received_items)
if len(unprocessed_items) >= chunk_size:
chunks = split_into_chunks(unprocessed_items, chunk_size)
if chunks and len(chunks[-1]) < chunk_size:
unprocessed_items = chunks.pop(-1)
else:
unprocessed_items = []
for chunk in chunks:
yield list_cls._load(chunk, cognite_client=self._cognite_client)

next_cursor = res.json().get("nextCursor")
if total_items_retrieved == limit or next_cursor is None:
if chunk_size and unprocessed_items:
response = res.json()
yield from self._process_into_chunks(response, chunk_size, resource_cls, list_cls, unprocessed_items)

next_cursor = response.get("nextCursor")
total_retrieved += len(response["items"])
if total_retrieved == limit or next_cursor is None:
if unprocessed_items: # may only happen when -not- yielding one-by-one
yield list_cls._load(unprocessed_items, cognite_client=self._cognite_client)
break

def _list_generator_raw_responses(
self,
method: Literal["GET", "POST"],
settings_forcing_raw_response_loading: list[str],
resource_path: str | None = None,
url_path: str | None = None,
limit: int | None = None,
chunk_size: int | None = None,
filter: dict[str, Any] | None = None,
sort: SequenceNotStr[str | dict[str, Any]] | None = None,
other_params: dict[str, Any] | None = None,
partitions: int | None = None,
headers: dict[str, Any] | None = None,
initial_cursor: str | None = None,
advanced_filter: dict | Filter | None = None,
api_subversion: str | None = None,
) -> Iterator[dict[str, Any]]:
if partitions:
raise ValueError("When fetching additional data (besides items), using partitions is not supported")
if not chunk_size:
raise ValueError(
f"When fetching additional data (besides items), {chunk_size=} must match the "
f"API limit: {self._LIST_LIMIT}"
)
if chunk_size != self._LIST_LIMIT:
warnings.warn(
f"When fetching additional data (besides items), an arbitrary {chunk_size=} setting is "
f"not supported, only {self._LIST_LIMIT} (the API limit). This is caused by the following "
f"settings: {settings_forcing_raw_response_loading}.",
UserWarning,
)
limit, url_path, params = self._prepare_params_for_list_generator(
limit, method, filter, url_path, resource_path, sort, other_params, advanced_filter
)
total_retrieved, current_limit, next_cursor = 0, self._LIST_LIMIT, initial_cursor
while True:
if limit and (n_remaining := limit - total_retrieved) < current_limit:
current_limit = n_remaining

params.update(limit=current_limit, cursor=next_cursor)
if method == "GET":
res = self._get(url_path=url_path, params=params, headers=headers)
else:
res = self._post(url_path=url_path, json=params, headers=headers, api_subversion=api_subversion)

yield (response := res.json())
next_cursor = response.get("nextCursor")
total_retrieved += len(response["items"])
if total_retrieved == limit or next_cursor is None:
break

def _prepare_params_for_list_generator(
self,
limit: int | None,
method: Literal["GET", "POST"],
filter: dict[str, Any] | None,
url_path: str | None,
resource_path: str | None,
sort: SequenceNotStr[str | dict[str, Any]] | None,
other_params: dict[str, Any] | None,
advanced_filter: dict | Filter | None,
) -> tuple[int | None, str, dict[str, Any]]:
verify_limit(limit)
if is_unlimited(limit):
limit = None
filter, other_params = (filter or {}).copy(), (other_params or {}).copy()
if method == "GET":
url_path = url_path or resource_path or self._RESOURCE_PATH
if sort is not None:
filter["sort"] = sort
filter.update(other_params)
return limit, url_path, filter

if method == "POST":
url_path = url_path or (resource_path or self._RESOURCE_PATH) + "/list"
body: dict[str, Any] = {}
if filter:
body["filter"] = filter
if advanced_filter:
if isinstance(advanced_filter, Filter):
# TODO: Does our json.dumps now understand Filter?
body["advancedFilter"] = advanced_filter.dump(camel_case_property=True)
else:
body["advancedFilter"] = advanced_filter
if sort is not None:
body["sort"] = sort
body.update(other_params)
return limit, url_path, body
raise ValueError(f"_list_generator parameter `method` must be GET or POST, not {method}")

def _process_into_chunks(
self,
response: dict[str, Any],
chunk_size: int | None,
resource_cls: type[T_CogniteResource],
list_cls: type[T_CogniteResourceList],
unprocessed_items: list[dict[str, Any]],
) -> Iterator[T_CogniteResourceList] | Iterator[T_CogniteResource]:
if not chunk_size:
for item in response["items"]:
yield resource_cls._load(item, cognite_client=self._cognite_client)
else:
unprocessed_items.extend(response["items"])
if len(unprocessed_items) >= chunk_size:
chunks = split_into_chunks(unprocessed_items, chunk_size)
unprocessed_items.clear()
if chunks and len(chunks[-1]) < chunk_size:
unprocessed_items.extend(chunks.pop(-1))
for chunk in chunks:
yield list_cls._load(chunk, cognite_client=self._cognite_client)

def _list(
self,
method: Literal["POST", "GET"],
Expand All @@ -507,6 +586,7 @@ def _list(
initial_cursor: str | None = None,
advanced_filter: dict | Filter | None = None,
api_subversion: str | None = None,
settings_forcing_raw_response_loading: list[str] | None = None,
) -> T_CogniteResourceList:
verify_limit(limit)
if partitions:
Expand All @@ -516,6 +596,12 @@ def _list(
)
if sort is not None:
raise ValueError("When using sort, partitions is not supported.")
if settings_forcing_raw_response_loading:
raise ValueError(
"When using partitions, the following settings are not "
f"supported (yet): {settings_forcing_raw_response_loading}"
)
assert initial_cursor is api_subversion is None
return self._list_partitioned(
partitions=partitions,
method=method,
Expand All @@ -526,15 +612,9 @@ def _list(
other_params=other_params,
headers=headers,
)

resource_path = resource_path or self._RESOURCE_PATH
items: list[T_CogniteResource] = []
for resource_list in self._list_generator(
list_cls=list_cls,
resource_cls=resource_cls,
resource_path=resource_path,
fetch_kwargs = dict(
resource_path=resource_path or self._RESOURCE_PATH,
url_path=url_path,
method=method,
limit=limit,
chunk_size=self._LIST_LIMIT,
filter=filter,
Expand All @@ -544,9 +624,27 @@ def _list(
initial_cursor=initial_cursor,
advanced_filter=advanced_filter,
api_subversion=api_subversion,
):
items.extend(cast(T_CogniteResourceList, resource_list).data)
return list_cls(items, cognite_client=self._cognite_client)
)
if settings_forcing_raw_response_loading:
raw_response_fetcher = self._list_generator_raw_responses(
method,
settings_forcing_raw_response_loading,
**fetch_kwargs, # type: ignore [arg-type]
)
return list_cls._load_raw_api_response(
list(raw_response_fetcher),
cognite_client=self._cognite_client,
)
# TODO: List generator loads each chunk into 'list_cls', so kind of weird for us to chain
# elements, then do it again. Perhaps a modified version of 'raw responses' should be used:
resource_fetcher = cast(
Iterator[T_CogniteResourceList],
self._list_generator(method, list_cls, resource_cls, **fetch_kwargs), # type: ignore [arg-type]
)
return list_cls(
list(itertools.chain.from_iterable(resource_fetcher)),
cognite_client=self._cognite_client,
)

def _list_partitioned(
self,
Expand Down
6 changes: 6 additions & 0 deletions cognite/client/data_classes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,12 @@ def _load(
resources = [cls._RESOURCE._load(resource, cognite_client=cognite_client) for resource in resource_list]
return cls(resources, cognite_client=cognite_client)

@classmethod
def _load_raw_api_response(cls, responses: list[dict[str, Any]], cognite_client: CogniteClient) -> Self:
# Certain classes may need more than just 'items' from the raw repsonse. These need to provide
# an implementation of this method
raise NotImplementedError


T_CogniteResourceList = TypeVar("T_CogniteResourceList", bound=CogniteResourceList)

Expand Down
4 changes: 4 additions & 0 deletions cognite/client/utils/_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def joined_results(self, unwrap_fn: Callable = no_op) -> list:
joined_results.append(unwrapped)
return joined_results

@property
def raw_api_responses(self) -> list[dict[str, Any]]:
return [res.json() for res in self.results]

def raise_compound_exception_if_failed_tasks(
self,
task_unwrap_fn: Callable = no_op,
Expand Down

0 comments on commit 4ad52b8

Please sign in to comment.