Skip to content
This repository has been archived by the owner on Aug 12, 2024. It is now read-only.

Make chat_completion async #75

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions prem_utils/connectors/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
APIResponseValidationError,
APIStatusError,
APITimeoutError,
AsyncAnthropic,
AuthenticationError,
BadRequestError,
ConflictError,
Expand All @@ -27,6 +28,7 @@ class AnthropicConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = Anthropic(api_key=api_key)
self.async_client = AsyncAnthropic(api_key=api_key)
self.exception_mapping = {
PermissionDeniedError: errors.PremProviderPermissionDeniedError,
UnprocessableEntityError: errors.PremProviderUnprocessableEntityError,
Expand Down Expand Up @@ -80,7 +82,7 @@ def preprocess_messages(self, messages):
filtered_messages.append(message)
return system_prompt, filtered_messages

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -100,16 +102,21 @@ def chat_completion(
if max_tokens is None:
max_tokens = 4096

request_data = dict(
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
model=model,
top_p=top_p,
temperature=temperature,
stream=stream,
stop_sequences=stop,
)
try:
response = self.client.messages.create(
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
model=model,
top_p=top_p,
temperature=temperature,
stream=stream,
)
if stream:
return await self.async_client.messages.create(**request_data)

response = self.client.messages.create(**request_data)
except (
NotFoundError,
APIResponseValidationError,
Expand All @@ -127,9 +134,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="anthropic", model=model, provider_message=str(error))

if stream:
return response

plain_response = {
"choices": [
{
Expand Down
44 changes: 28 additions & 16 deletions prem_utils/connectors/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
APIResponseValidationError,
APIStatusError,
APITimeoutError,
AsyncAzureOpenAI,
AuthenticationError,
AzureOpenAI,
BadRequestError,
Expand All @@ -24,14 +25,21 @@
from prem_utils.connectors import utils as connector_utils
from prem_utils.connectors.base import BaseConnector

API_VERSION = "2023-10-01-preview"


class AzureOpenAIConnector(BaseConnector):
def __init__(self, api_key: str, base_url: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = AzureOpenAI(
api_key=api_key,
azure_endpoint=base_url,
api_version="2023-10-01-preview",
api_version=API_VERSION,
)
self.async_client = AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=base_url,
api_version=API_VERSION,
)
self.exception_mapping = {
APIError: errors.PremProviderAPIErrror,
Expand Down Expand Up @@ -67,7 +75,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -82,19 +90,25 @@ def chat_completion(
temperature: float = 1,
top_p: float = 1,
):
request_data = dict(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
logprobs=log_probs,
logit_bias=logit_bias,
)
try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
if stream:
return await self.async_client.chat.completions.create(**request_data)

response = self.client.chat.completions.create(**request_data)
except (
NotFoundError,
APIResponseValidationError,
Expand All @@ -112,8 +126,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="azure", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"choices": [
{
Expand Down
21 changes: 17 additions & 4 deletions prem_utils/connectors/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class CohereConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
self.exception_mapping = {
CohereAPIError: errors.PremProviderAPIErrror,
CohereConnectionError: errors.PremProviderAPIConnectionError,
Expand Down Expand Up @@ -49,7 +50,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -66,19 +67,31 @@ def chat_completion(
):
chat_history, message = self.preprocess_messages(messages)
try:
if stream:
return await self.async_client.chat(
chat_history=chat_history,
max_tokens=max_tokens,
message=message,
model=model,
p=top_p,
temperature=temperature,
stream=stream,
)

response = self.client.chat(
chat_history=chat_history,
max_tokens=max_tokens,
message=message,
model="command",
model=model,
p=top_p,
temperature=temperature,
stream=stream,
)

except (CohereAPIError, CohereConnectionError) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="cohere", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"choices": [
{
Expand Down
38 changes: 22 additions & 16 deletions prem_utils/connectors/groq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from groq import Groq
from groq import AsyncGroq, Groq
from groq._exceptions import (
APIConnectionError,
APIError,
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(self, api_key: str, prompt_template: str = None):
APIResponseValidationError: errors.PremProviderAPIResponseValidationError,
}
self.client = Groq(api_key=api_key)
self.async_client = AsyncGroq(api_key=api_key)

def parse_chunk(self, chunk):
return {
Expand All @@ -62,7 +63,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -83,19 +84,26 @@ def chat_completion(
if "groq" in model:
model = model.replace("groq/", "", 1)

request_data = dict(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
logprobs=log_probs,
logit_bias=logit_bias,
top_p=top_p,
)

try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
if stream:
return await self.async_client.chat.completions.create(**request_data)

response = self.client.chat.completions.create(**request_data)
except (
APIConnectionError,
APIError,
Expand All @@ -114,8 +122,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"id": response.id,
"choices": [
Expand Down
82 changes: 42 additions & 40 deletions prem_utils/connectors/mistral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence

from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.exceptions import MistralAPIException, MistralConnectionException
from mistralai.models.chat_completion import ChatMessage
Expand All @@ -13,6 +14,7 @@ class MistralConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = MistralClient(api_key=api_key)
self.async_client = MistralAsyncClient(api_key=api_key)
self.exception_mapping = {
MistralAPIException: errors.PremProviderAPIStatusError,
MistralConnectionException: errors.PremProviderAPIConnectionError,
Expand Down Expand Up @@ -44,7 +46,7 @@ def build_messages(self, messages):
chat_messages.append(chat_message)
return chat_messages

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -60,47 +62,46 @@ def chat_completion(
top_p: float = 1,
):
messages = self.build_messages(messages)

request_data = dict(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
random_seed=seed,
)
try:
if stream:
response = self.client.chat_stream(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
return response
else:
response = self.client.chat(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
plain_response = {
"choices": [
{
"finish_reason": str(choice.finish_reason),
"index": choice.index,
"message": {
"content": choice.message.content,
"role": choice.message.role,
},
}
for choice in response.choices
],
"created": connector_utils.default_chatcompletion_response_created(),
"model": response.model,
"provider_name": "Mistral",
"provider_id": "mistralai",
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
}
return plain_response
# Client actually returns an AsyncIterator,
# not a coroutine, so there's no need to await it
return self.async_client.chat_stream(**request_data)

response = self.client.chat(**request_data)

plain_response = {
"choices": [
{
"finish_reason": str(choice.finish_reason),
"index": choice.index,
"message": {
"content": choice.message.content,
"role": choice.message.role,
},
}
for choice in response.choices
],
"created": connector_utils.default_chatcompletion_response_created(),
"model": response.model,
"provider_name": "Mistral",
"provider_id": "mistralai",
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
}
return plain_response
except (MistralAPIException, MistralConnectionException) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="mistralai", model=model, provider_message=str(error))
Expand Down Expand Up @@ -137,6 +138,7 @@ class MistralAzureConnector(MistralConnector):
def __init__(self, api_key: str, endpoint: str, prompt_template: str = None):
super().__init__(api_key=api_key, prompt_template=prompt_template)
self.client = MistralClient(endpoint=endpoint, api_key=api_key)
self.async_client = MistralAsyncClient(endpoint=endpoint, api_key=api_key)
self.exception_mapping = {
MistralAPIException: errors.PremProviderAPIStatusError,
MistralConnectionException: errors.PremProviderAPIConnectionError,
Expand Down
Loading
Loading