diff --git a/agents-api/agents_api/worker/codec.py b/agents-api/agents_api/worker/codec.py index cdd7e3448..d56b81de1 100644 --- a/agents-api/agents_api/worker/codec.py +++ b/agents-api/agents_api/worker/codec.py @@ -1,12 +1,20 @@ +### +### NOTE: Working with temporal's codec is really really weird +### This is a workaround to use pydantic models with temporal +### The codec is used to serialize/deserialize the data +### But this code is quite brittle. Be careful when changing it + + import dataclasses -import json +import logging +import pickle from typing import Any, Optional, Type -import openai.types as openai_types -import openai.types.chat as openai_chat_types import temporalio.converter -from litellm.utils import ModelResponse -from pydantic import BaseModel + +# from beartype import BeartypeConf +# from beartype.door import is_bearable, is_subhint +from lz4.frame import compress, decompress from temporalio.api.common.v1 import Payload from temporalio.converter import ( CompositePayloadConverter, @@ -14,63 +22,76 @@ EncodingPayloadConverter, ) -import agents_api.autogen.openapi_model as openapi_model -import agents_api.common.protocol.tasks as tasks -from agents_api.common.utils.json import dumps as json_dumps - -# Map of model name to class so that we can look up the class when deserializing -model_class_map: dict = { - subclass.__module__ + "." + subclass.__name__: subclass - for subclass in { - # All the models we want to support - **openai_types.__dict__, - **openai_chat_types.__dict__, - **openapi_model.__dict__, - **tasks.__dict__, - }.values() - # - # Filter out the ones that aren't pydantic models - if isinstance(subclass, type) and issubclass(subclass, BaseModel) -} -# Also include dict -model_class_map["builtins.dict"] = dict -model_class_map["litellm.utils.ModelResponse"] = ModelResponse +def serialize(x: Any) -> bytes: + return compress(pickle.dumps(x)) + + +def deserialize(b: bytes) -> Any: + return pickle.loads(decompress(b)) + + +def from_payload_data(data: bytes, type_hint: Optional[Type] = None) -> Any: + decoded = deserialize(data) + + if type_hint is None: + return decoded + + decoded_type = type(decoded) + + # FIXME: Enable this check when temporal's codec stuff is fixed + # + # # Otherwise, check if the decoded value is bearable to the type hint + # if not is_bearable( + # decoded, + # type_hint, + # conf=BeartypeConf( + # is_pep484_tower=True + # ), # Check PEP 484 type hints. (be more lax on numeric types) + # ): + # logging.warning( + # f"WARNING: Decoded value {decoded_type} is not bearable to {type_hint}" + # ) + + # FIXME: Enable this check when temporal's codec stuff is fixed + # + # If the decoded value is a BaseModel and the type hint is a subclass of BaseModel + # and the decoded value's class is a subclass of the type hint, then promote the decoded value + # to the type hint. + if ( + type_hint != decoded_type + and hasattr(type_hint, "model_construct") + and hasattr(decoded, "model_dump") + # + # FIXME: Enable this check when temporal's codec stuff is fixed + # + # and is_subhint(type_hint, decoded_type) + ): + try: + decoded = type_hint(**decoded.model_dump()) + except Exception as e: + logging.warning( + f"WARNING: Could not promote {decoded_type} to {type_hint}: {e}" + ) + + return decoded class PydanticEncodingPayloadConverter(EncodingPayloadConverter): - @property - def encoding(self) -> str: - return "text/pydantic-json" + encoding = "text/pickle+lz4" + b_encoding = encoding.encode() def to_payload(self, value: Any) -> Optional[Payload]: - data: str = ( - value.model_dump_json() - if hasattr(value, "model_dump_json") - else json_dumps(value) - ) - return Payload( metadata={ - "encoding": self.encoding.encode(), - "model_name": value.__class__.__name__.encode(), - "model_module": value.__class__.__module__.encode(), + "encoding": self.b_encoding, }, - data=data.encode(), + data=serialize(value), ) def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any: - data = json.loads(payload.data.decode()) - - if not isinstance(data, dict): - return data - - # Otherwise, we have a model - model_name = payload.metadata["model_name"].decode() - model_module = payload.metadata["model_module"].decode() - model_class = model_class_map[model_module + "." + model_name] - - return model_class(**data) + assert payload.metadata["encoding"] == self.b_encoding + return from_payload_data(payload.data, type_hint) class PydanticPayloadConverter(CompositePayloadConverter):