Skip to content

Commit

Permalink
generic removed at all :(
Browse files Browse the repository at this point in the history
  • Loading branch information
= committed Jan 24, 2025
1 parent 1441814 commit 52fd530
Showing 1 changed file with 97 additions and 30 deletions.
127 changes: 97 additions & 30 deletions chatsky/core/ctx_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Any,
Callable,
Dict,
Generic,
Iterable,
List,
Mapping,
Expand All @@ -16,7 +15,6 @@
Set,
Tuple,
Type,
TypeVar,
Union,
overload,
TYPE_CHECKING,
Expand All @@ -31,44 +29,41 @@
if TYPE_CHECKING:
from chatsky.context_storages.database import DBContextStorage

K = TypeVar("K", bound=int)
V = TypeVar("V")

logger = logging.getLogger(__name__)


def _get_hash(string: bytes) -> bytes:
return sha256(string).digest()


class ContextDict(ABC, BaseModel, Generic[K, V]):
class ContextDict(ABC, BaseModel):
"""
Dictionary-like structure for storing different dialog types in a context storage.
It holds all the possible keys, but may not store all the values locally.
Some of them might be loaded lazily upon querying.
"""

_items: Dict[K, V] = PrivateAttr(default_factory=dict)
_items: Dict[int, BaseModel] = PrivateAttr(default_factory=dict)
"""
Already loaded from storage items collection.
"""

_hashes: Dict[K, int] = PrivateAttr(default_factory=dict)
_hashes: Dict[int, int] = PrivateAttr(default_factory=dict)
"""
Hashes of the loaded items (as they were upon loading), only populated if `rewrite_existing` flag is enabled.
"""

_keys: Set[K] = PrivateAttr(default_factory=set)
_keys: Set[int] = PrivateAttr(default_factory=set)
"""
All the item keys available in the storage.
"""

_added: Set[K] = PrivateAttr(default_factory=set)
_added: Set[int] = PrivateAttr(default_factory=set)
"""
Keys added localy (need to be synchronized with the storage).
"""

_removed: Set[K] = PrivateAttr(default_factory=set)
_removed: Set[int] = PrivateAttr(default_factory=set)
"""
Keys removed localy (need to be synchronized with the storage).
"""
Expand All @@ -90,11 +85,11 @@ class ContextDict(ABC, BaseModel, Generic[K, V]):

@property
@abstractmethod
def _value_type(self) -> TypeAdapter[Type[V]]:
def _value_type(self) -> TypeAdapter[Type[BaseModel]]:
raise NotImplementedError

@classmethod
async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict[K, V]":
async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict":
"""
Create a new context dict, without connecting it to the context storage.
No keys or items will be loaded, but any newly added items will be available for synchronization.
Expand All @@ -114,7 +109,7 @@ async def new(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDi
return instance

@classmethod
async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict[K, V]":
async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "ContextDict":
"""
Create a new context dict, connecting it to the context storage.
All the keys and some items will be loaded, all the other items will be available for synchronization.
Expand All @@ -139,7 +134,7 @@ async def connected(cls, storage: DBContextStorage, id: str, field: str) -> "Con
instance._hashes = {k: _get_hash(v) for k, v in val_key_items} if storage.rewrite_existing else dict()
return instance

async def _load_items(self, keys: List[K]) -> None:
async def _load_items(self, keys: List[int]) -> None:
"""
Load items for the given keys from the connected context storage.
Update the `_items` and `_hashes` fields if necessary.
Expand All @@ -162,10 +157,10 @@ async def _load_items(self, keys: List[K]) -> None:
self._hashes[key] = _get_hash(value)

@overload
async def __getitem__(self, key: K) -> V: ... # noqa: E704
async def __getitem__(self, key: int) -> BaseModel: ... # noqa: E704

@overload
async def __getitem__(self, key: slice) -> List[V]: ... # noqa: E704
async def __getitem__(self, key: slice) -> List[BaseModel]: ... # noqa: E704

async def __getitem__(self, key):
if isinstance(key, int) and key < 0:
Expand All @@ -182,7 +177,7 @@ async def __getitem__(self, key):
else:
return self._items[key]

def __setitem__(self, key: Union[K, slice], value: Union[V, Sequence[V]]) -> None:
def __setitem__(self, key: Union[int, slice], value: Union[BaseModel, Sequence[BaseModel]]) -> None:
if isinstance(key, int) and key < 0:
key = self.keys()[key]
if isinstance(key, slice):
Expand All @@ -200,7 +195,7 @@ def __setitem__(self, key: Union[K, slice], value: Union[V, Sequence[V]]) -> Non
self._removed.discard(key)
self._items[key] = self._value_type.validate_python(value)

def __delitem__(self, key: Union[K, slice]) -> None:
def __delitem__(self, key: Union[int, slice]) -> None:
if isinstance(key, int) and key < 0:
key = self.keys()[key]
if isinstance(key, slice):
Expand All @@ -212,17 +207,17 @@ def __delitem__(self, key: Union[K, slice]) -> None:
self._keys.discard(key)
del self._items[key]

def __iter__(self) -> Sequence[K]:
def __iter__(self) -> Sequence[int]:
return iter(self.keys() if self._storage is not None else self._items.keys())

def __len__(self) -> int:
return len(self.keys() if self._storage is not None else self._items.keys())

@overload
async def get(self, key: K, default=None) -> V: ... # noqa: E704
async def get(self, key: int, default=None) -> BaseModel: ... # noqa: E704

@overload
async def get(self, key: Iterable[K], default=None) -> List[V]: ... # noqa: E704
async def get(self, key: Iterable[int], default=None) -> List[BaseModel]: ... # noqa: E704

async def get(self, key, default=None):
"""
Expand All @@ -243,19 +238,19 @@ async def get(self, key, default=None):
else:
return default

def __contains__(self, key: K) -> bool:
def __contains__(self, key: int) -> bool:
return key in self.keys()

def keys(self) -> List[K]:
def keys(self) -> List[int]:
return sorted(self._keys)

async def values(self) -> List[V]:
async def values(self) -> List[BaseModel]:
return await self[:]

async def items(self) -> List[Tuple[K, V]]:
async def items(self) -> List[Tuple[int, BaseModel]]:
return [(k, v) for k, v in zip(self.keys(), await self.values())]

async def pop(self, key: K, default=None) -> V:
async def pop(self, key: int, default=None) -> BaseModel:
try:
value = await self[key]
except KeyError:
Expand All @@ -264,7 +259,7 @@ async def pop(self, key: K, default=None) -> V:
del self[key]
return value

async def popitem(self) -> Tuple[K, V]:
async def popitem(self) -> Tuple[int, BaseModel]:
try:
key = next(iter(self))
except StopIteration:
Expand All @@ -291,7 +286,7 @@ async def update(self, other: Any = (), /, **kwds) -> None:
for key, value in kwds.items():
self[key] = value

async def setdefault(self, key: K, default=None) -> V:
async def setdefault(self, key: int, default=None) -> BaseModel:
try:
return await self[key]
except KeyError:
Expand Down Expand Up @@ -345,7 +340,7 @@ def _validate_model(value: Any, handler: Callable[[Any], "ContextDict"], _) -> "
raise ValueError(f"Unknown type of ContextDict value: {type(value).__name__}!")

@model_serializer()
def _serialize_model(self) -> Dict[K, V]:
def _serialize_model(self) -> Dict[int, BaseModel]:
if self._storage is None:
return self._items
elif not self._storage.rewrite_existing:
Expand Down Expand Up @@ -389,16 +384,88 @@ class LabelContextDict(ContextDict[int, AbsoluteNodeLabel]):
Context dictionary for storing `AbsoluteNodeLabel` types.
"""

_items: Dict[int, AbsoluteNodeLabel]

@property
def _value_type(self) -> TypeAdapter[Type[AbsoluteNodeLabel]]:
return TypeAdapter(AbsoluteNodeLabel)

@overload
async def __getitem__(self, key: int) -> AbsoluteNodeLabel: ... # noqa: E704

@overload
async def __getitem__(self, key: slice) -> List[AbsoluteNodeLabel]: ... # noqa: E704

def __setitem__(self, key: Union[int, slice], value: Union[AbsoluteNodeLabel, Sequence[AbsoluteNodeLabel]]) -> None:
return super().__setitem__(key, value)

@overload
async def get(self, key: int, default=None) -> AbsoluteNodeLabel: ... # noqa: E704

@overload
async def get(self, key: Iterable[int], default=None) -> List[AbsoluteNodeLabel]: ... # noqa: E704

async def values(self) -> List[AbsoluteNodeLabel]:
return super().values()

async def items(self) -> List[Tuple[int, AbsoluteNodeLabel]]:
return super().items()

async def pop(self, key: int, default=None) -> AbsoluteNodeLabel:
return super().pop(key, default)

async def popitem(self) -> Tuple[int, AbsoluteNodeLabel]:
return super().popitem()

async def setdefault(self, key: int, default=None) -> AbsoluteNodeLabel:
return super().setdefault(key, default)

@model_serializer()
def _serialize_model(self) -> Dict[int, AbsoluteNodeLabel]:
return super()._serialize_model()


class MessageContextDict(ContextDict[int, Message]):
"""
Context dictionary for storing `Message` types.
"""

_items: Dict[int, Message]

@property
def _value_type(self) -> TypeAdapter[Type[Message]]:
return TypeAdapter(Message)

@overload
async def __getitem__(self, key: int) -> Message: ... # noqa: E704

@overload
async def __getitem__(self, key: slice) -> List[Message]: ... # noqa: E704

def __setitem__(self, key: Union[int, slice], value: Union[Message, Sequence[Message]]) -> None:
return super().__setitem__(key, value)

@overload
async def get(self, key: int, default=None) -> Message: ... # noqa: E704

@overload
async def get(self, key: Iterable[int], default=None) -> List[Message]: ... # noqa: E704

async def values(self) -> List[Message]:
return super().values()

async def items(self) -> List[Tuple[int, Message]]:
return super().items()

async def pop(self, key: int, default=None) -> Message:
return super().pop(key, default)

async def popitem(self) -> Tuple[int, Message]:
return super().popitem()

async def setdefault(self, key: int, default=None) -> Message:
return super().setdefault(key, default)

@model_serializer()
def _serialize_model(self) -> Dict[int, Message]:
return super()._serialize_model()

0 comments on commit 52fd530

Please sign in to comment.