Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/feat/partial_context_updates' in…
Browse files Browse the repository at this point in the history
…to feat/partial_context_updates
  • Loading branch information
RLKRo committed Jan 17, 2025
2 parents fcf3739 + cdca730 commit 3d73555
Show file tree
Hide file tree
Showing 10 changed files with 363 additions and 61 deletions.
108 changes: 102 additions & 6 deletions chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@


class NameConfig:
"""
Configuration of names of different database parts,
including table names, column names, field names, etc.
"""

_main_table: Literal["main"] = "main"
_turns_table: Literal["turns"] = "turns"
_key_column: Literal["key"] = "key"
Expand All @@ -50,6 +55,13 @@ class NameConfig:


class ContextInfo(BaseModel):
"""
Main context fields, that are stored in `MAIN` table.
For most of the database backends, it will be serialized to json.
For SQL database backends, it will be written to different table columns.
For memory context storage, it won't be serialized at all.
"""

turn_id: int
created_at: int = Field(default_factory=time_ns)
updated_at: int = Field(default_factory=time_ns)
Expand All @@ -70,7 +82,7 @@ def _serialize_misc(self, misc: Dict[str, Any]) -> bytes:
return self._misc_adaptor.dump_json(misc)

@field_serializer("framework_data", when_used="always")
def serialize_courses_in_order(self, framework_data: FrameworkData) -> bytes:
def _serialize_framework_data(self, framework_data: FrameworkData) -> bytes:
return framework_data.model_dump_json().encode()

def __eq__(self, other: Any) -> bool:
Expand All @@ -92,6 +104,15 @@ async def wrapped(self: DBContextStorage, *args, **kwargs):


class DBContextStorage(ABC):
"""
Base context storage class.
Includes a set of methods for storing and reading different context parts.
:param path: Path to the storage instance.
:param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not.
:param partial_read_config: Dictionary of subscripts for all possible turn items.
"""

_default_subscript_value: int = 3

def __init__(
Expand All @@ -102,15 +123,42 @@ def __init__(
):
_, _, file_path = path.partition("://")
configuration = partial_read_config if partial_read_config is not None else dict()

self.full_path = path
"""Full path to access the context storage, as it was provided by user."""
"""
Full path to access the context storage, as it was provided by user.
"""

self.path = Path(file_path)
"""`full_path` without a prefix defining db used."""
"""
`full_path` without a prefix defining db used.
"""

self.rewrite_existing = rewrite_existing
"""Whether to rewrite existing data in the storage."""
"""
Whether to rewrite existing data in the storage.
"""

self._subscripts = dict()
"""
Subscripts control how many elements will be loaded from the database.
Can be an integer, meaning the number of *last* elements to load.
A special value for loading all the elements at once: "__all__".
Can also be a set of keys that should be loaded.
"""

self._sync_lock = Lock()
"""
Synchronization lock for the databases that don't support
asynchronous atomic reads and writes.
"""

self.connected = False
"""
Flag that marks if the storage is connected to the backend.
Should be set in `pipeline.run` or later (lazily).
"""

for field in (NameConfig._labels_field, NameConfig._requests_field, NameConfig._responses_field):
value = configuration.get(field, self._default_subscript_value)
if (not isinstance(value, int)) or value >= 1:
Expand All @@ -121,6 +169,10 @@ def __init__(
@property
@abstractmethod
def is_concurrent(self) -> bool:
"""
If the database backend support asynchronous IO.
"""

raise NotImplementedError

@classmethod
Expand All @@ -135,6 +187,10 @@ async def _connect(self) -> None:
raise NotImplementedError

async def connect(self) -> None:
"""
Connect to the backend context storage.
"""

logger.info(f"Connecting to context storage {type(self).__name__} ...")
await self._connect()
self.connected = True
Expand All @@ -147,7 +203,11 @@ async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]:
async def load_main_info(self, ctx_id: str) -> Optional[ContextInfo]:
"""
Load main information about the context.
:param ctx_id: Context identifier.
:return: Context main information (from `MAIN` table).
"""

if not self.connected:
await self.connect()
logger.debug(f"Loading main info for {ctx_id}...")
Expand All @@ -163,7 +223,11 @@ async def _update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None:
async def update_main_info(self, ctx_id: str, ctx_info: ContextInfo) -> None:
"""
Update main information about the context.
:param ctx_id: Context identifier.
:param ctx_info: New context information (will be written to `MAIN` table).
"""

if not self.connected:
await self.connect()
logger.debug(f"Updating main info for {ctx_id}...")
Expand All @@ -178,7 +242,10 @@ async def _delete_context(self, ctx_id: str) -> None:
async def delete_context(self, ctx_id: str) -> None:
"""
Delete context from context storage.
:param ctx_id: Context identifier.
"""

if not self.connected:
await self.connect()
logger.debug(f"Deleting context {ctx_id}...")
Expand All @@ -192,8 +259,13 @@ async def _load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[i
@_lock
async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]:
"""
Load the latest field data.
Load the latest field data (specified by `subscript` value).
:param ctx_id: Context identifier.
:param field_name: Field name to load from `TURNS` table.
:return: List of tuples (step number, serialized value).
"""

if not self.connected:
await self.connect()
logger.debug(f"Loading latest items for {ctx_id}, {field_name}...")
Expand All @@ -209,7 +281,12 @@ async def _load_field_keys(self, ctx_id: str, field_name: str) -> List[int]:
async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]:
"""
Load all field keys.
:param ctx_id: Context identifier.
:param field_name: Field name to load from `TURNS` table.
:return: List of all the step numbers.
"""

if not self.connected:
await self.connect()
logger.debug(f"Loading field keys for {ctx_id}, {field_name}...")
Expand All @@ -224,8 +301,15 @@ async def _load_field_items(self, ctx_id: str, field_name: str, keys: List[int])
@_lock
async def load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]:
"""
Load field items.
Load field items (specified by key list).
The items that are equal to `None` will be ignored.
:param ctx_id: Context identifier.
:param field_name: Field name to load from `TURNS` table.
:param keys: List of keys to load.
:return: List of tuples (step number, serialized value).
"""

if not self.connected:
await self.connect()
logger.debug(f"Loading field items for {ctx_id}, {field_name} ({collapse_num_list(keys)})...")
Expand All @@ -241,7 +325,12 @@ async def _update_field_items(self, ctx_id: str, field_name: str, items: List[Tu
async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tuple[int, Optional[bytes]]]) -> None:
"""
Update field items.
:param ctx_id: Context identifier.
:param field_name: Field name to load from `TURNS` table.
:param items: List of tuples that will be written (step number, serialized value or `None`).
"""

if len(items) == 0:
logger.debug(f"No fields to update in {ctx_id}, {field_name}!")
return
Expand All @@ -258,7 +347,12 @@ async def _delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]
async def delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]) -> None:
"""
Delete field keys.
:param ctx_id: Context identifier.
:param field_name: Field name to load from `TURNS` table.
:param keys: List of keys to delete (will be just overwritten with `None`).
"""

if len(keys) == 0:
logger.debug(f"No fields to delete in {ctx_id}, {field_name}!")
return
Expand All @@ -277,6 +371,7 @@ async def clear_all(self) -> None:
"""
Clear all the chatsky tables and records.
"""

if not self.connected:
await self.connect()
logger.debug("Clearing all")
Expand Down Expand Up @@ -323,6 +418,7 @@ def context_storage_factory(path: str, **kwargs) -> DBContextStorage:
:param path: Path to the file.
"""

if path == "":
module = "memory"
_class = "MemoryContextStorage"
Expand Down
37 changes: 33 additions & 4 deletions chatsky/context_storages/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,22 @@


class SerializableStorage(BaseModel):
"""
A special serializable database implementation.
One element of this class will be used to store all the contexts, read and written to file on every turn.
"""

main: Dict[str, ContextInfo] = Field(default_factory=dict)
turns: List[Tuple[str, str, int, Optional[bytes]]] = Field(default_factory=list)


class FileContextStorage(DBContextStorage, ABC):
"""
Implements :py:class:`.DBContextStorage` with `json` as the storage format.
Implements :py:class:`.DBContextStorage` with any file-based storage format.
:param path: Target file URI. Example: `json://file.json`.
:param context_schema: Context schema for this storage.
:param serializer: Serializer that will be used for serializing contexts.
:param path: Target file URI.
:param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not.
:param partial_read_config: Dictionary of subscripts for all possible turn items.
"""

is_concurrent: bool = False
Expand Down Expand Up @@ -116,6 +121,14 @@ async def _clear_all(self) -> None:


class JSONContextStorage(FileContextStorage):
"""
Implements :py:class:`.DBContextStorage` with `json` as the storage format.
:param path: Target file URI. Example: `json://file.json`.
:param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not.
:param partial_read_config: Dictionary of subscripts for all possible turn items.
"""

async def _save(self, data: SerializableStorage) -> None:
if not await isfile(self.path) or (await stat(self.path)).st_size == 0:
await makedirs(self.path.parent, exist_ok=True)
Expand All @@ -133,6 +146,14 @@ async def _load(self) -> SerializableStorage:


class PickleContextStorage(FileContextStorage):
"""
Implements :py:class:`.DBContextStorage` with `pickle` as the storage format.
:param path: Target file URI. Example: `pickle://file.pkl`.
:param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not.
:param partial_read_config: Dictionary of subscripts for all possible turn items.
"""

async def _save(self, data: SerializableStorage) -> None:
if not await isfile(self.path) or (await stat(self.path)).st_size == 0:
await makedirs(self.path.parent, exist_ok=True)
Expand All @@ -150,6 +171,14 @@ async def _load(self) -> SerializableStorage:


class ShelveContextStorage(FileContextStorage):
"""
Implements :py:class:`.DBContextStorage` with `shelve` as the storage format.
:param path: Target file URI. Example: `shelve://file.shlv`.
:param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not.
:param partial_read_config: Dictionary of subscripts for all possible turn items.
"""

_SHELVE_ROOT = "root"

def __init__(
Expand Down
14 changes: 8 additions & 6 deletions chatsky/context_storages/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
class MemoryContextStorage(DBContextStorage):
"""
Implements :py:class:`.DBContextStorage` storing contexts in memory, wthout file backend.
Uses :py:class:`.JsonSerializer` as the default serializer.
By default it sets path to an empty string.
Does not serialize any data. By default it sets path to an empty string.
Keeps data in a dictionary and two lists:
Keeps data in a dictionary and two dictionaries:
- `main`: {context_id: [created_at, turn_id, updated_at, framework_data]}
- `turns`: [context_id, turn_number, label, request, response]
- `misc`: [context_id, turn_number, misc]
- `main`: {context_id: context_info}
- `turns`: {context_id: {labels, requests, responses}}
:param path: Any string, won't be used.
:param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not.
:param partial_read_config: Dictionary of subscripts for all possible turn items.
"""

is_concurrent: bool = True
Expand Down
4 changes: 2 additions & 2 deletions chatsky/context_storages/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class MongoContextStorage(DBContextStorage):
LOGS table is stored as `COLLECTION_PREFIX_logs` collection.
:param path: Database URI. Example: `mongodb://user:password@host:port/dbname`.
:param context_schema: Context schema for this storage.
:param serializer: Serializer that will be used for serializing contexts.
:param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not.
:param partial_read_config: Dictionary of subscripts for all possible turn items.
:param collection_prefix: "namespace" prefix for the two collections created for context storing.
"""

Expand Down
17 changes: 8 additions & 9 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,17 @@ class RedisContextStorage(DBContextStorage):
"""
Implements :py:class:`.DBContextStorage` with `redis` as the database backend.
The relations between primary identifiers and active context storage keys are stored
as a redis hash ("KEY_PREFIX:index:general").
The keys of active contexts are stored as redis sets ("KEY_PREFIX:index:subindex:PRIMARY_ID").
The main context info is stored in redis hashes, one for each context.
The `TURNS` table values are stored in redis hashes, one for each field.
That's how CONTEXT table fields are stored:
`"KEY_PREFIX:contexts:PRIMARY_ID:FIELD": "DATA"`
That's how LOGS table fields are stored:
`"KEY_PREFIX:logs:PRIMARY_ID:FIELD": "DATA"`
That's how MAIN table fields are stored:
`"KEY_PREFIX:main:ctx_id": "DATA"`
That's how TURNS table fields are stored:
`"KEY_PREFIX:turns:ctx_id:FIELD_NAME": "DATA"`
:param path: Database URI string. Example: `redis://user:password@host:port`.
:param context_schema: Context schema for this storage.
:param serializer: Serializer that will be used for serializing contexts.
:param rewrite_existing: Whether `TURNS` modified locally should be updated in database or not.
:param partial_read_config: Dictionary of subscripts for all possible turn items.
:param key_prefix: "namespace" prefix for all keys, should be set for efficient clearing of all data.
"""

Expand Down
Loading

0 comments on commit 3d73555

Please sign in to comment.