Skip to content

Commit

Permalink
fix(agents-api): Fix execution input query
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Tomer <[email protected]>
  • Loading branch information
Diwank Tomer committed Aug 2, 2024
1 parent 9683d78 commit 52c6d2c
Show file tree
Hide file tree
Showing 56 changed files with 796 additions and 1,081 deletions.
2 changes: 1 addition & 1 deletion agents-api/agents_api/autogen/Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Annotated, Literal
from uuid import UUID

from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, RootModel
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field

from .Docs import DocReference
from .Entries import ChatMLMessage
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/autogen/Entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Annotated, Literal
from uuid import UUID

from pydantic import AnyUrl, AwareDatetime, BaseModel, ConfigDict, Field, RootModel
from pydantic import AnyUrl, AwareDatetime, BaseModel, ConfigDict, Field

from .Tools import ChosenToolCall, Tool, ToolResponse

Expand Down
88 changes: 88 additions & 0 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def make_session(
Create a new session object.
"""
cls, participants = None, {}

match (len(agents), len(users)):
case (0, _):
raise ValueError("At least one agent must be provided.")
Expand All @@ -70,3 +71,90 @@ def make_session(
participants = {"agents": agents, "users": users}

return cls(**{**data, **participants})


WorkflowStep = (
PromptStep
| EvaluateStep
| YieldStep
| ToolCallStep
| ErrorWorkflowStep
| IfElseWorkflowStep
)


class Workflow(BaseModel):
name: str
steps: list[WorkflowStep]


class TaskToolDef(BaseModel):
type: str
name: str
spec: dict
inherited: bool = False


_Task = Task


class TaskSpec(_Task):
model_config = ConfigDict(extra="ignore")

workflows: list[Workflow]
main: list[WorkflowStep] | None = None


class TaskSpecDef(TaskSpec):
id: UUID | None = None
created_at: AwareDatetime | None = None
updated_at: AwareDatetime | None = None


class PartialTaskSpecDef(TaskSpecDef):
name: str | None = None


class Task(_Task):
model_config = ConfigDict(
**{
**_Task.model_config,
"extra": "allow",
}
)


_CreateTaskRequest = CreateTaskRequest


class CreateTaskRequest(_CreateTaskRequest):
model_config = ConfigDict(
**{
**_CreateTaskRequest.model_config,
"extra": "allow",
}
)


_PatchTaskRequest = PatchTaskRequest


class PatchTaskRequest(_PatchTaskRequest):
model_config = ConfigDict(
**{
**_PatchTaskRequest.model_config,
"extra": "allow",
}
)


_UpdateTaskRequest = UpdateTaskRequest


class UpdateTaskRequest(_UpdateTaskRequest):
model_config = ConfigDict(
**{
**_UpdateTaskRequest.model_config,
"extra": "allow",
}
)
1 change: 0 additions & 1 deletion agents-api/agents_api/common/protocol/entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from ...autogen.openapi_model import (
ChatMLImageContentPart,
ChatMLRole,
ChatMLTextContentPart,
)
from ...autogen.openapi_model import (
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/common/protocol/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
GenerationPresetSettings,
OpenAISettings,
Session,
Settings,
Tool,
User,
VLLMSettings,
Expand Down
211 changes: 70 additions & 141 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
@@ -1,160 +1,36 @@
from datetime import datetime
from typing import Annotated, Any, List, Literal, Tuple
from typing import Annotated, Any, List, Tuple
from uuid import UUID

from pydantic import UUID4, BaseModel, Field, computed_field
from pydantic import BaseModel, Field

from ...autogen.openapi_model import (
Agent,
ErrorWorkflowStep,
EvaluateStep,
CreateTaskRequest,
Execution,
FunctionDef,
IfElseWorkflowStep,
PromptStep,
PartialTaskSpecDef,
PatchTaskRequest,
Session,
Task,
TaskSpec,
TaskSpecDef,
TaskToolDef,
Tool,
ToolCallStep,
UpdateTaskRequest,
User,
YieldStep,
Workflow,
WorkflowStep,
)
from ...models.execution.prepare_execution_data import prepare_execution_data
from ..utils.cozo import uuid_int_list_to_uuid4

WorkflowStep = (
PromptStep
| EvaluateStep
| YieldStep
| ToolCallStep
| ErrorWorkflowStep
| IfElseWorkflowStep
)


# Make Task serializable (created_at is a datetime)
class SerializableTask(Task):
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
dump = super().model_dump(*args, **kwargs)
dump["created_at"] = self.created_at.isoformat()

return dump

# And load it back
@classmethod
def model_load(cls, data: dict[str, Any], *args, **kwargs) -> "SerializableTask":
data["created_at"] = datetime.fromisoformat(data["created_at"])
return super().model_load(data, *args, **kwargs)


class TaskWorkflow(BaseModel):
name: str
steps: list[WorkflowStep]


class TaskSpec(BaseModel):
name: str | None
description: str | None
tools_available: list[str] | Literal["all"] | None = "all"
input_schema: dict[str, Any] | None = {}
workflows: list[TaskWorkflow]


class TaskProtocol(SerializableTask):
@classmethod
def from_cozo_data(cls, task_data: dict[str, Any]) -> "SerializableTask":
workflows = task_data.pop("workflows")
assert len(workflows) > 0

main_wf_idx, main_wf = next(
(i, wf) for i, wf in enumerate(workflows) if wf["name"] == "main"
)

task_data["main"] = main_wf["steps"]
workflows.pop(main_wf_idx)

for workflow in workflows:
task_data[workflow["name"]] = workflow["steps"]

return cls(**task_data)

@computed_field
@property
def spec(self) -> TaskSpec:
other_workflows = {
workflow_name: getattr(self, workflow_name)
for workflow_name in self.model_extra.keys()
if workflow_name not in Task.model_fields.keys() and workflow_name != "spec"
}

workflows = [
TaskWorkflow(name="main", steps=self.main),
# ... others
] + [
TaskWorkflow(name=workflow_name, steps=workflow_steps)
for workflow_name, workflow_steps in other_workflows.items()
]

return TaskSpec(
name=self.name,
description=self.description,
tools_available=self.tools_available,
input_schema=self.input_schema,
workflows=workflows,
)


class ExecutionInput(BaseModel):
developer_id: UUID4
developer_id: UUID
execution: Execution
task: TaskProtocol
task: TaskSpec
agent: Agent
user: User | None
session: Session | None
tools: list[Tool]
arguments: dict[str, Any]

@classmethod
def fetch(
cls, *, developer_id: UUID4, task_id: UUID4, execution_id: UUID4, client: Any
) -> "ExecutionInput":
[data] = prepare_execution_data(
task_id=task_id,
execution_id=execution_id,
client=client,
).to_dict(orient="records")

# FIXME: Need to manually convert id from list of int to UUID4
# because cozo has a bug with UUID4
# See: https://github.com/cozodb/cozo/issues/269
for kind in ["task", "execution", "agent", "user", "session"]:
if not data[kind]:
continue

for key in data[kind]:
if key == "id" or key.endswith("_id") and data[kind][key] is not None:
data[kind][key] = uuid_int_list_to_uuid4(data[kind][key])

agent = Agent(**data["agent"])
task = TaskProtocol.from_cozo_data(data["task"])
execution = Execution(**data["execution"])
user = User(**data["user"]) if data["user"] else None
session = Session(**data["session"]) if data["session"] else None
tools = [
Tool(type="function", id=function["id"], function=FunctionDef(**function))
for function in data["tools"]
]
arguments = execution.arguments

return cls(
developer_id=developer_id,
execution=execution,
task=task,
agent=agent,
user=user,
session=session,
tools=tools,
arguments=arguments,
)
user: User | None = None
session: Session | None = None


class StepContext(ExecutionInput):
Expand All @@ -164,7 +40,7 @@ class StepContext(ExecutionInput):
def model_dump(self, *args, **kwargs) -> dict[str, Any]:
dump = super().model_dump(*args, **kwargs)

dump["$"] = self.inputs[-1]
dump["_"] = self.inputs[-1]
dump["outputs"] = self.inputs[1:]

return dump
Expand All @@ -175,3 +51,56 @@ class TransitionInfo(BaseModel):
to: List[str | int] | None = None
type: Annotated[str, Field(pattern="^(finish|wait|error|step)$")]
outputs: dict[str, Any] | None = None


def task_to_spec(
task: Task | CreateTaskRequest | UpdateTaskRequest | PatchTaskRequest, **model_opts
) -> TaskSpecDef | PartialTaskSpecDef:
task_data = task.model_dump(**model_opts)
workflows = [Workflow(name="main", steps=task_data.pop("main"))]

for k in list(task_data.keys()):
if k in TaskSpec.model_fields.keys():
continue

steps = task_data.pop(k)
workflows.append(Workflow(name=k, steps=steps))

tools = task_data.pop("tools", [])
tools = [TaskToolDef(spec=tool.pop(tool["type"]), **tool) for tool in tools]

cls = PartialTaskSpecDef if isinstance(task, PatchTaskRequest) else TaskSpecDef
return cls(
workflows=workflows,
tools=tools,
**task_data,
)


def spec_to_task_data(spec: dict) -> dict:
task_id = spec.pop("task_id", None)

workflows = spec.pop("workflows")
workflows_dict = {workflow["name"]: workflow["steps"] for workflow in workflows}

tools = spec.pop("tools", [])
tools = [{tool["type"]: tool.pop("spec"), **tool} for tool in tools]

return {
"id": task_id,
"tools": tools,
**spec,
**workflows_dict,
}


def spec_to_task(**spec) -> Task | CreateTaskRequest:
if not spec.get("id"):
spec["id"] = spec.pop("task_id", None)

if not spec.get("updated_at"):
[updated_at_ms, _] = spec.pop("updated_at_ms", None)
spec["updated_at"] = updated_at_ms and (updated_at_ms / 1000)

cls = Task if spec["id"] else CreateTaskRequest
return cls(**spec_to_task_data(spec))
7 changes: 2 additions & 5 deletions agents-api/agents_api/models/agent/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_agent(
developer_id: UUID,
agent_id: UUID | None = None,
data: CreateAgentRequest,
) -> tuple[str, dict]:
) -> tuple[list[str], dict]:
"""
Constructs and executes a datalog query to create a new agent in the database.
Expand Down Expand Up @@ -114,11 +114,8 @@ def create_agent(
agent_query,
]

query = "}\n\n{\n".join(queries)
query = f"{{ {query} }}"

return (
query,
queries,
{
"settings_vals": settings_vals,
"agent_id": str(agent_id),
Expand Down
Loading

0 comments on commit 52c6d2c

Please sign in to comment.