From 9a3cfd07f5668c743561f61c5f0fccd397cd80ac Mon Sep 17 00:00:00 2001 From: Ahmad Haidar Date: Thu, 28 Nov 2024 18:43:12 +0300 Subject: [PATCH 1/2] feat(agents-api): Add label lookup for inputs and outputs --- agents-api/agents_api/common/protocol/tasks.py | 10 ++++++++++ .../agents_api/workflows/task_execution/__init__.py | 9 +++++++++ .../agents_api/workflows/task_execution/helpers.py | 2 ++ 3 files changed, 21 insertions(+) diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index b4e37f892..55f89d22d 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -156,6 +156,7 @@ class StepContext(BaseRemoteModel): execution_input: ExecutionInput | RemoteObject inputs: list[Any] | RemoteObject cursor: TransitionTarget + labels: dict[str, Any] | RemoteObject @computed_field @property @@ -244,13 +245,22 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]: def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]: current_input = self.current_input inputs = self.inputs + if activity.in_activity(): inputs = [load_from_blob_store_if_remote(input) for input in inputs] current_input = load_from_blob_store_if_remote(current_input) + inputs = {i: input for i, input in enumerate(inputs)} + inputs.update(self.labels) + + outputs = {i: output for i, output in enumerate(self.outputs)} + outputs.update(self.labels) + # Merge execution inputs into the dump dict dump = self.model_dump(*args, **kwargs) dump["inputs"] = inputs + dump["outputs"] = outputs + prepared = dump | {"_": current_input} for i, input in enumerate(inputs): diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index 5d0300e8f..cc7c1df60 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -128,6 +128,7 @@ async def run( execution_input: ExecutionInput, start: TransitionTarget = TransitionTarget(workflow="main", step=0), previous_inputs: RemoteList | None = None, + previous_labels: dict[str, Any] | None = None, ) -> Any: workflow.logger.info( f"TaskExecutionWorkflow for task {execution_input.task.id}" @@ -136,11 +137,13 @@ async def run( # 0. Prepare context previous_inputs = previous_inputs or RemoteList([execution_input.arguments]) + previous_labels = previous_labels or {} context = StepContext( execution_input=execution_input, inputs=previous_inputs, cursor=start, + labels=previous_labels, ) step_type = type(context.current_step) @@ -670,6 +673,11 @@ async def run( retry_policy=DEFAULT_RETRY_POLICY, ) + current_label = context.current_step.label + + if current_label: + previous_labels[current_label] = final_output + previous_inputs.append(final_output) # Continue as a child workflow @@ -677,5 +685,6 @@ async def run( context.execution_input, start=final_state.next, previous_inputs=previous_inputs, + previous_labels=previous_labels, user_state=state.user_state, ) diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 09ecb6150..e7ca2369e 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -29,6 +29,7 @@ async def continue_as_child( execution_input: ExecutionInput, start: TransitionTarget, previous_inputs: RemoteList | list[Any], + previous_labels: dict[str, Any] | None = None, user_state: dict[str, Any] = {}, ) -> Any: info = workflow.info() @@ -45,6 +46,7 @@ async def continue_as_child( execution_input, start, previous_inputs, + previous_labels, ], retry_policy=DEFAULT_RETRY_POLICY, memo=workflow.memo() | user_state, From 02a0886439a3cddbf045639647534579bbd5fd46 Mon Sep 17 00:00:00 2001 From: Ahmad-mtos Date: Thu, 28 Nov 2024 15:44:14 +0000 Subject: [PATCH 2/2] refactor: Lint agents-api (CI) --- agents-api/agents_api/activities/task_steps/prompt_step.py | 2 -- agents-api/agents_api/routers/sessions/chat.py | 5 ----- 2 files changed, 7 deletions(-) diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index 5e58a5d08..b6ed96c52 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -5,11 +5,9 @@ from langchain_core.tools import BaseTool from langchain_core.tools.convert import tool as tool_decorator from litellm.types.utils import ModelResponse -from litellm.types.utils import ModelResponse from temporalio import activity from temporalio.exceptions import ApplicationError -from ...autogen.openapi_model import Tool from ...autogen.openapi_model import Tool from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index cbfab4fdd..ecd94686d 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -1,10 +1,7 @@ from datetime import datetime from typing import Annotated, Callable, Optional -from datetime import datetime -from typing import Annotated, Callable, Optional from uuid import UUID, uuid4 -from fastapi import BackgroundTasks, Depends, Header, HTTPException, status from fastapi import BackgroundTasks, Depends, Header, HTTPException, status from starlette.status import HTTP_201_CREATED @@ -22,12 +19,10 @@ from ...common.utils.template import render_template from ...dependencies.developer_id import get_developer_data from ...env import max_free_sessions -from ...env import max_free_sessions from ...models.chat.gather_messages import gather_messages from ...models.chat.prepare_chat_context import prepare_chat_context from ...models.entry.create_entries import create_entries from ...models.session.count_sessions import count_sessions as count_sessions_query -from ...models.session.count_sessions import count_sessions as count_sessions_query from .metrics import total_tokens_per_user from .router import router