From 501bef75cf0adbe6f0953dc87352b81d30cf7759 Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 6 Jun 2024 17:31:54 +0200 Subject: [PATCH] PR feedback --- .../tools/openapi/{openapi.py => _openapi.py} | 469 +++++++++--------- ...d_extraction.py => _payload_extraction.py} | 5 + ...ma_conversion.py => _schema_conversion.py} | 35 +- .../tools/openapi/generator_factory.py | 105 ---- .../components/tools/openapi/openapi_tool.py | 53 +- test/components/tools/openapi/conftest.py | 2 +- .../tools/openapi/test_openapi_client.py | 2 +- .../tools/openapi/test_openapi_client_auth.py | 43 +- ...est_openapi_client_complex_request_body.py | 2 +- ...enapi_client_complex_request_body_mixed.py | 2 +- .../openapi/test_openapi_client_edge_cases.py | 2 +- .../test_openapi_client_error_handling.py | 2 +- .../tools/openapi/test_openapi_client_live.py | 22 +- .../test_openapi_client_live_anthropic.py | 11 +- .../test_openapi_client_live_cohere.py | 7 +- .../test_openapi_client_live_openai.py | 2 +- .../openapi/test_openapi_cohere_conversion.py | 2 +- .../openapi/test_openapi_openai_conversion.py | 2 +- .../tools/openapi/test_openapi_spec.py | 17 +- 19 files changed, 325 insertions(+), 460 deletions(-) rename haystack_experimental/components/tools/openapi/{openapi.py => _openapi.py} (50%) rename haystack_experimental/components/tools/openapi/{payload_extraction.py => _payload_extraction.py} (90%) rename haystack_experimental/components/tools/openapi/{schema_conversion.py => _schema_conversion.py} (85%) delete mode 100644 haystack_experimental/components/tools/openapi/generator_factory.py diff --git a/haystack_experimental/components/tools/openapi/openapi.py b/haystack_experimental/components/tools/openapi/_openapi.py similarity index 50% rename from haystack_experimental/components/tools/openapi/openapi.py rename to haystack_experimental/components/tools/openapi/_openapi.py index f8f60bdb..e171c35d 100644 --- a/haystack_experimental/components/tools/openapi/openapi.py +++ b/haystack_experimental/components/tools/openapi/_openapi.py @@ -5,8 +5,8 @@ import json import logging import os -from base64 import b64encode from dataclasses import dataclass, field +from enum import Enum from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Union from urllib.parse import urlparse @@ -14,10 +14,10 @@ import requests import yaml -from haystack_experimental.components.tools.openapi.payload_extraction import ( +from haystack_experimental.components.tools.openapi._payload_extraction import ( create_function_payload_extractor, ) -from haystack_experimental.components.tools.openapi.schema_conversion import ( +from haystack_experimental.components.tools.openapi._schema_conversion import ( anthropic_converter, cohere_converter, openai_converter, @@ -37,27 +37,69 @@ logger = logging.getLogger(__name__) -class AuthenticationStrategy: +class LLMProvider(Enum): """ - Represents an authentication strategy that can be applied to an HTTP request. + Enum for the supported LLM providers. """ + OPENAI = "openai" + ANTHROPIC = "anthropic" + COHERE = "cohere" - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): - """ - Apply the authentication strategy to the given request. - :param security_scheme: the security scheme from the OpenAPI spec. - :param request: the request to apply the authentication to. - """ +def is_valid_http_url(url: str) -> bool: + """ + Check if a URL is a valid HTTP/HTTPS URL. + + :param url: The URL to check. + :return: True if the URL is a valid HTTP/HTTPS URL, False otherwise. + """ + r = urlparse(url) + return all([r.scheme in ["http", "https"], r.netloc]) -@dataclass -class ApiKeyAuthentication(AuthenticationStrategy): - """API key authentication strategy.""" +def send_request(request: Dict[str, Any]) -> Dict[str, Any]: + """ + Send an HTTP request and return the response. - api_key: Optional[str] = None + :param request: The request to send. + :return: The response from the server. + """ + url = request["url"] + headers = {**request.get("headers", {})} + try: + response = requests.request( + request["method"], + url, + headers=headers, + params=request.get("params", {}), + json=request.get("json"), + auth=request.get("auth"), + timeout=10, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.warning("HTTP error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except requests.exceptions.RequestException as e: + logger.warning("Request error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"HTTP error occurred: {e}") from e + except Exception as e: + logger.warning("An error occurred: %s while sending request to %s", e, url) + raise HttpClientError(f"An error occurred: {e}") from e + + +# Authentication strategies +def create_api_key_auth_function(api_key: str): + """ + Create a function that applies the API key authentication strategy to a given request. - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): + :param api_key: the API key to use for authentication. + :return: a function that applies the API key authentication to a request + at the schema specified location. + """ + + def apply_api_key_auth(security_scheme: Dict[str, Any], request: Dict[str, Any]): """ Apply the API key authentication strategy to the given request. @@ -65,67 +107,18 @@ def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): :param request: the request to apply the authentication to. """ if security_scheme["in"] == "header": - request.setdefault("headers", {})[security_scheme["name"]] = self.api_key + request.setdefault("headers", {})[security_scheme["name"]] = api_key elif security_scheme["in"] == "query": - request.setdefault("params", {})[security_scheme["name"]] = self.api_key + request.setdefault("params", {})[security_scheme["name"]] = api_key elif security_scheme["in"] == "cookie": - request.setdefault("cookies", {})[security_scheme["name"]] = self.api_key + request.setdefault("cookies", {})[security_scheme["name"]] = api_key else: raise ValueError( f"Unsupported apiKey authentication location: {security_scheme['in']}, " f"must be one of 'header', 'query', or 'cookie'" ) - -@dataclass -class HTTPAuthentication(AuthenticationStrategy): - """HTTP authentication strategy.""" - - username: Optional[str] = None - password: Optional[str] = None - token: Optional[str] = None - - def __post_init__(self): - if not self.token and (not self.username or not self.password): - raise ValueError( - "For HTTP Basic Auth, both username and password must be provided. " - "For Bearer Auth, a token must be provided." - ) - - def apply_auth(self, security_scheme: Dict[str, Any], request: Dict[str, Any]): - """ - Apply the HTTP authentication strategy to the given request. - - :param security_scheme: the security scheme from the OpenAPI spec. - :param request: the request to apply the authentication to. - """ - if security_scheme["type"] == "http": - if security_scheme["scheme"].lower() == "basic": - if not self.username or not self.password: - raise ValueError( - "Username and password must be provided for Basic Auth." - ) - credentials = f"{self.username}:{self.password}" - encoded_credentials = b64encode(credentials.encode("utf-8")).decode( - "utf-8" - ) - request.setdefault("headers", {})[ - "Authorization" - ] = f"Basic {encoded_credentials}" - elif security_scheme["scheme"].lower() == "bearer": - if not self.token: - raise ValueError("Token must be provided for Bearer Auth.") - request.setdefault("headers", {})[ - "Authorization" - ] = f"Bearer {self.token}" - else: - raise ValueError( - f"Unsupported HTTP authentication scheme: {security_scheme['scheme']}" - ) - else: - raise ValueError( - "HTTPAuthentication strategy received a non-HTTP security scheme." - ) + return apply_api_key_auth class HttpClientError(Exception): @@ -134,7 +127,21 @@ class HttpClientError(Exception): @dataclass class Operation: - """Represents an operation in an OpenAPI specification.""" + """ + Represents an operation in an OpenAPI specification + + See https://spec.openapis.org/oas/latest.html#paths-object for details. + Path objects can contain multiple operations, each with a unique combination of path and method. + + Attributes: + path (str): Path of the operation. + method (str): HTTP method of the operation. + operation_dict (Dict[str, Any]): Operation details from OpenAPI spec. + spec_dict (Dict[str, Any]): The encompassing OpenAPI specification. + security_requirements (List[Dict[str, List[str]]]): Security requirements for the operation. + request_body (Dict[str, Any]): Request body details. + parameters (List[Dict[str, Any]]): Parameters for the operation. + """ path: str method: str @@ -161,23 +168,41 @@ def get_parameters( ) -> List[Dict[str, Any]]: """ Get the parameters for the operation. + + :param location: The location of the parameters to get. + :return: The parameters for the operation as a list of dictionaries. """ if location: return [param for param in self.parameters if param["in"] == location] return self.parameters - def get_server(self) -> str: + def get_server(self, server_index: int = 0) -> str: """ Get the servers for the operation. + + :param server_index: The index of the server to use. + :return: The server URL. + :raises ValueError: If no servers are found in the specification. """ servers = self.operation_dict.get("servers", []) or self.spec_dict.get( "servers", [] ) - return servers[0].get("url", "") # just use the first server from the list + if not servers: + raise ValueError("No servers found in the provided specification.") + if server_index >= len(servers): + raise ValueError( + f"Server index {server_index} is out of bounds. " + f"Only {len(servers)} servers found." + ) + return servers[server_index].get( + "url" + ) # just use the first server from the list class OpenAPISpecification: - """Represents an OpenAPI specification.""" + """ + Represents an OpenAPI specification. See https://spec.openapis.org/oas/latest.html for details. + """ def __init__(self, spec_dict: Dict[str, Any]): if not isinstance(spec_dict, Dict): @@ -196,18 +221,13 @@ def __init__(self, spec_dict: Dict[str, Any]): ) self.spec_dict = spec_dict - @classmethod - def from_dict(cls, spec_dict: Dict[str, Any]) -> "OpenAPISpecification": - """ - Create an OpenAPISpecification instance from a dictionary. - """ - parser = cls(spec_dict) - return parser - @classmethod def from_str(cls, content: str) -> "OpenAPISpecification": """ Create an OpenAPISpecification instance from a string. + + :param content: The string content of the OpenAPI specification. + :return: The OpenAPISpecification instance. """ try: loaded_spec = json.loads(content) @@ -224,6 +244,9 @@ def from_str(cls, content: str) -> "OpenAPISpecification": def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": """ Create an OpenAPISpecification instance from a file. + + :param spec_file: The file path to the OpenAPI specification. + :return: The OpenAPISpecification instance. """ with open(spec_file, encoding="utf-8") as file: content = file.read() @@ -233,6 +256,9 @@ def from_file(cls, spec_file: Union[str, Path]) -> "OpenAPISpecification": def from_url(cls, url: str) -> "OpenAPISpecification": """ Create an OpenAPISpecification instance from a URL. + + :param url: The URL to fetch the OpenAPI specification from. + :return: The OpenAPISpecification instance. """ try: response = requests.get(url, timeout=10) @@ -248,23 +274,31 @@ def find_operation_by_id( self, op_id: str, method: Optional[str] = None ) -> Operation: """ - Find an operation by operationId. + Find an Operation by operationId. + + :param op_id: The operationId of the operation. + :param method: The HTTP method of the operation. + :return: The matching operation + :raises ValueError: If no operation is found with the given operationId. """ for path, path_item in self.spec_dict.get("paths", {}).items(): op: Operation = self.get_operation_item(path, path_item, method) if op_id in op.operation_dict.get("operationId", ""): return self.get_operation_item(path, path_item, method) - raise ValueError(f"No operation found with operationId {op_id}") + raise ValueError( + f"No operation found with operationId {op_id}, method {method}" + ) def get_operation_item( self, path: str, path_item: Dict[str, Any], method: Optional[str] = None ) -> Operation: """ - Get an operation item from the OpenAPI specification. + Gets a particular Operation item from the OpenAPI specification given the path and method. :param path: The path of the operation. :param path_item: The path item from the OpenAPI specification. :param method: The HTTP method of the operation. + :return: The operation """ if method: operation_dict = path_item.get(method.lower(), {}) @@ -280,11 +314,13 @@ def get_operation_item( raise ValueError( f"Multiple operations found at path {path}, method parameter is required." ) - raise ValueError(f"No operations found at path {path}.") + raise ValueError(f"No operations found at path {path} and method {method}") def get_security_schemes(self) -> Dict[str, Dict[str, Any]]: """ Get the security schemes from the OpenAPI specification. + + :return: The security schemes as a dictionary. """ return self.spec_dict.get("components", {}).get("securitySchemes", {}) @@ -294,19 +330,15 @@ class ClientConfiguration: def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments self, - openapi_spec: Union[str, Path, Dict[str, Any]], - credentials: Optional[ - Union[str, Dict[str, Any], AuthenticationStrategy] - ] = None, + openapi_spec: Union[str, Path], + credentials: Optional[str] = None, request_sender: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, - llm_provider: Optional[str] = None, + llm_provider: Optional[LLMProvider] = None, ): # noqa: PLR0913 if isinstance(openapi_spec, (str, Path)) and os.path.isfile(openapi_spec): self.openapi_spec = OpenAPISpecification.from_file(openapi_spec) - elif isinstance(openapi_spec, dict): - self.openapi_spec = OpenAPISpecification.from_dict(openapi_spec) elif isinstance(openapi_spec, str): - if self.is_valid_http_url(openapi_spec): + if is_valid_http_url(openapi_spec): self.openapi_spec = OpenAPISpecification.from_url(openapi_spec) else: self.openapi_spec = OpenAPISpecification.from_str(openapi_spec) @@ -316,40 +348,46 @@ def __init__( # noqa: PLR0913 pylint: disable=too-many-arguments ) self.credentials = credentials - self.request_sender = request_sender - self.llm_provider = llm_provider or "openai" + self.request_sender = request_sender or send_request + self.llm_provider: LLMProvider = llm_provider or LLMProvider.OPENAI - def get_auth_config(self) -> AuthenticationStrategy: + def get_auth_function(self) -> Callable[[dict[str, Any], dict[str, Any]], Any]: """ - Get the authentication configuration. + Get the authentication function that sets a schema specified authentication to the request. + + The function takes a security scheme and a request as arguments: + `security_scheme: Dict[str, Any] - The security scheme from the OpenAPI spec.` + `request: Dict[str, Any] - The request to apply the authentication to.` + :return: The authentication function. """ - if not self.credentials: - return AuthenticationStrategy() - if isinstance(self.credentials, AuthenticationStrategy): - return self.credentials security_schemes = self.openapi_spec.get_security_schemes() + if not self.credentials: + return lambda security_scheme, request: None # No-op function if isinstance(self.credentials, str): return self._create_authentication_from_string( self.credentials, security_schemes ) - if isinstance(self.credentials, dict): - return self._create_authentication_from_dict(self.credentials) raise ValueError(f"Unsupported credentials type: {type(self.credentials)}") def get_tools_definitions(self) -> List[Dict[str, Any]]: """ Get the tools definitions used as tools LLM parameter. + + :return: The tools definitions passed to the LLM as tools parameter. """ provider_to_converter = { "anthropic": anthropic_converter, "cohere": cohere_converter, } - converter = provider_to_converter.get(self.llm_provider, openai_converter) + converter = provider_to_converter.get(self.llm_provider.value, openai_converter) return converter(self.openapi_spec) def get_payload_extractor(self): """ Get the payload extractor for the LLM provider. + + This function knows how to extract the exact function payload from the LLM generated function calling payload. + :return: The payload extractor function. """ provider_to_arguments_field_name = { "anthropic": "input", @@ -357,45 +395,102 @@ def get_payload_extractor(self): } # add more providers here # default to OpenAI "arguments" arguments_field_name = provider_to_arguments_field_name.get( - self.llm_provider, "arguments" + self.llm_provider.value, "arguments" ) return create_function_payload_extractor(arguments_field_name) def _create_authentication_from_string( self, credentials: str, security_schemes: Dict[str, Any] - ) -> AuthenticationStrategy: + ) -> Callable[[dict[str, Any], dict[str, Any]], Any]: for scheme in security_schemes.values(): if scheme["type"] == "apiKey": - return ApiKeyAuthentication(api_key=credentials) + return create_api_key_auth_function(api_key=credentials) if scheme["type"] == "http": - return HTTPAuthentication(token=credentials) + raise NotImplementedError("HTTP authentication is not yet supported.") if scheme["type"] == "oauth2": raise NotImplementedError("OAuth2 authentication is not yet supported.") raise ValueError( f"Unable to create authentication from provided credentials: {credentials}" ) - def _create_authentication_from_dict( - self, credentials: Dict[str, Any] - ) -> AuthenticationStrategy: - if "username" in credentials and "password" in credentials: - return HTTPAuthentication( - username=credentials["username"], password=credentials["password"] - ) - if "api_key" in credentials: - return ApiKeyAuthentication(api_key=credentials["api_key"]) - if "token" in credentials: - return HTTPAuthentication(token=credentials["token"]) - if "access_token" in credentials: - raise NotImplementedError("OAuth2 authentication is not yet supported.") - raise ValueError( - "Unable to create authentication from provided credentials: {credentials}" - ) - def is_valid_http_url(self, url: str) -> bool: - """Check if a URL is a valid HTTP/HTTPS URL.""" - r = urlparse(url) - return all([r.scheme in ["http", "https"], r.netloc]) +def build_request(operation: Operation, **kwargs) -> Dict[str, Any]: + """ + Build an HTTP request for the operation. + + :param operation: The operation to build the request for. + :param kwargs: The arguments to use for building the request. + :return: The HTTP request as a dictionary. + """ + path = operation.path + for parameter in operation.get_parameters("path"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + path = path.replace(f"{{{parameter['name']}}}", str(param_value)) + elif parameter.get("required", False): + raise ValueError(f"Missing required path parameter: {parameter['name']}") + url = operation.get_server() + path + # method + method = operation.method.lower() + # headers + headers = {} + for parameter in operation.get_parameters("header"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + headers[parameter["name"]] = str(param_value) + elif parameter.get("required", False): + raise ValueError(f"Missing required header parameter: {parameter['name']}") + # query params + query_params = {} + for parameter in operation.get_parameters("query"): + param_value = kwargs.get(parameter["name"], None) + if param_value: + query_params[parameter["name"]] = param_value + elif parameter.get("required", False): + raise ValueError(f"Missing required query parameter: {parameter['name']}") + + json_payload = None + request_body = operation.request_body + if request_body: + content = request_body.get("content", {}) + if "application/json" in content: + json_payload = {**kwargs} + else: + raise NotImplementedError("Request body content type not supported") + return { + "url": url, + "method": method, + "headers": headers, + "params": query_params, + "json": json_payload, + } + + +def apply_authentication( + auth_strategy: Callable[[Dict[str, Any], Dict[str, Any]], Any], + operation: Operation, + request: Dict[str, Any], +): + """ + Apply the authentication strategy to the given request. + + :param auth_strategy: The authentication strategy to apply. + This is a function that takes a security scheme and a request as arguments (at runtime) + and applies the authentication + :param operation: The operation to apply the authentication to. + :param request: The request to apply the authentication to. + """ + security_requirements = operation.security_requirements + security_schemes = operation.spec_dict.get("components", {}).get( + "securitySchemes", {} + ) + if security_requirements: + for requirement in security_requirements: + for scheme_name in requirement: + if scheme_name in security_schemes: + security_scheme = security_schemes[scheme_name] + auth_strategy(security_scheme, request) + break class OpenAPIServiceClient: @@ -405,7 +500,7 @@ class OpenAPIServiceClient: def __init__(self, client_config: ClientConfiguration): self.client_config = client_config - self.request_sender = client_config.request_sender or self._request_sender() + self.request_sender = client_config.request_sender def invoke(self, function_payload: Any) -> Any: """ @@ -423,118 +518,12 @@ def invoke(self, function_payload: Any) -> Any: f"Failed to extract function invocation payload from {function_payload}" ) # fn_invocation_payload, if not empty, guaranteed to have "name" and "arguments" keys from here on - operation = self.client_config.openapi_spec.find_operation_by_id(fn_invocation_payload.get("name")) - request = self._build_request(operation, **fn_invocation_payload.get("arguments")) - self._apply_authentication(self.client_config.get_auth_config(), operation, request) - return self.request_sender(request) - - def _request_sender(self) -> Callable[[Dict[str, Any]], Dict[str, Any]]: - """ - Returns a callable that sends the request using the HTTP client. - """ - - def send_request(request: Dict[str, Any]) -> Dict[str, Any]: - url = request["url"] - headers = {**request.get("headers", {})} - try: - response = requests.request( - request["method"], - url, - headers=headers, - params=request.get("params", {}), - json=request.get("json"), - auth=request.get("auth"), - timeout=10, - ) - response.raise_for_status() - return response.json() - except requests.exceptions.HTTPError as e: - logger.warning( - "HTTP error occurred: %s while sending request to %s", e, url - ) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except requests.exceptions.RequestException as e: - logger.warning( - "Request error occurred: %s while sending request to %s", e, url - ) - raise HttpClientError(f"HTTP error occurred: {e}") from e - except Exception as e: - logger.warning( - "An error occurred: %s while sending request to %s", e, url - ) - raise HttpClientError(f"An error occurred: {e}") from e - - return send_request - - def _build_request(self, operation: Operation, **kwargs) -> Any: - # url - path = operation.path - for parameter in operation.get_parameters("path"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - path = path.replace(f"{{{parameter['name']}}}", str(param_value)) - elif parameter.get("required", False): - raise ValueError( - f"Missing required path parameter: {parameter['name']}" - ) - url = operation.get_server() + path - # method - method = operation.method.lower() - # headers - headers = {} - for parameter in operation.get_parameters("header"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - headers[parameter["name"]] = str(param_value) - elif parameter.get("required", False): - raise ValueError( - f"Missing required header parameter: {parameter['name']}" - ) - # query params - query_params = {} - for parameter in operation.get_parameters("query"): - param_value = kwargs.get(parameter["name"], None) - if param_value: - query_params[parameter["name"]] = param_value - elif parameter.get("required", False): - raise ValueError( - f"Missing required query parameter: {parameter['name']}" - ) - - json_payload = None - request_body = operation.request_body - if request_body: - content = request_body.get("content", {}) - if "application/json" in content: - json_payload = {**kwargs} - else: - raise NotImplementedError("Request body content type not supported") - return { - "url": url, - "method": method, - "headers": headers, - "params": query_params, - "json": json_payload, - } - - def _apply_authentication( - self, - auth: AuthenticationStrategy, - operation: Operation, - request: Dict[str, Any], - ): - auth_config = auth or AuthenticationStrategy() - security_requirements = operation.security_requirements - security_schemes = operation.spec_dict.get("components", {}).get( - "securitySchemes", {} + operation = self.client_config.openapi_spec.find_operation_by_id( + fn_invocation_payload.get("name") ) - if security_requirements: - for requirement in security_requirements: - for scheme_name in requirement: - if scheme_name in security_schemes: - security_scheme = security_schemes[scheme_name] - auth_config.apply_auth(security_scheme, request) - break + request = build_request(operation, **fn_invocation_payload.get("arguments")) + apply_authentication(self.client_config.get_auth_function(), operation, request) + return self.request_sender(request) class OpenAPIClientError(Exception): diff --git a/haystack_experimental/components/tools/openapi/payload_extraction.py b/haystack_experimental/components/tools/openapi/_payload_extraction.py similarity index 90% rename from haystack_experimental/components/tools/openapi/payload_extraction.py rename to haystack_experimental/components/tools/openapi/_payload_extraction.py index 416841bf..6247c56a 100644 --- a/haystack_experimental/components/tools/openapi/payload_extraction.py +++ b/haystack_experimental/components/tools/openapi/_payload_extraction.py @@ -12,11 +12,16 @@ def create_function_payload_extractor( ) -> Callable[[Any], Dict[str, Any]]: """ Extracts invocation payload from a given LLM completion containing function invocation. + + :param arguments_field_name: The name of the field containing the function arguments. + :return: A function that extracts the function invocation details from the LLM payload. """ def _extract_function_invocation(payload: Any) -> Dict[str, Any]: """ Extract the function invocation details from the payload. + + :param payload: The LLM fc payload to extract the function invocation details from. """ fields_and_values = _search(payload, arguments_field_name) if fields_and_values: diff --git a/haystack_experimental/components/tools/openapi/schema_conversion.py b/haystack_experimental/components/tools/openapi/_schema_conversion.py similarity index 85% rename from haystack_experimental/components/tools/openapi/schema_conversion.py rename to haystack_experimental/components/tools/openapi/_schema_conversion.py index b74bc1cf..abc23c9c 100644 --- a/haystack_experimental/components/tools/openapi/schema_conversion.py +++ b/haystack_experimental/components/tools/openapi/_schema_conversion.py @@ -5,9 +5,6 @@ import logging from typing import Any, Callable, Dict, List, Optional -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 import jsonref MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3 @@ -19,8 +16,9 @@ def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # """ Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling. + See https://platform.openai.com/docs/guides/function-calling for more information about OpenAI's function schema. :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a function definition. + :return: A list of dictionaries, each dictionary representing an OpenAI function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) fn_definitions = _openapi_to_functions( @@ -33,8 +31,10 @@ def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: """ Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling. + See https://docs.anthropic.com/en/docs/tool-use for more information about Anthropic's function schema. + :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a function definition. + :return: A list of dictionaries, each dictionary representing Anthropic function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) return _openapi_to_functions( @@ -46,8 +46,10 @@ def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: # """ Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling. + See https://docs.cohere.com/docs/tool-use for more information about Cohere's function schema. + :param schema: The OpenAPI specification to convert. - :return: A list of dictionaries, each representing a function definition. + :return: A list of dictionaries, each representing a Cohere style function definition. """ resolved_schema = jsonref.replace_refs(schema.spec_dict) return _openapi_to_functions( @@ -62,6 +64,10 @@ def _openapi_to_functions( ) -> List[Dict[str, Any]]: """ Extracts functions from the OpenAPI specification, converts them into a function schema. + + :param service_openapi_spec: The OpenAPI specification to extract functions from. + :param parameters_name: The name of the parameters field in the function schema. + :param parse_endpoint_fn: The function to parse the endpoint specification. """ # Doesn't enforce rigid spec validation because that would require a lot of dependencies @@ -92,6 +98,9 @@ def _parse_endpoint_spec_openai( ) -> Dict[str, Any]: """ Parses an OpenAPI endpoint specification for OpenAI. + + :param resolved_spec: The resolved OpenAPI specification. + :param parameters_name: The name of the parameters field in the function schema. """ if not isinstance(resolved_spec, dict): logger.warning( @@ -145,6 +154,9 @@ def _parse_property_attributes( ) -> Dict[str, Any]: """ Recursively parses the attributes of a property schema. + + :param property_schema: The property schema to parse. + :param include_attributes: The attributes to include in the parsed schema. """ include_attributes = include_attributes or ["description", "pattern", "enum"] schema_type = property_schema.get("type") @@ -172,6 +184,9 @@ def _parse_endpoint_spec_cohere( ) -> Dict[str, Any]: """ Parses an endpoint specification for Cohere. + + :param operation: The operation specification to parse. + :param ignored_param: ignored, left for compatibility with the OpenAI converter. """ function_name = operation.get("operationId") description = operation.get("description") or operation.get("summary", "") @@ -189,6 +204,9 @@ def _parse_endpoint_spec_cohere( def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: """ Parses the parameters from an operation specification. + + :param operation: The operation specification to parse. + :return: A dictionary containing the parsed parameters. """ parameters = {} for param in operation.get("parameters", []): @@ -217,6 +235,11 @@ def _parse_schema( ) -> Dict[str, Any]: # noqa: FBT001 """ Parses a schema part of an operation specification. + + :param schema: The schema to parse. + :param required: Whether the schema is required. + :param description: The description of the schema. + :return: A dictionary containing the parsed schema. """ schema_type = _get_type(schema) if schema_type == "object": diff --git a/haystack_experimental/components/tools/openapi/generator_factory.py b/haystack_experimental/components/tools/openapi/generator_factory.py deleted file mode 100644 index 28401863..00000000 --- a/haystack_experimental/components/tools/openapi/generator_factory.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-FileCopyrightText: 2022-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - -import importlib -import re -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple - - -@dataclass -class ChatGeneratorDescriptor: - """ - Dataclass to describe a Chat Generator - """ - - class_path: str - patterns: List[re.Pattern] - name: str - model_name: str - - -class ChatGeneratorDescriptorManager: - """ - Class to manage Chat Generator Descriptors - """ - - def __init__(self): - self._descriptors: Dict[str, ChatGeneratorDescriptor] = {} - self._register_default_descriptors() - - def _register_default_descriptors(self): - """ - Register default Chat Generator Descriptors. - """ - default_descriptors = [ - ChatGeneratorDescriptor( - class_path="haystack.components.generators.chat.openai.OpenAIChatGenerator", - patterns=[re.compile(r"^gpt.*")], - name="openai", - model_name="gpt-3.5-turbo", - ), - ChatGeneratorDescriptor( - class_path="haystack_integrations.components.generators.anthropic.AnthropicChatGenerator", - patterns=[re.compile(r"^claude.*")], - name="anthropic", - model_name="claude-1", - ), - ChatGeneratorDescriptor( - class_path="haystack_integrations.components.generators.cohere.CohereChatGenerator", - patterns=[re.compile(r"^command-r.*")], - name="cohere", - model_name="command-r", - ), - ] - - for descriptor in default_descriptors: - self.register_descriptor(descriptor) - - def _load_class(self, full_class_path: str): - """ - Load a class from a string representation of its path e.g. "module.submodule.class_name" - """ - module_path, _, class_name = full_class_path.rpartition(".") - module = importlib.import_module(module_path) - return getattr(module, class_name) - - def register_descriptor(self, descriptor: ChatGeneratorDescriptor): - """ - Register a new Chat Generator Descriptor. - """ - if descriptor.name in self._descriptors: - raise ValueError(f"Descriptor {descriptor.name} already exists.") - - self._descriptors[descriptor.name] = descriptor - - def _infer_descriptor(self, model_name: str) -> Optional[ChatGeneratorDescriptor]: - """ - Infer the descriptor based on the model name. - """ - for descriptor in self._descriptors.values(): - if any(pattern.match(model_name) for pattern in descriptor.patterns): - return descriptor - return None - - def create_generator( - self, model_name: str, descriptor_name: Optional[str] = None, **model_kwargs - ) -> Tuple[ChatGeneratorDescriptor, Any]: - """ - Create ChatGenerator instance based on the model name and descriptor. - """ - if descriptor_name: - descriptor = self._descriptors.get(descriptor_name) - if not descriptor: - raise ValueError(f"Invalid descriptor name: {descriptor_name}") - else: - descriptor = self._infer_descriptor(model_name) - if not descriptor: - raise ValueError( - f"Could not infer descriptor for model name: {model_name}" - ) - - return descriptor, self._load_class(descriptor.class_path)( - model=model_name, **(model_kwargs or {}) - ) diff --git a/haystack_experimental/components/tools/openapi/openapi_tool.py b/haystack_experimental/components/tools/openapi/openapi_tool.py index b0ea10a2..cb26cc5e 100644 --- a/haystack_experimental/components/tools/openapi/openapi_tool.py +++ b/haystack_experimental/components/tools/openapi/openapi_tool.py @@ -7,13 +7,13 @@ from typing import Any, Dict, List, Optional, Union from haystack import component, logging +from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, ChatRole +from haystack.lazy_imports import LazyImport -from haystack_experimental.components.tools.openapi.generator_factory import ( - ChatGeneratorDescriptorManager, -) -from haystack_experimental.components.tools.openapi.openapi import ( +from haystack_experimental.components.tools.openapi._openapi import ( ClientConfiguration, + LLMProvider, OpenAPIServiceClient, ) @@ -46,30 +46,28 @@ class OpenAPITool: def __init__( self, - model: str, + generator_api: LLMProvider, + generator_api_params: Dict[str, Any], tool_spec: Optional[Union[str, Path]] = None, - tool_credentials: Optional[Union[str, Dict[str, Any]]] = None, - model_kwargs: Optional[Dict[str, Any]] = None, + tool_credentials: Optional[str] = None, ): """ Initialize the OpenAPITool component. - :param model: Name of the chat generator model to use. + :param generator_api: The API provider for the chat generator. + :param generator_api_params: Parameters for the chat generator. :param tool_spec: OpenAPI specification for the tool/service. :param tool_credentials: Credentials for the tool/service. - :param model_kwargs: Additional arguments for the chat generator model. """ - manager = ChatGeneratorDescriptorManager() - self.descriptor, self.chat_generator = manager.create_generator( - model_name=model, **(model_kwargs or {}) - ) + self.generator_api = generator_api + self.chat_generator = self._init_generator(generator_api, generator_api_params) self.config_openapi: Optional[ClientConfiguration] = None self.open_api_service: Optional[OpenAPIServiceClient] = None if tool_spec: self.config_openapi = ClientConfiguration( openapi_spec=tool_spec, credentials=tool_credentials, - llm_provider=self.descriptor.name, + llm_provider=generator_api ) self.open_api_service = OpenAPIServiceClient(self.config_openapi) @@ -78,8 +76,8 @@ def run( self, messages: List[ChatMessage], fc_generator_kwargs: Optional[Dict[str, Any]] = None, - tool_spec: Optional[Union[str, Path, Dict[str, Any]]] = None, - tool_credentials: Optional[Union[dict, str]] = None, + tool_spec: Optional[Union[str, Path]] = None, + tool_credentials: Optional[str] = None, ) -> Dict[str, List[ChatMessage]]: """ Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator. @@ -104,7 +102,7 @@ def run( config_openapi = ClientConfiguration( openapi_spec=tool_spec, credentials=tool_credentials, - llm_provider=self.descriptor.name, + llm_provider=self.generator_api, ) openapi_service = OpenAPIServiceClient(config_openapi) @@ -134,4 +132,25 @@ def run( logger.error("Error invoking OpenAPI endpoint. Error: {e}", e=str(e)) service_response = {"error": str(e)} response_messages = [ChatMessage.from_user(json.dumps(service_response))] + return {"service_response": response_messages} + + def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict[str, Any]): + """ + Initialize the chat generator based on the specified API provider and parameters. + """ + if generator_api == LLMProvider.OPENAI: + return OpenAIChatGenerator(**generator_api_params) + if generator_api == LLMProvider.ANTHROPIC: + with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_import: + anthropic_import.check() + # pylint: disable=import-outside-toplevel + from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator + return AnthropicChatGenerator(**generator_api_params) + if generator_api == LLMProvider.COHERE: + with LazyImport("Run 'pip install cohere-haystack'") as cohere_import: + cohere_import.check() + # pylint: disable=import-outside-toplevel + from haystack_integrations.components.generators.cohere import CohereChatGenerator + return CohereChatGenerator(**generator_api_params) + raise ValueError(f"Unsupported generator API: {generator_api}") diff --git a/test/components/tools/openapi/conftest.py b/test/components/tools/openapi/conftest.py index 2df4f76c..89ec74d4 100644 --- a/test/components/tools/openapi/conftest.py +++ b/test/components/tools/openapi/conftest.py @@ -10,7 +10,7 @@ from fastapi import FastAPI from fastapi.testclient import TestClient -from haystack_experimental.components.tools.openapi.openapi import HttpClientError +from haystack_experimental.components.tools.openapi._openapi import HttpClientError @pytest.fixture() diff --git a/test/components/tools/openapi/test_openapi_client.py b/test/components/tools/openapi/test_openapi_client.py index 4842b75e..c622642e 100644 --- a/test/components/tools/openapi/test_openapi_client.py +++ b/test/components/tools/openapi/test_openapi_client.py @@ -7,7 +7,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient """ diff --git a/test/components/tools/openapi/test_openapi_client_auth.py b/test/components/tools/openapi/test_openapi_client_auth.py index de91855e..ab6205e8 100644 --- a/test/components/tools/openapi/test_openapi_client_auth.py +++ b/test/components/tools/openapi/test_openapi_client_auth.py @@ -15,8 +15,7 @@ HTTPBearer, ) -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ApiKeyAuthentication, \ - HTTPAuthentication, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient API_KEY = "secret_api_key" @@ -140,7 +139,7 @@ class TestOpenAPIAuth: def test_greet_api_key_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", request_sender=FastAPITestClient(create_greet_api_key_auth_app()), - credentials=ApiKeyAuthentication(API_KEY)) + credentials=API_KEY) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -153,26 +152,10 @@ def test_greet_api_key_auth(self, test_files_path): response = client.invoke(payload) assert response == {"greeting": "Hello, John from api_key_auth, using secret_api_key"} - def test_greet_basic_auth(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - request_sender=FastAPITestClient(create_greet_basic_auth_app()), - credentials=HTTPAuthentication(BASIC_AUTH_USERNAME, BASIC_AUTH_PASSWORD)) - client = OpenAPIServiceClient(config) - payload = { - "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", - "function": { - "arguments": '{"name": "John"}', - "name": "greetBasicAuth", - }, - "type": "function", - } - response = client.invoke(payload) - assert response == {"greeting": "Hello, John from basic_auth, using admin"} - def test_greet_api_key_query_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", request_sender=FastAPITestClient(create_greet_api_key_query_app()), - credentials=ApiKeyAuthentication(API_KEY_QUERY)) + credentials=API_KEY_QUERY) client = OpenAPIServiceClient(config) payload = { "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", @@ -189,7 +172,7 @@ def test_greet_api_key_cookie_auth(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", request_sender=FastAPITestClient(create_greet_api_key_cookie_app()), - credentials=ApiKeyAuthentication(API_KEY_COOKIE)) + credentials=API_KEY_COOKIE) client = OpenAPIServiceClient(config) payload = { @@ -201,20 +184,4 @@ def test_greet_api_key_cookie_auth(self, test_files_path): "type": "function", } response = client.invoke(payload) - assert response == {"greeting": "Hello, John from api_key_cookie_auth, using secret_api_key_cookie"} - - def test_greet_bearer_auth(self, test_files_path): - config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "openapi_greeting_service.yml", - request_sender=FastAPITestClient(create_greet_bearer_auth_app()), - credentials=HTTPAuthentication(token=BEARER_TOKEN)) - client = OpenAPIServiceClient(config) - payload = { - "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", - "function": { - "arguments": '{"name": "John"}', - "name": "greetBearerAuth", - }, - "type": "function", - } - response = client.invoke(payload) - assert response == {"greeting": "Hello, John from bearer_auth, using secret_bearer_token"} + assert response == {"greeting": "Hello, John from api_key_cookie_auth, using secret_api_key_cookie"} \ No newline at end of file diff --git a/test/components/tools/openapi/test_openapi_client_complex_request_body.py b/test/components/tools/openapi/test_openapi_client_complex_request_body.py index e6007efb..4b4b5a20 100644 --- a/test/components/tools/openapi/test_openapi_client_complex_request_body.py +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body.py @@ -11,7 +11,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient diff --git a/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py index 33e95387..bcb5cf48 100644 --- a/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py +++ b/test/components/tools/openapi/test_openapi_client_complex_request_body_mixed.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient diff --git a/test/components/tools/openapi/test_openapi_client_edge_cases.py b/test/components/tools/openapi/test_openapi_client_edge_cases.py index 912fe7d5..4dfe7a06 100644 --- a/test/components/tools/openapi/test_openapi_client_edge_cases.py +++ b/test/components/tools/openapi/test_openapi_client_edge_cases.py @@ -5,7 +5,7 @@ import pytest -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient diff --git a/test/components/tools/openapi/test_openapi_client_error_handling.py b/test/components/tools/openapi/test_openapi_client_error_handling.py index 5b6e8dc4..a1d730aa 100644 --- a/test/components/tools/openapi/test_openapi_client_error_handling.py +++ b/test/components/tools/openapi/test_openapi_client_error_handling.py @@ -8,7 +8,7 @@ import pytest from fastapi import FastAPI, HTTPException -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, HttpClientError, \ +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, HttpClientError, \ ClientConfiguration from test.components.tools.openapi.conftest import FastAPITestClient diff --git a/test/components/tools/openapi/test_openapi_client_live.py b/test/components/tools/openapi/test_openapi_client_live.py index 1ee5b9f4..02ae0b74 100644 --- a/test/components/tools/openapi/test_openapi_client_live.py +++ b/test/components/tools/openapi/test_openapi_client_live.py @@ -7,7 +7,7 @@ import pytest import yaml -from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration +from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration class TestClientLive: @@ -28,26 +28,6 @@ def test_serperdev(self, test_files_path): response = serper_api.invoke(payload) assert "invention" in str(response) - @pytest.mark.skipif("SERPERDEV_API_KEY" not in os.environ, reason="SERPERDEV_API_KEY not set") - @pytest.mark.integration - def test_serperdev_load_spec_first(self, test_files_path): - with open(test_files_path / "yaml" / "serper.yml") as file: - loaded_spec = yaml.safe_load(file) - - # use builder with dict spec - config = ClientConfiguration(openapi_spec=loaded_spec, credentials=os.getenv("SERPERDEV_API_KEY")) - serper_api = OpenAPIServiceClient(config) - payload = { - "id": "call_NJr1NBz2Th7iUWJpRIJZoJIA", - "function": { - "arguments": '{"q": "Who was Nikola Tesla?"}', - "name": "serperdev_search", - }, - "type": "function", - } - response = serper_api.invoke(payload) - assert "invention" in str(response) - @pytest.mark.integration def test_github(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml") diff --git a/test/components/tools/openapi/test_openapi_client_live_anthropic.py b/test/components/tools/openapi/test_openapi_client_live_anthropic.py index ede1503a..91ca6334 100644 --- a/test/components/tools/openapi/test_openapi_client_live_anthropic.py +++ b/test/components/tools/openapi/test_openapi_client_live_anthropic.py @@ -7,7 +7,8 @@ import anthropic import pytest -from haystack_experimental.components.tools.openapi.openapi import ClientConfiguration, OpenAPIServiceClient +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient, \ + LLMProvider class TestClientLiveAnthropic: @@ -18,9 +19,9 @@ class TestClientLiveAnthropic: def test_serperdev(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", credentials=os.getenv("SERPERDEV_API_KEY"), - llm_provider="anthropic") + llm_provider=LLMProvider.ANTHROPIC) client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) - response = client.beta.tools.messages.create( + response = client.messages.create( model="claude-3-opus-20240229", max_tokens=1024, tools=config.get_tools_definitions(), @@ -41,10 +42,10 @@ def test_serperdev(self, test_files_path): @pytest.mark.integration def test_github(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml", - llm_provider="anthropic") + llm_provider=LLMProvider.ANTHROPIC) client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) - response = client.beta.tools.messages.create( + response = client.messages.create( model="claude-3-opus-20240229", max_tokens=1024, tools=config.get_tools_definitions(), diff --git a/test/components/tools/openapi/test_openapi_client_live_cohere.py b/test/components/tools/openapi/test_openapi_client_live_cohere.py index 4cd87631..891bb5fa 100644 --- a/test/components/tools/openapi/test_openapi_client_live_cohere.py +++ b/test/components/tools/openapi/test_openapi_client_live_cohere.py @@ -6,7 +6,8 @@ import cohere import pytest -from haystack_experimental.components.tools.openapi.openapi import ClientConfiguration, OpenAPIServiceClient +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient, \ + LLMProvider # Copied from Cohere's documentation preamble = """ @@ -30,7 +31,7 @@ class TestClientLiveCohere: def test_serperdev(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "serper.yml", credentials=os.getenv("SERPERDEV_API_KEY"), - llm_provider="cohere") + llm_provider=LLMProvider.COHERE) client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) response = client.chat( model="command-r", @@ -53,7 +54,7 @@ def test_serperdev(self, test_files_path): @pytest.mark.integration def test_github(self, test_files_path): config = ClientConfiguration(openapi_spec=test_files_path / "yaml" / "github_compare.yml", - llm_provider="cohere") + llm_provider=LLMProvider.COHERE) client = cohere.Client(api_key=os.getenv("COHERE_API_KEY")) response = client.chat( diff --git a/test/components/tools/openapi/test_openapi_client_live_openai.py b/test/components/tools/openapi/test_openapi_client_live_openai.py index 4f2e2c39..c95f6614 100644 --- a/test/components/tools/openapi/test_openapi_client_live_openai.py +++ b/test/components/tools/openapi/test_openapi_client_live_openai.py @@ -7,7 +7,7 @@ import pytest from openai import OpenAI -from haystack_experimental.components.tools.openapi.openapi import ClientConfiguration, OpenAPIServiceClient +from haystack_experimental.components.tools.openapi._openapi import ClientConfiguration, OpenAPIServiceClient class TestClientLiveOpenAPI: diff --git a/test/components/tools/openapi/test_openapi_cohere_conversion.py b/test/components/tools/openapi/test_openapi_cohere_conversion.py index dd84b9b5..5837c040 100644 --- a/test/components/tools/openapi/test_openapi_cohere_conversion.py +++ b/test/components/tools/openapi/test_openapi_cohere_conversion.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from haystack_experimental.components.tools.openapi.openapi import OpenAPISpecification, cohere_converter +from haystack_experimental.components.tools.openapi._openapi import OpenAPISpecification, cohere_converter class TestOpenAPISchemaConversion: diff --git a/test/components/tools/openapi/test_openapi_openai_conversion.py b/test/components/tools/openapi/test_openapi_openai_conversion.py index 3bf7cc19..090f9c4a 100644 --- a/test/components/tools/openapi/test_openapi_openai_conversion.py +++ b/test/components/tools/openapi/test_openapi_openai_conversion.py @@ -4,7 +4,7 @@ import pytest -from haystack_experimental.components.tools.openapi.openapi import openai_converter, anthropic_converter, OpenAPISpecification +from haystack_experimental.components.tools.openapi._openapi import openai_converter, anthropic_converter, OpenAPISpecification class TestOpenAPISchemaConversion: diff --git a/test/components/tools/openapi/test_openapi_spec.py b/test/components/tools/openapi/test_openapi_spec.py index 93e2a972..4e38de2d 100644 --- a/test/components/tools/openapi/test_openapi_spec.py +++ b/test/components/tools/openapi/test_openapi_spec.py @@ -4,26 +4,11 @@ import pytest -from haystack_experimental.components.tools.openapi.openapi import OpenAPISpecification +from haystack_experimental.components.tools.openapi._openapi import OpenAPISpecification class TestOpenAPISpecification: - # can be initialized from a dictionary - def test_initialized_from_dictionary(self): - spec_dict = { - "openapi": "3.0.0", - "info": {"title": "Test API", "version": "1.0.0"}, - "servers": [{"url": "https://api.example.com"}], - "paths": { - "/users": { - "get": {"summary": "Get all users", "responses": {"200": {"description": "Successful response"}}} - } - }, - } - openapi_spec = OpenAPISpecification.from_dict(spec_dict) - assert openapi_spec.spec_dict == spec_dict - # can be initialized from a string def test_initialized_from_string(self): content = """