From 52fd53088f2650f53fd50bec2ca93e8417669cf9 Mon Sep 17 00:00:00 2001 From: = <=> Date: Fri, 24 Jan 2025 14:24:18 +0100 Subject: [PATCH] generic removed at all :( --- chatsky/core/ctx_dict.py | 127 ++++++++++++++++++++++++++++++--------- 1 file changed, 97 insertions(+), 30 deletions(-) diff --git a/chatsky/core/ctx_dict.py b/chatsky/core/ctx_dict.py index c04a99896..b513d517e 100644 --- a/chatsky/core/ctx_dict.py +++ b/chatsky/core/ctx_dict.py @@ -7,7 +7,6 @@ Any, Callable, Dict, - Generic, Iterable, List, Mapping, @@ -16,7 +15,6 @@ Set, Tuple, Type, - TypeVar, Union, overload, TYPE_CHECKING, @@ -31,9 +29,6 @@ if TYPE_CHECKING: from chatsky.context_storages.database import DBContextStorage -K = TypeVar("K", bound=int) -V = TypeVar("V") - logger = logging.getLogger(__name__) @@ -41,34 +36,34 @@ def _get_hash(string: bytes) -> bytes: return sha256(string).digest() -class ContextDict(ABC, BaseModel, Generic[K, V]): +class ContextDict(ABC, BaseModel): """ Dictionary-like structure for storing different dialog types in a context storage. It holds all the possible keys, but may not store all the values locally. Some of them might be loaded lazily upon querying. """ - _items: Dict[K, V] = PrivateAttr(default_factory=dict) + _items: Dict[int, BaseModel] = PrivateAttr(default_factory=dict) """ Already loaded from storage items collection. """ - _hashes: Dict[K, int] = PrivateAttr(default_factory=dict) + _hashes: Dict[int, int] = PrivateAttr(default_factory=dict) """ Hashes of the loaded items (as they were upon loading), only populated if `rewrite_existing` flag is enabled. """ - _keys: Set[K] = PrivateAttr(default_factory=set) + _keys: Set[int] = PrivateAttr(default_factory=set) """ All the item keys available in the storage. """ - _added: Set[K] = PrivateAttr(default_factory=set) + _added: Set[int] = PrivateAttr(default_factory=set) """ Keys added localy (need to be synchronized with the storage). """ - _removed: Set[K] = PrivateAttr(default_factory=set) + _removed: Set[int] = PrivateAttr(default_factory=set) """ Keys removed localy (need to be synchronized with the storage). """ @@ -90,11 +85,11 @@ class ContextDict(ABC, BaseModel, Generic[K, V]): @property @abstractmethod - def _value_type(self) -> TypeAdapter[Type[V]]: + def _value_type(self) -> TypeAdapter[Type[BaseModel]]: raise NotImplementedError @classmethod - async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict[K, V]": + async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict": """ Create a new context dict, without connecting it to the context storage. No keys or items will be loaded, but any newly added items will be available for synchronization. @@ -114,7 +109,7 @@ async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDi return instance @classmethod - async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict[K, V]": + async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict": """ Create a new context dict, connecting it to the context storage. All the keys and some items will be loaded, all the other items will be available for synchronization. @@ -139,7 +134,7 @@ async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "Con instance._hashes = {k: _get_hash(v) for k, v in val_key_items} if storage.rewrite_existing else dict() return instance - async def _load_items(self, keys: List[K]) -> None: + async def _load_items(self, keys: List[int]) -> None: """ Load items for the given keys from the connected context storage. Update the `_items` and `_hashes` fields if necessary. @@ -162,10 +157,10 @@ async def _load_items(self, keys: List[K]) -> None: self._hashes[key] = _get_hash(value) @overload - async def __getitem__(self, key: K) -> V: ... # noqa: E704 + async def __getitem__(self, key: int) -> BaseModel: ... # noqa: E704 @overload - async def __getitem__(self, key: slice) -> List[V]: ... # noqa: E704 + async def __getitem__(self, key: slice) -> List[BaseModel]: ... # noqa: E704 async def __getitem__(self, key): if isinstance(key, int) and key < 0: @@ -182,7 +177,7 @@ async def __getitem__(self, key): else: return self._items[key] - def __setitem__(self, key: Union[K, slice], value: Union[V, Sequence[V]]) -> None: + def __setitem__(self, key: Union[int, slice], value: Union[BaseModel, Sequence[BaseModel]]) -> None: if isinstance(key, int) and key < 0: key = self.keys()[key] if isinstance(key, slice): @@ -200,7 +195,7 @@ def __setitem__(self, key: Union[K, slice], value: Union[V, Sequence[V]]) -> Non self._removed.discard(key) self._items[key] = self._value_type.validate_python(value) - def __delitem__(self, key: Union[K, slice]) -> None: + def __delitem__(self, key: Union[int, slice]) -> None: if isinstance(key, int) and key < 0: key = self.keys()[key] if isinstance(key, slice): @@ -212,17 +207,17 @@ def __delitem__(self, key: Union[K, slice]) -> None: self._keys.discard(key) del self._items[key] - def __iter__(self) -> Sequence[K]: + def __iter__(self) -> Sequence[int]: return iter(self.keys() if self._storage is not None else self._items.keys()) def __len__(self) -> int: return len(self.keys() if self._storage is not None else self._items.keys()) @overload - async def get(self, key: K, default=None) -> V: ... # noqa: E704 + async def get(self, key: int, default=None) -> BaseModel: ... # noqa: E704 @overload - async def get(self, key: Iterable[K], default=None) -> List[V]: ... # noqa: E704 + async def get(self, key: Iterable[int], default=None) -> List[BaseModel]: ... # noqa: E704 async def get(self, key, default=None): """ @@ -243,19 +238,19 @@ async def get(self, key, default=None): else: return default - def __contains__(self, key: K) -> bool: + def __contains__(self, key: int) -> bool: return key in self.keys() - def keys(self) -> List[K]: + def keys(self) -> List[int]: return sorted(self._keys) - async def values(self) -> List[V]: + async def values(self) -> List[BaseModel]: return await self[:] - async def items(self) -> List[Tuple[K, V]]: + async def items(self) -> List[Tuple[int, BaseModel]]: return [(k, v) for k, v in zip(self.keys(), await self.values())] - async def pop(self, key: K, default=None) -> V: + async def pop(self, key: int, default=None) -> BaseModel: try: value = await self[key] except KeyError: @@ -264,7 +259,7 @@ async def pop(self, key: K, default=None) -> V: del self[key] return value - async def popitem(self) -> Tuple[K, V]: + async def popitem(self) -> Tuple[int, BaseModel]: try: key = next(iter(self)) except StopIteration: @@ -291,7 +286,7 @@ async def update(self, other: Any = (), /, **kwds) -> None: for key, value in kwds.items(): self[key] = value - async def setdefault(self, key: K, default=None) -> V: + async def setdefault(self, key: int, default=None) -> BaseModel: try: return await self[key] except KeyError: @@ -345,7 +340,7 @@ def _validate_model(value: Any, handler: Callable[[Any], "ContextDict"], _) -> " raise ValueError(f"Unknown type of ContextDict value: {type(value).__name__}!") @model_serializer() - def _serialize_model(self) -> Dict[K, V]: + def _serialize_model(self) -> Dict[int, BaseModel]: if self._storage is None: return self._items elif not self._storage.rewrite_existing: @@ -389,16 +384,88 @@ class LabelContextDict(ContextDict[int, AbsoluteNodeLabel]): Context dictionary for storing `AbsoluteNodeLabel` types. """ + _items: Dict[int, AbsoluteNodeLabel] + @property def _value_type(self) -> TypeAdapter[Type[AbsoluteNodeLabel]]: return TypeAdapter(AbsoluteNodeLabel) + @overload + async def __getitem__(self, key: int) -> AbsoluteNodeLabel: ... # noqa: E704 + + @overload + async def __getitem__(self, key: slice) -> List[AbsoluteNodeLabel]: ... # noqa: E704 + + def __setitem__(self, key: Union[int, slice], value: Union[AbsoluteNodeLabel, Sequence[AbsoluteNodeLabel]]) -> None: + return super().__setitem__(key, value) + + @overload + async def get(self, key: int, default=None) -> AbsoluteNodeLabel: ... # noqa: E704 + + @overload + async def get(self, key: Iterable[int], default=None) -> List[AbsoluteNodeLabel]: ... # noqa: E704 + + async def values(self) -> List[AbsoluteNodeLabel]: + return super().values() + + async def items(self) -> List[Tuple[int, AbsoluteNodeLabel]]: + return super().items() + + async def pop(self, key: int, default=None) -> AbsoluteNodeLabel: + return super().pop(key, default) + + async def popitem(self) -> Tuple[int, AbsoluteNodeLabel]: + return super().popitem() + + async def setdefault(self, key: int, default=None) -> AbsoluteNodeLabel: + return super().setdefault(key, default) + + @model_serializer() + def _serialize_model(self) -> Dict[int, AbsoluteNodeLabel]: + return super()._serialize_model() + class MessageContextDict(ContextDict[int, Message]): """ Context dictionary for storing `Message` types. """ + _items: Dict[int, Message] + @property def _value_type(self) -> TypeAdapter[Type[Message]]: return TypeAdapter(Message) + + @overload + async def __getitem__(self, key: int) -> Message: ... # noqa: E704 + + @overload + async def __getitem__(self, key: slice) -> List[Message]: ... # noqa: E704 + + def __setitem__(self, key: Union[int, slice], value: Union[Message, Sequence[Message]]) -> None: + return super().__setitem__(key, value) + + @overload + async def get(self, key: int, default=None) -> Message: ... # noqa: E704 + + @overload + async def get(self, key: Iterable[int], default=None) -> List[Message]: ... # noqa: E704 + + async def values(self) -> List[Message]: + return super().values() + + async def items(self) -> List[Tuple[int, Message]]: + return super().items() + + async def pop(self, key: int, default=None) -> Message: + return super().pop(key, default) + + async def popitem(self) -> Tuple[int, Message]: + return super().popitem() + + async def setdefault(self, key: int, default=None) -> Message: + return super().setdefault(key, default) + + @model_serializer() + def _serialize_model(self) -> Dict[int, Message]: + return super()._serialize_model()