diff --git a/cognite/client/data_classes/_base.py b/cognite/client/data_classes/_base.py index 771501ec10..8d8e137592 100644 --- a/cognite/client/data_classes/_base.py +++ b/cognite/client/data_classes/_base.py @@ -28,9 +28,10 @@ from cognite.client import utils from cognite.client.exceptions import CogniteMissingClientError +from cognite.client.utils._auxiliary import fast_dict_load from cognite.client.utils._identifier import IdentifierSequence from cognite.client.utils._pandas_helpers import convert_nullable_int_cols, notebook_display_with_fallback -from cognite.client.utils._text import convert_all_keys_to_camel_case, to_camel_case, to_snake_case +from cognite.client.utils._text import convert_all_keys_to_camel_case, to_camel_case from cognite.client.utils._time import convert_time_attributes_to_datetime if TYPE_CHECKING: @@ -132,15 +133,10 @@ def dump(self, camel_case: bool = False) -> dict[str, Any]: def _load( cls: type[T_CogniteResource], resource: dict | str, cognite_client: CogniteClient | None = None ) -> T_CogniteResource: - if isinstance(resource, str): + if isinstance(resource, dict): + return fast_dict_load(cls, resource, cognite_client=cognite_client) + elif isinstance(resource, str): return cls._load(json.loads(resource), cognite_client=cognite_client) - elif isinstance(resource, dict): - instance = cls(cognite_client=cognite_client) - for key, value in resource.items(): - snake_case_key = to_snake_case(key) - if hasattr(instance, snake_case_key): - setattr(instance, snake_case_key, value) - return instance raise TypeError(f"Resource must be json str or dict, not {type(resource)}") def to_pandas( diff --git a/cognite/client/data_classes/contextualization.py b/cognite/client/data_classes/contextualization.py index 00f95e28a7..a86d55b7ff 100644 --- a/cognite/client/data_classes/contextualization.py +++ b/cognite/client/data_classes/contextualization.py @@ -160,6 +160,9 @@ def _load_with_status( ) -> T_ContextualizationJob: obj = cls._load({**data, "jobToken": headers.get("X-Job-Token")}, cognite_client=cognite_client) obj._status_path = status_path + # '_load' does not see properties (real attribute stored under a different name, e.g. '_items' not 'items'): + if "items" in data and hasattr(obj, "items"): + obj.items = data["items"] return obj diff --git a/cognite/client/data_classes/transformations/__init__.py b/cognite/client/data_classes/transformations/__init__.py index 6a145913fd..16094158af 100644 --- a/cognite/client/data_classes/transformations/__init__.py +++ b/cognite/client/data_classes/transformations/__init__.py @@ -27,7 +27,7 @@ from cognite.client.data_classes.transformations.schedules import TransformationSchedule from cognite.client.data_classes.transformations.schema import TransformationSchemaColumnList from cognite.client.exceptions import CogniteAPIError -from cognite.client.utils._text import convert_all_keys_to_camel_case, convert_all_keys_to_snake_case +from cognite.client.utils._text import convert_all_keys_to_camel_case if TYPE_CHECKING: from cognite.client import CogniteClient @@ -55,6 +55,14 @@ def __init__( self.client_id = client_id self.project_name = project_name + @classmethod + def _load(cls, resource: dict[str, Any]) -> SessionDetails: + return cls( + session_id=resource.get("sessionId"), + client_id=resource.get("clientId"), + project_name=resource.get("projectName"), + ) + def dump(self, camel_case: bool = False) -> dict[str, Any]: """Dump the instance into a json serializable Python data type. @@ -287,29 +295,24 @@ def _load(cls, resource: dict | str, cognite_client: CogniteClient | None = None instance.destination = _load_destination_dct(instance.destination) if isinstance(instance.running_job, dict): - snake_dict = convert_all_keys_to_snake_case(instance.running_job) - instance.running_job = TransformationJob._load(snake_dict, cognite_client=cognite_client) + instance.running_job = TransformationJob._load(instance.running_job, cognite_client=cognite_client) if isinstance(instance.last_finished_job, dict): - snake_dict = convert_all_keys_to_snake_case(instance.last_finished_job) - instance.last_finished_job = TransformationJob._load(snake_dict, cognite_client=cognite_client) + instance.last_finished_job = TransformationJob._load( + instance.last_finished_job, cognite_client=cognite_client + ) if isinstance(instance.blocked, dict): - snake_dict = convert_all_keys_to_snake_case(instance.blocked) - snake_dict.pop("time") - instance.blocked = TransformationBlockedInfo(**snake_dict) + instance.blocked = TransformationBlockedInfo._load(instance.blocked) if isinstance(instance.schedule, dict): - snake_dict = convert_all_keys_to_snake_case(instance.schedule) - instance.schedule = TransformationSchedule._load(snake_dict, cognite_client=cognite_client) + instance.schedule = TransformationSchedule._load(instance.schedule, cognite_client=cognite_client) if isinstance(instance.source_session, dict): - snake_dict = convert_all_keys_to_snake_case(instance.source_session) - instance.source_session = SessionDetails(**snake_dict) + instance.source_session = SessionDetails._load(instance.source_session) if isinstance(instance.destination_session, dict): - snake_dict = convert_all_keys_to_snake_case(instance.destination_session) - instance.destination_session = SessionDetails(**snake_dict) + instance.destination_session = SessionDetails._load(instance.destination_session) return instance def dump(self, camel_case: bool = False) -> dict[str, Any]: diff --git a/cognite/client/data_classes/transformations/common.py b/cognite/client/data_classes/transformations/common.py index 8a8d1fcfef..9a30a102cf 100644 --- a/cognite/client/data_classes/transformations/common.py +++ b/cognite/client/data_classes/transformations/common.py @@ -343,14 +343,18 @@ class TransformationBlockedInfo: """Information about the reason why and when a transformation is blocked. Args: - reason (str | None): Reason why the transformation is blocked. - created_time (int | None): Timestamp when the transformation was blocked. + reason (str): Reason why the transformation is blocked. + created_time (int): Timestamp when the transformation was blocked. """ - def __init__(self, reason: str | None = None, created_time: int | None = None) -> None: + def __init__(self, reason: str, created_time: int) -> None: self.reason = reason self.created_time = created_time + @classmethod + def _load(cls, resource: dict[str, Any]) -> TransformationBlockedInfo: + return cls(reason=resource["reason"], created_time=resource["createdTime"]) + def _load_destination_dct(dct: dict[str, Any]) -> RawTable | Nodes | Edges | SequenceRows | TransformationDestination: """Helper function to load destination from dictionary""" diff --git a/cognite/client/utils/_auxiliary.py b/cognite/client/utils/_auxiliary.py index e13ff57182..513ef801a1 100644 --- a/cognite/client/utils/_auxiliary.py +++ b/cognite/client/utils/_auxiliary.py @@ -9,6 +9,7 @@ from decimal import Decimal from types import ModuleType from typing import ( + TYPE_CHECKING, Any, Hashable, Iterable, @@ -21,9 +22,19 @@ import cognite.client from cognite.client.exceptions import CogniteImportError -from cognite.client.utils._text import convert_all_keys_to_camel_case, convert_all_keys_to_snake_case, to_snake_case +from cognite.client.utils._text import ( + convert_all_keys_to_camel_case, + convert_all_keys_to_snake_case, + to_camel_case, + to_snake_case, +) from cognite.client.utils._version_checker import get_newest_version_in_major_release +if TYPE_CHECKING: + from cognite.client import CogniteClient + from cognite.client.data_classes._base import T_CogniteResource + + T = TypeVar("T") THashable = TypeVar("THashable", bound=Hashable) @@ -32,6 +43,25 @@ def is_unlimited(limit: float | int | None) -> bool: return limit in {None, -1, math.inf} +@functools.lru_cache(None) +def get_accepted_params(cls: type[T_CogniteResource]) -> dict[str, str]: + return {to_camel_case(k): k for k in vars(cls()) if not k.startswith("_")} + + +def fast_dict_load( + cls: type[T_CogniteResource], item: dict[str, Any], cognite_client: CogniteClient | None +) -> T_CogniteResource: + instance = cls(cognite_client=cognite_client) + # Note: Do not use cast(Hashable, cls) here as this is often called in a hot loop + accepted = get_accepted_params(cls) # type: ignore [arg-type] + for camel_attr, value in item.items(): + try: + setattr(instance, accepted[camel_attr], value) + except KeyError: + pass + return instance + + def basic_obj_dump(obj: Any, camel_case: bool) -> dict[str, Any]: if camel_case: return convert_all_keys_to_camel_case(vars(obj)) diff --git a/tests/tests_unit/test_base.py b/tests/tests_unit/test_base.py index 7079e805a6..21ec759e27 100644 --- a/tests/tests_unit/test_base.py +++ b/tests/tests_unit/test_base.py @@ -135,8 +135,8 @@ def test_dump_camel_case(self): def test_load(self): assert MyResource(1).dump() == MyResource._load({"varA": 1}).dump() - assert MyResource(1, 2).dump() == MyResource._load({"var_a": 1, "var_b": 2}).dump() - assert {"var_a": 1} == MyResource._load({"var_a": 1, "var_c": 1}).dump() + assert MyResource().dump() == MyResource._load({"var_a": 1, "var_b": 2}).dump() + assert {"var_a": 1} == MyResource._load({"varA": 1, "varC": 1}).dump() def test_load_unknown_attribute(self): assert {"var_a": 1, "var_b": 2} == MyResource._load({"varA": 1, "varB": 2, "varC": 3}).dump() diff --git a/tests/tests_unit/test_utils/test_auxiliary.py b/tests/tests_unit/test_utils/test_auxiliary.py index d5213a46a7..d35a1567de 100644 --- a/tests/tests_unit/test_utils/test_auxiliary.py +++ b/tests/tests_unit/test_utils/test_auxiliary.py @@ -7,11 +7,14 @@ import pytest +from cognite.client.data_classes._base import CogniteResource from cognite.client.exceptions import CogniteImportError from cognite.client.utils._auxiliary import ( assert_type, exactly_one_is_not_none, + fast_dict_load, find_duplicates, + get_accepted_params, handle_deprecated_camel_case_argument, interpolate_and_url_encode, json_dump_default, @@ -273,3 +276,32 @@ class TestExactlyOneIsNotNone: ) def test_exactly_one_is_not_none(self, inp, expected): assert exactly_one_is_not_none(*inp) == expected + + +class MyTestResource(CogniteResource): + # Test resource for fast_dict_load below + def __init__(self, foo=None, foo_bar=None, foo_bar_baz=None, cognite_client=None): + self.foo = foo + self.foo_bar = foo_bar + self.foo_bar_baz = foo_bar_baz + self._cognite_client = cognite_client + + def _load(*a, **kw): + raise NotImplementedError + + +class TestFastDictLoad: + @pytest.mark.parametrize( + "item, expected", + ( + # Simple load test for all keys: + ({"foo": "a", "fooBar": "b", "fooBarBaz": "c"}, MyTestResource(*"abc")), + # Ensure unknown keys are skipped silently: + ({"f": "a", "foot": "b", "fooBarBaz": "c"}, MyTestResource(foo_bar_baz="c")), + # Ensure keys must be camel cased: + ({"foo": "a", "foo_bar": "b", "foo_bar_baz": "c"}, MyTestResource(foo="a")), + ), + ) + def test_load(self, item, expected): + get_accepted_params.cache_clear() # For good measure + assert expected == fast_dict_load(MyTestResource, item, cognite_client=None)