Skip to content

Commit

Permalink
Revise partitions+generator logic and move max_workers to global_conf…
Browse files Browse the repository at this point in the history
…ig (#1526)
  • Loading branch information
erlendvollset authored Dec 6, 2023
1 parent 3ae8328 commit 6e287fe
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 122 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ Changes are grouped as follows
- `Removed` for now removed features.
- `Fixed` for any bug fixes.
- `Security` in case of vulnerabilities.



## [7.5.4] - 2023-12-06
### Changed
- The `partitions` parameter is no longer respected when using generator methods to list resources
- The `max_workers` config option has been moved from ClientConfig to the global config.

## [7.5.3] - 2023-12-06
### Added
- Support for `subworkflow` tasks in `workflows`.
Expand Down
183 changes: 71 additions & 112 deletions cognite/client/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json as _json
import logging
import re
import warnings
from collections import UserList
from json.decoder import JSONDecodeError
from typing import (
Expand Down Expand Up @@ -406,123 +407,75 @@ def _list_generator(
verify_limit(limit)
if is_unlimited(limit):
limit = None
resource_path = resource_path or self._RESOURCE_PATH

if partitions:
if limit is not None:
raise ValueError("When using partitions, limit should be `None`, `-1` or `inf`.")
if sort is not None:
raise ValueError("When using sort, partitions is not supported.")

yield from self._list_generator_partitioned(
partitions=partitions,
resource_cls=resource_cls,
resource_path=resource_path,
filter=filter,
other_params=other_params,
headers=headers,
)
warnings.warn("passing `partitions` to a generator method is not supported, so it's being ignored")
# set chunk_size to None in order to not break the existing API.
# TODO: Remove this and support for partitions (in combo with generator) in the next major version
chunk_size = None

else:
total_items_retrieved = 0
current_limit = self._LIST_LIMIT
if chunk_size and chunk_size <= self._LIST_LIMIT:
current_limit = chunk_size
next_cursor = initial_cursor
filter = filter or {}
current_items = []
while True:
if limit:
num_of_remaining_items = limit - total_items_retrieved
if num_of_remaining_items < current_limit:
current_limit = num_of_remaining_items

if method == "GET":
params = filter.copy()
params["limit"] = current_limit
params["cursor"] = next_cursor
if sort is not None:
params["sort"] = sort
params.update(other_params or {})
res = self._get(url_path=url_path or resource_path, params=params, headers=headers)

elif method == "POST":
body: dict[str, Any] = {"limit": current_limit, "cursor": next_cursor, **(other_params or {})}
if filter:
body["filter"] = filter
if advanced_filter:
body["advancedFilter"] = (
advanced_filter.dump(camel_case_property=True)
if isinstance(advanced_filter, Filter)
else advanced_filter
)
if sort is not None:
body["sort"] = sort
res = self._post(
url_path=url_path or resource_path + "/list",
json=body,
headers=headers,
api_subversion=api_subversion,
resource_path = resource_path or self._RESOURCE_PATH
total_items_retrieved = 0
current_limit = self._LIST_LIMIT
next_cursor = initial_cursor
filter = filter or {}
unprocessed_items = []
while True:
if limit:
num_of_remaining_items = limit - total_items_retrieved
if num_of_remaining_items < current_limit:
current_limit = num_of_remaining_items

if method == "GET":
params = filter.copy()
params["limit"] = current_limit
params["cursor"] = next_cursor
if sort is not None:
params["sort"] = sort
params.update(other_params or {})
res = self._get(url_path=url_path or resource_path, params=params, headers=headers)

elif method == "POST":
body: dict[str, Any] = {"limit": current_limit, "cursor": next_cursor, **(other_params or {})}
if filter:
body["filter"] = filter
if advanced_filter:
body["advancedFilter"] = (
advanced_filter.dump(camel_case_property=True)
if isinstance(advanced_filter, Filter)
else advanced_filter
)
else:
raise ValueError(f"_list_generator parameter `method` must be GET or POST, not {method}")
last_received_items = res.json()["items"]
total_items_retrieved += len(last_received_items)

if not chunk_size:
for item in last_received_items:
yield resource_cls._load(item, cognite_client=self._cognite_client)
else:
current_items.extend(last_received_items)
if len(current_items) >= chunk_size:
items_to_yield = current_items[:chunk_size]
current_items = current_items[chunk_size:]
yield list_cls._load(items_to_yield, cognite_client=self._cognite_client)
if sort is not None:
body["sort"] = sort
res = self._post(
url_path=url_path or resource_path + "/list",
json=body,
headers=headers,
api_subversion=api_subversion,
)
else:
raise ValueError(f"_list_generator parameter `method` must be GET or POST, not {method}")
last_received_items = res.json()["items"]
total_items_retrieved += len(last_received_items)

next_cursor = res.json().get("nextCursor")
if total_items_retrieved == limit or next_cursor is None:
if chunk_size and current_items:
yield list_cls._load(current_items, cognite_client=self._cognite_client)
break
if not chunk_size:
for item in last_received_items:
yield resource_cls._load(item, cognite_client=self._cognite_client)
else:
unprocessed_items.extend(last_received_items)
if len(unprocessed_items) >= chunk_size:
chunks = split_into_chunks(unprocessed_items, chunk_size)
if chunks and len(chunks[-1]) < chunk_size:
unprocessed_items = chunks.pop(-1)
else:
unprocessed_items = []
for chunk in chunks:
yield list_cls._load(chunk, cognite_client=self._cognite_client)

def _list_generator_partitioned(
self,
partitions: int,
resource_cls: type[T_CogniteResource],
resource_path: str,
filter: dict | None = None,
other_params: dict | None = None,
headers: dict | None = None,
) -> Iterator[T_CogniteResource]:
next_cursors = {i + 1: None for i in range(partitions)}

def get_partition(partition_num: int) -> list[dict[str, Any]]:
next_cursor = next_cursors[partition_num]

body = {
"filter": filter or {},
"limit": self._LIST_LIMIT,
"cursor": next_cursor,
"partition": f"{partition_num}/{partitions}",
**(other_params or {}),
}
res = self._post(url_path=resource_path + "/list", json=body, headers=headers)
next_cursor = res.json().get("nextCursor")
next_cursors[partition_num] = next_cursor

return res.json()["items"]

while len(next_cursors) > 0:
tasks_summary = execute_tasks(
get_partition, [(partition,) for partition in next_cursors], max_workers=partitions, fail_fast=True
)
tasks_summary.raise_compound_exception_if_failed_tasks()

for item in tasks_summary.joined_results():
yield resource_cls._load(item, cognite_client=self._cognite_client)

# Remove None from cursor dict
next_cursors = {partition: next_cursors[partition] for partition in next_cursors if next_cursors[partition]}
if total_items_retrieved == limit or next_cursor is None:
if chunk_size and unprocessed_items:
yield list_cls._load(unprocessed_items, cognite_client=self._cognite_client)
break

def _list(
self,
Expand Down Expand Up @@ -628,7 +581,7 @@ def get_partition(partition: int) -> list[dict[str, Any]]:
return retrieved_items

tasks = [(f"{i + 1}/{partitions}",) for i in range(partitions)]
tasks_summary = execute_tasks(get_partition, tasks, max_workers=partitions, fail_fast=True)
tasks_summary = execute_tasks(get_partition, tasks, max_workers=self._config.max_workers, fail_fast=True)
tasks_summary.raise_compound_exception_if_failed_tasks()

return list_cls._load(tasks_summary.joined_results(), cognite_client=self._cognite_client)
Expand Down Expand Up @@ -863,7 +816,13 @@ def str_format_element(el: T) -> str | dict | T:
if isinstance(el, CogniteResource):
dumped = el.dump()
if "external_id" in dumped:
if "space" in dumped:
return f"{dumped['space']}:{dumped['external_id']}"
return dumped["external_id"]
if "externalId" in dumped:
if "space" in dumped:
return f"{dumped['space']}:{dumped['externalId']}"
return dumped["externalId"]
return dumped
return el

Expand Down
2 changes: 1 addition & 1 deletion cognite/client/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

__version__ = "7.5.3"
__version__ = "7.5.4"
__api_subversion__ = "V20220125"
14 changes: 12 additions & 2 deletions cognite/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import getpass
import pprint
import warnings
from contextlib import suppress

from cognite.client._version import __api_subversion__
Expand All @@ -26,6 +27,7 @@ class GlobalConfig:
Defaults to 50.
disable_ssl (bool): Whether or not to disable SSL. Defaults to False
proxies (Dict[str, str]): Dictionary mapping from protocol to url. e.g. {"https": "http://10.10.1.10:1080"}
max_workers (int | None): Max number of workers to spawn when parallelizing API calls. Defaults to 5.
"""

def __init__(self) -> None:
Expand All @@ -39,6 +41,7 @@ def __init__(self) -> None:
self.max_connection_pool_size: int = 50
self.disable_ssl: bool = False
self.proxies: dict[str, str] | None = {}
self.max_workers: int = 10


global_config = GlobalConfig()
Expand All @@ -53,7 +56,8 @@ class ClientConfig:
credentials (CredentialProvider): Credentials. e.g. Token, ClientCredentials.
api_subversion (str | None): API subversion
base_url (str | None): Base url to send requests to. Defaults to "https://api.cognitedata.com"
max_workers (int | None): Max number of workers to spawn when parallelizing data fetching. Defaults to 10. Can not be changed after your first API call.
max_workers (int | None): DEPRECATED. Use global_config.max_workers instead.
Max number of workers to spawn when parallelizing data fetching. Defaults to 5.
headers (dict[str, str] | None): Additional headers to add to all requests.
timeout (int | None): Timeout on requests sent to the api. Defaults to 30 seconds.
file_transfer_timeout (int | None): Timeout on file upload/download requests. Defaults to 600 seconds.
Expand All @@ -78,7 +82,13 @@ def __init__(
self.credentials = credentials
self.api_subversion = api_subversion or __api_subversion__
self.base_url = (base_url or "https://api.cognitedata.com").rstrip("/")
self.max_workers = max_workers if max_workers is not None else 10
if max_workers is not None:
# TODO: Remove max_workers from ClientConfig in next major version
warnings.warn(
"Passing max_workers to ClientConfig is deprecated. Please use global_config.max_workers instead",
DeprecationWarning,
)
self.max_workers = max_workers if max_workers is not None else global_config.max_workers
self.headers = headers or {}
self.timeout = timeout or 30
self.file_transfer_timeout = file_transfer_timeout or 600
Expand Down
17 changes: 12 additions & 5 deletions cognite/client/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,22 @@ def __init__(
def _unwrap_list(self, lst: list) -> list:
return [self._unwrap_fn(elem) for elem in lst]

def _truncate_elements(self, lst: list) -> str:
truncate_at = 10
elements = ",".join([str(element) for element in lst[:truncate_at]])
if len(elements) > truncate_at:
elements += ", ..."
return f"[{elements}]"

def _get_multi_exception_summary(self) -> str:
if len(self.successful) == 0 and len(self.unknown) == 0 and len(self.failed) == 0 and len(self.skipped) == 0:
return ""
summary = [
"", # start string with newline
"The API Failed to process some items.",
f"Successful (2xx): {self._unwrap_list(self.successful)}",
f"Unknown (5xx): {self._unwrap_list(self.unknown)}",
f"Failed (4xx): {self._unwrap_list(self.failed)}",
f"Successful (2xx): {self._truncate_elements(self._unwrap_list(self.successful))}",
f"Unknown (5xx): {self._truncate_elements(self._unwrap_list(self.unknown))}",
f"Failed (4xx): {self._truncate_elements(self._unwrap_list(self.failed))}",
]
# Only show 'skipped' when tasks were skipped to avoid confusion:
if skipped := self._unwrap_list(self.skipped):
Expand Down Expand Up @@ -175,9 +182,9 @@ def __init__(
def __str__(self) -> str:
msg = f"{self.message} | code: {self.code} | X-Request-ID: {self.x_request_id}"
if self.missing:
msg += f"\nMissing: {self.missing}"
msg += f"\nMissing: {self._truncate_elements(self.missing)}"
if self.duplicated:
msg += f"\nDuplicated: {self.duplicated}"
msg += f"\nDuplicated: {self._truncate_elements(self.duplicated)}"
msg += self._get_multi_exception_summary()
if self.extra:
pretty_extra = json.dumps(self.extra, indent=4, sort_keys=True)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "cognite-sdk"

version = "7.5.3"
version = "7.5.4"
description = "Cognite Python SDK"
readme = "README.md"
documentation = "https://cognite-sdk-python.readthedocs-hosted.com"
Expand Down
18 changes: 18 additions & 0 deletions tests/tests_unit/test_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,24 @@ def test_standard_list_generator__chunk_size_exceeds_max(self, api_client_with_t
total_resources += 1001
assert 2002 == total_resources

@pytest.mark.usefixtures("mock_get_for_autopaging")
def test_standard_list_generator_vs_partitions(self, api_client_with_token):
total_resources = 0
for resource_chunk in api_client_with_token._list_generator(
list_cls=SomeResourceList,
resource_cls=SomeResource,
resource_path=URL_PATH,
method="GET",
partitions=1,
limit=2000,
chunk_size=1001,
):
# TODO: chunk_size is ignored when partitions is set, fix in next major version
assert isinstance(resource_chunk, SomeResource)
total_resources += 1

assert 2000 == total_resources

@pytest.mark.usefixtures("mock_get_for_autopaging")
def test_standard_list_autopaging(self, api_client_with_token):
res = api_client_with_token._list(
Expand Down

0 comments on commit 6e287fe

Please sign in to comment.