Skip to content

Commit

Permalink
fix: Align all the functions according to the new models
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Jul 30, 2024
1 parent 4808a00 commit 9fd1a49
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 45 deletions.
19 changes: 9 additions & 10 deletions agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
from uuid import uuid4

from openai.types.chat.chat_completion import ChatCompletion

from simpleeval import simple_eval
from temporalio import activity

from ...autogen.openapi_model import (
PromptWorkflowStep,
EvaluateWorkflowStep,
ToolCallWorkflowStep,
EvaluateStep,
# ErrorWorkflowStep,
IfElseWorkflowStep,
InputChatMLMessage,
YieldWorkflowStep,
PromptStep,
ToolCallStep,
YieldStep,
)
from ...clients.worker.types import ChatML
from ...common.protocol.tasks import (
Expand All @@ -22,15 +21,15 @@
)
from ...common.utils.template import render_template
from ...models.execution.create_execution_transition import (
create_execution_transition_query,
create_execution_transition as create_execution_transition_query,
)
from ...routers.sessions.protocol import Settings
from ...routers.sessions.session import llm_generate


@activity.defn
async def prompt_step(context: StepContext) -> dict:
assert isinstance(context.definition, PromptWorkflowStep)
assert isinstance(context.definition, PromptStep)

# Get context data
context_data: dict = context.model_dump()
Expand Down Expand Up @@ -63,7 +62,7 @@ async def prompt_step(context: StepContext) -> dict:

@activity.defn
async def evaluate_step(context: StepContext) -> dict:
assert isinstance(context.definition, EvaluateWorkflowStep)
assert isinstance(context.definition, EvaluateStep)

# FIXME: set the field to keep source code
source: str = context.definition.evaluate
Expand All @@ -76,7 +75,7 @@ async def evaluate_step(context: StepContext) -> dict:

@activity.defn
async def yield_step(context: StepContext) -> dict:
if not isinstance(context.definition, YieldWorkflowStep):
if not isinstance(context.definition, YieldStep):
return {}

# TODO: implement
Expand All @@ -86,7 +85,7 @@ async def yield_step(context: StepContext) -> dict:

@activity.defn
async def tool_call_step(context: StepContext) -> dict:
assert isinstance(context.definition, ToolCallWorkflowStep)
assert isinstance(context.definition, ToolCallStep)

context.definition.tool_id
context.definition.arguments
Expand Down
7 changes: 7 additions & 0 deletions agents-api/agents_api/models/task/create_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import sys
from datetime import datetime
from uuid import UUID, uuid4

from fastapi import HTTPException
Expand Down Expand Up @@ -51,8 +52,11 @@ class TaskSpec(TypedDict):
tools: NotRequired[list[TaskToolDef]]
metadata: dict
workflows: list[Workflow]
created_at: datetime


# FIXME: resolve this typing issue
# pytype: disable=bad-return-type
def task_to_spec(
task: Task | CreateTaskRequest | UpdateTaskRequest, **model_opts
) -> TaskSpec:
Expand All @@ -78,6 +82,9 @@ def task_to_spec(
)


# pytype: enable=bad-return-type


def spec_to_task_data(spec: dict) -> dict:
task_id = spec.pop("task_id", None)

Expand Down
47 changes: 27 additions & 20 deletions agents-api/agents_api/routers/tasks/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from starlette.status import HTTP_201_CREATED

from agents_api.autogen.openapi_model import (
CreateExecution,
CreateTask,
CreateExecutionRequest,
CreateTaskRequest,
Execution,
ExecutionTransition,
ResourceCreatedResponse,
ResourceUpdatedResponse,
Task,
Transition,
UpdateExecutionRequest,
)
from agents_api.clients.cozo import client as cozo_client
from agents_api.clients.temporal import run_task_execution_workflow
Expand All @@ -29,11 +30,13 @@
get_execution_transition_query,
)
from agents_api.models.execution.list_execution_transitions import (
list_execution_transitions_query,
list_execution_transitions as list_execution_transitions_query,
)
from agents_api.models.execution.list_executions import list_task_executions_query
from agents_api.models.execution.update_execution_status import (
update_execution_status_query,
from agents_api.models.execution.list_executions import (
list_executions as list_task_executions_query,
)
from agents_api.models.execution.update_execution import (
update_execution as update_execution_status_query,
)
from agents_api.models.execution.update_execution_transition import (
update_execution_transition_query,
Expand All @@ -55,7 +58,7 @@ class ExecutionList(BaseModel):


class ExecutionTransitionList(BaseModel):
items: list[ExecutionTransition]
items: list[Transition]


router = APIRouter()
Expand Down Expand Up @@ -88,7 +91,7 @@ async def list_tasks(

@router.post("/agents/{agent_id}/tasks", status_code=HTTP_201_CREATED, tags=["tasks"])
async def create_task(
request: CreateTask,
request: CreateTaskRequest,
agent_id: UUID4,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
Expand All @@ -107,7 +110,7 @@ async def create_task(
name=request.name,
description=request.description,
input_schema=request.input_schema or {},
tools_available=request.tools_available or [],
tools_available=request.tools or [],
workflows=workflows,
)

Expand Down Expand Up @@ -158,7 +161,7 @@ async def get_task(
async def create_task_execution(
agent_id: UUID4,
task_id: UUID4,
request: CreateExecution,
request: CreateExecutionRequest,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
) -> ResourceCreatedResponse:
try:
Expand All @@ -169,7 +172,7 @@ async def create_task_execution(
).iterrows()
][0]

validate(request.arguments, task["input_schema"])
validate(request.input, task["input_schema"])
except ValidationError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -194,7 +197,7 @@ async def create_task_execution(
task_id=task_id,
execution_id=execution_id,
developer_id=x_developer_id,
arguments=request.arguments,
arguments=request.input,
)

execution_input = ExecutionInput.fetch(
Expand All @@ -213,9 +216,10 @@ async def create_task_execution(
logger.exception(e)

update_execution_status_query(
developer_id=x_developer_id,
task_id=task_id,
execution_id=execution_id,
status="failed",
data=UpdateExecutionRequest(status="failed"),
)

raise HTTPException(
Expand All @@ -230,11 +234,14 @@ async def create_task_execution(

@router.get("/agents/{agent_id}/tasks/{task_id}/executions", tags=["tasks"])
async def list_task_executions(
agent_id: UUID4,
task_id: UUID4,
x_developer_id: Annotated[UUID4, Depends(get_developer_id)],
limit: int = 100,
offset: int = 0,
) -> ExecutionList:
res = list_task_executions_query(agent_id, task_id, x_developer_id)
res = list_task_executions_query(
task_id=task_id, developer_id=x_developer_id, limit=limit, offse=offset
)
return ExecutionList(
items=[Execution(**row.to_dict()) for _, row in res.iterrows()]
)
Expand All @@ -259,15 +266,15 @@ async def get_execution(task_id: UUID4, execution_id: UUID4) -> Execution:
async def get_execution_transition(
execution_id: UUID4,
transition_id: UUID4,
) -> ExecutionTransition:
) -> Transition:
try:
res = [
row.to_dict()
for _, row in get_execution_transition_query(
execution_id, transition_id
).iterrows()
][0]
return ExecutionTransition(**res)
return Transition(**res)
except (IndexError, KeyError):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -281,7 +288,7 @@ async def get_execution_transition(
async def update_execution_transition(
execution_id: UUID4,
transition_id: UUID4,
request: ExecutionTransition,
request: Transition,
) -> ResourceUpdatedResponse:
try:
resp = update_execution_transition_query(
Expand Down Expand Up @@ -309,5 +316,5 @@ async def list_execution_transitions(
execution_id=execution_id, limit=limit, offset=offset
)
return ExecutionTransitionList(
items=[ExecutionTransition(**row.to_dict()) for _, row in res.iterrows()]
items=[Transition(**row.to_dict()) for _, row in res.iterrows()]
)
6 changes: 3 additions & 3 deletions agents-api/agents_api/worker/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from ..activities.salient_questions import salient_questions
from ..activities.summarization import summarization
from ..activities.task_steps import (
# tool_call_step,
evaluate_step,
# error_step,
if_else_step,
prompt_step,
tool_call_step,
transition_step,
# evaluate_step,
yield_step,
)
from ..activities.truncation import truncation
Expand Down Expand Up @@ -76,7 +76,7 @@ async def main():
prompt_step,
evaluate_step,
yield_step,
# tool_call_step,
tool_call_step,
# error_step,
if_else_step,
transition_step,
Expand Down
23 changes: 11 additions & 12 deletions agents-api/agents_api/workflows/task_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,22 @@

with workflow.unsafe.imports_passed_through():
from ..activities.task_steps import (
evaluate_step,
if_else_step,
prompt_step,
transition_step,
evaluate_step,
tool_call_step,
transition_step,
)
from ..common.protocol.tasks import (
EvaluateStep,
ExecutionInput,
PromptWorkflowStep,
EvaluateWorkflowStep,
ToolCallWorkflowStep,
# ErrorWorkflowStep,
IfElseWorkflowStep,
PromptWorkflowStep,
PromptStep,
StepContext,
ToolCallStep,
TransitionInfo,
YieldWorkflowStep,
YieldStep,
)


Expand Down Expand Up @@ -59,7 +58,7 @@ async def run(
should_wait = False
# Run the step
match step:
case PromptWorkflowStep():
case PromptStep():
outputs = await workflow.execute_activity(
prompt_step,
context,
Expand All @@ -70,18 +69,18 @@ async def run(
# if outputs.tool_calls is not None:
# should_wait = True

case EvaluateWorkflowStep():
case EvaluateStep():
outputs = await workflow.execute_activity(
evaluate_step,
context,
schedule_to_close_timeout=timedelta(seconds=600),
)
case YieldWorkflowStep():
case YieldStep():
outputs = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
args=[execution_input, (step.workflow, 0), previous_inputs],
)
case ToolCallWorkflowStep():
case ToolCallStep():
outputs = await workflow.execute_activity(
tool_call_step,
context,
Expand All @@ -99,7 +98,7 @@ async def run(
context,
schedule_to_close_timeout=timedelta(seconds=600),
)
workflow_step = YieldWorkflowStep(**outputs["goto_workflow"])
workflow_step = YieldStep(**outputs["goto_workflow"])

outputs = await workflow.execute_child_workflow(
TaskExecutionWorkflow.run,
Expand Down

0 comments on commit 9fd1a49

Please sign in to comment.