From a4151d0bb2644c69f0bc59abd601828a8f01088f Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Wed, 25 Sep 2024 10:08:09 -0400 Subject: [PATCH] feat: Add basic support for integration tools to ToolStep (#519) Signed-off-by: Diwank Singh Tomer ---- > [!IMPORTANT] > This PR updates the handling of integrations and systems by adding new models, updating workflows, and modifying session options, along with dependency updates and a migration script. > > - **Behavior**: > - Adds `execute_integration` function in `execute_integration.py` to handle integration tool calls. > - Updates `prompt_step.py` to handle unwrapping of prompt responses and tool call results. > - Modifies `tool_call_step.py` to handle tool calls using `Tool` model. > - **Models**: > - Adds `IntegrationDef` and `SystemDef` models in `Tools.py`. > - Updates `CreateToolRequest`, `PatchToolRequest`, `UpdateToolRequest`, and `Tool` to use `IntegrationDef` and `SystemDef`. > - Adds `forward_tool_results` option to session models in `Sessions.py`. > - **Workflow**: > - Updates `TaskExecutionWorkflow` in `task_execution/__init__.py` to handle integration tool calls. > - **Dependencies**: > - Updates `@typespec/*` dependencies in `package.json` to version `0.60.x`. > - **Migration**: > - Adds migration script `migrate_1727235852_add_forward_tool_calls_option.py` to add `forward_tool_calls` option to sessions. > > This description was created by [Ellipsis](https://www.ellipsis.dev?ref=julep-ai%2Fjulep&utm_source=github&utm_medium=referral) for a49aa123e96125fc45e33839bd7e90a708088b70. It will automatically update as commits are pushed. --------- Signed-off-by: Diwank Singh Tomer --- .../activities/execute_integration.py | 55 ++++ .../activities/task_steps/prompt_step.py | 16 +- .../activities/task_steps/tool_call_step.py | 22 +- agents-api/agents_api/autogen/Common.py | 10 + agents-api/agents_api/autogen/Sessions.py | 50 +++ agents-api/agents_api/autogen/Tasks.py | 292 +++++++++++++++++- agents-api/agents_api/autogen/Tools.py | 168 ++++++++-- .../agents_api/autogen/openapi_model.py | 1 + .../agents_api/common/protocol/tasks.py | 19 +- .../execution/prepare_execution_input.py | 7 +- agents-api/agents_api/models/task/get_task.py | 19 +- .../agents_api/models/task/list_tasks.py | 19 +- .../agents_api/models/tools/__init__.py | 1 + .../tools/get_tool_args_from_metadata.py | 126 ++++++++ agents-api/agents_api/worker/worker.py | 6 +- .../workflows/task_execution/__init__.py | 100 ++++-- ...727235852_add_forward_tool_calls_option.py | 87 ++++++ agents-api/tests/fixtures.py | 2 + agents-api/tests/test_execution_workflow.py | 118 ++++++- agents-api/tests/utils.py | 2 +- typespec/common/scalars.tsp | 19 +- typespec/package-lock.json | 97 +++--- typespec/package.json | 14 +- typespec/sessions/models.tsp | 9 + typespec/tasks/models.tsp | 18 ++ typespec/tasks/steps.tsp | 22 +- typespec/tools/models.tsp | 52 +++- 27 files changed, 1156 insertions(+), 195 deletions(-) create mode 100644 agents-api/agents_api/activities/execute_integration.py create mode 100644 agents-api/agents_api/models/tools/get_tool_args_from_metadata.py create mode 100644 agents-api/migrations/migrate_1727235852_add_forward_tool_calls_option.py diff --git a/agents-api/agents_api/activities/execute_integration.py b/agents-api/agents_api/activities/execute_integration.py new file mode 100644 index 000000000..9183eca2b --- /dev/null +++ b/agents-api/agents_api/activities/execute_integration.py @@ -0,0 +1,55 @@ +from typing import Any + +from beartype import beartype +from temporalio import activity + +from ..autogen.openapi_model import IntegrationDef +from ..common.protocol.tasks import StepContext +from ..env import testing +from ..models.tools import get_tool_args_from_metadata + + +@beartype +async def execute_integration( + context: StepContext, + tool_name: str, + integration: IntegrationDef, + arguments: dict[str, Any], +) -> Any: + developer_id = context.execution_input.developer_id + agent_id = context.execution_input.agent.id + task_id = context.execution_input.task.id + + merged_tool_args = get_tool_args_from_metadata( + developer_id=developer_id, agent_id=agent_id, task_id=task_id + ) + + arguments = merged_tool_args.get(tool_name, {}) | arguments + + try: + if integration.provider == "dummy": + return arguments + + else: + raise NotImplementedError( + f"Unknown integration provider: {integration.provider}" + ) + except BaseException as e: + if activity.in_activity(): + activity.logger.error(f"Error in execute_integration: {e}") + + raise + + +async def mock_execute_integration( + context: StepContext, + tool_name: str, + integration: IntegrationDef, + arguments: dict[str, Any], +) -> Any: + return arguments + + +execute_integration = activity.defn(name="execute_integration")( + execute_integration if not testing else mock_execute_integration +) diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index 3eecc2b55..e9b4daeb3 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -1,5 +1,6 @@ from beartype import beartype from temporalio import activity +from temporalio.exceptions import ApplicationError from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it @@ -63,18 +64,29 @@ async def prompt_step(context: StepContext) -> StepOutcome: else: passed_settings: dict = {} + # Wrap the prompt in a list if it is not already + if isinstance(prompt, str): + prompt = [{"role": "user", "content": prompt}] + completion_data: dict = { "model": agent_model, "tools": formatted_agent_tools or None, - ("messages" if isinstance(prompt, list) else "prompt"): prompt, + "messages": prompt, **agent_default_settings, **passed_settings, } + response = await litellm.acompletion( **completion_data, ) + if context.current_step.unwrap: + if response.choices[0].finish_reason == "tool_calls": + raise ApplicationError("Tool calls cannot be unwrapped") + + response = response.choices[0].message.content + return StepOutcome( - output=response.model_dump(), + output=response.model_dump() if hasattr(response, "model_dump") else response, next=None, ) diff --git a/agents-api/agents_api/activities/task_steps/tool_call_step.py b/agents-api/agents_api/activities/task_steps/tool_call_step.py index a576bbbeb..3082d8706 100644 --- a/agents-api/agents_api/activities/task_steps/tool_call_step.py +++ b/agents-api/agents_api/activities/task_steps/tool_call_step.py @@ -3,9 +3,10 @@ from beartype import beartype from temporalio import activity +from temporalio.exceptions import ApplicationError -from ...activities.task_steps import base_evaluate -from ...autogen.openapi_model import ToolCallStep +from ...activities.task_steps.base_evaluate import base_evaluate +from ...autogen.openapi_model import Tool, ToolCallStep from ...common.protocol.tasks import ( StepContext, StepOutcome, @@ -26,24 +27,27 @@ def generate_call_id(): async def tool_call_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, ToolCallStep) - tool_type, tool_name = context.current_step.tool.split(".") + tools: list[Tool] = context.tools + tool_name = context.current_step.tool + + tool = next((t for t in tools if t.name == tool_name), None) + + if tool is None: + raise ApplicationError(f"Tool {tool_name} not found in the toolset") + arguments = await base_evaluate( context.current_step.arguments, context.model_dump() ) - tools = context.execution_input.tools - - assert tool_name in [tool.name for tool in tools], f"Tool {tool_name} not found" - call_id = generate_call_id() tool_call = { - tool_type: { + tool.type: { "arguments": arguments, "name": tool_name, }, "id": call_id, - "type": tool_type, + "type": tool.type, } return StepOutcome(output=tool_call) diff --git a/agents-api/agents_api/autogen/Common.py b/agents-api/agents_api/autogen/Common.py index cb8d88035..a2d6b99a6 100644 --- a/agents-api/agents_api/autogen/Common.py +++ b/agents-api/agents_api/autogen/Common.py @@ -9,6 +9,16 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, RootModel +class JinjaTemplate(RootModel[str]): + model_config = ConfigDict( + populate_by_name=True, + ) + root: str + """ + A valid jinja template. + """ + + class Limit(RootModel[int]): model_config = ConfigDict( populate_by_name=True, diff --git a/agents-api/agents_api/autogen/Sessions.py b/agents-api/agents_api/autogen/Sessions.py index f167efec5..75e5252cd 100644 --- a/agents-api/agents_api/autogen/Sessions.py +++ b/agents-api/agents_api/autogen/Sessions.py @@ -43,6 +43,16 @@ class CreateSessionRequest(BaseModel): """ Action to start on context window overflow """ + forward_tool_results: StrictBool | None = None + """ + Whether to forward the tool results to the model when available. + "true" => always forward + "false" => never forward + null => forward if applicable (default) + + If a tool call is made, the tool's output will be sent back to the model as the model's input. + If a tool call is not made, the model's output will be returned as is. + """ metadata: dict[str, Any] | None = None @@ -70,6 +80,16 @@ class PatchSessionRequest(BaseModel): """ Action to start on context window overflow """ + forward_tool_results: StrictBool | None = None + """ + Whether to forward the tool results to the model when available. + "true" => always forward + "false" => never forward + null => forward if applicable (default) + + If a tool call is made, the tool's output will be sent back to the model as the model's input. + If a tool call is not made, the model's output will be returned as is. + """ metadata: dict[str, Any] | None = None @@ -97,6 +117,16 @@ class Session(BaseModel): """ Action to start on context window overflow """ + forward_tool_results: StrictBool | None = None + """ + Whether to forward the tool results to the model when available. + "true" => always forward + "false" => never forward + null => forward if applicable (default) + + If a tool call is made, the tool's output will be sent back to the model as the model's input. + If a tool call is not made, the model's output will be returned as is. + """ id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None created_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})] @@ -160,6 +190,16 @@ class UpdateSessionRequest(BaseModel): """ Action to start on context window overflow """ + forward_tool_results: StrictBool | None = None + """ + Whether to forward the tool results to the model when available. + "true" => always forward + "false" => never forward + null => forward if applicable (default) + + If a tool call is made, the tool's output will be sent back to the model as the model's input. + If a tool call is not made, the model's output will be returned as is. + """ metadata: dict[str, Any] | None = None @@ -194,6 +234,16 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): """ Action to start on context window overflow """ + forward_tool_results: StrictBool | None = None + """ + Whether to forward the tool results to the model when available. + "true" => always forward + "false" => never forward + null => forward if applicable (default) + + If a tool call is made, the tool's output will be sent back to the model as the model's input. + If a tool call is not made, the model's output will be returned as is. + """ metadata: dict[str, Any] | None = None diff --git a/agents-api/agents_api/autogen/Tasks.py b/agents-api/agents_api/autogen/Tasks.py index c1e69d492..48dba4ad7 100644 --- a/agents-api/agents_api/autogen/Tasks.py +++ b/agents-api/agents_api/autogen/Tasks.py @@ -15,7 +15,7 @@ TextOnlyDocSearchRequest, VectorDocSearchRequest, ) -from .Tools import CreateToolRequest +from .Tools import CreateToolRequest, NamedToolChoice class CaseThen(BaseModel): @@ -46,6 +46,34 @@ class CaseThen(BaseModel): """ +class CaseThenUpdateItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + case: Literal["_"] | str + """ + The condition to evaluate + """ + then: ( + EvaluateStep + | ToolCallStep + | PromptStepUpdateItem + | GetStep + | SetStep + | LogStep + | EmbedStep + | SearchStep + | ReturnStep + | SleepStep + | ErrorWorkflowStep + | YieldStep + | WaitForInputStep + ) + """ + The steps to run if the condition is true + """ + + class Content(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -74,6 +102,14 @@ class ContentModel(BaseModel): """ +class ContentModel1(Content): + pass + + +class ContentModel2(ContentModel): + pass + + class CreateTaskRequest(BaseModel): """ Payload for creating a task @@ -197,6 +233,30 @@ class ForeachDo(BaseModel): """ +class ForeachDoUpdateItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + in_: Annotated[str, Field(alias="in")] + """ + The variable to iterate over. + VALIDATION: Should NOT return more than 1000 elements. + """ + do: ( + EvaluateStep + | ToolCallStep + | PromptStepUpdateItem + | GetStep + | SetStep + | LogStep + | EmbedStep + | SearchStep + ) + """ + The steps to run for each iteration + """ + + class ForeachStep(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -213,6 +273,20 @@ class ForeachStep(BaseModel): """ +class ForeachStepUpdateItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + kind_: str | None = None + """ + Discriminator property for BaseWorkflowStep. + """ + foreach: ForeachDoUpdateItem + """ + The steps to run for each iteration + """ + + class GetStep(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -281,6 +355,58 @@ class IfElseWorkflowStep(BaseModel): """ +class IfElseWorkflowStepUpdateItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + kind_: str | None = None + """ + Discriminator property for BaseWorkflowStep. + """ + if_: Annotated[str, Field(alias="if")] + """ + The condition to evaluate + """ + then: ( + EvaluateStep + | ToolCallStep + | PromptStepUpdateItem + | GetStep + | SetStep + | LogStep + | EmbedStep + | SearchStep + | ReturnStep + | SleepStep + | ErrorWorkflowStep + | YieldStep + | WaitForInputStep + ) + """ + The steps to run if the condition is true + """ + else_: Annotated[ + EvaluateStep + | ToolCallStep + | PromptStepUpdateItem + | GetStep + | SetStep + | LogStep + | EmbedStep + | SearchStep + | ReturnStep + | SleepStep + | ErrorWorkflowStep + | YieldStep + | WaitForInputStep + | None, + Field(None, alias="else"), + ] + """ + The steps to run if the condition is false + """ + + class ImageUrl(BaseModel): """ The image URL @@ -371,7 +497,7 @@ class MainModel(BaseModel): map: ( EvaluateStep | ToolCallStep - | PromptStep + | PromptStepUpdateItem | GetStep | SetStep | LogStep @@ -425,6 +551,32 @@ class ParallelStep(BaseModel): """ +class ParallelStepUpdateItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + kind_: str | None = None + """ + Discriminator property for BaseWorkflowStep. + """ + parallel: Annotated[ + list[ + EvaluateStep + | ToolCallStep + | PromptStepUpdateItem + | GetStep + | SetStep + | LogStep + | EmbedStep + | SearchStep + ], + Field(max_length=100), + ] + """ + The steps to run in parallel. Max concurrency will depend on the platform. + """ + + class PatchTaskRequest(BaseModel): """ Payload for patching a task @@ -438,7 +590,7 @@ class PatchTaskRequest(BaseModel): list[ EvaluateStep | ToolCallStep - | PromptStep + | PromptStepUpdateItem | GetStep | SetStep | LogStep @@ -449,10 +601,10 @@ class PatchTaskRequest(BaseModel): | ErrorWorkflowStep | YieldStep | WaitForInputStep - | IfElseWorkflowStep - | SwitchStep - | ForeachStep - | ParallelStep + | IfElseWorkflowStepUpdateItem + | SwitchStepUpdateItem + | ForeachStepUpdateItem + | ParallelStepUpdateItem | MainModel ] | None, @@ -520,10 +672,72 @@ class PromptStep(BaseModel): """ The prompt to run """ + tools: Literal["all"] | list[ToolRef | CreateToolRequest] = [] + """ + The tools to use for the prompt + """ + tool_choice: Literal["auto", "none"] | NamedToolChoice | None = None + """ + The tool choice for the prompt + """ + settings: ChatSettings | None = None + """ + Settings for the prompt + """ + unwrap: StrictBool = False + """ + Whether to unwrap the output of the prompt step, equivalent to `response.choices[0].message.content` + """ + forward_tool_results: StrictBool | None = None + """ + Whether to forward the tool results to the model when available. + "true" => always forward + "false" => never forward + null => forward if applicable (default) + + If a tool call is made, the tool's output will be used as the model's input. + If a tool call is not made, the model's output will be used as the next step's input. + """ + + +class PromptStepUpdateItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + kind_: str | None = None + """ + Discriminator property for BaseWorkflowStep. + """ + prompt: list[PromptItem] | str + """ + The prompt to run + """ + tools: Literal["all"] | list[ToolRefUpdateItem | CreateToolRequest] = [] + """ + The tools to use for the prompt + """ + tool_choice: Literal["auto", "none"] | NamedToolChoice | None = None + """ + The tool choice for the prompt + """ settings: ChatSettings | None = None """ Settings for the prompt """ + unwrap: StrictBool = False + """ + Whether to unwrap the output of the prompt step, equivalent to `response.choices[0].message.content` + """ + forward_tool_results: StrictBool | None = None + """ + Whether to forward the tool results to the model when available. + "true" => always forward + "false" => never forward + null => forward if applicable (default) + + If a tool call is made, the tool's output will be used as the model's input. + If a tool call is not made, the model's output will be used as the next step's input. + """ class ReturnStep(BaseModel): @@ -626,6 +840,20 @@ class SwitchStep(BaseModel): """ +class SwitchStepUpdateItem(BaseModel): + model_config = ConfigDict( + populate_by_name=True, + ) + kind_: str | None = None + """ + Discriminator property for BaseWorkflowStep. + """ + switch: Annotated[list[CaseThenUpdateItem], Field(min_length=1)] + """ + The cond tree + """ + + class Task(BaseModel): """ Object describing a Task @@ -706,9 +934,7 @@ class ToolCallStep(BaseModel): """ The kind of step """ - tool: Annotated[ - str, Field(pattern="^(function|integration|system|api_call)\\.(\\w+)$") - ] + tool: Annotated[str, Field(max_length=40, pattern="^[^\\W0-9]\\w*$")] """ The tool to run """ @@ -718,6 +944,52 @@ class ToolCallStep(BaseModel): """ +class ToolRef(BaseModel): + """ + Reference to a tool + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + ref: ToolRefById | ToolRefByName + + +class ToolRefById(BaseModel): + """ + Reference to a tool by id + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + id: UUID | None = None + + +class ToolRefByName(BaseModel): + """ + Reference to a tool by name + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + name: Annotated[str | None, Field(None, max_length=40, pattern="^[^\\W0-9]\\w*$")] + """ + Valid python identifier names + """ + + +class ToolRefUpdateItem(BaseModel): + """ + Reference to a tool + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + + class UpdateTaskRequest(BaseModel): """ Payload for updating a task diff --git a/agents-api/agents_api/autogen/Tools.py b/agents-api/agents_api/autogen/Tools.py index 7664af1af..7b7f38214 100644 --- a/agents-api/agents_api/autogen/Tools.py +++ b/agents-api/agents_api/autogen/Tools.py @@ -22,9 +22,6 @@ class ChosenToolCall(BaseModel): Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now) """ function: FunctionCallOption | None = None - integration: Any | None = None - system: Any | None = None - api_call: Any | None = None id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] @@ -44,13 +41,12 @@ class CreateToolRequest(BaseModel): """ Name of the tool (must be unique for this agent and a valid python identifier string ) """ - function: FunctionDef + function: FunctionDef | None = None """ The function to call """ - integration: Any | None = None - system: Any | None = None - api_call: Any | None = None + integration: IntegrationDef | None = None + system: SystemDef | None = None class FunctionCallOption(BaseModel): @@ -75,14 +71,7 @@ class FunctionDef(BaseModel): """ DO NOT USE: This will be overriden by the tool name. Here only for compatibility reasons. """ - description: Annotated[ - str | None, - Field( - None, - max_length=120, - pattern="^[\\p{L}\\p{Nl}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]+[\\p{ID_Start}\\p{Mn}\\p{Mc}\\p{Nd}\\p{Pc}\\p{Pattern_Syntax}\\p{Pattern_White_Space}]*$", - ), - ] + description: str | None = None """ Description of the function """ @@ -92,6 +81,89 @@ class FunctionDef(BaseModel): """ +class IntegrationDef(BaseModel): + """ + Integration definition + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + provider: Literal[ + "dummy", + "dall-e", + "duckduckgo", + "hackernews", + "weather", + "wikipedia", + "twitter", + "webpage", + "requests", + ] + """ + The provider of the integration + """ + method: str | None = None + """ + The specific method of the integration to call + """ + description: str | None = None + """ + Optional description of the integration + """ + setup: dict[str, Any] | None = None + """ + The setup parameters the integration accepts + """ + arguments: dict[str, Any] | None = None + """ + The arguments to pre-apply to the integration call + """ + + +class IntegrationDefUpdate(BaseModel): + """ + Integration definition + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + provider: ( + Literal[ + "dummy", + "dall-e", + "duckduckgo", + "hackernews", + "weather", + "wikipedia", + "twitter", + "webpage", + "requests", + ] + | None + ) = None + """ + The provider of the integration + """ + method: str | None = None + """ + The specific method of the integration to call + """ + description: str | None = None + """ + Optional description of the integration + """ + setup: dict[str, Any] | None = None + """ + The setup parameters the integration accepts + """ + arguments: dict[str, Any] | None = None + """ + The arguments to pre-apply to the integration call + """ + + class NamedToolChoice(BaseModel): model_config = ConfigDict( populate_by_name=True, @@ -101,9 +173,6 @@ class NamedToolChoice(BaseModel): Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now) """ function: FunctionCallOption | None = None - integration: Any | None = None - system: Any | None = None - api_call: Any | None = None class PatchToolRequest(BaseModel): @@ -126,9 +195,52 @@ class PatchToolRequest(BaseModel): """ The function to call """ - integration: Any | None = None - system: Any | None = None - api_call: Any | None = None + integration: IntegrationDefUpdate | None = None + system: SystemDefUpdate | None = None + + +class SystemDef(BaseModel): + """ + System definition + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + call: str + """ + The name of the system call + """ + description: str | None = None + """ + Optional description of the system call + """ + arguments: dict[str, Any] | None = None + """ + The arguments to pre-apply to the system call + """ + + +class SystemDefUpdate(BaseModel): + """ + System definition + """ + + model_config = ConfigDict( + populate_by_name=True, + ) + call: str | None = None + """ + The name of the system call + """ + description: str | None = None + """ + Optional description of the system call + """ + arguments: dict[str, Any] | None = None + """ + The arguments to pre-apply to the system call + """ class Tool(BaseModel): @@ -143,13 +255,12 @@ class Tool(BaseModel): """ Name of the tool (must be unique for this agent and a valid python identifier string ) """ - function: FunctionDef + function: FunctionDef | None = None """ The function to call """ - integration: Any | None = None - system: Any | None = None - api_call: Any | None = None + integration: IntegrationDef | None = None + system: SystemDef | None = None created_at: Annotated[AwareDatetime, Field(json_schema_extra={"readOnly": True})] """ When this resource was created as UTC date-time @@ -188,13 +299,12 @@ class UpdateToolRequest(BaseModel): """ Name of the tool (must be unique for this agent and a valid python identifier string ) """ - function: FunctionDef + function: FunctionDef | None = None """ The function to call """ - integration: Any | None = None - system: Any | None = None - api_call: Any | None = None + integration: IntegrationDef | None = None + system: SystemDef | None = None class ChosenFunctionCall(ChosenToolCall): diff --git a/agents-api/agents_api/autogen/openapi_model.py b/agents-api/agents_api/autogen/openapi_model.py index 3e329581c..5c5a8c86f 100644 --- a/agents-api/agents_api/autogen/openapi_model.py +++ b/agents-api/agents_api/autogen/openapi_model.py @@ -200,6 +200,7 @@ class TaskSpec(_Task): model_config = ConfigDict(extra="ignore") workflows: list[Workflow] + tools: list[TaskToolDef] # Remove main field from the model main: None = None diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index e50b5c2ea..6892d7098 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -126,7 +126,7 @@ class ExecutionInput(BaseModel): execution: Execution task: TaskSpecDef agent: Agent - tools: list[Tool] + agent_tools: list[Tool] arguments: dict[str, Any] # Not used at the moment @@ -139,6 +139,23 @@ class StepContext(BaseModel): inputs: list[Any] cursor: TransitionTarget + @computed_field + @property + def tools(self) -> list[Tool]: + execution_input = self.execution_input + task = execution_input.task + agent_tools = execution_input.agent_tools + + if not task.inherit_tools: + return task.tools + + # Remove duplicates from agent_tools + filtered_tools = [ + t for t in agent_tools if t.name not in map(lambda x: x.name, task.tools) + ] + + return filtered_tools + task.tools + @computed_field @property def outputs(self) -> list[dict[str, Any]]: # included in dump diff --git a/agents-api/agents_api/models/execution/prepare_execution_input.py b/agents-api/agents_api/models/execution/prepare_execution_input.py index 5bc14939a..b763a7508 100644 --- a/agents-api/agents_api/models/execution/prepare_execution_input.py +++ b/agents-api/agents_api/models/execution/prepare_execution_input.py @@ -39,13 +39,10 @@ transform=lambda d: { **d, "task": { - "tools": [ - {tool["type"]: tool.pop("spec"), **tool} - for tool in map(fix_uuid_if_present, d["task"].pop("tools")) - ], + "tools": [*map(fix_uuid_if_present, d["task"].pop("tools"))], **d["task"], }, - "tools": [ + "agent_tools": [ {tool["type"]: tool.pop("spec"), **tool} for tool in map(fix_uuid_if_present, d["tools"]) ], diff --git a/agents-api/agents_api/models/task/get_task.py b/agents-api/agents_api/models/task/get_task.py index 4bb3d06b6..4e5b2fb97 100644 --- a/agents-api/agents_api/models/task/get_task.py +++ b/agents-api/agents_api/models/task/get_task.py @@ -68,20 +68,6 @@ def get_task( }, updated_at = to_int(updated_at_ms) / 1000 - tool_data[collect(tool_def)] := - task_data[_, agent_id, _, _, _, _, _, _, _, _, _], - *tools { - agent_id, - type, - name, - spec, - }, tool_def = { - "type": type, - "name": name, - "spec": spec, - "inherited": true, - } - ?[ id, agent_id, @@ -95,20 +81,19 @@ def get_task( updated_at, metadata, ] := - tool_data[inherited_tools], task_data[ id, agent_id, name, description, input_schema, - task_tools, + tools, inherit_tools, workflows, created_at, updated_at, metadata, - ], tools = task_tools ++ if(inherit_tools, inherited_tools, []) + ] :limit 1 """ diff --git a/agents-api/agents_api/models/task/list_tasks.py b/agents-api/agents_api/models/task/list_tasks.py index 35c52d184..a443ce9e2 100644 --- a/agents-api/agents_api/models/task/list_tasks.py +++ b/agents-api/agents_api/models/task/list_tasks.py @@ -74,20 +74,6 @@ def list_tasks( }}, updated_at = to_int(updated_at_ms) / 1000 - tool_data[collect(tool_def)] := - input[agent_id], - *tools {{ - agent_id, - type, - name, - spec, - }}, tool_def = {{ - "type": type, - "name": name, - "spec": spec, - "inherited": true, - }} - ?[ task_id, agent_id, @@ -101,20 +87,19 @@ def list_tasks( updated_at, metadata, ] := - tool_data[inherited_tools], task_data[ task_id, agent_id, name, description, input_schema, - task_tools, + tools, inherit_tools, workflows, created_at, updated_at, metadata, - ], tools = task_tools ++ if(inherit_tools, inherited_tools, []) + ] :limit $limit :offset $offset diff --git a/agents-api/agents_api/models/tools/__init__.py b/agents-api/agents_api/models/tools/__init__.py index 98f3a5e3a..b1775f1a9 100644 --- a/agents-api/agents_api/models/tools/__init__.py +++ b/agents-api/agents_api/models/tools/__init__.py @@ -14,6 +14,7 @@ from .create_tools import create_tools from .delete_tool import delete_tool from .get_tool import get_tool +from .get_tool_args_from_metadata import get_tool_args_from_metadata from .list_tools import list_tools from .patch_tool import patch_tool from .update_tool import update_tool diff --git a/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py new file mode 100644 index 000000000..08882ae6f --- /dev/null +++ b/agents-api/agents_api/models/tools/get_tool_args_from_metadata.py @@ -0,0 +1,126 @@ +from uuid import UUID + +from beartype import beartype +from fastapi import HTTPException +from pycozo.client import QueryException +from pydantic import ValidationError + +from ..utils import ( + cozo_query, + partialclass, + rewrap_exceptions, + verify_developer_id_query, + verify_developer_owns_resource_query, + wrap_in_class, +) + + +def tool_args_for_task( + *, + developer_id: UUID, + agent_id: UUID, + task_id: UUID, +) -> tuple[list[str], dict]: + agent_id = str(agent_id) + task_id = str(task_id) + + get_query = """ + input[agent_id, task_id] <- [[to_uuid($agent_id), to_uuid($task_id)]] + + ?[args] := + input[agent_id, task_id], + *tasks { + task_id, + metadata: task_metadata, + }, + *agents { + agent_id, + metadata: agent_metadata, + }, + task_args = get(task_metadata, "x-tool-args", {}), + agent_args = get(agent_metadata, "x-tool-args", {}), + + # Right values overwrite left values + # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat + args = concat(agent_args, task_args), + """ + + queries = [ + verify_developer_id_query(developer_id), + verify_developer_owns_resource_query( + developer_id, "tasks", task_id=task_id, parents=[("agents", "agent_id")] + ), + get_query, + ] + + return (queries, {"agent_id": agent_id, "task_id": task_id}) + + +def tool_args_for_session( + *, + developer_id: UUID, + session_id: UUID, + agent_id: UUID, +) -> tuple[list[str], dict]: + session_id = str(session_id) + + get_query = """ + input[session_id, agent_id] <- [[to_uuid($session_id), to_uuid($agent_id)]] + + ?[args] := + input[session_id, agent_id], + *sessions { + session_id, + metadata: session_metadata, + }, + *agents { + agent_id, + metadata: agent_metadata, + }, + session_args = get(session_metadata, "x-tool-args"), + agent_args = get(agent_metadata, "x-tool-args"), + + # Right values overwrite left values + # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat + args = concat(agent_args, session_args), + """ + + queries = [ + verify_developer_id_query(developer_id), + verify_developer_owns_resource_query( + developer_id, "sessions", session_id=session_id + ), + get_query, + ] + + return (queries, {"agent_id": agent_id, "session_id": session_id}) + + +@rewrap_exceptions( + { + QueryException: partialclass(HTTPException, status_code=400), + ValidationError: partialclass(HTTPException, status_code=400), + TypeError: partialclass(HTTPException, status_code=400), + } +) +@wrap_in_class(dict, transform=lambda x: x["args"], one=True) +@cozo_query +@beartype +def get_tool_args_from_metadata( + *, + developer_id: UUID, + agent_id: UUID, + session_id: UUID | None = None, + task_id: UUID | None = None, +) -> tuple[list[str], dict]: + match session_id, task_id: + case (None, task_id) if task_id is not None: + return tool_args_for_task( + developer_id=developer_id, agent_id=agent_id, task_id=task_id + ) + case (session_id, None) if session_id is not None: + return tool_args_for_session( + developer_id=developer_id, agent_id=agent_id, session_id=session_id + ) + case (_, _): + raise ValueError("Either session_id or task_id must be provided") diff --git a/agents-api/agents_api/worker/worker.py b/agents-api/agents_api/worker/worker.py index 65e813023..77698364d 100644 --- a/agents-api/agents_api/worker/worker.py +++ b/agents-api/agents_api/worker/worker.py @@ -15,6 +15,7 @@ def create_worker(client: Client) -> Any: from ..activities import task_steps from ..activities.demo import demo_activity from ..activities.embed_docs import embed_docs + from ..activities.execute_integration import execute_integration from ..activities.mem_mgmt import mem_mgmt from ..activities.mem_rating import mem_rating from ..activities.summarization import summarization @@ -49,10 +50,11 @@ def create_worker(client: Client) -> Any: activities=[ *task_activities, demo_activity, - summarization, + embed_docs, + execute_integration, mem_mgmt, mem_rating, - embed_docs, + summarization, truncation, ], ) diff --git a/agents-api/agents_api/workflows/task_execution/__init__.py b/agents-api/agents_api/workflows/task_execution/__init__.py index d1f13585f..3c8197e29 100644 --- a/agents-api/agents_api/workflows/task_execution/__init__.py +++ b/agents-api/agents_api/workflows/task_execution/__init__.py @@ -11,6 +11,7 @@ # Import necessary modules and types with workflow.unsafe.imports_passed_through(): from ...activities import task_steps + from ...activities.execute_integration import execute_integration from ...autogen.openapi_model import ( EmbedStep, ErrorWorkflowStep, @@ -389,34 +390,47 @@ async def run( state = PartialTransition(type="resume", output=result) - case PromptStep(), StepOutcome(output=response): + case PromptStep(unwrap=True), StepOutcome(output=response): workflow.logger.debug(f"Prompt step: Received response: {response}") - if response["choices"][0]["finish_reason"] != "tool_calls": - workflow.logger.debug("Prompt step: Received response") - state = PartialTransition(output=response) - else: - workflow.logger.debug("Prompt step: Received tool call") - message = response["choices"][0]["message"] - tool_calls_input = message["tool_calls"] - - # Enter a wait-for-input step to ask the developer to run the tool calls - tool_calls_results = await workflow.execute_activity( - task_steps.raise_complete_async, - args=[context, tool_calls_input], - schedule_to_close_timeout=timedelta(days=31), - ) - # Feed the tool call results back to the model - # context.inputs.append(tool_calls_results) - context.current_step.prompt.append(message) - context.current_step.prompt.append(tool_calls_results) - new_response = await workflow.execute_activity( - task_steps.prompt_step, - context, - schedule_to_close_timeout=timedelta( - seconds=30 if debug or testing else 600 - ), - ) - state = PartialTransition(output=new_response.output, type="resume") + state = PartialTransition(output=response) + + case PromptStep(forward_tool_results=False, unwrap=False), StepOutcome( + output=response + ): + workflow.logger.debug(f"Prompt step: Received response: {response}") + state = PartialTransition(output=response) + + case PromptStep(unwrap=False), StepOutcome(output=response) if response[ + "choices" + ][0]["finish_reason"] != "tool_calls": + workflow.logger.debug(f"Prompt step: Received response: {response}") + state = PartialTransition(output=response) + + case PromptStep(unwrap=False), StepOutcome(output=response) if response[ + "choices" + ][0]["finish_reason"] == "tool_calls": + workflow.logger.debug("Prompt step: Received tool call") + message = response["choices"][0]["message"] + tool_calls_input = message["tool_calls"] + + # Enter a wait-for-input step to ask the developer to run the tool calls + tool_calls_results = await workflow.execute_activity( + task_steps.raise_complete_async, + args=[context, tool_calls_input], + schedule_to_close_timeout=timedelta(days=31), + ) + + # Feed the tool call results back to the model + context.current_step.prompt.append(message) + context.current_step.prompt.append(tool_calls_results) + new_response = await workflow.execute_activity( + task_steps.prompt_step, + context, + schedule_to_close_timeout=timedelta( + seconds=30 if debug or testing else 600 + ), + ) + state = PartialTransition(output=new_response.output, type="resume") case SetStep(), StepOutcome(output=evaluated_output): workflow.logger.info("Set step: Updating user state") @@ -450,7 +464,9 @@ async def run( workflow.logger.error("ParallelStep not yet implemented") raise ApplicationError("Not implemented") - case ToolCallStep(), StepOutcome(output=tool_call): + case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[ + "type" + ] == "function": # Enter a wait-for-input step to ask the developer to run the tool calls tool_call_response = await workflow.execute_activity( task_steps.raise_complete_async, @@ -460,6 +476,33 @@ async def run( state = PartialTransition(output=tool_call_response, type="resume") + case ToolCallStep(), StepOutcome(output=tool_call) if tool_call[ + "type" + ] == "integration": + call = tool_call["integration"] + tool_name = call["name"] + arguments = call["arguments"] + integration = next( + (t for t in context.tools if t.name == tool_name), None + ) + + if integration is None: + raise ApplicationError(f"Integration {tool_name} not found") + + tool_call_response = await workflow.execute_activity( + execute_integration, + args=[context, tool_name, integration, arguments], + schedule_to_close_timeout=timedelta( + seconds=30 if debug or testing else 600 + ), + ) + + state = PartialTransition(output=tool_call_response) + + case ToolCallStep(), StepOutcome(output=_): + # FIXME: Handle system/api_call tool_calls + raise ApplicationError("Not implemented") + case _: workflow.logger.error( f"Unhandled step type: {type(context.current_step).__name__}" @@ -468,6 +511,7 @@ async def run( # 4. Transition to the next step workflow.logger.info(f"Transitioning after step {context.cursor.step}") + # The returned value is the transition finally created final_state = await transition(context, state) diff --git a/agents-api/migrations/migrate_1727235852_add_forward_tool_calls_option.py b/agents-api/migrations/migrate_1727235852_add_forward_tool_calls_option.py new file mode 100644 index 000000000..ad1aab998 --- /dev/null +++ b/agents-api/migrations/migrate_1727235852_add_forward_tool_calls_option.py @@ -0,0 +1,87 @@ +#/usr/bin/env python3 + +MIGRATION_ID = "add_forward_tool_calls_option" +CREATED_AT = 1727235852.744035 + + +def run(client, queries): + joiner = "}\n\n{" + + query = joiner.join(queries) + query = f"{{\n{query}\n}}" + client.run(query) + + +add_forward_tool_calls_option_to_session_query = dict( + up=""" + ?[forward_tool_calls, token_budget, context_overflow, developer_id, session_id, updated_at, situation, summary, created_at, metadata, render_templates, token_budget, context_overflow] := *sessions{ + developer_id, + session_id, + updated_at, + situation, + summary, + created_at, + metadata, + render_templates, + token_budget, + context_overflow, + }, + forward_tool_calls = null + + :replace sessions { + developer_id: Uuid, + session_id: Uuid, + updated_at: Validity default [floor(now()), true], + => + situation: String, + summary: String? default null, + created_at: Float default now(), + metadata: Json default {}, + render_templates: Bool default false, + token_budget: Int? default null, + context_overflow: String? default null, + forward_tool_calls: Bool? default null, + } + """, + down=""" + ?[token_budget, context_overflow, developer_id, session_id, updated_at, situation, summary, created_at, metadata, render_templates, token_budget, context_overflow] := *sessions{ + developer_id, + session_id, + updated_at, + situation, + summary, + created_at, + metadata, + render_templates, + token_budget, + context_overflow, + } + + :replace sessions { + developer_id: Uuid, + session_id: Uuid, + updated_at: Validity default [floor(now()), true], + => + situation: String, + summary: String? default null, + created_at: Float default now(), + metadata: Json default {}, + render_templates: Bool default false, + token_budget: Int? default null, + context_overflow: String? default null, + } + """, +) + + +queries = [ + add_forward_tool_calls_option_to_session_query, +] + + +def up(client): + run(client, [q["up"] for q in queries]) + + +def down(client): + run(client, [q["down"] for q in reversed(queries)]) diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 4cc5c2674..d5b032311 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -48,6 +48,8 @@ def cozo_client(migrations_dir: str = "./migrations"): # and initialize the schema. client = CozoClient() + setattr(app.state, "cozo_client", client) + init(client) apply(client, migrations_dir=migrations_dir, all_=True) diff --git a/agents-api/tests/test_execution_workflow.py b/agents-api/tests/test_execution_workflow.py index 44f6d6ed4..cddc65666 100644 --- a/agents-api/tests/test_execution_workflow.py +++ b/agents-api/tests/test_execution_workflow.py @@ -338,7 +338,7 @@ async def _( assert result["value"] == data.input["test"] -# @test("workflow: log step") +@test("workflow: log step") async def _( client=cozo_client, developer_id=test_developer_id, @@ -441,7 +441,61 @@ async def _( assert result["hello"] == data.input["test"] -@test("workflow: wait for input step start") +@test("workflow: tool call integration type step") +async def _( + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + data = CreateExecutionRequest(input={"test": "input"}) + + task = create_task( + developer_id=developer_id, + agent_id=agent.id, + data=CreateTaskRequest( + **{ + "name": "test task", + "description": "test task about", + "input_schema": {"type": "object", "additionalProperties": True}, + "tools": [ + { + "type": "integration", + "name": "hello", + "integration": { + "provider": "dummy", + }, + } + ], + "main": [ + { + "tool": "hello", + "arguments": {"test": "_.test"}, + }, + ], + } + ), + client=client, + ) + + async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): + execution, handle = await start_execution( + developer_id=developer_id, + task_id=task.id, + data=data, + client=client, + ) + + assert handle is not None + assert execution.task_id == task.id + assert execution.input == data.input + mock_run_task_execution_workflow.assert_called_once() + + result = await handle.result() + assert result["test"] == data.input["test"] + + +# FIXME: This test is not working. It gets stuck +# @test("workflow: wait for input step start") async def _( client=cozo_client, developer_id=test_developer_id, @@ -818,6 +872,66 @@ async def _( assert result["role"] == "assistant" +@test("workflow: prompt step unwrap") +async def _( + client=cozo_client, + developer_id=test_developer_id, + agent=test_agent, +): + mock_model_response = ModelResponse( + id="fake_id", + choices=[Choices(message={"role": "assistant", "content": "Hello, world!"})], + created=0, + object="text_completion", + ) + + with patch("agents_api.clients.litellm.acompletion") as acompletion: + acompletion.return_value = mock_model_response + data = CreateExecutionRequest(input={"test": "input"}) + + task = create_task( + developer_id=developer_id, + agent_id=agent.id, + data=CreateTaskRequest( + **{ + "name": "test task", + "description": "test task about", + "input_schema": {"type": "object", "additionalProperties": True}, + "main": [ + { + "prompt": [ + { + "role": "user", + "content": "message", + }, + ], + "unwrap": True, + "settings": {}, + }, + ], + } + ), + client=client, + ) + + async with patch_testing_temporal() as (_, mock_run_task_execution_workflow): + execution, handle = await start_execution( + developer_id=developer_id, + task_id=task.id, + data=data, + client=client, + ) + + assert handle is not None + assert execution.task_id == task.id + assert execution.input == data.input + + mock_run_task_execution_workflow.assert_called_once() + + result = await handle.result() + assert result == "Hello, world!" + + @test("workflow: set and get steps") async def _( client=cozo_client, diff --git a/agents-api/tests/utils.py b/agents-api/tests/utils.py index 83b864472..e54dabe17 100644 --- a/agents-api/tests/utils.py +++ b/agents-api/tests/utils.py @@ -84,7 +84,7 @@ def patch_embed_acompletion(output={"role": "assistant", "content": "Hello, worl object="text_completion", ) - with patch("agents_api.clients.embed.embed") as embed, patch( + with patch("agents_api.clients.litellm.aembedding") as embed, patch( "agents_api.clients.litellm.acompletion" ) as acompletion: embed.return_value = [[1.0] * EMBEDDING_SIZE] diff --git a/typespec/common/scalars.tsp b/typespec/common/scalars.tsp index 79eda2d99..97fe37706 100644 --- a/typespec/common/scalars.tsp +++ b/typespec/common/scalars.tsp @@ -48,12 +48,21 @@ alias eventStream = "text/event-stream"; /** Different possible sources that can produce new entries */ alias entrySource = "api_request" | "api_response" | "tool_response" | "internal" | "summarizer" | "meta"; -/** Naming convention for tool references. Tools are resolved in order: `step-settings` -> `task` -> `agent` */ -@pattern("^(function|integration|system|api_call)\\.(\\w+)$") -scalar toolRef extends string; - /** A simple python expression compatible with SimpleEval. */ scalar PyExpression extends string; /** A valid jinja template. */ -scalar JinjaTemplate extends string; \ No newline at end of file +scalar JinjaTemplate extends string; + +/** Integration provider name */ +alias integrationProvider = ( + | "dummy" + | "dall-e" + | "duckduckgo" + | "hackernews" + | "weather" + | "wikipedia" + | "twitter" + | "webpage" + | "requests" +); diff --git a/typespec/package-lock.json b/typespec/package-lock.json index 2eb61490e..0ddfdb155 100644 --- a/typespec/package-lock.json +++ b/typespec/package-lock.json @@ -1,19 +1,19 @@ { "name": "julep-typespec", - "version": "0.3.0", + "version": "0.4.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "julep-typespec", - "version": "0.3.0", + "version": "0.4.0", "dependencies": { - "@typespec/compiler": "^0.59.1", - "@typespec/http": "^0.59.1", - "@typespec/openapi": "^0.59.0", - "@typespec/openapi3": "^0.59.1", - "@typespec/rest": "^0.59.1", - "@typespec/versioning": "^0.59.0" + "@typespec/compiler": "^0.60.1", + "@typespec/http": "^0.60.0", + "@typespec/openapi": "^0.60.0", + "@typespec/openapi3": "^0.60.0", + "@typespec/rest": "^0.60.0", + "@typespec/versioning": "^0.60.1" } }, "node_modules/@apidevtools/swagger-methods": { @@ -60,9 +60,9 @@ } }, "node_modules/@babel/runtime": { - "version": "7.24.8", - "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.24.8.tgz", - "integrity": "sha512-5F7SDGs1T72ZczbRwbGO9lQi0NLjQxzl6i4lJxLxfW9U5UluCSyEJeniWvnhl3/euNiqQVbo8zruhsDfid0esA==", + "version": "7.25.6", + "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.25.6.tgz", + "integrity": "sha512-VBj9MYyDb9tuLq7yzqjgzt6Q+IBQLrGZfdjOekyEirZPHxXWoTSGUTMrpsfi58Up73d13NfYLv8HT9vmznjzhQ==", "license": "MIT", "dependencies": { "regenerator-runtime": "^0.14.0" @@ -274,9 +274,9 @@ "license": "MIT" }, "node_modules/@typespec/compiler": { - "version": "0.59.1", - "resolved": "https://registry.npmjs.org/@typespec/compiler/-/compiler-0.59.1.tgz", - "integrity": "sha512-O2ljgr6YoFaIH6a8lWc90/czdv4B2X6N9wz4WsnQnVvgO0Tj0s+3xkvp4Tv59RKMhT0f3fK6dL8oEGO32FYk1A==", + "version": "0.60.1", + "resolved": "https://registry.npmjs.org/@typespec/compiler/-/compiler-0.60.1.tgz", + "integrity": "sha512-I6Vcpvd7mBP7SI5vCBh9rZGXAtVy95BKhAd33Enw32psswiSzRpA7zdyZhOMekTOGVXNS/+E5l2PGGCzQddB4w==", "license": "MIT", "dependencies": { "@babel/code-frame": "~7.24.7", @@ -303,34 +303,34 @@ } }, "node_modules/@typespec/http": { - "version": "0.59.1", - "resolved": "https://registry.npmjs.org/@typespec/http/-/http-0.59.1.tgz", - "integrity": "sha512-Ai8oCAO+Bw1HMSZ9gOI5Od4fNn/ul4HrVtTB01xFuLK6FQj854pxhzao8ylPnr7gIRQ327FV12/QfXR87yCiYQ==", + "version": "0.60.0", + "resolved": "https://registry.npmjs.org/@typespec/http/-/http-0.60.0.tgz", + "integrity": "sha512-ktfS9vpHfltyeAaQLNAZdqrn6Per3vmB/HDH/iyudYLA5wWblT1siKvpFCMWq53CJorRO7yeOKv+Q/M26zwEtg==", "license": "MIT", "engines": { "node": ">=18.0.0" }, "peerDependencies": { - "@typespec/compiler": "~0.59.0" + "@typespec/compiler": "~0.60.0" } }, "node_modules/@typespec/openapi": { - "version": "0.59.0", - "resolved": "https://registry.npmjs.org/@typespec/openapi/-/openapi-0.59.0.tgz", - "integrity": "sha512-do1Dm5w0MuK3994gYTBg6qMfgeIxmmsDqnz3zimYKMPpbnUBi4F6/o4iCfn0Fn9kaNl+H6UlOzZpsZW9xHui1Q==", + "version": "0.60.0", + "resolved": "https://registry.npmjs.org/@typespec/openapi/-/openapi-0.60.0.tgz", + "integrity": "sha512-YVwLppgHY8r/MudHNSLSUXzdw+CIpjmb31gI2a0KDGnI6sWDwY7LSWfjGU4TY/ubt0+X0Tjoy330mTvw71YBTg==", "license": "MIT", "engines": { "node": ">=18.0.0" }, "peerDependencies": { - "@typespec/compiler": "~0.59.0", - "@typespec/http": "~0.59.0" + "@typespec/compiler": "~0.60.0", + "@typespec/http": "~0.60.0" } }, "node_modules/@typespec/openapi3": { - "version": "0.59.1", - "resolved": "https://registry.npmjs.org/@typespec/openapi3/-/openapi3-0.59.1.tgz", - "integrity": "sha512-89VbUbkWKxeFgE0w0hpVyk1UZ6ZHRxOhcAHvF5MgxQxEhs2ALXKAqapWjFQsYrLBhAUoWzdPFrJJUMbwF9kX0Q==", + "version": "0.60.0", + "resolved": "https://registry.npmjs.org/@typespec/openapi3/-/openapi3-0.60.0.tgz", + "integrity": "sha512-gvrTHZACdeQtV7GfhVOHqkyTgMFyM2nKAIiz2P83LIncMCDUc00bGKGmaBk+xpuwKtCJyxBeVpCbID31YAq96g==", "license": "MIT", "dependencies": { "@readme/openapi-parser": "~2.6.0", @@ -343,35 +343,35 @@ "node": ">=18.0.0" }, "peerDependencies": { - "@typespec/compiler": "~0.59.0", - "@typespec/http": "~0.59.1", - "@typespec/openapi": "~0.59.0", - "@typespec/versioning": "~0.59.0" + "@typespec/compiler": "~0.60.0", + "@typespec/http": "~0.60.0", + "@typespec/openapi": "~0.60.0", + "@typespec/versioning": "~0.60.0" } }, "node_modules/@typespec/rest": { - "version": "0.59.1", - "resolved": "https://registry.npmjs.org/@typespec/rest/-/rest-0.59.1.tgz", - "integrity": "sha512-uKU431jBYL2tVQWG5THA75+OtXDa1e8cMAafYK/JJRRiVRd8D/Epd8fp07dzlB8tFGrhCaGlekRMqFPFrHh2/A==", + "version": "0.60.0", + "resolved": "https://registry.npmjs.org/@typespec/rest/-/rest-0.60.0.tgz", + "integrity": "sha512-mHYubyuBvwdV2xkHrJfPwV7b/Ksyb9lA1Q/AQwpVFa7Qu1X075TBVALmH+hK3V0EdUG1CGJZ5Sw4BWgl8ZS0BA==", "license": "MIT", "engines": { "node": ">=18.0.0" }, "peerDependencies": { - "@typespec/compiler": "~0.59.0", - "@typespec/http": "~0.59.1" + "@typespec/compiler": "~0.60.0", + "@typespec/http": "~0.60.0" } }, "node_modules/@typespec/versioning": { - "version": "0.59.0", - "resolved": "https://registry.npmjs.org/@typespec/versioning/-/versioning-0.59.0.tgz", - "integrity": "sha512-aihO/ux0lLmsuYAdGVkiBflSudcZokYG42SELk1FtMFo609G3Pd7ep7hau6unBnMIceQZejB0ow5UGRupK4X5A==", + "version": "0.60.1", + "resolved": "https://registry.npmjs.org/@typespec/versioning/-/versioning-0.60.1.tgz", + "integrity": "sha512-HogYL7P9uOPoSvkLLDjF22S6E9td6EY3c6TcIHhCzDTAQoi54csikD0gNrtcCkFG0UeQk29HgQymV397j+vp4g==", "license": "MIT", "engines": { "node": ">=18.0.0" }, "peerDependencies": { - "@typespec/compiler": "~0.59.0" + "@typespec/compiler": "~0.60.0" } }, "node_modules/ajv": { @@ -514,9 +514,9 @@ "license": "MIT" }, "node_modules/escalade": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.2.tgz", - "integrity": "sha512-ErCHMCae19vR8vQGe50xIsVomy19rg6gFu3+r3jkEO46suLMWBksvVyoGgQV+jOfl84ZSOSlmv6Gxa89PmTGmA==", + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", "license": "MIT", "engines": { "node": ">=6" @@ -637,9 +637,9 @@ } }, "node_modules/ignore": { - "version": "5.3.1", - "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.1.tgz", - "integrity": "sha512-5Fytz/IraMjqpwfd34ke28PTVMjZjJG2MPn5t7OE4eUCUNf8BAa7b5WUS9/Qvr6mwOQS7Mk6vdsMno5he+T8Xw==", + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz", + "integrity": "sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==", "license": "MIT", "engines": { "node": ">= 4" @@ -761,6 +761,7 @@ "version": "4.0.8", "resolved": "https://registry.npmjs.org/micromatch/-/micromatch-4.0.8.tgz", "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", + "license": "MIT", "dependencies": { "braces": "^3.0.3", "picomatch": "^2.3.1" @@ -1059,9 +1060,9 @@ } }, "node_modules/vscode-languageserver-textdocument": { - "version": "1.0.11", - "resolved": "https://registry.npmjs.org/vscode-languageserver-textdocument/-/vscode-languageserver-textdocument-1.0.11.tgz", - "integrity": "sha512-X+8T3GoiwTVlJbicx/sIAF+yuJAqz8VvwJyoMVhwEMoEKE/fkDmrqUgDMyBECcM2A2frVZIUj5HI/ErRXCfOeA==", + "version": "1.0.12", + "resolved": "https://registry.npmjs.org/vscode-languageserver-textdocument/-/vscode-languageserver-textdocument-1.0.12.tgz", + "integrity": "sha512-cxWNPesCnQCcMPeenjKKsOCKQZ/L6Tv19DTRIGuLWe32lyzWhihGVJ/rcckZXJxfdKCFvRLS3fpBIsV/ZGX4zA==", "license": "MIT" }, "node_modules/vscode-languageserver-types": { diff --git a/typespec/package.json b/typespec/package.json index a7b3d7ee4..d424d67dc 100644 --- a/typespec/package.json +++ b/typespec/package.json @@ -1,14 +1,14 @@ { "name": "julep-typespec", - "version": "0.3.0", + "version": "0.4.0", "type": "module", "dependencies": { - "@typespec/compiler": "^0.59.1", - "@typespec/http": "^0.59.1", - "@typespec/openapi": "^0.59.0", - "@typespec/openapi3": "^0.59.1", - "@typespec/rest": "^0.59.1", - "@typespec/versioning": "^0.59.0" + "@typespec/compiler": "^0.60.1", + "@typespec/http": "^0.60.0", + "@typespec/openapi": "^0.60.0", + "@typespec/openapi3": "^0.60.0", + "@typespec/rest": "^0.60.0", + "@typespec/versioning": "^0.60.1" }, "private": true } diff --git a/typespec/sessions/models.tsp b/typespec/sessions/models.tsp index aecec4ba1..dfbb6ea41 100644 --- a/typespec/sessions/models.tsp +++ b/typespec/sessions/models.tsp @@ -62,6 +62,15 @@ model Session { /** Action to start on context window overflow */ context_overflow: ContextOverflowType | null = null; + /** Whether to forward the tool results to the model when available. + * "true" => always forward + * "false" => never forward + * null => forward if applicable (default) + * + * If a tool call is made, the tool's output will be sent back to the model as the model's input. + * If a tool call is not made, the model's output will be returned as is. */ + forward_tool_results: boolean | null = null; + ...HasId; ...HasMetadata; ...HasTimestamps; diff --git a/typespec/tasks/models.tsp b/typespec/tasks/models.tsp index 4ec23f31b..b8b115861 100644 --- a/typespec/tasks/models.tsp +++ b/typespec/tasks/models.tsp @@ -30,6 +30,24 @@ model TaskTool extends CreateToolRequest { inherited?: boolean = false; } +/** Reference to a tool by id */ +model ToolRefById { + @visibility("read", "create") + id?: uuid; +} + +/** Reference to a tool by name */ +model ToolRefByName { + @visibility("read", "create") + name?: validPythonIdentifier; +} + +/** Reference to a tool */ +model ToolRef { + @visibility("read", "create") + ref: ToolRefById | ToolRefByName; +} + /** Object describing a Task */ model Task { @visibility("read", "create") diff --git a/typespec/tasks/steps.tsp b/typespec/tasks/steps.tsp index d9be46577..3495def1b 100644 --- a/typespec/tasks/steps.tsp +++ b/typespec/tasks/steps.tsp @@ -4,6 +4,7 @@ import "../chat"; import "../common"; import "../docs"; import "../entries"; +import "../tools"; import "./step_kind.tsp"; @@ -13,6 +14,7 @@ using Chat; using Common; using Docs; using Entries; +using Tools; namespace Tasks; @@ -76,7 +78,7 @@ model ToolCallStep extends BaseWorkflowStep<"tool_call"> { model ToolCallStepDef { /** The tool to run */ - tool: toolRef; + tool: validPythonIdentifier; /** The input parameters for the tool (defaults to last step output) */ arguments: ExpressionObject | "_" = "_"; @@ -93,8 +95,26 @@ model PromptStepDef { /** The prompt to run */ prompt: JinjaTemplate | InputChatMLMessage[]; + /** The tools to use for the prompt */ + tools: "all" | (ToolRef | CreateToolRequest)[] = #[]; + + /** The tool choice for the prompt */ + tool_choice?: ToolChoiceOption; + /** Settings for the prompt */ settings?: ChatSettings; + + /** Whether to unwrap the output of the prompt step, equivalent to `response.choices[0].message.content` */ + unwrap?: boolean = false; + + /** Whether to forward the tool results to the model when available. + * "true" => always forward + * "false" => never forward + * null => forward if applicable (default) + * + * If a tool call is made, the tool's output will be used as the model's input. + * If a tool call is not made, the model's output will be used as the next step's input. */ + forward_tool_results: boolean | null = null; } model EvaluateStep extends BaseWorkflowStep<"evaluate"> { diff --git a/typespec/tools/models.tsp b/typespec/tools/models.tsp index bab70b643..b7478ee24 100644 --- a/typespec/tools/models.tsp +++ b/typespec/tools/models.tsp @@ -35,13 +35,43 @@ model FunctionDef { name?: null = null; /** Description of the function */ - description?: identifierSafeUnicode; + description?: string; /** The parameters the function accepts */ parameters?: FunctionParameters; } +/** Integration definition */ +model IntegrationDef { + /** The provider of the integration */ + provider: integrationProvider; + + /** The specific method of the integration to call */ + method?: string; + + /** Optional description of the integration */ + description?: string; + + /** The setup parameters the integration accepts */ + setup?: FunctionParameters; + + /** The arguments to pre-apply to the integration call */ + arguments?: FunctionParameters; +} + +/** System definition */ +model SystemDef { + /** The name of the system call */ + call: string; + + /** Optional description of the system call */ + description?: string; + + /** The arguments to pre-apply to the system call */ + arguments?: FunctionParameters; +} + // TODO: We should use this model for all tools, not just functions and discriminate on the type model Tool { /** Whether this tool is a `function`, `api_call`, `system` etc. (Only `function` tool supported right now) */ @@ -51,10 +81,10 @@ model Tool { name: validPythonIdentifier; /** The function to call */ - function: FunctionDef; - integration?: unknown; - system?: unknown; - api_call?: unknown; + function?: FunctionDef; + integration?: IntegrationDef; + system?: SystemDef; + api_call?: never; // TODO: Implement ...HasTimestamps; ...HasId; @@ -71,9 +101,9 @@ model NamedToolChoice { type: ToolType; function?: FunctionCallOption; - integration?: unknown; - system?: unknown; - api_call?: unknown; + integration?: never; // TODO: Implement + system?: never; // TODO: Implement + api_call?: never; // TODO: Implement } model NamedFunctionChoice extends NamedToolChoice { @@ -112,9 +142,9 @@ model ChosenToolCall { type: ToolType; function?: FunctionCallOption; - integration?: unknown; - system?: unknown; - api_call?: unknown; + integration?: never; // TODO: Implement + system?: never; // TODO: Implement + api_call?: never; // TODO: Implement ...HasId; }