Skip to content

Commit

Permalink
Fix deadlock in RelationshipsAPI.list (#1849)
Browse files Browse the repository at this point in the history
  • Loading branch information
haakonvt authored Jul 16, 2024
1 parent d6730a2 commit 252e737
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 97 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ Changes are grouped as follows
- `Fixed` for any bug fixes.
- `Security` in case of vulnerabilities.

## [7.54.1] - 2024-07-13
## [7.54.2] - 2024-07-16
### Fixed
- A bug in the list method of the RelationshipsAPI that could cause a thread deadlock.

## [7.54.1] - 2024-07-15
### Fixed
- Calling `client.functions.retrieve` or `client.functions.delete` with more than 10 ids no longer
raises a `CogniteAPIError`.
Expand Down
151 changes: 62 additions & 89 deletions cognite/client/_api/relationships.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import copy
from typing import TYPE_CHECKING, Any, Iterator, Literal, Sequence, overload
import itertools
import warnings
from functools import partial
from typing import TYPE_CHECKING, Iterator, Literal, Sequence, overload

from cognite.client._api_client import APIClient
from cognite.client._constants import DEFAULT_LIMIT_READ
Expand All @@ -14,7 +16,7 @@
)
from cognite.client.data_classes.labels import LabelFilter
from cognite.client.data_classes.relationships import RelationshipCore
from cognite.client.utils._auxiliary import is_unlimited
from cognite.client.utils._auxiliary import is_unlimited, split_into_chunks
from cognite.client.utils._concurrency import execute_tasks
from cognite.client.utils._identifier import IdentifierSequence
from cognite.client.utils._validation import assert_type, process_data_set_ids
Expand All @@ -32,36 +34,6 @@ def __init__(self, config: ClientConfig, api_version: str | None, cognite_client
super().__init__(config, api_version, cognite_client)
self._LIST_SUBQUERY_LIMIT = 1000

def _create_filter(
self,
source_external_ids: SequenceNotStr[str] | None = None,
source_types: SequenceNotStr[str] | None = None,
target_external_ids: SequenceNotStr[str] | None = None,
target_types: SequenceNotStr[str] | None = None,
data_set_ids: Sequence[dict[str, Any]] | None = None,
start_time: dict[str, int] | None = None,
end_time: dict[str, int] | None = None,
confidence: dict[str, int] | None = None,
last_updated_time: dict[str, int] | None = None,
created_time: dict[str, int] | None = None,
active_at_time: dict[str, int] | None = None,
labels: LabelFilter | None = None,
) -> dict[str, Any]:
return RelationshipFilter(
source_external_ids=source_external_ids,
source_types=source_types,
target_external_ids=target_external_ids,
target_types=target_types,
data_set_ids=data_set_ids,
start_time=start_time,
end_time=end_time,
confidence=confidence,
last_updated_time=last_updated_time,
created_time=created_time,
active_at_time=active_at_time,
labels=labels,
).dump(camel_case=True)

@overload
def __call__(
self,
Expand Down Expand Up @@ -153,8 +125,7 @@ def __call__(
Iterator[Relationship] | Iterator[RelationshipList]: yields Relationship one by one if chunk_size is not specified, else RelationshipList objects.
"""
data_set_ids_processed = process_data_set_ids(data_set_ids, data_set_external_ids)

filter = self._create_filter(
filter = RelationshipFilter(
source_external_ids=source_external_ids,
source_types=source_types,
target_external_ids=target_external_ids,
Expand All @@ -167,16 +138,13 @@ def __call__(
created_time=created_time,
active_at_time=active_at_time,
labels=labels,
)
if (
len(filter.get("targetExternalIds", [])) > self._LIST_SUBQUERY_LIMIT
or len(filter.get("sourceExternalIds", [])) > self._LIST_SUBQUERY_LIMIT
):
).dump(camel_case=True)
n_target_xids, n_source_xids = len(target_external_ids or []), len(source_external_ids or [])
if n_target_xids > self._LIST_SUBQUERY_LIMIT or n_source_xids > self._LIST_SUBQUERY_LIMIT:
raise ValueError(
f"For queries with more than {self._LIST_SUBQUERY_LIMIT} source_external_ids or "
"target_external_ids, only list is supported"
"target_external_ids, only `list` is supported"
)

return self._list_generator(
list_cls=RelationshipList,
resource_cls=Relationship,
Expand Down Expand Up @@ -313,8 +281,7 @@ def list(
... relationship # do something with the relationship
"""
data_set_ids_processed = process_data_set_ids(data_set_ids, data_set_external_ids)

filter = self._create_filter(
filter = RelationshipFilter(
source_external_ids=source_external_ids,
source_types=source_types,
target_external_ids=target_external_ids,
Expand All @@ -327,53 +294,59 @@ def list(
created_time=created_time,
active_at_time=active_at_time,
labels=labels,
)
target_external_id_list: list[str] = filter.get("targetExternalIds", [])
source_external_id_list: list[str] = filter.get("sourceExternalIds", [])
if (
len(target_external_id_list) > self._LIST_SUBQUERY_LIMIT
or len(source_external_id_list) > self._LIST_SUBQUERY_LIMIT
):
if not is_unlimited(limit):
raise ValueError(
f"Querying more than {self._LIST_SUBQUERY_LIMIT} source_external_ids/target_external_ids only "
f"supported for queries without limit (pass -1 / None / inf instead of {limit})"
)
tasks = []

for ti in range(0, max(1, len(target_external_id_list)), self._LIST_SUBQUERY_LIMIT):
for si in range(0, max(1, len(source_external_id_list)), self._LIST_SUBQUERY_LIMIT):
task_filter = copy.copy(filter)
if target_external_id_list: # keep null if it was
task_filter["targetExternalIds"] = target_external_id_list[ti : ti + self._LIST_SUBQUERY_LIMIT]
if source_external_id_list: # keep null if it was
task_filter["sourceExternalIds"] = source_external_id_list[si : si + self._LIST_SUBQUERY_LIMIT]
tasks.append((task_filter,))

tasks_summary = execute_tasks(
lambda filter: self._list(
list_cls=RelationshipList,
resource_cls=Relationship,
method="POST",
limit=limit,
filter=filter,
other_params={"fetchResources": fetch_resources},
partitions=partitions,
),
tasks,
max_workers=self._config.max_workers,
)
tasks_summary.raise_compound_exception_if_failed_tasks()
).dump(camel_case=True)

return RelationshipList(tasks_summary.joined_results())
return self._list(
list_cls=RelationshipList,
resource_cls=Relationship,
method="POST",
limit=limit,
filter=filter,
other_params={"fetchResources": fetch_resources},
target_external_ids, source_external_ids = target_external_ids or [], source_external_ids or []
if all(len(xids) <= self._LIST_SUBQUERY_LIMIT for xids in (target_external_ids, source_external_ids)):
return self._list(
list_cls=RelationshipList,
resource_cls=Relationship,
method="POST",
limit=limit,
filter=filter,
partitions=partitions,
other_params={"fetchResources": fetch_resources},
)
if not is_unlimited(limit):
raise ValueError(
f"Querying more than {self._LIST_SUBQUERY_LIMIT} source_external_ids/target_external_ids is only "
f"supported for unlimited queries (pass -1 / None / inf instead of {limit})"
)
tasks = []
target_chunks = split_into_chunks(target_external_ids, self._LIST_SUBQUERY_LIMIT) or [[]]
source_chunks = split_into_chunks(source_external_ids, self._LIST_SUBQUERY_LIMIT) or [[]]

# All sources (if any) must be checked against all targets (if any). When either is not
# given, we must exhaustively list all matching just the source or the target:
for target_xids, source_xids in itertools.product(target_chunks, source_chunks):
task_filter = filter.copy()
if target_external_ids: # keep null if it was
task_filter["targetExternalIds"] = target_xids
if source_external_ids:
task_filter["sourceExternalIds"] = source_xids
tasks.append({"filter": task_filter})

if partitions is not None:
warnings.warn(
f"When one or both of source/target external IDs have more than {self._LIST_SUBQUERY_LIMIT} "
"elements, `partitions` is ignored",
UserWarning,
)
tasks_summary = execute_tasks(
partial(
self._list,
list_cls=RelationshipList,
resource_cls=Relationship,
method="POST",
limit=None,
partitions=None, # Otherwise, workers will spawn workers -> deadlock (singleton threadpool)
other_params={"fetchResources": fetch_resources},
),
tasks,
max_workers=self._config.max_workers,
)
tasks_summary.raise_compound_exception_if_failed_tasks()
return RelationshipList(tasks_summary.joined_results(), cognite_client=self._cognite_client)

@overload
def create(self, relationship: Relationship | RelationshipWrite) -> Relationship: ...
Expand Down
2 changes: 1 addition & 1 deletion cognite/client/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

__version__ = "7.54.1"
__version__ = "7.54.2"
__api_subversion__ = "20230101"
14 changes: 9 additions & 5 deletions cognite/client/utils/_auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
to_snake_case,
)
from cognite.client.utils._version_checker import get_newest_version_in_major_release
from cognite.client.utils.useful_types import SequenceNotStr

if TYPE_CHECKING:
from cognite.client import CogniteClient
from cognite.client.data_classes._base import T_CogniteObject, T_CogniteResource

T = TypeVar("T")
K = TypeVar("K")
THashable = TypeVar("THashable", bound=Hashable)


Expand Down Expand Up @@ -176,19 +178,21 @@ def split_into_n_parts(seq: Sequence[T], *, n: int) -> Iterator[Sequence[T]]:


@overload
def split_into_chunks(collection: set | list, chunk_size: int) -> list[list]: ...
def split_into_chunks(collection: set[T] | SequenceNotStr[T], chunk_size: int) -> list[list[T]]: ...


@overload
def split_into_chunks(collection: dict, chunk_size: int) -> list[dict]: ...
def split_into_chunks(collection: dict[K, T], chunk_size: int) -> list[dict[K, T]]: ...


def split_into_chunks(collection: set | list | dict, chunk_size: int) -> list[list] | list[dict]:
def split_into_chunks(
collection: SequenceNotStr[T] | set[T] | dict[K, T], chunk_size: int
) -> list[list[T]] | list[dict[K, T]]:
if isinstance(collection, set):
collection = list(collection)

if isinstance(collection, list):
return [collection[i : i + chunk_size] for i in range(0, len(collection), chunk_size)]
if isinstance(collection, SequenceNotStr):
return [list(collection[i : i + chunk_size]) for i in range(0, len(collection), chunk_size)]

if isinstance(collection, dict):
collection = list(collection.items())
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "cognite-sdk"

version = "7.54.1"
version = "7.54.2"
description = "Cognite Python SDK"
readme = "README.md"
documentation = "https://cognite-sdk-python.readthedocs-hosted.com"
Expand Down

0 comments on commit 252e737

Please sign in to comment.