Skip to content

Commit

Permalink
Merge pull request #463 from julep-ai/x/fix-codec
Browse files Browse the repository at this point in the history
x/fix codec
  • Loading branch information
whiterabbit1983 authored Aug 20, 2024
2 parents 95a8889 + b2fb5a9 commit b355113
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 101 deletions.
4 changes: 1 addition & 3 deletions agents-api/agents_api/activities/task_steps/evaluate_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

from beartype import beartype
from temporalio import activity

Expand All @@ -23,7 +21,7 @@ async def evaluate_step(context: StepContext) -> StepOutcome:
return result

except BaseException as e:
logging.error(f"Error in evaluate_step: {e}")
activity.logger.error(f"Error in evaluate_step: {e}")
return StepOutcome(error=str(e))


Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/activities/task_steps/if_else_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity
Expand Down Expand Up @@ -27,7 +25,7 @@ async def if_else_step(context: StepContext) -> StepOutcome:
return result

except BaseException as e:
logging.error(f"Error in if_else_step: {e}")
activity.logger.error(f"Error in if_else_step: {e}")
return StepOutcome(error=str(e))


Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity
Expand All @@ -26,7 +24,7 @@ async def log_step(context: StepContext) -> StepOutcome:
return result

except BaseException as e:
logging.error(f"Error in log_step: {e}")
activity.logger.error(f"Error in log_step: {e}")
return StepOutcome(error=str(e))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@

@activity.defn
async def raise_complete_async() -> None:
activity.heartbeat("Starting to wait")
activity.raise_complete_async()
4 changes: 1 addition & 3 deletions agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

from temporalio import activity

from ...activities.task_steps.utils import simple_eval_dict
Expand All @@ -24,7 +22,7 @@ async def return_step(context: StepContext) -> StepOutcome:
return result

except BaseException as e:
logging.error(f"Error in log_step: {e}")
activity.logger.error(f"Error in log_step: {e}")
return StepOutcome(error=str(e))


Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

from beartype import beartype
from simpleeval import simple_eval
from temporalio import activity
Expand Down Expand Up @@ -35,7 +33,7 @@ async def switch_step(context: StepContext) -> StepOutcome:
return result

except BaseException as e:
logging.error(f"Error in switch_step: {e}")
activity.logger.error(f"Error in switch_step: {e}")
return StepOutcome(error=str(e))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ async def transition_step(
transition_info.task_token = task_token

# Create transition
activity.heartbeat("Creating transition in db")
create_execution_transition_query(
developer_id=context.developer_id,
execution_id=context.execution.id,
Expand Down
3 changes: 1 addition & 2 deletions agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from typing import Callable

from beartype import beartype
Expand Down Expand Up @@ -38,7 +37,7 @@ async def yield_step(context: StepContext) -> StepOutcome:
return StepOutcome(output=arguments, transition_to=("step", transition_target))

except BaseException as e:
logging.error(f"Error in log_step: {e}")
activity.logger.error(f"Error in yield_step: {e}")
return StepOutcome(error=str(e))


Expand Down
121 changes: 71 additions & 50 deletions agents-api/agents_api/worker/codec.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,97 @@
###
### NOTE: Working with temporal's codec is really really weird
### This is a workaround to use pydantic models with temporal
### The codec is used to serialize/deserialize the data
### But this code is quite brittle. Be careful when changing it


import dataclasses
import json
import logging
import pickle
from typing import Any, Optional, Type

import openai.types as openai_types
import openai.types.chat as openai_chat_types
import temporalio.converter
from litellm.utils import ModelResponse
from pydantic import BaseModel

# from beartype import BeartypeConf
# from beartype.door import is_bearable, is_subhint
from lz4.frame import compress, decompress
from temporalio.api.common.v1 import Payload
from temporalio.converter import (
CompositePayloadConverter,
DefaultPayloadConverter,
EncodingPayloadConverter,
)

import agents_api.autogen.openapi_model as openapi_model
import agents_api.common.protocol.tasks as tasks
from agents_api.common.utils.json import dumps as json_dumps

# Map of model name to class so that we can look up the class when deserializing
model_class_map: dict = {
subclass.__module__ + "." + subclass.__name__: subclass
for subclass in {
# All the models we want to support
**openai_types.__dict__,
**openai_chat_types.__dict__,
**openapi_model.__dict__,
**tasks.__dict__,
}.values()
#
# Filter out the ones that aren't pydantic models
if isinstance(subclass, type) and issubclass(subclass, BaseModel)
}

# Also include dict
model_class_map["builtins.dict"] = dict
model_class_map["litellm.utils.ModelResponse"] = ModelResponse
def serialize(x: Any) -> bytes:
return compress(pickle.dumps(x))


def deserialize(b: bytes) -> Any:
return pickle.loads(decompress(b))


def from_payload_data(data: bytes, type_hint: Optional[Type] = None) -> Any:
decoded = deserialize(data)

if type_hint is None:
return decoded

decoded_type = type(decoded)

# FIXME: Enable this check when temporal's codec stuff is fixed
#
# # Otherwise, check if the decoded value is bearable to the type hint
# if not is_bearable(
# decoded,
# type_hint,
# conf=BeartypeConf(
# is_pep484_tower=True
# ), # Check PEP 484 type hints. (be more lax on numeric types)
# ):
# logging.warning(
# f"WARNING: Decoded value {decoded_type} is not bearable to {type_hint}"
# )

# FIXME: Enable this check when temporal's codec stuff is fixed
#
# If the decoded value is a BaseModel and the type hint is a subclass of BaseModel
# and the decoded value's class is a subclass of the type hint, then promote the decoded value
# to the type hint.
if (
type_hint != decoded_type
and hasattr(type_hint, "model_construct")
and hasattr(decoded, "model_dump")
#
# FIXME: Enable this check when temporal's codec stuff is fixed
#
# and is_subhint(type_hint, decoded_type)
):
try:
decoded = type_hint(**decoded.model_dump())
except Exception as e:
logging.warning(
f"WARNING: Could not promote {decoded_type} to {type_hint}: {e}"
)

return decoded


class PydanticEncodingPayloadConverter(EncodingPayloadConverter):
@property
def encoding(self) -> str:
return "text/pydantic-json"
encoding = "text/pickle+lz4"
b_encoding = encoding.encode()

def to_payload(self, value: Any) -> Optional[Payload]:
data: str = (
value.model_dump_json()
if hasattr(value, "model_dump_json")
else json_dumps(value)
)

return Payload(
metadata={
"encoding": self.encoding.encode(),
"model_name": value.__class__.__name__.encode(),
"model_module": value.__class__.__module__.encode(),
"encoding": self.b_encoding,
},
data=data.encode(),
data=serialize(value),
)

def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
data = json.loads(payload.data.decode())

if not isinstance(data, dict):
return data

# Otherwise, we have a model
model_name = payload.metadata["model_name"].decode()
model_module = payload.metadata["model_module"].decode()
model_class = model_class_map[model_module + "." + model_name]

return model_class(**data)
assert payload.metadata["encoding"] == self.b_encoding
return from_payload_data(payload.data, type_hint)


class PydanticPayloadConverter(CompositePayloadConverter):
Expand Down
27 changes: 3 additions & 24 deletions agents-api/agents_api/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import timedelta
from inspect import getmembers, isfunction
from typing import Any

from temporalio.client import Client
Expand All @@ -11,23 +12,12 @@ def create_worker(client: Client) -> Any:
then create a worker to listen for tasks on the configured task queue.
"""

from ..activities import task_steps
from ..activities.demo import demo_activity
from ..activities.embed_docs import embed_docs
from ..activities.mem_mgmt import mem_mgmt
from ..activities.mem_rating import mem_rating
from ..activities.summarization import summarization
from ..activities.task_steps import (
evaluate_step,
if_else_step,
log_step,
prompt_step,
return_step,
switch_step,
tool_call_step,
transition_step,
wait_for_input_step,
yield_step,
)
from ..activities.truncation import truncation
from ..env import (
temporal_task_queue,
Expand All @@ -40,18 +30,7 @@ def create_worker(client: Client) -> Any:
from ..workflows.task_execution import TaskExecutionWorkflow
from ..workflows.truncation import TruncationWorkflow

task_activities = [
evaluate_step,
if_else_step,
log_step,
prompt_step,
return_step,
switch_step,
tool_call_step,
transition_step,
wait_for_input_step,
yield_step,
]
task_activity_names, task_activities = zip(*getmembers(task_steps, isfunction))

# Initialize the worker with the specified task queue, workflows, and activities
worker = Worker(
Expand Down
22 changes: 14 additions & 8 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,26 @@
# PromptStep: prompt_step,
# ToolCallStep: tool_call_step,
WaitForInputStep: task_steps.wait_for_input_step,
LogStep: task_steps.log_step,
SwitchStep: task_steps.switch_step,
}

# Use few local activities (currently experimental)
STEP_TO_LOCAL_ACTIVITY = {
# NOTE: local activities are directly called in the workflow executor
# They MUST NOT FAIL, otherwise they will crash the workflow
# FIXME: These should be moved to local activities
# once temporal has fixed error handling for local activities
LogStep: task_steps.log_step,
EvaluateStep: task_steps.evaluate_step,
ReturnStep: task_steps.return_step,
YieldStep: task_steps.yield_step,
IfElseWorkflowStep: task_steps.if_else_step,
}

# TODO: Avoid local activities for now (currently experimental)
STEP_TO_LOCAL_ACTIVITY = {
# # NOTE: local activities are directly called in the workflow executor
# # They MUST NOT FAIL, otherwise they will crash the workflow
# EvaluateStep: task_steps.evaluate_step,
# ReturnStep: task_steps.return_step,
# YieldStep: task_steps.yield_step,
# IfElseWorkflowStep: task_steps.if_else_step,
}


@workflow.defn
class TaskExecutionWorkflow:
Expand Down Expand Up @@ -131,7 +137,7 @@ async def transition(**kwargs) -> None:
# Handle errors (activity returns None)
case step, StepOutcome(error=error) if error is not None:
raise ApplicationError(
f"{step.__class__.__name__} step threw error: {error}"
f"{type(step).__name__} step threw error: {error}"
)

case LogStep(), StepOutcome(output=output):
Expand Down

0 comments on commit b355113

Please sign in to comment.