Skip to content

Commit

Permalink
Speed up basic API response translation (_load) (#1348)
Browse files Browse the repository at this point in the history
  • Loading branch information
haakonvt authored Sep 20, 2023
1 parent 2046b0b commit aad22ae
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 29 deletions.
14 changes: 5 additions & 9 deletions cognite/client/data_classes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@

from cognite.client import utils
from cognite.client.exceptions import CogniteMissingClientError
from cognite.client.utils._auxiliary import fast_dict_load
from cognite.client.utils._identifier import IdentifierSequence
from cognite.client.utils._pandas_helpers import convert_nullable_int_cols, notebook_display_with_fallback
from cognite.client.utils._text import convert_all_keys_to_camel_case, to_camel_case, to_snake_case
from cognite.client.utils._text import convert_all_keys_to_camel_case, to_camel_case
from cognite.client.utils._time import convert_time_attributes_to_datetime

if TYPE_CHECKING:
Expand Down Expand Up @@ -132,15 +133,10 @@ def dump(self, camel_case: bool = False) -> dict[str, Any]:
def _load(
cls: type[T_CogniteResource], resource: dict | str, cognite_client: CogniteClient | None = None
) -> T_CogniteResource:
if isinstance(resource, str):
if isinstance(resource, dict):
return fast_dict_load(cls, resource, cognite_client=cognite_client)
elif isinstance(resource, str):
return cls._load(json.loads(resource), cognite_client=cognite_client)
elif isinstance(resource, dict):
instance = cls(cognite_client=cognite_client)
for key, value in resource.items():
snake_case_key = to_snake_case(key)
if hasattr(instance, snake_case_key):
setattr(instance, snake_case_key, value)
return instance
raise TypeError(f"Resource must be json str or dict, not {type(resource)}")

def to_pandas(
Expand Down
3 changes: 3 additions & 0 deletions cognite/client/data_classes/contextualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def _load_with_status(
) -> T_ContextualizationJob:
obj = cls._load({**data, "jobToken": headers.get("X-Job-Token")}, cognite_client=cognite_client)
obj._status_path = status_path
# '_load' does not see properties (real attribute stored under a different name, e.g. '_items' not 'items'):
if "items" in data and hasattr(obj, "items"):
obj.items = data["items"]
return obj


Expand Down
31 changes: 17 additions & 14 deletions cognite/client/data_classes/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from cognite.client.data_classes.transformations.schedules import TransformationSchedule
from cognite.client.data_classes.transformations.schema import TransformationSchemaColumnList
from cognite.client.exceptions import CogniteAPIError
from cognite.client.utils._text import convert_all_keys_to_camel_case, convert_all_keys_to_snake_case
from cognite.client.utils._text import convert_all_keys_to_camel_case

if TYPE_CHECKING:
from cognite.client import CogniteClient
Expand Down Expand Up @@ -55,6 +55,14 @@ def __init__(
self.client_id = client_id
self.project_name = project_name

@classmethod
def _load(cls, resource: dict[str, Any]) -> SessionDetails:
return cls(
session_id=resource.get("sessionId"),
client_id=resource.get("clientId"),
project_name=resource.get("projectName"),
)

def dump(self, camel_case: bool = False) -> dict[str, Any]:
"""Dump the instance into a json serializable Python data type.
Expand Down Expand Up @@ -287,29 +295,24 @@ def _load(cls, resource: dict | str, cognite_client: CogniteClient | None = None
instance.destination = _load_destination_dct(instance.destination)

if isinstance(instance.running_job, dict):
snake_dict = convert_all_keys_to_snake_case(instance.running_job)
instance.running_job = TransformationJob._load(snake_dict, cognite_client=cognite_client)
instance.running_job = TransformationJob._load(instance.running_job, cognite_client=cognite_client)

if isinstance(instance.last_finished_job, dict):
snake_dict = convert_all_keys_to_snake_case(instance.last_finished_job)
instance.last_finished_job = TransformationJob._load(snake_dict, cognite_client=cognite_client)
instance.last_finished_job = TransformationJob._load(
instance.last_finished_job, cognite_client=cognite_client
)

if isinstance(instance.blocked, dict):
snake_dict = convert_all_keys_to_snake_case(instance.blocked)
snake_dict.pop("time")
instance.blocked = TransformationBlockedInfo(**snake_dict)
instance.blocked = TransformationBlockedInfo._load(instance.blocked)

if isinstance(instance.schedule, dict):
snake_dict = convert_all_keys_to_snake_case(instance.schedule)
instance.schedule = TransformationSchedule._load(snake_dict, cognite_client=cognite_client)
instance.schedule = TransformationSchedule._load(instance.schedule, cognite_client=cognite_client)

if isinstance(instance.source_session, dict):
snake_dict = convert_all_keys_to_snake_case(instance.source_session)
instance.source_session = SessionDetails(**snake_dict)
instance.source_session = SessionDetails._load(instance.source_session)

if isinstance(instance.destination_session, dict):
snake_dict = convert_all_keys_to_snake_case(instance.destination_session)
instance.destination_session = SessionDetails(**snake_dict)
instance.destination_session = SessionDetails._load(instance.destination_session)
return instance

def dump(self, camel_case: bool = False) -> dict[str, Any]:
Expand Down
10 changes: 7 additions & 3 deletions cognite/client/data_classes/transformations/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,18 @@ class TransformationBlockedInfo:
"""Information about the reason why and when a transformation is blocked.
Args:
reason (str | None): Reason why the transformation is blocked.
created_time (int | None): Timestamp when the transformation was blocked.
reason (str): Reason why the transformation is blocked.
created_time (int): Timestamp when the transformation was blocked.
"""

def __init__(self, reason: str | None = None, created_time: int | None = None) -> None:
def __init__(self, reason: str, created_time: int) -> None:
self.reason = reason
self.created_time = created_time

@classmethod
def _load(cls, resource: dict[str, Any]) -> TransformationBlockedInfo:
return cls(reason=resource["reason"], created_time=resource["createdTime"])


def _load_destination_dct(dct: dict[str, Any]) -> RawTable | Nodes | Edges | SequenceRows | TransformationDestination:
"""Helper function to load destination from dictionary"""
Expand Down
32 changes: 31 additions & 1 deletion cognite/client/utils/_auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from decimal import Decimal
from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
Hashable,
Iterable,
Expand All @@ -21,9 +22,19 @@

import cognite.client
from cognite.client.exceptions import CogniteImportError
from cognite.client.utils._text import convert_all_keys_to_camel_case, convert_all_keys_to_snake_case, to_snake_case
from cognite.client.utils._text import (
convert_all_keys_to_camel_case,
convert_all_keys_to_snake_case,
to_camel_case,
to_snake_case,
)
from cognite.client.utils._version_checker import get_newest_version_in_major_release

if TYPE_CHECKING:
from cognite.client import CogniteClient
from cognite.client.data_classes._base import T_CogniteResource


T = TypeVar("T")
THashable = TypeVar("THashable", bound=Hashable)

Expand All @@ -32,6 +43,25 @@ def is_unlimited(limit: float | int | None) -> bool:
return limit in {None, -1, math.inf}


@functools.lru_cache(None)
def get_accepted_params(cls: type[T_CogniteResource]) -> dict[str, str]:
return {to_camel_case(k): k for k in vars(cls()) if not k.startswith("_")}


def fast_dict_load(
cls: type[T_CogniteResource], item: dict[str, Any], cognite_client: CogniteClient | None
) -> T_CogniteResource:
instance = cls(cognite_client=cognite_client)
# Note: Do not use cast(Hashable, cls) here as this is often called in a hot loop
accepted = get_accepted_params(cls) # type: ignore [arg-type]
for camel_attr, value in item.items():
try:
setattr(instance, accepted[camel_attr], value)
except KeyError:
pass
return instance


def basic_obj_dump(obj: Any, camel_case: bool) -> dict[str, Any]:
if camel_case:
return convert_all_keys_to_camel_case(vars(obj))
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_unit/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ def test_dump_camel_case(self):

def test_load(self):
assert MyResource(1).dump() == MyResource._load({"varA": 1}).dump()
assert MyResource(1, 2).dump() == MyResource._load({"var_a": 1, "var_b": 2}).dump()
assert {"var_a": 1} == MyResource._load({"var_a": 1, "var_c": 1}).dump()
assert MyResource().dump() == MyResource._load({"var_a": 1, "var_b": 2}).dump()
assert {"var_a": 1} == MyResource._load({"varA": 1, "varC": 1}).dump()

def test_load_unknown_attribute(self):
assert {"var_a": 1, "var_b": 2} == MyResource._load({"varA": 1, "varB": 2, "varC": 3}).dump()
Expand Down
32 changes: 32 additions & 0 deletions tests/tests_unit/test_utils/test_auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

import pytest

from cognite.client.data_classes._base import CogniteResource
from cognite.client.exceptions import CogniteImportError
from cognite.client.utils._auxiliary import (
assert_type,
exactly_one_is_not_none,
fast_dict_load,
find_duplicates,
get_accepted_params,
handle_deprecated_camel_case_argument,
interpolate_and_url_encode,
json_dump_default,
Expand Down Expand Up @@ -273,3 +276,32 @@ class TestExactlyOneIsNotNone:
)
def test_exactly_one_is_not_none(self, inp, expected):
assert exactly_one_is_not_none(*inp) == expected


class MyTestResource(CogniteResource):
# Test resource for fast_dict_load below
def __init__(self, foo=None, foo_bar=None, foo_bar_baz=None, cognite_client=None):
self.foo = foo
self.foo_bar = foo_bar
self.foo_bar_baz = foo_bar_baz
self._cognite_client = cognite_client

def _load(*a, **kw):
raise NotImplementedError


class TestFastDictLoad:
@pytest.mark.parametrize(
"item, expected",
(
# Simple load test for all keys:
({"foo": "a", "fooBar": "b", "fooBarBaz": "c"}, MyTestResource(*"abc")),
# Ensure unknown keys are skipped silently:
({"f": "a", "foot": "b", "fooBarBaz": "c"}, MyTestResource(foo_bar_baz="c")),
# Ensure keys must be camel cased:
({"foo": "a", "foo_bar": "b", "foo_bar_baz": "c"}, MyTestResource(foo="a")),
),
)
def test_load(self, item, expected):
get_accepted_params.cache_clear() # For good measure
assert expected == fast_dict_load(MyTestResource, item, cognite_client=None)

0 comments on commit aad22ae

Please sign in to comment.