Skip to content
This repository has been archived by the owner on Aug 12, 2024. It is now read-only.

Commit

Permalink
Async chat_completion
Browse files Browse the repository at this point in the history
  • Loading branch information
albemordo committed Mar 19, 2024
1 parent 598b8a0 commit 6c7bca6
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 80 deletions.
30 changes: 17 additions & 13 deletions prem_utils/connectors/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
APIResponseValidationError,
APIStatusError,
APITimeoutError,
AsyncAnthropic,
AuthenticationError,
BadRequestError,
ConflictError,
Expand All @@ -27,6 +28,7 @@ class AnthropicConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = Anthropic(api_key=api_key)
self.async_client = AsyncAnthropic(api_key=api_key)
self.exception_mapping = {
PermissionDeniedError: errors.PremProviderPermissionDeniedError,
UnprocessableEntityError: errors.PremProviderUnprocessableEntityError,
Expand Down Expand Up @@ -80,7 +82,7 @@ def preprocess_messages(self, messages):
filtered_messages.append(message)
return system_prompt, filtered_messages

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -100,16 +102,21 @@ def chat_completion(
if max_tokens is None:
max_tokens = 4096

request_data = dict(
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
model=model,
top_p=top_p,
temperature=temperature,
stream=stream,
stop_sequences=stop,
)
try:
response = self.client.messages.create(
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
model=model,
top_p=top_p,
temperature=temperature,
stream=stream,
)
if stream:
return await self.async_client.messages.create(**request_data)

response = self.client.messages.create(**request_data)
except (
NotFoundError,
APIResponseValidationError,
Expand All @@ -127,9 +134,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="anthropic", model=model, provider_message=str(error))

if stream:
return response

plain_response = {
"choices": [
{
Expand Down
21 changes: 17 additions & 4 deletions prem_utils/connectors/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class CohereConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
self.exception_mapping = {
CohereAPIError: errors.PremProviderAPIErrror,
CohereConnectionError: errors.PremProviderAPIConnectionError,
Expand Down Expand Up @@ -49,7 +50,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -66,19 +67,31 @@ def chat_completion(
):
chat_history, message = self.preprocess_messages(messages)
try:
if stream:
return await self.async_client.chat(
chat_history=chat_history,
max_tokens=max_tokens,
message=message,
model=model,
p=top_p,
temperature=temperature,
stream=stream,
)

response = self.client.chat(
chat_history=chat_history,
max_tokens=max_tokens,
message=message,
model="command",
model=model,
p=top_p,
temperature=temperature,
stream=stream,
)

except (CohereAPIError, CohereConnectionError) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="cohere", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"choices": [
{
Expand Down
38 changes: 22 additions & 16 deletions prem_utils/connectors/groq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from groq import Groq
from groq import AsyncGroq, Groq
from groq._exceptions import (
APIConnectionError,
APIError,
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(self, api_key: str, prompt_template: str = None):
APIResponseValidationError: errors.PremProviderAPIResponseValidationError,
}
self.client = Groq(api_key=api_key)
self.async_client = AsyncGroq(api_key=api_key)

def parse_chunk(self, chunk):
return {
Expand All @@ -62,7 +63,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -83,19 +84,26 @@ def chat_completion(
if "groq" in model:
model = model.replace("groq/", "", 1)

request_data = dict(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
logprobs=log_probs,
logit_bias=logit_bias,
top_p=top_p,
)

try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
if stream:
return await self.async_client.chat.completions.create(**request_data)

response = self.client.chat.completions.create(**request_data)
except (
APIConnectionError,
APIError,
Expand All @@ -114,8 +122,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"id": response.id,
"choices": [
Expand Down
81 changes: 41 additions & 40 deletions prem_utils/connectors/mistral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence

from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.exceptions import MistralAPIException, MistralConnectionException
from mistralai.models.chat_completion import ChatMessage
Expand All @@ -13,6 +14,7 @@ class MistralConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = MistralClient(api_key=api_key)
self.async_client = MistralAsyncClient(api_key=api_key)
self.exception_mapping = {
MistralAPIException: errors.PremProviderAPIStatusError,
MistralConnectionException: errors.PremProviderAPIConnectionError,
Expand Down Expand Up @@ -44,7 +46,7 @@ def build_messages(self, messages):
chat_messages.append(chat_message)
return chat_messages

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -60,47 +62,46 @@ def chat_completion(
top_p: float = 1,
):
messages = self.build_messages(messages)

request_data = dict(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
random_seed=seed,
)
try:
if stream:
response = self.client.chat_stream(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
return response
else:
response = self.client.chat(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
plain_response = {
"choices": [
{
"finish_reason": str(choice.finish_reason),
"index": choice.index,
"message": {
"content": choice.message.content,
"role": choice.message.role,
},
}
for choice in response.choices
],
"created": connector_utils.default_chatcompletion_response_created(),
"model": response.model,
"provider_name": "Mistral",
"provider_id": "mistralai",
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
}
return plain_response
# Don't know why, but it actually returns an AsyncIterator,
# not a coroutine, so there's no need to await it
return self.async_client.chat_stream(**request_data)

response = self.client.chat(**request_data)

plain_response = {
"choices": [
{
"finish_reason": str(choice.finish_reason),
"index": choice.index,
"message": {
"content": choice.message.content,
"role": choice.message.role,
},
}
for choice in response.choices
],
"created": connector_utils.default_chatcompletion_response_created(),
"model": response.model,
"provider_name": "Mistral",
"provider_id": "mistralai",
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
}
return plain_response
except (MistralAPIException, MistralConnectionException) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="mistralai", model=model, provider_message=str(error))
Expand Down
22 changes: 19 additions & 3 deletions prem_utils/connectors/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
APIResponseValidationError,
APIStatusError,
APITimeoutError,
AsyncOpenAI,
AuthenticationError,
BadRequestError,
ConflictError,
Expand Down Expand Up @@ -43,8 +44,10 @@ def __init__(self, api_key: str = None, base_url: str = None, prompt_template: s
}
if base_url is not None:
self.client = OpenAI(api_key=api_key, base_url=base_url)
self.async_client = AsyncOpenAI(api_key=api_key, base_url=base_url)
else:
self.client = OpenAI(api_key=api_key)
self.async_client = AsyncOpenAI(api_key=api_key)

def parse_chunk(self, chunk):
return {
Expand All @@ -64,7 +67,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -83,6 +86,20 @@ def chat_completion(
messages = self.apply_prompt_template(messages)

try:
if stream:
return await self.async_client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)

response = self.client.chat.completions.create(
model=model,
messages=messages,
Expand All @@ -95,6 +112,7 @@ def chat_completion(
temperature=temperature,
top_p=top_p,
)

except (
NotFoundError,
APIResponseValidationError,
Expand All @@ -112,8 +130,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"id": response.id,
"choices": [
Expand Down
6 changes: 4 additions & 2 deletions prem_utils/connectors/openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(
) -> None:
super().__init__(prompt_template=prompt_template, base_url=base_url, api_key=api_key)

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -28,7 +28,7 @@ def chat_completion(
if "openrouter" in model:
model = model.replace("openrouter/", "", 1)

return super().chat_completion(
return await super().chat_completion(
model=model,
messages=messages,
stream=stream,
Expand All @@ -39,6 +39,8 @@ def chat_completion(
stop=stop,
temperature=temperature,
top_p=top_p,
log_probs=log_probs,
logit_bias=logit_bias,
)

def embeddings(
Expand Down
Loading

0 comments on commit 6c7bca6

Please sign in to comment.