diff --git a/config/config.yaml b/config/config.yaml index f290c3ad..83055e0e 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -37,6 +37,7 @@ chat_engine: max_generated_tokens: null # Leaving `null` will use the default of the underlying LLM max_context_tokens: null # Leaving `null` will use 70% of `max_prompt_tokens` system_prompt: *system_prompt # The chat engine's system prompt for calling the LLM + allow_model_params_override: false # Whether to allow overriding the LLM's parameters in an API call history_pruner: # How to prune messages if chat history is too long. Options: [RecentHistoryPruner, RaisingHistoryPruner] type: RecentHistoryPruner @@ -51,10 +52,10 @@ chat_engine: type: OpenAILLM # Options: [OpenAILLM, AzureOpenAILLM] params: model_name: gpt-3.5-turbo # The name of the model to use. - # You can add any additional parameters which are supported by the LLM's `ChatCompletion()` API. The values set - # here will be used in every LLM API call. For example: -# temperature: 0.7 -# top_p: 0.9 + # You can add any additional parameters which are supported by the LLM's `ChatCompletion()` API. The values + # set here will be used in every LLM API call, but may be overridden if `allow_model_params_override` is true. +# temperature: 0.7 +# top_p: 0.9 query_builder: # ------------------------------------------------------------------------------------------------------------- diff --git a/src/canopy/chat_engine/chat_engine.py b/src/canopy/chat_engine/chat_engine.py index edf381d3..c6090021 100644 --- a/src/canopy/chat_engine/chat_engine.py +++ b/src/canopy/chat_engine/chat_engine.py @@ -90,6 +90,7 @@ def __init__(self, max_context_tokens: Optional[int] = None, query_builder: Optional[QueryGenerator] = None, system_prompt: Optional[str] = None, + allow_model_params_override: bool = False, history_pruner: Optional[HistoryPruner] = None, ): """ @@ -103,6 +104,7 @@ def __init__(self, max_context_tokens: The maximum number of tokens to use for the context to prompt the LLM. Defaults to be 70% of the max_prompt_tokens. query_builder: An instance of a query generator to use for generating queries from the chat history. Defaults to FunctionCallingQueryGenerator. system_prompt: The system prompt to use for the LLM. Defaults to a generic prompt that is suitable for most use cases. + allow_model_params_override: Whether to allow individual `chat()` calls to override the pre-configured LLM params. Defaults to False. history_pruner: The history pruner to use for pruning the chat history before prompting the LLM. Defaults to None, which means no pruning will be done. """ # noqa: E501 if not isinstance(context_engine, ContextEngine): @@ -161,6 +163,8 @@ def __init__(self, ) self.max_context_tokens = max_context_tokens + self.allow_model_params_override = allow_model_params_override + def chat(self, messages: Messages, *, @@ -205,13 +209,19 @@ def chat(self, system_prompt=self.system_prompt, context=context ) + model_params_dict = {} + if self.allow_model_params_override and model_params: + model_params_dict = { + k: v for k, v in model_params.items() if v is not None + } + if model_params_dict.get("max_tokens", None) is None: + model_params_dict["max_tokens"] = self.max_generated_tokens llm_response = self.llm.chat_completion(system_prompt=self.system_prompt, chat_history=llm_messages, context=context, - max_tokens=self.max_generated_tokens, stream=stream, - model_params=model_params) + model_params=model_params_dict) debug_info = {} if CE_DEBUG_INFO: debug_info['context'] = context.dict() diff --git a/src/canopy/llm/base.py b/src/canopy/llm/base.py index 1a59e287..06d1efad 100644 --- a/src/canopy/llm/base.py +++ b/src/canopy/llm/base.py @@ -31,7 +31,7 @@ def enforced_function_call(self, function: Function, *, max_tokens: Optional[int] = None, - model_params: Optional[dict] = None + model_params: Optional[dict] = None, ) -> dict: pass diff --git a/src/canopy/llm/openai.py b/src/canopy/llm/openai.py index 32464af1..55079f98 100644 --- a/src/canopy/llm/openai.py +++ b/src/canopy/llm/openai.py @@ -61,6 +61,11 @@ def __init__(self, ) self.default_model_params = kwargs + if "model" in self.default_model_params: + raise ValueError( + "The 'model' parameter is not allowed in the default model params. " + "Please use the 'model_name' argument instead." + ) @property def available_models(self): @@ -87,7 +92,7 @@ def chat_completion(self, stream: Whether to stream the response or not. max_tokens: Maximum number of tokens to generate. Defaults to None (generates until stop sequence or until hitting max context size). model_params: Model parameters to use for this request. Defaults to None (uses the default model parameters). - Dictonary of parametrs to override the default model parameters if set on initialization. + Dictonary of parameters to override the default model parameters if set on initialization. For example, you can pass: {"temperature": 0.9, "top_p": 1.0} to override the default temperature and top_p. see: https://platform.openai.com/docs/api-reference/chat/create Returns: @@ -106,9 +111,11 @@ def chat_completion(self, """ # noqa: E501 model_params_dict: Dict[str, Any] = deepcopy(self.default_model_params) - model_params_dict.update( - model_params or {} - ) + model_params_dict.update(model_params or {}) + if max_tokens is not None: + model_params_dict["max_tokens"] = max_tokens + + model = model_params_dict.pop("model", self.model_name) if context is None: system_message = system_prompt @@ -117,10 +124,9 @@ def chat_completion(self, messages = [SystemMessage(content=system_message).dict() ] + [m.dict() for m in chat_history] try: - response = self._client.chat.completions.create(model=self.model_name, + response = self._client.chat.completions.create(model=model, messages=messages, stream=stream, - max_tokens=max_tokens, **model_params_dict) except openai.OpenAIError as e: self._handle_chat_error(e) @@ -193,9 +199,11 @@ def enforced_function_call(self, """ # noqa: E501 model_params_dict: Dict[str, Any] = deepcopy(self.default_model_params) - model_params_dict.update( - model_params or {} - ) + model_params_dict.update(model_params or {}) + if max_tokens is not None: + model_params_dict["max_tokens"] = max_tokens + + model = model_params_dict.pop("model", self.model_name) function_dict = cast(ChatCompletionToolParam, {"type": "function", "function": function.dict()}) @@ -204,8 +212,8 @@ def enforced_function_call(self, ] + [m.dict() for m in chat_history] try: chat_completion = self._client.chat.completions.create( + model=model, messages=messages, - model=self.model_name, tools=[function_dict], tool_choice={"type": "function", "function": {"name": function.name}}, diff --git a/src/canopy_server/app.py b/src/canopy_server/app.py index e5ce8b59..f4572857 100644 --- a/src/canopy_server/app.py +++ b/src/canopy_server/app.py @@ -118,16 +118,19 @@ async def chat( OpenAI compatible chat response """ # noqa: E501 + try: logger.debug(f"The namespace is {namespace}") session_id = request.user or "None" # noqa: F841 question_id = str(uuid.uuid4()) logger.debug(f"Received chat request: {request.messages[-1].content}") + model_params = request.dict(exclude={"messages", "stream"}) answer = await run_in_threadpool( chat_engine.chat, messages=request.messages, stream=request.stream, - namespace=namespace + namespace=namespace, + model_params=model_params, ) if request.stream: diff --git a/src/canopy_server/models/v1/api_models.py b/src/canopy_server/models/v1/api_models.py index 224fbf8c..197babf4 100644 --- a/src/canopy_server/models/v1/api_models.py +++ b/src/canopy_server/models/v1/api_models.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Dict, List, Optional, Union from pydantic import BaseModel, Field @@ -8,10 +8,6 @@ class ChatRequest(BaseModel): - model: str = Field( - default="", - description="The ID of the model to use. This field is ignored; instead, configure this field in the Canopy config.", # noqa: E501 - ) messages: Messages = Field( description="A list of messages comprising the conversation so far." ) @@ -19,6 +15,57 @@ class ChatRequest(BaseModel): default=False, description="""Whether or not to stream the chatbot's response. If set, the response is server-sent events containing [chat.completion.chunk](https://platform.openai.com/docs/api-reference/chat/streaming) objects""", # noqa: E501 ) + + # ------------------------------------------------------------------------------- + # Optional params. The params below are ignored by default, and should be usually + # configured in the Canopy config file instead. + # You can allow passing these params to the API by setting the + # `allow_model_params_override` flag in the ChatEngine's config. + # ------------------------------------------------------------------------------- + model: str = Field( + default="", + description="The name of the model to use." # noqa: E501 + ) + frequency_penalty: Optional[float] = Field( + default=None, + description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", # noqa: E501 + ) + logit_bias: Optional[Dict[str, float]] = Field( + default=None, + description="A map of tokens to an associated bias value between -100 and 100.", + ) + max_tokens: Optional[int] = Field( + default=None, + description="The maximum number of tokens to generate in the chat completion.", + ) + n: Optional[int] = Field( + default=None, + description="How many chat completion choices to generate for each input message.", # noqa: E501 + ) + presence_penalty: Optional[float] = Field( + default=None, + description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", # noqa: E501 + ) + response_format: Optional[Dict[str, str]] = Field( + default=None, + description="The format of the returned response.", + ) + seed: Optional[int] = Field( + default=None, + description="When provided, OpenAI will make a best effort to sample results deterministically.", # noqa: E501 + ) + stop: Optional[Union[List[str], str]] = Field( + default=None, + description="One or more sequences where the API will stop generating further tokens.", # noqa: E501 + ) + temperature: Optional[float] = Field( + default=None, + description="What sampling temperature to use.", + ) + top_p: Optional[float] = Field( + default=None, + description="What nucleus sampling probability to use.", + ) user: Optional[str] = Field( default=None, description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Unused, reserved for future extensions", # noqa: E501 diff --git a/tests/system/llm/test_openai.py b/tests/system/llm/test_openai.py index 7e28c40a..7e34984a 100644 --- a/tests/system/llm/test_openai.py +++ b/tests/system/llm/test_openai.py @@ -174,6 +174,23 @@ def test_enforced_function_call_low_temperature(openai_llm, assert_function_call_format(result) +def test_chat_completion_with_model_name(openai_llm, messages): + if isinstance(openai_llm, AzureOpenAILLM): + pytest.skip("In Azure the model name has to be a valid deployment") + + new_model_name = "gpt-3.5-turbo-1106" + assert new_model_name != openai_llm.model_name, ( + "The new model name should be different from the default one. Please change it." + ) + response = openai_llm.chat_completion( + system_prompt=SYSTEM_PROMPT, + chat_history=messages, + model_params={"model": new_model_name} + ) + + assert response.model == new_model_name + + def test_chat_streaming(openai_llm, messages): stream = True response = openai_llm.chat_completion(system_prompt=SYSTEM_PROMPT, diff --git a/tests/unit/chat_engine/test_chat_engine.py b/tests/unit/chat_engine/test_chat_engine.py index f5aefa86..d4091bd5 100644 --- a/tests/unit/chat_engine/test_chat_engine.py +++ b/tests/unit/chat_engine/test_chat_engine.py @@ -130,9 +130,8 @@ def test_chat(self, namespace, history_length=5, snippet_length=10): system_prompt=expected['prompt'], context=expected['context'], chat_history=messages, - max_tokens=200, stream=False, - model_params=None + model_params={'max_tokens': 200} ) assert response == expected['response'] @@ -140,8 +139,19 @@ def test_chat(self, namespace, history_length=5, snippet_length=10): @pytest.mark.parametrize("namespace", [ None, TEST_NAMESPACE ]) + @pytest.mark.parametrize("allow_model_params_override,params_override", + [("False", None), + ("False", {'temperature': 0.99, 'top_p': 0.5}), + ("True", {'temperature': 0.99, 'top_p': 0.5}), + ("True", {'temperature': 0.99, 'max_tokens': 200}),], + ids=["no_override", + "override_not_allowed", + "valid_override", + "valid_override_with_max_tokens"]) def test_chat_engine_params(self, namespace, + allow_model_params_override, + params_override, system_prompt_length=10, max_prompt_tokens=80, max_context_tokens=60, @@ -152,10 +162,13 @@ def test_chat_engine_params(self, ): system_prompt = self._generate_text(system_prompt_length) - chat_engine = self._init_chat_engine(system_prompt=system_prompt, - max_prompt_tokens=max_prompt_tokens, - max_context_tokens=max_context_tokens, - max_generated_tokens=max_generated_tokens) + chat_engine = self._init_chat_engine( + system_prompt=system_prompt, + max_prompt_tokens=max_prompt_tokens, + max_context_tokens=max_context_tokens, + max_generated_tokens=max_generated_tokens, + allow_model_params_override=allow_model_params_override + ) # Mock input and expected output messages, expected = self._get_inputs_and_expected(history_length, @@ -168,7 +181,13 @@ def test_chat_engine_params(self, chat_engine.chat(messages) return - response = chat_engine.chat(messages, namespace=namespace) + response = chat_engine.chat(messages, + namespace=namespace, + model_params=params_override) + + expected_model_params = {'max_tokens': max_generated_tokens} + if allow_model_params_override and params_override is not None: + expected_model_params.update(params_override) # Assertions self.mock_query_builder.generate.assert_called_once_with( @@ -184,9 +203,8 @@ def test_chat_engine_params(self, system_prompt=expected['prompt'], context=expected['context'], chat_history=messages, - max_tokens=max_generated_tokens, stream=False, - model_params=None + model_params=expected_model_params ) assert response == expected['response']