Skip to content

Commit

Permalink
refactor: robust loading
Browse files Browse the repository at this point in the history
  • Loading branch information
doctrino committed Aug 29, 2024
1 parent bb27dd2 commit 1318e92
Showing 1 changed file with 156 additions and 12 deletions.
168 changes: 156 additions & 12 deletions cognite/client/data_classes/hosted_extractors/sources.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from __future__ import annotations

from abc import ABC
import itertools
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast

from typing_extensions import Self, TypeAlias
from typing_extensions import Self

from cognite.client.data_classes._base import (
CogniteObject,
CogniteResource,
CogniteResourceList,
ExternalIDTransformerMixin,
T_WriteClass,
UnknownCogniteObject,
WriteableCogniteResource,
WriteableCogniteResourceList,
)
from cognite.client.utils._auxiliary import fast_dict_load

if TYPE_CHECKING:
from cognite.client import CogniteClient


SourceType: TypeAlias = Literal["mqtt5", "mqtt3", "eventhub"]


class SourceWrite(CogniteResource, ABC):
"""A hosted extractor source represents an external source system on the internet.
The source resource in CDF contains all the information the extractor needs to
Expand All @@ -39,6 +39,28 @@ class SourceWrite(CogniteResource, ABC):
def __init__(self, external_id: str) -> None:
self.external_id = external_id

@classmethod
@abstractmethod
def _load_source(cls, resource: dict[str, Any]) -> Self:
raise NotImplementedError()

@classmethod
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
type_ = resource.get("type")
if type_ is None and hasattr(cls, "_type"):
type_ = cls._type
else:
raise KeyError("type")
source_class = _SOURCE_WRITE_CLASS_BY_TYPE.get(type_)
if source_class is None:
return UnknownCogniteObject(resource) # type: ignore[return-value]
return cast(Self, source_class._load_source(resource))

def dump(self, camel_case: bool = True) -> dict[str, Any]:
output = super().dump(camel_case)
output["type"] = self._type
return output


class Source(WriteableCogniteResource[T_WriteClass], ABC):
"""A hosted extractor source represents an external source system on the internet.
Expand All @@ -56,6 +78,28 @@ class Source(WriteableCogniteResource[T_WriteClass], ABC):
def __init__(self, external_id: str) -> None:
self.external_id = external_id

@classmethod
@abstractmethod
def _load_source(cls, resource: dict[str, Any]) -> Self:
raise NotImplementedError()

@classmethod
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
type_ = resource.get("type")
if type_ is None and hasattr(cls, "_type"):
type_ = cls._type
else:
raise KeyError("type")
source_class = _SOURCE_CLASS_BY_TYPE.get(type_)
if source_class is None:
return UnknownCogniteObject(resource) # type: ignore[return-value]
return cast(Self, source_class._load(resource))

def dump(self, camel_case: bool = True) -> dict[str, Any]:
output = super().dump(camel_case)
output["type"] = self._type
return output


class EventHubSourceWrite(SourceWrite):
"""A hosted extractor source represents an external source system on the internet.
Expand Down Expand Up @@ -92,6 +136,10 @@ def __init__(
def as_write(self) -> SourceWrite:
return self

@classmethod
def _load_source(cls, resource: dict[str, Any]) -> Self:
return fast_dict_load(cls, resource, None)


class EventHubSource(Source):
"""A hosted extractor source represents an external source system on the internet.
Expand Down Expand Up @@ -140,17 +188,48 @@ def as_write(self, key_value: str | None = None) -> EventHubSourceWrite:
consumer_group=self.consumer_group,
)

@classmethod
def _load_source(cls, resource: dict[str, Any]) -> Self:
return fast_dict_load(cls, resource, None)


@dataclass
class MQTTAuthenticationWrite(CogniteObject, ABC):
_type: ClassVar[str]

@classmethod
@abstractmethod
def _load_authentication(cls, resource: dict[str, Any]) -> Self:
raise NotImplementedError()

@classmethod
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
type_ = resource.get("type")
if type_ is None and hasattr(cls, "_type"):
type_ = cls._type
else:
raise KeyError("type is required")
authentication_class = _MQTTAUTHENTICATION_WRITE_CLASS_BY_TYPE.get(type_)
if authentication_class is None:
return UnknownCogniteObject(resource) # type: ignore[return-value]
return cast(Self, authentication_class._load_authentication(resource))

def dump(self, camel_case: bool = True) -> dict[str, Any]:
output = super().dump(camel_case)
output["type"] = self._type
return output


@dataclass
class BasicMQTTAuthenticationWrite(MQTTAuthenticationWrite):
_type = "basic"
username: str
password: str | None

@classmethod
def _load_authentication(cls, resource: dict[str, Any]) -> Self:
return fast_dict_load(cls, resource, None)


@dataclass
class CACertificateWrite(CogniteObject):
Expand Down Expand Up @@ -179,7 +258,7 @@ def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None =
)


class MQTTSourceWrite(SourceWrite):
class _MQTTSourceWrite(SourceWrite, ABC):
def __init__(
self,
external_id: str,
Expand All @@ -199,7 +278,7 @@ def __init__(
self.auth_certificate = auth_certificate

@classmethod
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
def _load_source(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
return cls(
external_id=resource["externalId"],
host=resource["host"],
Expand Down Expand Up @@ -229,13 +308,37 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]:
class MQTTAuthentication(CogniteObject, ABC):
_type: ClassVar[str]

@classmethod
@abstractmethod
def _load_authentication(cls, resource: dict[str, Any]) -> Self:
raise NotImplementedError()

@classmethod
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
type_ = resource.get("type")
if type_ is None and hasattr(cls, "_type"):
type_ = cls._type
else:
raise KeyError("type")

authentication_class = _MQTTAUTHENTICATION_CLASS_BY_TYPE.get(type_)
if authentication_class is None:
return UnknownCogniteObject(resource) # type: ignore[return-value]
return cast(Self, authentication_class._load_authentication(resource))

def dump(self, camel_case: bool = True) -> dict[str, Any]:
output = super().dump(camel_case)
output["type"] = self._type
return output


@dataclass
class BasicMQTTAuthentication(MQTTAuthentication):
_type = "basic"
username: str

@classmethod
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
def _load_authentication(cls, resource: dict[str, Any]) -> Self:
return cls(username=resource["username"])


Expand All @@ -259,7 +362,7 @@ def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None =
return cls(thumbprint=resource["thumbprint"], expires_at=resource["expiresAt"])


class MQTTSource(Source):
class _MQTTSource(Source, ABC):
def __init__(
self,
external_id: str,
Expand All @@ -283,7 +386,7 @@ def __init__(
self.last_updated_time = last_updated_time

@classmethod
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
def _load_source(cls, resource: dict[str, Any]) -> Self:
return cls(
external_id=resource["externalId"],
host=resource["host"],
Expand All @@ -300,7 +403,7 @@ def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None =
last_updated_time=resource["lastUpdatedTime"],
)

def as_write(self) -> MQTTSourceWrite:
def as_write(self) -> _MQTTSourceWrite:
raise TypeError(f"{type(self).__name__} cannot be converted to write as id does not contain the secrets")

def dump(self, camel_case: bool = True) -> dict[str, Any]:
Expand All @@ -314,6 +417,22 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]:
return output


class MQTT3SourceWrite(_MQTTSourceWrite):
_type = "mqtt3"


class MQTT5SourceWrite(_MQTTSourceWrite):
_type = "mqtt5"


class MQTT3Source(_MQTTSource):
_type = "mqtt3"


class MQTT5Source(_MQTTSource):
_type = "mqtt5"


class SourceWriteList(CogniteResourceList[SourceWrite], ExternalIDTransformerMixin):
_RESOURCE = SourceWrite

Expand All @@ -325,3 +444,28 @@ def as_write(
self,
) -> SourceWriteList:
raise TypeError(f"{type(self).__name__} cannot be converted to write")


_SOURCE_WRITE_CLASS_BY_TYPE: dict[str, type[SourceWrite]] = {
subclass._type: subclass # type: ignore[type-abstract, misc]
for subclass in itertools.chain(SourceWrite.__subclasses__(), _MQTTSourceWrite.__subclasses__())
if hasattr(subclass, "_type")
}

_SOURCE_CLASS_BY_TYPE: dict[str, type[Source]] = {
subclass._type: subclass # type: ignore[type-abstract, misc]
for subclass in itertools.chain(Source.__subclasses__(), _MQTTSource.__subclasses__())
if hasattr(subclass, "_type")
}

_MQTTAUTHENTICATION_WRITE_CLASS_BY_TYPE: dict[str, type[MQTTAuthenticationWrite]] = {
subclass._type: subclass # type: ignore[type-abstract]
for subclass in MQTTAuthenticationWrite.__subclasses__()
if hasattr(subclass, "_type")
}

_MQTTAUTHENTICATION_CLASS_BY_TYPE: dict[str, type[MQTTAuthentication]] = {
subclass._type: subclass # type: ignore[type-abstract]
for subclass in MQTTAuthentication.__subclasses__()
if hasattr(subclass, "_type")
}

0 comments on commit 1318e92

Please sign in to comment.