Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add agent tools to completion data before sending to litellm in prompt #498

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
Binary file added agents-api/.DS_Store
Binary file not shown.
25 changes: 24 additions & 1 deletion agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from ...common.protocol.tasks import StepContext, StepOutcome
from ...common.utils.template import render_template
from ...models.tools.list_tools import list_tools


@activity.defn
Expand Down Expand Up @@ -34,6 +35,28 @@ async def prompt_step(context: StepContext) -> StepOutcome:
else "gpt-4o"
)

agent_tools = list_tools(
developer_id=context.execution_input.developer_id,
agent_id=context.execution_input.agent.id,
limit=128, # Max number of supported functions in OpenAI. See https://platform.openai.com/docs/api-reference/chat/create
offset=0,
sort_by="created_at",
direction="desc",
)

# Format agent_tools for litellm
formatted_agent_tools = [
{
"type": tool.type,
"function": {
"name": tool.function.name,
"description": tool.function.description,
"parameters": tool.function.parameters,
},
}
for tool in agent_tools
]

if context.current_step.settings:
passed_settings: dict = context.current_step.settings.model_dump(
exclude_unset=True
Expand All @@ -43,11 +66,11 @@ async def prompt_step(context: StepContext) -> StepOutcome:

completion_data: dict = {
"model": agent_model,
"tools": formatted_agent_tools or None,
("messages" if isinstance(prompt, list) else "prompt"): prompt,
**agent_default_settings,
**passed_settings,
}

response = await litellm.acompletion(
**completion_data,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
from temporalio import activity

from ...autogen.openapi_model import CreateTransitionRequest
from ...common.protocol.tasks import (
StepContext,
StepOutcome,
)
from .transition_step import original_transition_step


@activity.defn
async def raise_complete_async() -> None:
async def raise_complete_async(context: StepContext, output: StepOutcome) -> None:
# TODO: Create a transtition to "wait" and save the captured_token to the transition

captured_token = activity.info().task_token
captured_token = captured_token.decode('latin-1')
transition_info = CreateTransitionRequest(
current=context.cursor,
type="wait",
next=None,
output=output,
task_token=captured_token,
)

await original_transition_step(context, transition_info)

# await transition(context, output=output, type="wait", next=None, task_token=captured_token)

print("transition to wait called")
activity.raise_complete_async()
12 changes: 3 additions & 9 deletions agents-api/agents_api/activities/task_steps/transition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,21 @@ async def transition_step(
context: StepContext,
transition_info: CreateTransitionRequest,
) -> Transition:
need_to_wait = transition_info.type == "wait"

# Get task token if it's a waiting step
if need_to_wait:
task_token = activity.info().task_token
transition_info.task_token = task_token

# Create transition
transition = create_execution_transition(
developer_id=context.execution_input.developer_id,
execution_id=context.execution_input.execution.id,
task_id=context.execution_input.task.id,
data=transition_info,
task_token=transition_info.task_token,
update_execution_status=True,
)

return transition


original_transition_step = transition_step
mock_transition_step = transition_step

transition_step = activity.defn(name="transition_step")(
transition_step if not testing else mock_transition_step
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ async def wait_for_input_step(context: StepContext) -> StepOutcome:
try:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input
exprs = context.current_step.wait_for_input.info
output = await base_evaluate(exprs, context.model_dump())

result = StepOutcome(output=output)
Expand Down
1 change: 1 addition & 0 deletions agents-api/agents_api/autogen/openapi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class CreateTransitionRequest(Transition):
created_at: AwareDatetime | None = None
updated_at: AwareDatetime | None = None
metadata: dict[str, Any] | None = None
task_token: str | None = None


class CreateEntryRequest(BaseEntry):
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"error": [],
"cancelled": [],
# Intermediate states
"wait": ["resume", "cancelled"],
"wait": ["resume", "cancelled", "finish", "finish_branch"],
"resume": [
"wait",
"error",
Expand Down Expand Up @@ -100,7 +100,7 @@
"queued": [],
"awaiting_input": ["starting", "running"],
"cancelled": ["queued", "starting", "awaiting_input", "running"],
"succeeded": ["starting", "running"],
"succeeded": ["starting", "awaiting_input", "running"],
"failed": ["starting", "running"],
} # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def create_execution_transition(
task_id: UUID | None = None,
) -> tuple[list[str | None], dict]:
transition_id = transition_id or uuid4()

data.metadata = data.metadata or {}
data.execution_id = execution_id

Expand All @@ -111,7 +110,7 @@ def create_execution_transition(
columns, transition_values = cozo_process_mutate_data(
{
**transition_data,
"task_token": task_token,
"task_token": str(task_token), # Converting to str for JSON serialisation
"transition_id": str(transition_id),
"execution_id": str(execution_id),
}
Expand Down
5 changes: 3 additions & 2 deletions agents-api/agents_api/routers/tasks/update_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ async def update_execution(
await wf_handle.cancel()

case ResumeExecutionRequest():

token_data = get_paused_execution_token(
developer_id=x_developer_id, execution_id=execution_id
)
act_handle = temporal_client.get_async_activity_handle(
token_data["task_token"]
task_token=str.encode(token_data["task_token"], encoding="latin-1")
)
await act_handle.complete(data.input)

print("Resumed execution successfully")
case _:
raise HTTPException(status_code=400, detail="Invalid request data")
34 changes: 30 additions & 4 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,12 @@ async def run(

case WaitForInputStep(), StepOutcome(output=output):
workflow.logger.info("Wait for input step: Waiting for external input")
await transition(context, output=output, type="wait", next=None)

await transition(context, type="wait", output=output)

result = await workflow.execute_activity(
task_steps.raise_complete_async,
args=[context, output],
schedule_to_close_timeout=timedelta(days=31),
)

Expand All @@ -391,8 +393,33 @@ async def run(
output=response
): # FIXME: if not response.choices[0].tool_calls:
# SCRUM-15
workflow.logger.debug("Prompt step: Received response")
state = PartialTransition(output=response)
workflow.logger.debug(f"Prompt step: Received response: {response}")
if response["choices"][0]["finish_reason"] != "tool_calls":
workflow.logger.debug("Prompt step: Received response")
state = PartialTransition(output=response)
else:
workflow.logger.debug("Prompt step: Received tool call")
message = response["choices"][0]["message"]
tool_calls_input = message["tool_calls"]

# Enter a wait-for-input step to ask the developer to run the tool calls
tool_calls_results = await workflow.execute_activity(
task_steps.raise_complete_async,
args=[context, tool_calls_input],
schedule_to_close_timeout=timedelta(days=31),
)
# Feed the tool call results back to the model
# context.inputs.append(tool_calls_results)
context.current_step.prompt.append(message)
context.current_step.prompt.append(tool_calls_results)
new_response = await workflow.execute_activity(
task_steps.prompt_step,
context,
schedule_to_close_timeout=timedelta(
seconds=30 if debug or testing else 600),
)
state = PartialTransition(
output=new_response.output, type="resume")

# case PromptStep(), StepOutcome(
# output=response
Expand Down Expand Up @@ -453,7 +480,6 @@ async def run(

# 4. Transition to the next step
workflow.logger.info(f"Transitioning after step {context.cursor.step}")

# The returned value is the transition finally created
final_state = await transition(context, state)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ async def transition(
state.type = "finish_branch"
case _, _:
state.type = "step"

transition_request = CreateTransitionRequest(
current=context.cursor,
**{
Expand Down
2 changes: 1 addition & 1 deletion agents-api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ x--shared-environment: &shared-environment
AGENTS_API_KEY: ${AGENTS_API_KEY}
AGENTS_API_KEY_HEADER_NAME: ${AGENTS_API_KEY_HEADER_NAME:-Authorization}
AGENTS_API_HOSTNAME: ${AGENTS_API_HOSTNAME:-localhost}
AGENTS_API_PROTOCOL: ${AGENTS_API_PROTOCOL:-http}
AGENTS_API_PUBLIC_PORT: ${AGENTS_API_PUBLIC_PORT:-80}
AGENTS_API_PROTOCOL: ${AGENTS_API_PROTOCOL:-http}
AGENTS_API_URL: ${AGENTS_API_URL:-http://agents-api:8080}
COZO_AUTH_TOKEN: ${COZO_AUTH_TOKEN}
COZO_HOST: ${COZO_HOST:-http://memory-store:9070}
Expand Down
11 changes: 10 additions & 1 deletion agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ async def _(
mock_run_task_execution_workflow.assert_called_once()

# Let it run for a bit
await asyncio.sleep(1)
await asyncio.sleep(3)

# Get the history
history = await handle.fetch_history()
Expand All @@ -497,6 +497,15 @@ async def _(
activity for activity in activities_scheduled if activity
]

try:
future = handle.result()
breakpoint()
await future
except BaseException as exc:
print("exc", exc)
breakpoint()
raise

assert "wait_for_input_step" in activities_scheduled


Expand Down
Loading