Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Pass additional model params to the LLM (#191)
Browse files Browse the repository at this point in the history
* Initial work on passing additional params to the LLM

* Improve merging of request params with config params. Add request_params_passthrough config option

* Add remaining OpenAI LLM params to ChatRequest

* Add ability to override LLM API keys for single requests

* Add tests for model param and api key overriding

* Fix reference to old config value name

* Move changes related to API key overriding to another branch

* Default  to false

* Fix tests

* Fix lint and style issues

* [chat] Bug fix after wrong conflict resolution

* [llm] Explicit model= param in chat_completion()

I believe this makes the code more readable and explicit

* [test] test_chat_engine - parameterize instead of code duplication

This saves redundant code duplication, while still checking all possible cases

* [test] test_openai_llm - added test case for passing model as param

Now that we support forwarding params from the API, this is an important missing test

* [models] Restructured the ChatRequest model a bit

Mostly for readability

---------

Co-authored-by: ilai <[email protected]>
  • Loading branch information
malexw and igiloh-pinecone authored Jan 17, 2024
1 parent 10cae12 commit b42f06c
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 32 deletions.
9 changes: 5 additions & 4 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ chat_engine:
max_generated_tokens: null # Leaving `null` will use the default of the underlying LLM
max_context_tokens: null # Leaving `null` will use 70% of `max_prompt_tokens`
system_prompt: *system_prompt # The chat engine's system prompt for calling the LLM
allow_model_params_override: false # Whether to allow overriding the LLM's parameters in an API call

history_pruner: # How to prune messages if chat history is too long. Options: [RecentHistoryPruner, RaisingHistoryPruner]
type: RecentHistoryPruner
Expand All @@ -51,10 +52,10 @@ chat_engine:
type: OpenAILLM # Options: [OpenAILLM, AzureOpenAILLM]
params:
model_name: gpt-3.5-turbo # The name of the model to use.
# You can add any additional parameters which are supported by the LLM's `ChatCompletion()` API. The values set
# here will be used in every LLM API call. For example:
# temperature: 0.7
# top_p: 0.9
# You can add any additional parameters which are supported by the LLM's `ChatCompletion()` API. The values
# set here will be used in every LLM API call, but may be overridden if `allow_model_params_override` is true.
# temperature: 0.7
# top_p: 0.9

query_builder:
# -------------------------------------------------------------------------------------------------------------
Expand Down
14 changes: 12 additions & 2 deletions src/canopy/chat_engine/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(self,
max_context_tokens: Optional[int] = None,
query_builder: Optional[QueryGenerator] = None,
system_prompt: Optional[str] = None,
allow_model_params_override: bool = False,
history_pruner: Optional[HistoryPruner] = None,
):
"""
Expand All @@ -103,6 +104,7 @@ def __init__(self,
max_context_tokens: The maximum number of tokens to use for the context to prompt the LLM. Defaults to be 70% of the max_prompt_tokens.
query_builder: An instance of a query generator to use for generating queries from the chat history. Defaults to FunctionCallingQueryGenerator.
system_prompt: The system prompt to use for the LLM. Defaults to a generic prompt that is suitable for most use cases.
allow_model_params_override: Whether to allow individual `chat()` calls to override the pre-configured LLM params. Defaults to False.
history_pruner: The history pruner to use for pruning the chat history before prompting the LLM. Defaults to None, which means no pruning will be done.
""" # noqa: E501
if not isinstance(context_engine, ContextEngine):
Expand Down Expand Up @@ -161,6 +163,8 @@ def __init__(self,
)
self.max_context_tokens = max_context_tokens

self.allow_model_params_override = allow_model_params_override

def chat(self,
messages: Messages,
*,
Expand Down Expand Up @@ -205,13 +209,19 @@ def chat(self,
system_prompt=self.system_prompt,
context=context
)
model_params_dict = {}
if self.allow_model_params_override and model_params:
model_params_dict = {
k: v for k, v in model_params.items() if v is not None
}
if model_params_dict.get("max_tokens", None) is None:
model_params_dict["max_tokens"] = self.max_generated_tokens

llm_response = self.llm.chat_completion(system_prompt=self.system_prompt,
chat_history=llm_messages,
context=context,
max_tokens=self.max_generated_tokens,
stream=stream,
model_params=model_params)
model_params=model_params_dict)
debug_info = {}
if CE_DEBUG_INFO:
debug_info['context'] = context.dict()
Expand Down
2 changes: 1 addition & 1 deletion src/canopy/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def enforced_function_call(self,
function: Function,
*,
max_tokens: Optional[int] = None,
model_params: Optional[dict] = None
model_params: Optional[dict] = None,
) -> dict:
pass

Expand Down
28 changes: 18 additions & 10 deletions src/canopy/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def __init__(self,
)

self.default_model_params = kwargs
if "model" in self.default_model_params:
raise ValueError(
"The 'model' parameter is not allowed in the default model params. "
"Please use the 'model_name' argument instead."
)

@property
def available_models(self):
Expand All @@ -87,7 +92,7 @@ def chat_completion(self,
stream: Whether to stream the response or not.
max_tokens: Maximum number of tokens to generate. Defaults to None (generates until stop sequence or until hitting max context size).
model_params: Model parameters to use for this request. Defaults to None (uses the default model parameters).
Dictonary of parametrs to override the default model parameters if set on initialization.
Dictonary of parameters to override the default model parameters if set on initialization.
For example, you can pass: {"temperature": 0.9, "top_p": 1.0} to override the default temperature and top_p.
see: https://platform.openai.com/docs/api-reference/chat/create
Returns:
Expand All @@ -106,9 +111,11 @@ def chat_completion(self,
""" # noqa: E501

model_params_dict: Dict[str, Any] = deepcopy(self.default_model_params)
model_params_dict.update(
model_params or {}
)
model_params_dict.update(model_params or {})
if max_tokens is not None:
model_params_dict["max_tokens"] = max_tokens

model = model_params_dict.pop("model", self.model_name)

if context is None:
system_message = system_prompt
Expand All @@ -117,10 +124,9 @@ def chat_completion(self,
messages = [SystemMessage(content=system_message).dict()
] + [m.dict() for m in chat_history]
try:
response = self._client.chat.completions.create(model=self.model_name,
response = self._client.chat.completions.create(model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
**model_params_dict)
except openai.OpenAIError as e:
self._handle_chat_error(e)
Expand Down Expand Up @@ -193,9 +199,11 @@ def enforced_function_call(self,
""" # noqa: E501

model_params_dict: Dict[str, Any] = deepcopy(self.default_model_params)
model_params_dict.update(
model_params or {}
)
model_params_dict.update(model_params or {})
if max_tokens is not None:
model_params_dict["max_tokens"] = max_tokens

model = model_params_dict.pop("model", self.model_name)

function_dict = cast(ChatCompletionToolParam,
{"type": "function", "function": function.dict()})
Expand All @@ -204,8 +212,8 @@ def enforced_function_call(self,
] + [m.dict() for m in chat_history]
try:
chat_completion = self._client.chat.completions.create(
model=model,
messages=messages,
model=self.model_name,
tools=[function_dict],
tool_choice={"type": "function",
"function": {"name": function.name}},
Expand Down
5 changes: 4 additions & 1 deletion src/canopy_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,19 @@ async def chat(
OpenAI compatible chat response
""" # noqa: E501

try:
logger.debug(f"The namespace is {namespace}")
session_id = request.user or "None" # noqa: F841
question_id = str(uuid.uuid4())
logger.debug(f"Received chat request: {request.messages[-1].content}")
model_params = request.dict(exclude={"messages", "stream"})
answer = await run_in_threadpool(
chat_engine.chat,
messages=request.messages,
stream=request.stream,
namespace=namespace
namespace=namespace,
model_params=model_params,
)

if request.stream:
Expand Down
57 changes: 52 additions & 5 deletions src/canopy_server/models/v1/api_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field

Expand All @@ -8,17 +8,64 @@


class ChatRequest(BaseModel):
model: str = Field(
default="",
description="The ID of the model to use. This field is ignored; instead, configure this field in the Canopy config.", # noqa: E501
)
messages: Messages = Field(
description="A list of messages comprising the conversation so far."
)
stream: bool = Field(
default=False,
description="""Whether or not to stream the chatbot's response. If set, the response is server-sent events containing [chat.completion.chunk](https://platform.openai.com/docs/api-reference/chat/streaming) objects""", # noqa: E501
)

# -------------------------------------------------------------------------------
# Optional params. The params below are ignored by default, and should be usually
# configured in the Canopy config file instead.
# You can allow passing these params to the API by setting the
# `allow_model_params_override` flag in the ChatEngine's config.
# -------------------------------------------------------------------------------
model: str = Field(
default="",
description="The name of the model to use." # noqa: E501
)
frequency_penalty: Optional[float] = Field(
default=None,
description="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.", # noqa: E501
)
logit_bias: Optional[Dict[str, float]] = Field(
default=None,
description="A map of tokens to an associated bias value between -100 and 100.",
)
max_tokens: Optional[int] = Field(
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
n: Optional[int] = Field(
default=None,
description="How many chat completion choices to generate for each input message.", # noqa: E501
)
presence_penalty: Optional[float] = Field(
default=None,
description="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.", # noqa: E501
)
response_format: Optional[Dict[str, str]] = Field(
default=None,
description="The format of the returned response.",
)
seed: Optional[int] = Field(
default=None,
description="When provided, OpenAI will make a best effort to sample results deterministically.", # noqa: E501
)
stop: Optional[Union[List[str], str]] = Field(
default=None,
description="One or more sequences where the API will stop generating further tokens.", # noqa: E501
)
temperature: Optional[float] = Field(
default=None,
description="What sampling temperature to use.",
)
top_p: Optional[float] = Field(
default=None,
description="What nucleus sampling probability to use.",
)
user: Optional[str] = Field(
default=None,
description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. Unused, reserved for future extensions", # noqa: E501
Expand Down
17 changes: 17 additions & 0 deletions tests/system/llm/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,23 @@ def test_enforced_function_call_low_temperature(openai_llm,
assert_function_call_format(result)


def test_chat_completion_with_model_name(openai_llm, messages):
if isinstance(openai_llm, AzureOpenAILLM):
pytest.skip("In Azure the model name has to be a valid deployment")

new_model_name = "gpt-3.5-turbo-1106"
assert new_model_name != openai_llm.model_name, (
"The new model name should be different from the default one. Please change it."
)
response = openai_llm.chat_completion(
system_prompt=SYSTEM_PROMPT,
chat_history=messages,
model_params={"model": new_model_name}
)

assert response.model == new_model_name


def test_chat_streaming(openai_llm, messages):
stream = True
response = openai_llm.chat_completion(system_prompt=SYSTEM_PROMPT,
Expand Down
36 changes: 27 additions & 9 deletions tests/unit/chat_engine/test_chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,28 @@ def test_chat(self, namespace, history_length=5, snippet_length=10):
system_prompt=expected['prompt'],
context=expected['context'],
chat_history=messages,
max_tokens=200,
stream=False,
model_params=None
model_params={'max_tokens': 200}
)

assert response == expected['response']

@pytest.mark.parametrize("namespace", [
None, TEST_NAMESPACE
])
@pytest.mark.parametrize("allow_model_params_override,params_override",
[("False", None),
("False", {'temperature': 0.99, 'top_p': 0.5}),
("True", {'temperature': 0.99, 'top_p': 0.5}),
("True", {'temperature': 0.99, 'max_tokens': 200}),],
ids=["no_override",
"override_not_allowed",
"valid_override",
"valid_override_with_max_tokens"])
def test_chat_engine_params(self,
namespace,
allow_model_params_override,
params_override,
system_prompt_length=10,
max_prompt_tokens=80,
max_context_tokens=60,
Expand All @@ -152,10 +162,13 @@ def test_chat_engine_params(self,
):

system_prompt = self._generate_text(system_prompt_length)
chat_engine = self._init_chat_engine(system_prompt=system_prompt,
max_prompt_tokens=max_prompt_tokens,
max_context_tokens=max_context_tokens,
max_generated_tokens=max_generated_tokens)
chat_engine = self._init_chat_engine(
system_prompt=system_prompt,
max_prompt_tokens=max_prompt_tokens,
max_context_tokens=max_context_tokens,
max_generated_tokens=max_generated_tokens,
allow_model_params_override=allow_model_params_override
)

# Mock input and expected output
messages, expected = self._get_inputs_and_expected(history_length,
Expand All @@ -168,7 +181,13 @@ def test_chat_engine_params(self,
chat_engine.chat(messages)
return

response = chat_engine.chat(messages, namespace=namespace)
response = chat_engine.chat(messages,
namespace=namespace,
model_params=params_override)

expected_model_params = {'max_tokens': max_generated_tokens}
if allow_model_params_override and params_override is not None:
expected_model_params.update(params_override)

# Assertions
self.mock_query_builder.generate.assert_called_once_with(
Expand All @@ -184,9 +203,8 @@ def test_chat_engine_params(self,
system_prompt=expected['prompt'],
context=expected['context'],
chat_history=messages,
max_tokens=max_generated_tokens,
stream=False,
model_params=None
model_params=expected_model_params
)

assert response == expected['response']
Expand Down

0 comments on commit b42f06c

Please sign in to comment.