Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agents-api): Fix task execution logical errors #483

Merged
merged 7 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions agents-api/agents_api/activities/task_steps/base_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ async def base_evaluate(
k: v.model_dump() if isinstance(v, BaseModel) else v for k, v in values.items()
}

# TODO: We should make this frozen_box=True, but we need to make sure that
# we don't break anything
values = Box(values, frozen_box=False, conversion_box=False)
# frozen_box doesn't work coz we need some mutability in the values
values = Box(values, frozen_box=False, conversion_box=True)

evaluator = get_evaluator(names=values)

Expand Down
15 changes: 3 additions & 12 deletions agents-api/agents_api/activities/task_steps/transition_step.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from uuid import uuid4

from beartype import beartype
from temporalio import activity

from ...autogen.openapi_model import CreateTransitionRequest, Transition
from ...common.protocol.tasks import StepContext
from ...env import testing
from ...models.execution.create_execution_transition import (
create_execution_transition as create_execution_transition_query,
)
from ...models.execution.create_execution_transition import create_execution_transition


@beartype
Expand All @@ -24,7 +20,7 @@ async def transition_step(
transition_info.task_token = task_token

# Create transition
transition = create_execution_transition_query(
transition = create_execution_transition(
developer_id=context.execution_input.developer_id,
execution_id=context.execution_input.execution.id,
task_id=context.execution_input.task.id,
Expand All @@ -34,13 +30,8 @@ async def transition_step(

return transition

async def mock_transition_step(
context: StepContext,
transition_info: CreateTransitionRequest,
) -> None:
# Does nothing
return None

mock_transition_step = transition_step

transition_step = activity.defn(name="transition_step")(
transition_step if not testing else mock_transition_step
Expand Down
22 changes: 20 additions & 2 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,31 @@
from simpleeval import EvalWithCompoundTypes, SimpleEval
from yaml import CSafeLoader

# TODO: We need to make sure that we dont expose any security issues
ALLOWED_FUNCTIONS = {
"zip": zip,
"abs": abs,
"all": all,
"any": any,
"bool": bool,
"dict": dict,
"enumerate": enumerate,
"float": float,
"frozenset": frozenset,
"int": int,
"len": len,
"list": list,
"load_json": json.loads,
"load_yaml": lambda string: yaml.load(string, Loader=CSafeLoader),
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
"max": max,
"min": min,
"round": round,
"search_regex": lambda pattern, string: re2.search(pattern, string),
"load_json": json.loads,
"set": set,
"str": str,
"sum": sum,
"tuple": tuple,
"zip": zip,
}


Expand Down
116 changes: 60 additions & 56 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ruff: noqa: F401, F403, F405
from typing import Annotated, Any, Generic, Literal, Self, Type, TypeVar
from typing import Annotated, Any, Generic, Literal, Self, Type, TypeVar, get_args
from uuid import UUID

from litellm.utils import _select_tokenizer as select_tokenizer
Expand Down Expand Up @@ -32,24 +32,45 @@ class ListResponse(BaseModel, Generic[DataT]):
# Aliases
# -------

CreateToolRequest = UpdateToolRequest
CreateOrUpdateAgentRequest = UpdateAgentRequest
CreateOrUpdateUserRequest = UpdateUserRequest
CreateOrUpdateSessionRequest = CreateSessionRequest

class CreateToolRequest(UpdateToolRequest):
pass


class CreateOrUpdateAgentRequest(UpdateAgentRequest):
pass


class CreateOrUpdateUserRequest(UpdateUserRequest):
pass


class CreateOrUpdateSessionRequest(CreateSessionRequest):
pass


ChatResponse = ChunkChatResponse | MessageChatResponse

# TODO: Figure out wtf... 🤷‍♂️
MapReduceStep = Main
ChatMLTextContentPart = Content
ChatMLImageContentPart = ContentModel
InputChatMLMessage = Message

class MapReduceStep(Main):
pass


class ChatMLTextContentPart(Content):
pass


class ChatMLImageContentPart(ContentModel):
pass


class InputChatMLMessage(Message):
pass


# Custom types (not generated correctly)
# --------------------------------------

# TODO: Remove these when auto-population is fixed

ChatMLContent = (
list[ChatMLTextContentPart | ChatMLImageContentPart]
| Tool
Expand All @@ -65,48 +86,23 @@ class ListResponse(BaseModel, Generic[DataT]):
]
)

ChatMLRole = Literal[
"user",
"assistant",
"system",
"function",
"function_response",
"function_call",
"auto",
]
assert BaseEntry.model_fields["role"].annotation == ChatMLRole

ChatMLSource = Literal[
"api_request", "api_response", "tool_response", "internal", "summarizer", "meta"
]
assert BaseEntry.model_fields["source"].annotation == ChatMLSource


ExecutionStatus = Literal[
"queued",
"starting",
"running",
"awaiting_input",
"succeeded",
"failed",
"cancelled",
]
assert Execution.model_fields["status"].annotation == ExecutionStatus


TransitionType = Literal[
"init",
"init_branch",
"finish",
"finish_branch",
"wait",
"resume",
"error",
"step",
"cancelled",
]

assert Transition.model_fields["type"].annotation == TransitionType
# Extract ChatMLRole
ChatMLRole = BaseEntry.model_fields["role"].annotation

# Extract ChatMLSource
ChatMLSource = BaseEntry.model_fields["source"].annotation

# Extract ExecutionStatus
ExecutionStatus = Execution.model_fields["status"].annotation

# Extract TransitionType
TransitionType = Transition.model_fields["type"].annotation

# Assertions to ensure consistency (optional, but recommended for runtime checks)
assert ChatMLRole == BaseEntry.model_fields["role"].annotation
assert ChatMLSource == BaseEntry.model_fields["source"].annotation
assert ExecutionStatus == Execution.model_fields["status"].annotation
assert TransitionType == Transition.model_fields["type"].annotation


# Create models
Expand Down Expand Up @@ -155,8 +151,8 @@ def from_model_input(
)


# Task related models
# -------------------
# Workflow related models
# -----------------------

WorkflowStep = (
EvaluateStep
Expand Down Expand Up @@ -185,6 +181,10 @@ class Workflow(BaseModel):
steps: list[WorkflowStep]


# Task spec helper models
# ----------------------


class TaskToolDef(BaseModel):
type: str
name: str
Expand Down Expand Up @@ -223,6 +223,10 @@ class Task(_Task):
)


# Patch some models to allow extra fields
# --------------------------------------


_CreateTaskRequest = CreateTaskRequest


Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/clients/cozo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Dict

from pycozo.client import Client

Expand All @@ -10,7 +10,7 @@
options.update({"auth": cozo_auth})


def get_cozo_client() -> Any:
def get_cozo_client() -> Client:
client = getattr(app.state, "cozo_client", Client("http", options=options))
if not hasattr(app.state, "cozo_client"):
app.state.cozo_client = client
Expand Down
70 changes: 62 additions & 8 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,53 @@
WorkflowStep,
)

### NOTE: Here, "init" is NOT a real state, but a placeholder for the start state of the state machine
# TODO: Maybe we should use a library for this

# State Machine
#
# init -> wait | error | step | cancelled | init_branch | finish
# init_branch -> wait | error | step | cancelled | finish_branch
# wait -> resume | error | cancelled
# resume -> wait | error | cancelled | step | finish | finish_branch | init_branch
# step -> wait | error | cancelled | step | finish | finish_branch | init_branch
# finish_branch -> wait | error | cancelled | step | finish | init_branch
# error ->

## Mermaid Diagram
# ```mermaid
# ---
# title: Execution state machine
# ---
# stateDiagram-v2
# [*] --> queued
# queued --> starting
# queued --> cancelled
# starting --> cancelled
# starting --> failed
# starting --> running
# running --> running
# running --> awaiting_input
# running --> cancelled
# running --> failed
# running --> succeeded
# awaiting_input --> running
# awaiting_input --> cancelled
# cancelled --> [*]
# succeeded --> [*]
# failed --> [*]

# ```
# TODO: figure out how to type this
valid_transitions: dict[TransitionType, list[TransitionType]] = {
# Start state
"init": ["wait", "error", "step", "cancelled", "init_branch"],
"init_branch": ["wait", "error", "step", "cancelled"],
"init": ["wait", "error", "step", "cancelled", "init_branch", "finish"],
"init_branch": ["wait", "error", "step", "cancelled", "finish_branch"],
# End states
"finish": [],
"error": [],
"cancelled": [],
# Intermediate states
"wait": ["resume", "error", "cancelled"],
"wait": ["resume", "cancelled"],
"resume": [
"wait",
"error",
Expand All @@ -59,8 +95,13 @@
} # type: ignore

valid_previous_statuses: dict[ExecutionStatus, list[ExecutionStatus]] = {
"running": ["queued", "starting", "awaiting_input"],
"running": ["starting", "awaiting_input", "running"],
"starting": ["queued"],
"queued": [],
"awaiting_input": ["starting", "running"],
"cancelled": ["queued", "starting", "awaiting_input", "running"],
"succeeded": ["starting", "running"],
"failed": ["starting", "running"],
} # type: ignore

transition_to_execution_status: dict[TransitionType | None, ExecutionStatus] = {
Expand Down Expand Up @@ -100,12 +141,12 @@ class StepContext(BaseModel):

@computed_field
@property
def outputs(self) -> Annotated[list[dict[str, Any]], Field(exclude=True)]:
def outputs(self) -> list[dict[str, Any]]: # included in dump
return self.inputs[1:]

@computed_field
@property
def current_input(self) -> Annotated[dict[str, Any], Field(exclude=True)]:
def current_input(self) -> dict[str, Any]: # included in dump
return self.inputs[-1]

@computed_field
Expand All @@ -130,9 +171,22 @@ def is_last_step(self) -> Annotated[bool, Field(exclude=True)]:
def is_first_step(self) -> Annotated[bool, Field(exclude=True)]:
return self.cursor.step == 0

@computed_field
@property
def is_main(self) -> Annotated[bool, Field(exclude=True)]:
return self.cursor.workflow == "main"

def model_dump(self, *args, **kwargs) -> dict[str, Any]:
dump = super().model_dump(*args, **kwargs)
dump["_"] = self.current_input

# Merge execution inputs into the dump dict
execution_input: dict = dump.pop("execution_input")
current_input: Any = dump.pop("current_input")
dump = {
**dump,
**execution_input,
"_": current_input,
}

return dump

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def create_execution_transition(
?[valid] :=
matched[prev_transitions],
found = length(prev_transitions),
valid = assert(found > 0, "Invalid transition"),
valid = if($next_type == "init", found == 0, found > 0),
assert(valid, "Invalid transition"),
"""

# Prepare the insert query
Expand Down
Loading
Loading