Skip to content

Commit

Permalink
Implement ToolCallStep & Fix transition after PromptStep (#513)
Browse files Browse the repository at this point in the history
<!-- ELLIPSIS_HIDDEN -->



> [!IMPORTANT]
> Implement `ToolCallStep` and fix transition logic after `PromptStep`
in workflow execution.
> 
>   - **ToolCallStep Implementation**:
> - Implements `tool_call_step()` in `tool_call_step.py` to handle tool
calls, including generating a call ID and validating tool names.
> - Updates `STEP_TO_ACTIVITY` in `task_execution/__init__.py` to map
`ToolCallStep` to `tool_call_step()`.
>   - **PromptStep Transition Fix**:
> - Updates transition logic in `task_execution/__init__.py` to handle
tool calls after a `PromptStep`.
> - Removes unused code related to tool calls in `PromptStep` handling.
>   - **State Machine Update**:
> - Updates `valid_transitions` in `tasks.py` to allow 'wait' to 'step'
transition.
> 
> <sup>This description was created by </sup>[<img alt="Ellipsis"
src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=julep-ai%2Fjulep&utm_source=github&utm_medium=referral)<sup>
for 5ab9e3a. It will automatically
update as commits are pushed.</sup>

<!-- ELLIPSIS_HIDDEN -->
  • Loading branch information
HamadaSalhab authored and creatorrr committed Sep 30, 2024
1 parent 416459c commit 8ae6038
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 34 deletions.
1 change: 0 additions & 1 deletion agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ async def prompt_step(context: StepContext) -> StepOutcome:
context_data,
skip_vars=["developer_id"],
)

# Get settings and run llm
agent_default_settings: dict = (
context.execution_input.agent.default_settings.model_dump()
Expand Down
45 changes: 37 additions & 8 deletions agents-api/agents_api/activities/task_steps/tool_call_step.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,49 @@
import base64
import secrets

from beartype import beartype
from temporalio import activity

from ...activities.task_steps import base_evaluate
from ...autogen.openapi_model import ToolCallStep
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)


def generate_call_id():
# Generate 18 random bytes (which will result in 24 base64 characters)
random_bytes = secrets.token_bytes(18)
# Encode to base64 and remove padding
base64_string = base64.urlsafe_b64encode(random_bytes).decode("ascii").rstrip("=")
# Add the "call_" prefix
return f"call_{base64_string}"


@activity.defn
@beartype
async def tool_call_step(context: StepContext) -> dict:
raise NotImplementedError()
# assert isinstance(context.current_step, ToolCallStep)
async def tool_call_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, ToolCallStep)

tool_type, tool_name = context.current_step.tool.split(".")
arguments = await base_evaluate(
context.current_step.arguments, context.model_dump()
)

tools = context.execution_input.tools

assert tool_name in [tool.name for tool in tools], f"Tool {tool_name} not found"

call_id = generate_call_id()

# context.current_step.tool_id
# context.current_step.arguments
# # get tool by id
# # call tool
tool_call = {
tool_type: {
"arguments": arguments,
"name": tool_name,
},
"id": call_id,
"type": tool_type,
}

# return {}
return StepOutcome(output=tool_call)
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"error": [],
"cancelled": [],
# Intermediate states
"wait": ["resume", "cancelled", "finish", "finish_branch"],
"wait": ["resume", "step", "cancelled", "finish", "finish_branch"],
"resume": [
"wait",
"error",
Expand Down
36 changes: 12 additions & 24 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
# Mapping of step types to their corresponding activities
STEP_TO_ACTIVITY = {
PromptStep: task_steps.prompt_step,
# ToolCallStep: tool_call_step,
ToolCallStep: task_steps.tool_call_step,
WaitForInputStep: task_steps.wait_for_input_step,
SwitchStep: task_steps.switch_step,
LogStep: task_steps.log_step,
Expand Down Expand Up @@ -389,10 +389,7 @@ async def run(

state = PartialTransition(type="resume", output=result)

case PromptStep(), StepOutcome(
output=response
): # FIXME: if not response.choices[0].tool_calls:
# SCRUM-15
case PromptStep(), StepOutcome(output=response):
workflow.logger.debug(f"Prompt step: Received response: {response}")
if response["choices"][0]["finish_reason"] != "tool_calls":
workflow.logger.debug("Prompt step: Received response")
Expand Down Expand Up @@ -421,19 +418,6 @@ async def run(
)
state = PartialTransition(output=new_response.output, type="resume")

# case PromptStep(), StepOutcome(
# output=response
# ): # FIXME: if response.choices[0].tool_calls:
# # SCRUM-15
# workflow.logger.debug("Prompt step: Received response")
#
# ## First, enter a wait-for-input step and ask developer to run the tool calls
# ## Then, continue the workflow with the input received from the developer
# ## This will be a dict with the tool call name as key and the tool call arguments as value
# ## The prompt is run again with the tool call arguments as input
# ## And the result is returned
# ## If model asks for more tool calls, repeat the process

case SetStep(), StepOutcome(output=evaluated_output):
workflow.logger.info("Set step: Updating user state")
self.update_user_state(evaluated_output)
Expand Down Expand Up @@ -466,11 +450,15 @@ async def run(
workflow.logger.error("ParallelStep not yet implemented")
raise ApplicationError("Not implemented")

case ToolCallStep(), _:
# FIXME: Implement ToolCallStep
# SCRUM-16
workflow.logger.error("ToolCallStep not yet implemented")
raise ApplicationError("Not implemented")
case ToolCallStep(), StepOutcome(output=tool_call):
# Enter a wait-for-input step to ask the developer to run the tool calls
tool_call_response = await workflow.execute_activity(
task_steps.raise_complete_async,
args=[context, tool_call],
schedule_to_close_timeout=timedelta(days=31),
)

state = PartialTransition(output=tool_call_response, type="resume")

case _:
workflow.logger.error(
Expand Down Expand Up @@ -502,7 +490,7 @@ async def run(

# Continue as a child workflow
return await continue_as_child(
context,
context.execution_input,
start=final_state.next,
previous_inputs=previous_inputs + [final_state.output],
user_state=self.user_state,
Expand Down

0 comments on commit 8ae6038

Please sign in to comment.