Skip to content

Commit

Permalink
fix(agents-api): Fix the codec, which was causing a lot of bugs
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 20, 2024
1 parent bd83b5c commit b2fb5a9
Showing 1 changed file with 71 additions and 50 deletions.
121 changes: 71 additions & 50 deletions agents-api/agents_api/worker/codec.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,97 @@
###
### NOTE: Working with temporal's codec is really really weird
### This is a workaround to use pydantic models with temporal
### The codec is used to serialize/deserialize the data
### But this code is quite brittle. Be careful when changing it


import dataclasses
import json
import logging
import pickle
from typing import Any, Optional, Type

import openai.types as openai_types
import openai.types.chat as openai_chat_types
import temporalio.converter
from litellm.utils import ModelResponse
from pydantic import BaseModel

# from beartype import BeartypeConf
# from beartype.door import is_bearable, is_subhint
from lz4.frame import compress, decompress
from temporalio.api.common.v1 import Payload
from temporalio.converter import (
CompositePayloadConverter,
DefaultPayloadConverter,
EncodingPayloadConverter,
)

import agents_api.autogen.openapi_model as openapi_model
import agents_api.common.protocol.tasks as tasks
from agents_api.common.utils.json import dumps as json_dumps

# Map of model name to class so that we can look up the class when deserializing
model_class_map: dict = {
subclass.__module__ + "." + subclass.__name__: subclass
for subclass in {
# All the models we want to support
**openai_types.__dict__,
**openai_chat_types.__dict__,
**openapi_model.__dict__,
**tasks.__dict__,
}.values()
#
# Filter out the ones that aren't pydantic models
if isinstance(subclass, type) and issubclass(subclass, BaseModel)
}

# Also include dict
model_class_map["builtins.dict"] = dict
model_class_map["litellm.utils.ModelResponse"] = ModelResponse
def serialize(x: Any) -> bytes:
return compress(pickle.dumps(x))


def deserialize(b: bytes) -> Any:
return pickle.loads(decompress(b))


def from_payload_data(data: bytes, type_hint: Optional[Type] = None) -> Any:
decoded = deserialize(data)

if type_hint is None:
return decoded

decoded_type = type(decoded)

# FIXME: Enable this check when temporal's codec stuff is fixed
#
# # Otherwise, check if the decoded value is bearable to the type hint
# if not is_bearable(
# decoded,
# type_hint,
# conf=BeartypeConf(
# is_pep484_tower=True
# ), # Check PEP 484 type hints. (be more lax on numeric types)
# ):
# logging.warning(
# f"WARNING: Decoded value {decoded_type} is not bearable to {type_hint}"
# )

# FIXME: Enable this check when temporal's codec stuff is fixed
#
# If the decoded value is a BaseModel and the type hint is a subclass of BaseModel
# and the decoded value's class is a subclass of the type hint, then promote the decoded value
# to the type hint.
if (
type_hint != decoded_type
and hasattr(type_hint, "model_construct")
and hasattr(decoded, "model_dump")
#
# FIXME: Enable this check when temporal's codec stuff is fixed
#
# and is_subhint(type_hint, decoded_type)
):
try:
decoded = type_hint(**decoded.model_dump())
except Exception as e:
logging.warning(
f"WARNING: Could not promote {decoded_type} to {type_hint}: {e}"
)

return decoded


class PydanticEncodingPayloadConverter(EncodingPayloadConverter):
@property
def encoding(self) -> str:
return "text/pydantic-json"
encoding = "text/pickle+lz4"
b_encoding = encoding.encode()

def to_payload(self, value: Any) -> Optional[Payload]:
data: str = (
value.model_dump_json()
if hasattr(value, "model_dump_json")
else json_dumps(value)
)

return Payload(
metadata={
"encoding": self.encoding.encode(),
"model_name": value.__class__.__name__.encode(),
"model_module": value.__class__.__module__.encode(),
"encoding": self.b_encoding,
},
data=data.encode(),
data=serialize(value),
)

def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
data = json.loads(payload.data.decode())

if not isinstance(data, dict):
return data

# Otherwise, we have a model
model_name = payload.metadata["model_name"].decode()
model_module = payload.metadata["model_module"].decode()
model_class = model_class_map[model_module + "." + model_name]

return model_class(**data)
assert payload.metadata["encoding"] == self.b_encoding
return from_payload_data(payload.data, type_hint)


class PydanticPayloadConverter(CompositePayloadConverter):
Expand Down

0 comments on commit b2fb5a9

Please sign in to comment.