Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add label lookup for inputs and outputs #906

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
from langchain_core.tools import BaseTool
from langchain_core.tools.convert import tool as tool_decorator
from litellm.types.utils import ModelResponse
from litellm.types.utils import ModelResponse
from temporalio import activity
from temporalio.exceptions import ApplicationError

from ...autogen.openapi_model import Tool
from ...autogen.openapi_model import Tool
from ...clients import (
litellm, # We dont directly import `acompletion` so we can mock it
Expand Down
10 changes: 10 additions & 0 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class StepContext(BaseRemoteModel):
execution_input: ExecutionInput | RemoteObject
inputs: list[Any] | RemoteObject
cursor: TransitionTarget
labels: dict[str, Any] | RemoteObject

@computed_field
@property
Expand Down Expand Up @@ -244,13 +245,22 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]:
def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]:
current_input = self.current_input
inputs = self.inputs

if activity.in_activity():
inputs = [load_from_blob_store_if_remote(input) for input in inputs]
current_input = load_from_blob_store_if_remote(current_input)

inputs = {i: input for i, input in enumerate(inputs)}
inputs.update(self.labels)
Ahmad-mtos marked this conversation as resolved.
Show resolved Hide resolved

outputs = {i: output for i, output in enumerate(self.outputs)}
outputs.update(self.labels)

# Merge execution inputs into the dump dict
dump = self.model_dump(*args, **kwargs)
dump["inputs"] = inputs
dump["outputs"] = outputs

prepared = dump | {"_": current_input}

for i, input in enumerate(inputs):
Expand Down
5 changes: 0 additions & 5 deletions agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from datetime import datetime
from typing import Annotated, Callable, Optional
from datetime import datetime
from typing import Annotated, Callable, Optional
from uuid import UUID, uuid4

from fastapi import BackgroundTasks, Depends, Header, HTTPException, status
from fastapi import BackgroundTasks, Depends, Header, HTTPException, status
from starlette.status import HTTP_201_CREATED

Expand All @@ -22,12 +19,10 @@
from ...common.utils.template import render_template
from ...dependencies.developer_id import get_developer_data
from ...env import max_free_sessions
from ...env import max_free_sessions
from ...models.chat.gather_messages import gather_messages
from ...models.chat.prepare_chat_context import prepare_chat_context
from ...models.entry.create_entries import create_entries
from ...models.session.count_sessions import count_sessions as count_sessions_query
from ...models.session.count_sessions import count_sessions as count_sessions_query
from .metrics import total_tokens_per_user
from .router import router

Expand Down
9 changes: 9 additions & 0 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ async def run(
execution_input: ExecutionInput,
start: TransitionTarget = TransitionTarget(workflow="main", step=0),
previous_inputs: RemoteList | None = None,
previous_labels: dict[str, Any] | None = None,
) -> Any:
workflow.logger.info(
f"TaskExecutionWorkflow for task {execution_input.task.id}"
Expand All @@ -136,11 +137,13 @@ async def run(

# 0. Prepare context
previous_inputs = previous_inputs or RemoteList([execution_input.arguments])
previous_labels = previous_labels or {}

context = StepContext(
execution_input=execution_input,
inputs=previous_inputs,
cursor=start,
labels=previous_labels,
)

step_type = type(context.current_step)
Expand Down Expand Up @@ -670,12 +673,18 @@ async def run(
retry_policy=DEFAULT_RETRY_POLICY,
)

current_label = context.current_step.label

if current_label:
previous_labels[current_label] = final_output

previous_inputs.append(final_output)

# Continue as a child workflow
return await continue_as_child(
context.execution_input,
start=final_state.next,
previous_inputs=previous_inputs,
previous_labels=previous_labels,
user_state=state.user_state,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ async def continue_as_child(
execution_input: ExecutionInput,
start: TransitionTarget,
previous_inputs: RemoteList | list[Any],
previous_labels: dict[str, Any] | None = None,
user_state: dict[str, Any] = {},
) -> Any:
info = workflow.info()
Expand All @@ -45,6 +46,7 @@ async def continue_as_child(
execution_input,
start,
previous_inputs,
previous_labels,
],
retry_policy=DEFAULT_RETRY_POLICY,
memo=workflow.memo() | user_state,
Expand Down