diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index d252e401d..50d7fb312 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -10,6 +10,7 @@ from litellm.types.utils import Choices, ModelResponse from temporalio import activity from temporalio.exceptions import ApplicationError +from pydantic import BaseModel from ...autogen.Tools import ( BraveIntegrationDef, @@ -28,6 +29,7 @@ Tool, WeatherIntegrationDef, WikipediaIntegrationDef, + BaseIntegrationDef, ) from ...clients import ( litellm, # We dont directly import `acompletion` so we can mock it @@ -63,20 +65,33 @@ def _get_integration_arguments(tool: Tool): "remote_browser": RemoteBrowserIntegrationDef, } - integration = providers_map.get(tool.integration.provider) + integration: BaseIntegrationDef | dict[str, BaseIntegrationDef] = providers_map.get(tool.integration.provider) if isinstance(integration, dict): - integration = integration.get(tool.integration.method) + integration: BaseIntegrationDef = integration.get(tool.integration.method) - return integration.model_fields["arguments"].annotation if integration else None + properties = { + "type": "object", + "properties": {}, + "required": [], + } + + arguments: BaseModel | Any | None = integration.arguments + if not arguments: + return properties + if isinstance(arguments, BaseModel): + for fld_name, fld_annotation in arguments.model_fields.items(): + properties["properties"][fld_name] = { + "type": fld_annotation.annotation, + "description": fld_name, + } + if fld_annotation.is_required: + properties["required"].append(fld_name) -def _annotation_input_schema(annotation: Any) -> dict: - # TODO: implement - def _tool(x: annotation): - pass + elif isinstance(arguments, dict): + properties["properties"] = arguments - lc_tool: BaseTool = tool_decorator(_tool) - return lc_tool.get_input_jsonschema() + return properties def format_tool(tool: Tool) -> dict: @@ -125,8 +140,7 @@ def format_tool(tool: Tool) -> dict: formatted["function"]["parameters"] = json_schema elif tool.type == "integration" and tool.integration: - if annotation := _get_integration_arguments(tool): - formatted["function"]["parameters"] = _annotation_input_schema(annotation) + formatted["function"]["parameters"] = _get_integration_arguments(tool) elif tool.type == "api_call" and tool.api_call: formatted["function"]["parameters"] = tool.api_call.schema_