Skip to content

Commit

Permalink
add mixin class for CogniteClient reference handling (#1345)
Browse files Browse the repository at this point in the history
  • Loading branch information
haakonvt authored Sep 15, 2023
1 parent b1cdf81 commit 96cb9b4
Show file tree
Hide file tree
Showing 20 changed files with 187 additions and 139 deletions.
2 changes: 1 addition & 1 deletion cognite/client/_api/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def status(self) -> FunctionsStatus:
def _create_session_and_return_nonce(
client: CogniteClient,
client_credentials: dict | ClientCredentials | None = None,
) -> str | None:
) -> str:
if client_credentials is None:
if isinstance(client._config.credentials, OAuthClientCertificate):
raise CogniteAuthError("Client certificate credentials is not supported with the Functions API")
Expand Down
10 changes: 5 additions & 5 deletions cognite/client/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def _list(
advanced_filter=advanced_filter,
api_subversion=api_subversion,
):
items.extend(resource_list.data)
items.extend(cast(T_CogniteResourceList, resource_list).data)
return list_cls(items, cognite_client=self._cognite_client)

def _list_partitioned(
Expand Down Expand Up @@ -986,11 +986,11 @@ def _upsert_multiple(
if mode not in ["patch", "replace"]:
raise ValueError(f"mode must be either 'patch' or 'replace', got {mode!r}")
is_single = isinstance(items, CogniteResource)
items = cast(Sequence[CogniteResource], [items] if is_single else items)
items = cast(Sequence[T_CogniteResource], [items] if is_single else items)
try:
result = self._update_multiple(items, list_cls, resource_cls, update_cls, mode=mode)
except CogniteNotFoundError as not_found_error:
items_by_external_id = {item.external_id: item for item in items if item.external_id is not None}
items_by_external_id = {item.external_id: item for item in items if item.external_id is not None} # type: ignore [attr-defined]
items_by_id = {item.id: item for item in items if hasattr(item, "id") and item.id is not None}
# Not found must have an external id as they do not exist in CDF:
try:
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def _upsert_multiple(
successful.extend(item.external_id for item in created)
if updated is None and created is not None:
# The created call failed
failed.extend(item.external_id if item.external_id is not None else item.id for item in to_update)
failed.extend(item.external_id if item.external_id is not None else item.id for item in to_update) # type: ignore [attr-defined]
raise CogniteAPIError(
api_error.message, code=api_error.code, successful=successful, failed=failed, unknown=unknown
)
Expand All @@ -1059,7 +1059,7 @@ def _upsert_multiple(
# Reorder to match the order of the input items
result.data = [
result.get(
**Identifier.load(item.id if hasattr(item, "id") else None, item.external_id).as_dict(
**Identifier.load(item.id if hasattr(item, "id") else None, item.external_id).as_dict( # type: ignore [attr-defined]
camel_case=False
)
)
Expand Down
77 changes: 40 additions & 37 deletions cognite/client/data_classes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from collections import UserList
from collections.abc import Iterable
from contextlib import suppress
from dataclasses import dataclass
from enum import Enum
from typing import (
Expand Down Expand Up @@ -57,12 +58,6 @@ def __repr__(self) -> str:
def __eq__(self, other: Any) -> bool:
return type(other) is type(self) and other.dump() == self.dump()

def __getattribute__(self, item: Any) -> Any:
attr = super().__getattribute__(item)
if item == "_cognite_client" and attr is None:
raise CogniteMissingClientError(self)
return attr

def dump(self, camel_case: bool = False) -> dict[str, Any]:
"""Dump the instance into a json serializable Python data type.
Expand All @@ -85,15 +80,35 @@ def to_pandas(self) -> pandas.DataFrame:
T_CogniteResponse = TypeVar("T_CogniteResponse", bound=CogniteResponse)


class CogniteResource:
_cognite_client: Any
class _WithClientMixin:
@property
def _cognite_client(self) -> CogniteClient:
with suppress(AttributeError):
if self.__cognite_client is not None:
return self.__cognite_client
raise CogniteMissingClientError(self)

@_cognite_client.setter
def _cognite_client(self, value: CogniteClient | None) -> None:
from cognite.client import CogniteClient

if value is None or isinstance(value, CogniteClient):
self.__cognite_client = value
else:
raise AttributeError(
"Can't set the CogniteClient reference to anything else than a CogniteClient instance or None"
)

def _get_cognite_client(self) -> CogniteClient | None:
"""Get Cognite client reference without raising (when missing)"""
return self.__cognite_client


class CogniteResource(_WithClientMixin):
__cognite_client: CogniteClient | None

def __new__(cls, *args: Any, **kwargs: Any) -> CogniteResource:
obj = super().__new__(cls)
obj._cognite_client = None
if "cognite_client" in kwargs:
obj._cognite_client = kwargs["cognite_client"]
return obj
def __init__(self, cognite_client: CogniteClient | None = None) -> None:
raise NotImplementedError

def __eq__(self, other: Any) -> bool:
return type(self) is type(other) and self.dump() == other.dump()
Expand All @@ -102,12 +117,6 @@ def __str__(self) -> str:
item = convert_time_attributes_to_datetime(self.dump())
return json.dumps(item, default=utils._auxiliary.json_dump_default, indent=4)

def __getattribute__(self, item: Any) -> Any:
attr = super().__getattribute__(item)
if item == "_cognite_client" and attr is None:
raise CogniteMissingClientError(self)
return attr

def dump(self, camel_case: bool = False) -> dict[str, Any]:
"""Dump the instance into a json serializable Python data type.
Expand Down Expand Up @@ -188,8 +197,9 @@ def _property_setter(self: Any, schema_name: str, value: Any) -> None:
self[schema_name] = value


class CogniteResourceList(UserList, Generic[T_CogniteResource]):
_RESOURCE: type[CogniteResource]
class CogniteResourceList(UserList, Generic[T_CogniteResource], _WithClientMixin):
_RESOURCE: type[T_CogniteResource]
__cognite_client: CogniteClient | None

def __init__(self, resources: Collection[Any], cognite_client: CogniteClient | None = None) -> None:
for resource in resources:
Expand All @@ -209,34 +219,27 @@ def __init__(self, resources: Collection[Any], cognite_client: CogniteClient | N
if hasattr(self.data[0], "id"):
self._id_to_item = {item.id: item for item in self.data if item.id is not None}

def __getattribute__(self, item: Any) -> Any:
attr = super().__getattribute__(item)
if item == "_cognite_client" and attr is None:
raise CogniteMissingClientError(self)
return attr

def pop(self, i: int = -1) -> T_CogniteResource:
return super().pop(i)

def __iter__(self) -> Iterator[T_CogniteResource]:
return super().__iter__()

@overload
def __getitem__(self, item: SupportsIndex) -> T_CogniteResource:
def __getitem__(self: T_CogniteResourceList, item: SupportsIndex) -> T_CogniteResource:
...

@overload
def __getitem__(self, item: slice) -> CogniteResourceList[T_CogniteResource]:
def __getitem__(self: T_CogniteResourceList, item: slice) -> T_CogniteResourceList:
...

def __getitem__(self, item: Any) -> Any:
value = super().__getitem__(item)
def __getitem__(
self: T_CogniteResourceList, item: SupportsIndex | slice
) -> T_CogniteResource | T_CogniteResourceList:
value = self.data[item]
if isinstance(item, slice):
c = None
if super().__getattribute__("_cognite_client") is not None:
c = self._cognite_client
return self.__class__(value, cognite_client=c)
return value
return type(self)(value, cognite_client=self._get_cognite_client())
return cast(T_CogniteResource, value)

def __str__(self) -> str:
item = convert_time_attributes_to_datetime(self.dump())
Expand Down
13 changes: 9 additions & 4 deletions cognite/client/data_classes/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,21 +174,22 @@ def __hash__(self) -> int:
return hash(self.external_id)

def parent(self) -> Asset:
"""Returns this assets parent.
"""Returns this asset's parent.
Returns:
Asset: The parent asset.
"""
if self.parent_id is None:
raise ValueError("parent_id is None")
return self._cognite_client.assets.retrieve(id=self.parent_id)
raise ValueError("parent_id is None, is this a root asset?")
return cast(Asset, self._cognite_client.assets.retrieve(id=self.parent_id))

def children(self) -> AssetList:
"""Returns the children of this asset.
Returns:
AssetList: The requested assets
"""
assert self.id is not None
return self._cognite_client.assets.list(parent_ids=[self.id], limit=None)

def subtree(self, depth: int | None = None) -> AssetList:
Expand All @@ -200,6 +201,7 @@ def subtree(self, depth: int | None = None) -> AssetList:
Returns:
AssetList: The requested assets sorted topologically.
"""
assert self.id is not None
return self._cognite_client.assets.retrieve_subtree(id=self.id, depth=depth)

def time_series(self, **kwargs: Any) -> TimeSeriesList:
Expand All @@ -210,6 +212,7 @@ def time_series(self, **kwargs: Any) -> TimeSeriesList:
Returns:
TimeSeriesList: All time series related to this asset.
"""
assert self.id is not None
return self._cognite_client.time_series.list(asset_ids=[self.id], **kwargs)

def sequences(self, **kwargs: Any) -> SequenceList:
Expand All @@ -220,6 +223,7 @@ def sequences(self, **kwargs: Any) -> SequenceList:
Returns:
SequenceList: All sequences related to this asset.
"""
assert self.id is not None
return self._cognite_client.sequences.list(asset_ids=[self.id], **kwargs)

def events(self, **kwargs: Any) -> EventList:
Expand All @@ -230,7 +234,7 @@ def events(self, **kwargs: Any) -> EventList:
Returns:
EventList: All events related to this asset.
"""

assert self.id is not None
return self._cognite_client.events.list(asset_ids=[self.id], **kwargs)

def files(self, **kwargs: Any) -> FileMetadataList:
Expand All @@ -241,6 +245,7 @@ def files(self, **kwargs: Any) -> FileMetadataList:
Returns:
FileMetadataList: Metadata about all files related to this asset.
"""
assert self.id is not None
return self._cognite_client.files.list(asset_ids=[self.id], **kwargs)

def dump(self, camel_case: bool = False) -> dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion cognite/client/data_classes/data_modeling/ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@dataclass(frozen=True)
class AbstractDataclass(ABC):
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
if cls == AbstractDataclass or cls.__bases__[0] == AbstractDataclass:
if cls is AbstractDataclass or cls.__bases__[0] is AbstractDataclass:
raise TypeError("Cannot instantiate abstract class.")
return super().__new__(cls)

Expand Down
6 changes: 3 additions & 3 deletions cognite/client/data_classes/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Literal, Union
from typing import TYPE_CHECKING, Any, List, Literal, Union, cast

from typing_extensions import TypeAlias

Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
self.dataset_id = dataset_id
self.security_categories = security_categories
self.metadata: dict[str, str] = metadata or {}
self._cognite_client = cognite_client
self._cognite_client = cast("CogniteClient", cognite_client)

@classmethod
def _load(cls, resource: dict | str, cognite_client: CogniteClient | None = None) -> SourceFile:
Expand Down Expand Up @@ -163,7 +163,7 @@ def __init__(
self.asset_ids: list[int] = asset_ids or []
self.labels: list[Label] = Label._load_list(labels) or []
self.geo_location = geo_location
self._cognite_client = cognite_client
self._cognite_client = cast("CogniteClient", cognite_client)

@classmethod
def _load(cls, resource: dict | str, cognite_client: CogniteClient | None = None) -> Document:
Expand Down
2 changes: 1 addition & 1 deletion cognite/client/data_classes/extractionpipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def __init__(
self.revision = revision
self.description = description
self.created_time = created_time
self._cognite_client = cognite_client
self._cognite_client = cast("CogniteClient", cognite_client)


class ExtractionPipelineConfig(ExtractionPipelineConfigRevision):
Expand Down
25 changes: 15 additions & 10 deletions cognite/client/data_classes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,36 +312,41 @@ def __init__(
self.function_id = function_id
self._cognite_client = cast("CogniteClient", cognite_client)

def get_response(self) -> dict:
def get_response(self) -> dict | None:
"""Retrieve the response from this function call.
Returns:
dict: Response from the function call.
dict | None: Response from the function call.
"""
if self.id is None or self.function_id is None:
raise ValueError("FunctionCall is missing one or more of: [id, function_id]")

return self._cognite_client.functions.calls.get_response(call_id=self.id, function_id=self.function_id)
call_id, function_id = self._get_identifiers_or_raise(self.id, self.function_id)
return self._cognite_client.functions.calls.get_response(call_id=call_id, function_id=function_id)

def get_logs(self) -> FunctionCallLog:
"""`Retrieve logs for this function call. <https://docs.cognite.com/api/v1/#operation/getFunctionCallLogs>`_
Returns:
FunctionCallLog: Log for the function call.
"""
if self.id is None or self.function_id is None:
raise ValueError("FunctionCall is missing one or more of: [id, function_id]")
return self._cognite_client.functions.calls.get_logs(call_id=self.id, function_id=self.function_id)
call_id, function_id = self._get_identifiers_or_raise(self.id, self.function_id)
return self._cognite_client.functions.calls.get_logs(call_id=call_id, function_id=function_id)

def update(self) -> None:
"""Update the function call object. Can be useful if the call was made with wait=False."""
latest = self._cognite_client.functions.calls.retrieve(call_id=self.id, function_id=self.function_id)
call_id, function_id = self._get_identifiers_or_raise(self.id, self.function_id)
latest = self._cognite_client.functions.calls.retrieve(call_id=call_id, function_id=function_id)
if latest is None:
raise RuntimeError("Unable to update the function call object (it was not found)")
self.status = latest.status
self.end_time = latest.end_time
self.error = latest.error

@staticmethod
def _get_identifiers_or_raise(call_id: int | None, function_id: int | None) -> tuple[int, int]:
# Mostly a mypy thing, but for sure nice with an error message :D
if call_id is None or function_id is None:
raise ValueError("FunctionCall is missing one or more of: [id, function_id]")
return call_id, function_id

def wait(self) -> None:
while self.status == "Running":
self.update()
Expand Down
28 changes: 18 additions & 10 deletions cognite/client/data_classes/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,33 +119,41 @@ def dump(self, camel_case: bool = False) -> dict[str, Any]:
return dumped


class CreatedSession(CogniteResource):
class CreatedSession(CogniteResponse):
"""Session creation related information
Args:
id (int | None): ID of the created session.
id (int): ID of the created session.
status (str): Current status of the session.
nonce (str): Nonce to be passed to the internal service that will bind the session
type (str | None): Credentials kind used to create the session.
status (str | None): Current status of the session.
nonce (str | None): Nonce to be passed to the internal service that will bind the session
client_id (str | None): Client ID in identity provider. Returned only if the session was created using client credentials
cognite_client (CogniteClient | None): No description.
"""

def __init__(
self,
id: int | None = None,
id: int,
status: str,
nonce: str,
type: str | None = None,
status: str | None = None,
nonce: str | None = None,
client_id: str | None = None,
cognite_client: CogniteClient | None = None,
) -> None:
self.id = id
self.type = type
self.status = status
self.nonce = nonce
self.type = type
self.client_id = client_id

@classmethod
def _load(cls, response: dict[str, Any]) -> CreatedSession:
return cls(
id=response["id"],
status=response["status"],
nonce=response["nonce"],
type=response.get("type"),
client_id=response.get("clientId"),
)


class Session(CogniteResource):
"""Session status
Expand Down
Loading

0 comments on commit 96cb9b4

Please sign in to comment.