Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Add label lookup for inputs and outputs #906

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
@@ -156,6 +156,7 @@ class StepContext(BaseRemoteModel):
execution_input: ExecutionInput | RemoteObject
inputs: list[Any] | RemoteObject
cursor: TransitionTarget
labels: dict[str, Any] | RemoteObject

@computed_field
@property
@@ -244,13 +245,22 @@ def model_dump(self, *args, **kwargs) -> dict[str, Any]:
def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]:
current_input = self.current_input
inputs = self.inputs

if activity.in_activity():
inputs = [load_from_blob_store_if_remote(input) for input in inputs]
current_input = load_from_blob_store_if_remote(current_input)

inputs = {i: input for i, input in enumerate(inputs)}
inputs.update(self.labels)
Ahmad-mtos marked this conversation as resolved.
Show resolved Hide resolved

outputs = {i: output for i, output in enumerate(self.outputs)}
outputs.update(self.labels)

# Merge execution inputs into the dump dict
dump = self.model_dump(*args, **kwargs)
dump["inputs"] = inputs
dump["outputs"] = outputs

prepared = dump | {"_": current_input}

for i, input in enumerate(inputs):
9 changes: 9 additions & 0 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
@@ -128,6 +128,7 @@ async def run(
execution_input: ExecutionInput,
start: TransitionTarget = TransitionTarget(workflow="main", step=0),
previous_inputs: RemoteList | None = None,
previous_labels: dict[str, Any] | None = None,
) -> Any:
workflow.logger.info(
f"TaskExecutionWorkflow for task {execution_input.task.id}"
@@ -136,11 +137,13 @@ async def run(

# 0. Prepare context
previous_inputs = previous_inputs or RemoteList([execution_input.arguments])
previous_labels = previous_labels or {}

context = StepContext(
execution_input=execution_input,
inputs=previous_inputs,
cursor=start,
labels=previous_labels,
)

step_type = type(context.current_step)
@@ -670,12 +673,18 @@ async def run(
retry_policy=DEFAULT_RETRY_POLICY,
)

current_label = context.current_step.label

if current_label:
previous_labels[current_label] = final_output

previous_inputs.append(final_output)

# Continue as a child workflow
return await continue_as_child(
context.execution_input,
start=final_state.next,
previous_inputs=previous_inputs,
previous_labels=previous_labels,
user_state=state.user_state,
)
2 changes: 2 additions & 0 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ async def continue_as_child(
execution_input: ExecutionInput,
start: TransitionTarget,
previous_inputs: RemoteList | list[Any],
previous_labels: dict[str, Any] | None = None,
user_state: dict[str, Any] = {},
) -> Any:
info = workflow.info()
@@ -45,6 +46,7 @@ async def continue_as_child(
execution_input,
start,
previous_inputs,
previous_labels,
],
retry_policy=DEFAULT_RETRY_POLICY,
memo=workflow.memo() | user_state,