Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Jun 6, 2024
1 parent 91fa306 commit 501bef7
Show file tree
Hide file tree
Showing 19 changed files with 325 additions and 460 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@ def create_function_payload_extractor(
) -> Callable[[Any], Dict[str, Any]]:
"""
Extracts invocation payload from a given LLM completion containing function invocation.
:param arguments_field_name: The name of the field containing the function arguments.
:return: A function that extracts the function invocation details from the LLM payload.
"""

def _extract_function_invocation(payload: Any) -> Dict[str, Any]:
"""
Extract the function invocation details from the payload.
:param payload: The LLM fc payload to extract the function invocation details from.
"""
fields_and_values = _search(payload, arguments_field_name)
if fields_and_values:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import logging
from typing import Any, Callable, Dict, List, Optional

# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import jsonref

MIN_REQUIRED_OPENAPI_SPEC_VERSION = 3
Expand All @@ -19,8 +16,9 @@ def openai_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: #
"""
Converts OpenAPI specification to a list of function suitable for OpenAI LLM function calling.
See https://platform.openai.com/docs/guides/function-calling for more information about OpenAI's function schema.
:param schema: The OpenAPI specification to convert.
:return: A list of dictionaries, each representing a function definition.
:return: A list of dictionaries, each dictionary representing an OpenAI function definition.
"""
resolved_schema = jsonref.replace_refs(schema.spec_dict)
fn_definitions = _openapi_to_functions(
Expand All @@ -33,8 +31,10 @@ def anthropic_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]:
"""
Converts an OpenAPI specification to a list of function definitions for Anthropic LLM function calling.
See https://docs.anthropic.com/en/docs/tool-use for more information about Anthropic's function schema.
:param schema: The OpenAPI specification to convert.
:return: A list of dictionaries, each representing a function definition.
:return: A list of dictionaries, each dictionary representing Anthropic function definition.
"""
resolved_schema = jsonref.replace_refs(schema.spec_dict)
return _openapi_to_functions(
Expand All @@ -46,8 +46,10 @@ def cohere_converter(schema: "OpenAPISpecification") -> List[Dict[str, Any]]: #
"""
Converts an OpenAPI specification to a list of function definitions for Cohere LLM function calling.
See https://docs.cohere.com/docs/tool-use for more information about Cohere's function schema.
:param schema: The OpenAPI specification to convert.
:return: A list of dictionaries, each representing a function definition.
:return: A list of dictionaries, each representing a Cohere style function definition.
"""
resolved_schema = jsonref.replace_refs(schema.spec_dict)
return _openapi_to_functions(
Expand All @@ -62,6 +64,10 @@ def _openapi_to_functions(
) -> List[Dict[str, Any]]:
"""
Extracts functions from the OpenAPI specification, converts them into a function schema.
:param service_openapi_spec: The OpenAPI specification to extract functions from.
:param parameters_name: The name of the parameters field in the function schema.
:param parse_endpoint_fn: The function to parse the endpoint specification.
"""

# Doesn't enforce rigid spec validation because that would require a lot of dependencies
Expand Down Expand Up @@ -92,6 +98,9 @@ def _parse_endpoint_spec_openai(
) -> Dict[str, Any]:
"""
Parses an OpenAPI endpoint specification for OpenAI.
:param resolved_spec: The resolved OpenAPI specification.
:param parameters_name: The name of the parameters field in the function schema.
"""
if not isinstance(resolved_spec, dict):
logger.warning(
Expand Down Expand Up @@ -145,6 +154,9 @@ def _parse_property_attributes(
) -> Dict[str, Any]:
"""
Recursively parses the attributes of a property schema.
:param property_schema: The property schema to parse.
:param include_attributes: The attributes to include in the parsed schema.
"""
include_attributes = include_attributes or ["description", "pattern", "enum"]
schema_type = property_schema.get("type")
Expand Down Expand Up @@ -172,6 +184,9 @@ def _parse_endpoint_spec_cohere(
) -> Dict[str, Any]:
"""
Parses an endpoint specification for Cohere.
:param operation: The operation specification to parse.
:param ignored_param: ignored, left for compatibility with the OpenAI converter.
"""
function_name = operation.get("operationId")
description = operation.get("description") or operation.get("summary", "")
Expand All @@ -189,6 +204,9 @@ def _parse_endpoint_spec_cohere(
def _parse_parameters(operation: Dict[str, Any]) -> Dict[str, Any]:
"""
Parses the parameters from an operation specification.
:param operation: The operation specification to parse.
:return: A dictionary containing the parsed parameters.
"""
parameters = {}
for param in operation.get("parameters", []):
Expand Down Expand Up @@ -217,6 +235,11 @@ def _parse_schema(
) -> Dict[str, Any]: # noqa: FBT001
"""
Parses a schema part of an operation specification.
:param schema: The schema to parse.
:param required: Whether the schema is required.
:param description: The description of the schema.
:return: A dictionary containing the parsed schema.
"""
schema_type = _get_type(schema)
if schema_type == "object":
Expand Down
105 changes: 0 additions & 105 deletions haystack_experimental/components/tools/openapi/generator_factory.py

This file was deleted.

53 changes: 36 additions & 17 deletions haystack_experimental/components/tools/openapi/openapi_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from typing import Any, Dict, List, Optional, Union

from haystack import component, logging
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage, ChatRole
from haystack.lazy_imports import LazyImport

from haystack_experimental.components.tools.openapi.generator_factory import (
ChatGeneratorDescriptorManager,
)
from haystack_experimental.components.tools.openapi.openapi import (
from haystack_experimental.components.tools.openapi._openapi import (
ClientConfiguration,
LLMProvider,
OpenAPIServiceClient,
)

Expand Down Expand Up @@ -46,30 +46,28 @@ class OpenAPITool:

def __init__(
self,
model: str,
generator_api: LLMProvider,
generator_api_params: Dict[str, Any],
tool_spec: Optional[Union[str, Path]] = None,
tool_credentials: Optional[Union[str, Dict[str, Any]]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tool_credentials: Optional[str] = None,
):
"""
Initialize the OpenAPITool component.
:param model: Name of the chat generator model to use.
:param generator_api: The API provider for the chat generator.
:param generator_api_params: Parameters for the chat generator.
:param tool_spec: OpenAPI specification for the tool/service.
:param tool_credentials: Credentials for the tool/service.
:param model_kwargs: Additional arguments for the chat generator model.
"""
manager = ChatGeneratorDescriptorManager()
self.descriptor, self.chat_generator = manager.create_generator(
model_name=model, **(model_kwargs or {})
)
self.generator_api = generator_api
self.chat_generator = self._init_generator(generator_api, generator_api_params)
self.config_openapi: Optional[ClientConfiguration] = None
self.open_api_service: Optional[OpenAPIServiceClient] = None
if tool_spec:
self.config_openapi = ClientConfiguration(
openapi_spec=tool_spec,
credentials=tool_credentials,
llm_provider=self.descriptor.name,
llm_provider=generator_api
)
self.open_api_service = OpenAPIServiceClient(self.config_openapi)

Expand All @@ -78,8 +76,8 @@ def run(
self,
messages: List[ChatMessage],
fc_generator_kwargs: Optional[Dict[str, Any]] = None,
tool_spec: Optional[Union[str, Path, Dict[str, Any]]] = None,
tool_credentials: Optional[Union[dict, str]] = None,
tool_spec: Optional[Union[str, Path]] = None,
tool_credentials: Optional[str] = None,
) -> Dict[str, List[ChatMessage]]:
"""
Invokes the underlying OpenAPI service/tool with the function calling payload generated by the chat generator.
Expand All @@ -104,7 +102,7 @@ def run(
config_openapi = ClientConfiguration(
openapi_spec=tool_spec,
credentials=tool_credentials,
llm_provider=self.descriptor.name,
llm_provider=self.generator_api,
)
openapi_service = OpenAPIServiceClient(config_openapi)

Expand Down Expand Up @@ -134,4 +132,25 @@ def run(
logger.error("Error invoking OpenAPI endpoint. Error: {e}", e=str(e))
service_response = {"error": str(e)}
response_messages = [ChatMessage.from_user(json.dumps(service_response))]

return {"service_response": response_messages}

def _init_generator(self, generator_api: LLMProvider, generator_api_params: Dict[str, Any]):
"""
Initialize the chat generator based on the specified API provider and parameters.
"""
if generator_api == LLMProvider.OPENAI:
return OpenAIChatGenerator(**generator_api_params)
if generator_api == LLMProvider.ANTHROPIC:
with LazyImport("Run 'pip install anthropic-haystack'") as anthropic_import:
anthropic_import.check()
# pylint: disable=import-outside-toplevel
from haystack_integrations.components.generators.anthropic import AnthropicChatGenerator
return AnthropicChatGenerator(**generator_api_params)
if generator_api == LLMProvider.COHERE:
with LazyImport("Run 'pip install cohere-haystack'") as cohere_import:
cohere_import.check()
# pylint: disable=import-outside-toplevel
from haystack_integrations.components.generators.cohere import CohereChatGenerator
return CohereChatGenerator(**generator_api_params)
raise ValueError(f"Unsupported generator API: {generator_api}")
2 changes: 1 addition & 1 deletion test/components/tools/openapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fastapi import FastAPI
from fastapi.testclient import TestClient

from haystack_experimental.components.tools.openapi.openapi import HttpClientError
from haystack_experimental.components.tools.openapi._openapi import HttpClientError


@pytest.fixture()
Expand Down
2 changes: 1 addition & 1 deletion test/components/tools/openapi/test_openapi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi.responses import JSONResponse
from pydantic import BaseModel

from haystack_experimental.components.tools.openapi.openapi import OpenAPIServiceClient, ClientConfiguration
from haystack_experimental.components.tools.openapi._openapi import OpenAPIServiceClient, ClientConfiguration
from test.components.tools.openapi.conftest import FastAPITestClient

"""
Expand Down
Loading

0 comments on commit 501bef7

Please sign in to comment.