Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Adding function calling to AnyscaleLLM (#227)
Browse files Browse the repository at this point in the history
* Adding function calling to AnyscaleLLM

* remove unuse imports

* Add unit test and anyscale.yaml

* add model check

* update test_anyscale

* flake8

* [llm] Improve error handling for function calling

If an LLM doesn't support function calling - it will raise NotImplementedError, which will be caught by FunctionCallingQueryGenerator

* [test] Test Anyscale as another paramter in test_openai

Instead of duplicating the full test suite, we've changed all OpenAI-like models to run parameterized in test_openai.py

* anyscale.yaml config file improvements

---------

Co-authored-by: ilai <[email protected]>
  • Loading branch information
kylehh and igiloh-pinecone authored Jan 21, 2024
1 parent 44cfad1 commit ec5d304
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 202 deletions.
14 changes: 7 additions & 7 deletions config/anyscale.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ chat_engine:
params:
max_prompt_tokens: 2048 # The maximum number of tokens to use for input prompt to the LLM.
llm: &llm
type: AnyscaleLLM # Options: [OpenAILLM, AnyscaleLLM]
type: AnyscaleLLM
params:
model_name: meta-llama/Llama-2-7b-chat-hf # The name of the model to use.

query_builder:
# --------------------------------------------------------------------
# Configuration for the QueryBuilder subcomponent of the chat engine.
# Since Anyscale's LLM endpoint currently doesn't support function calling, we will use the InstructionQueryGenerator
# --------------------------------------------------------------------
type: InstructionQueryGenerator # Options: [InstructionQueryGenerator, LastMessageQueryGenerator]
type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator, InstructionQueryGenerator]
llm:
type: AnyscaleLLM
params:
model_name: mistralai/Mistral-7B-Instruct-v0.1

context_engine:
# -------------------------------------------------------------------------------------------------------------
Expand All @@ -43,7 +43,7 @@ chat_engine:
# Configuration for the RecordEncoder subcomponent of the knowledge base.
# Use Anyscale's Embedding endpoint for dense encoding
# --------------------------------------------------------------------------
type: AnyscaleRecordEncoder # Options: [OpenAIRecordEncoder, AnyscaleRecordEncoder]
type: AnyscaleRecordEncoder
params:
model_name: # The name of the model to use for encoding
thenlper/gte-large
Expand Down
17 changes: 7 additions & 10 deletions src/canopy/chat_engine/query_generator/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,13 @@ def generate(self,
chat_history=messages,
function=self._function
)
except RuntimeError as e:
if "function calling" in str(e):
raise RuntimeError(
"FunctionCallingQueryGenerator requires an LLM that supports "
"function calling. Please provide a different LLM, "
"or alternatively select a different QueryGenerator class. "
f"Received the following error from LLM:\n{e}"
) from e

raise
except NotImplementedError as e:
raise RuntimeError(
"FunctionCallingQueryGenerator requires an LLM that supports "
"function calling. Please provide a different LLM, "
"or alternatively select a different QueryGenerator class. "
f"Received the following error from LLM:\n{e}"
) from e

return [Query(text=q)
for q in arguments["queries"]]
Expand Down
20 changes: 19 additions & 1 deletion src/canopy/llm/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from canopy.models.data_models import Messages

ANYSCALE_BASE_URL = "https://api.endpoints.anyscale.com/v1"
FUNCTION_MODEL_LIST = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
]


class AnyscaleLLM(OpenAILLM):
Expand Down Expand Up @@ -42,7 +46,21 @@ def enforced_function_call(
max_tokens: Optional[int] = None,
model_params: Optional[dict] = None,
) -> dict:
raise NotImplementedError()
model = self.model_name
if model_params and "model" in model_params:
model = model_params["model"]
if model not in FUNCTION_MODEL_LIST:
raise NotImplementedError(
f"Model {model} doesn't support function calling. "
"To use function calling capability, please select a different model.\n"
"Pleaes check following link for details: "
"https://docs.endpoints.anyscale.com/guides/function-calling"
)
else:
return super().enforced_function_call(
system_prompt, chat_history, function,
max_tokens=max_tokens, model_params=model_params
)

def aenforced_function_call(self,
system_prompt: str,
Expand Down
14 changes: 7 additions & 7 deletions src/canopy/llm/azure_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,34 +72,34 @@ def available_models(self):
"Azure OpenAI LLM does not support listing available models"
)

def _handle_chat_error(self, e):
def _handle_chat_error(self, e, is_function_call=False):
if isinstance(e, openai.AuthenticationError):
raise RuntimeError(
"Failed to connect to Azure OpenAI, please make sure that the "
"AZURE_OPENAI_API_KEY environment variable is set correctly. "
f"Underlying Error:\n{self._format_openai_error(e)}"
)
) from e
elif isinstance(e, openai.APIConnectionError):
raise RuntimeError(
f"Failed to connect to your Azure OpenAI endpoint, please make sure "
f"that the provided endpoint {os.getenv('AZURE_OPENAI_ENDPOINT')} "
f"is correct. Underlying Error:\n{self._format_openai_error(e)}"
)
) from e
elif isinstance(e, openai.NotFoundError):
if e.type and 'invalid' in e.type:
raise RuntimeError(
if e.type and 'invalid' in e.type and is_function_call:
raise NotImplementedError(
f"It seems that you are trying to use OpenAI's `function calling` "
f"or `tools` features. Please note that Azure OpenAI only supports "
f"function calling for specific models and API versions. More "
f"details in: https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling. " # noqa: E501
f"Underlying Error:\n{self._format_openai_error(e)}"
)
) from e
else:
raise RuntimeError(
f"Failed to connect to your Azure OpenAI. Please make sure that "
f"you have provided the correct deployment name: {self.model_name} "
f"and API version: {self._client._api_version}. "
f"Underlying Error:\n{self._format_openai_error(e)}"
)
) from e
else:
super()._handle_chat_error(e)
1 change: 1 addition & 0 deletions src/canopy/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class FunctionArrayProperty(BaseModel):
def dict(self, *args, **kwargs):
super_dict = super().dict(*args, **kwargs)
if "items_type" in super_dict:
super_dict["type"] = "array"
super_dict["items"] = {"type": super_dict.pop("items_type")}
return super_dict

Expand Down
12 changes: 10 additions & 2 deletions src/canopy/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def enforced_function_call(self,
**model_params_dict
)
except openai.OpenAIError as e:
self._handle_chat_error(e)
self._handle_chat_error(e, is_function_call=True)

result = chat_completion.choices[0].message.tool_calls[0].function.arguments
arguments = json.loads(result)
Expand Down Expand Up @@ -262,7 +262,15 @@ def _format_openai_error(e):
except Exception:
return str(e)

def _handle_chat_error(self, e):
def _handle_chat_error(self, e, is_function_call=False):
if isinstance(e, openai.NotFoundError) and is_function_call:
if e.type and 'invalid' in e.type:
raise NotImplementedError(
f"The selected model ({self.model_name}) does not support "
f" function calling. "
f"Underlying Error:\n{self._format_openai_error(e)}"
) from e

provider_name = self.__class__.__name__.replace("LLM", "")
raise RuntimeError(
f"Failed to use {provider_name}'s {self.model_name} model for chat "
Expand Down
3 changes: 2 additions & 1 deletion tests/system/llm/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ def messages():
# Create a list of MessageBase objects
return [
UserMessage(content="Hello, assistant."),
AssistantMessage(content="Hello, user. How can I assist you?")
AssistantMessage(content="Hello, user. How can I assist you?"),
UserMessage(content="Just checking in. Be concise."),
]
164 changes: 0 additions & 164 deletions tests/system/llm/test_anyscale.py

This file was deleted.

Loading

0 comments on commit ec5d304

Please sign in to comment.