Skip to content
This repository has been archived by the owner on Aug 12, 2024. It is now read-only.

Commit

Permalink
Merge pull request #75 from premAI-io/fix/async-chat-74
Browse files Browse the repository at this point in the history
Make chat_completion async
  • Loading branch information
allemonta authored Mar 21, 2024
2 parents 9f09b93 + 14941b8 commit e9fa7ff
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 109 deletions.
30 changes: 17 additions & 13 deletions prem_utils/connectors/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
APIResponseValidationError,
APIStatusError,
APITimeoutError,
AsyncAnthropic,
AuthenticationError,
BadRequestError,
ConflictError,
Expand All @@ -27,6 +28,7 @@ class AnthropicConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = Anthropic(api_key=api_key)
self.async_client = AsyncAnthropic(api_key=api_key)
self.exception_mapping = {
PermissionDeniedError: errors.PremProviderPermissionDeniedError,
UnprocessableEntityError: errors.PremProviderUnprocessableEntityError,
Expand Down Expand Up @@ -80,7 +82,7 @@ def preprocess_messages(self, messages):
filtered_messages.append(message)
return system_prompt, filtered_messages

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -100,16 +102,21 @@ def chat_completion(
if max_tokens is None:
max_tokens = 4096

request_data = dict(
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
model=model,
top_p=top_p,
temperature=temperature,
stream=stream,
stop_sequences=stop,
)
try:
response = self.client.messages.create(
max_tokens=max_tokens,
system=system_prompt,
messages=messages,
model=model,
top_p=top_p,
temperature=temperature,
stream=stream,
)
if stream:
return await self.async_client.messages.create(**request_data)

response = self.client.messages.create(**request_data)
except (
NotFoundError,
APIResponseValidationError,
Expand All @@ -127,9 +134,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="anthropic", model=model, provider_message=str(error))

if stream:
return response

plain_response = {
"choices": [
{
Expand Down
44 changes: 28 additions & 16 deletions prem_utils/connectors/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
APIResponseValidationError,
APIStatusError,
APITimeoutError,
AsyncAzureOpenAI,
AuthenticationError,
AzureOpenAI,
BadRequestError,
Expand All @@ -24,14 +25,21 @@
from prem_utils.connectors import utils as connector_utils
from prem_utils.connectors.base import BaseConnector

API_VERSION = "2023-10-01-preview"


class AzureOpenAIConnector(BaseConnector):
def __init__(self, api_key: str, base_url: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = AzureOpenAI(
api_key=api_key,
azure_endpoint=base_url,
api_version="2023-10-01-preview",
api_version=API_VERSION,
)
self.async_client = AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=base_url,
api_version=API_VERSION,
)
self.exception_mapping = {
APIError: errors.PremProviderAPIErrror,
Expand Down Expand Up @@ -67,7 +75,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -82,19 +90,25 @@ def chat_completion(
temperature: float = 1,
top_p: float = 1,
):
request_data = dict(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
logprobs=log_probs,
logit_bias=logit_bias,
)
try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
if stream:
return await self.async_client.chat.completions.create(**request_data)

response = self.client.chat.completions.create(**request_data)
except (
NotFoundError,
APIResponseValidationError,
Expand All @@ -112,8 +126,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="azure", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"choices": [
{
Expand Down
21 changes: 17 additions & 4 deletions prem_utils/connectors/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class CohereConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
self.exception_mapping = {
CohereAPIError: errors.PremProviderAPIErrror,
CohereConnectionError: errors.PremProviderAPIConnectionError,
Expand Down Expand Up @@ -49,7 +50,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -66,19 +67,31 @@ def chat_completion(
):
chat_history, message = self.preprocess_messages(messages)
try:
if stream:
return await self.async_client.chat(
chat_history=chat_history,
max_tokens=max_tokens,
message=message,
model=model,
p=top_p,
temperature=temperature,
stream=stream,
)

response = self.client.chat(
chat_history=chat_history,
max_tokens=max_tokens,
message=message,
model="command",
model=model,
p=top_p,
temperature=temperature,
stream=stream,
)

except (CohereAPIError, CohereConnectionError) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="cohere", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"choices": [
{
Expand Down
38 changes: 22 additions & 16 deletions prem_utils/connectors/groq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from groq import Groq
from groq import AsyncGroq, Groq
from groq._exceptions import (
APIConnectionError,
APIError,
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(self, api_key: str, prompt_template: str = None):
APIResponseValidationError: errors.PremProviderAPIResponseValidationError,
}
self.client = Groq(api_key=api_key)
self.async_client = AsyncGroq(api_key=api_key)

def parse_chunk(self, chunk):
return {
Expand All @@ -62,7 +63,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -83,19 +84,26 @@ def chat_completion(
if "groq" in model:
model = model.replace("groq/", "", 1)

request_data = dict(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
logprobs=log_probs,
logit_bias=logit_bias,
top_p=top_p,
)

try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
if stream:
return await self.async_client.chat.completions.create(**request_data)

response = self.client.chat.completions.create(**request_data)
except (
APIConnectionError,
APIError,
Expand All @@ -114,8 +122,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"id": response.id,
"choices": [
Expand Down
82 changes: 42 additions & 40 deletions prem_utils/connectors/mistral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Sequence

from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai.exceptions import MistralAPIException, MistralConnectionException
from mistralai.models.chat_completion import ChatMessage
Expand All @@ -13,6 +14,7 @@ class MistralConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = MistralClient(api_key=api_key)
self.async_client = MistralAsyncClient(api_key=api_key)
self.exception_mapping = {
MistralAPIException: errors.PremProviderAPIStatusError,
MistralConnectionException: errors.PremProviderAPIConnectionError,
Expand Down Expand Up @@ -44,7 +46,7 @@ def build_messages(self, messages):
chat_messages.append(chat_message)
return chat_messages

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -60,47 +62,46 @@ def chat_completion(
top_p: float = 1,
):
messages = self.build_messages(messages)

request_data = dict(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
random_seed=seed,
)
try:
if stream:
response = self.client.chat_stream(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
return response
else:
response = self.client.chat(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
)
plain_response = {
"choices": [
{
"finish_reason": str(choice.finish_reason),
"index": choice.index,
"message": {
"content": choice.message.content,
"role": choice.message.role,
},
}
for choice in response.choices
],
"created": connector_utils.default_chatcompletion_response_created(),
"model": response.model,
"provider_name": "Mistral",
"provider_id": "mistralai",
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
}
return plain_response
# Client actually returns an AsyncIterator,
# not a coroutine, so there's no need to await it
return self.async_client.chat_stream(**request_data)

response = self.client.chat(**request_data)

plain_response = {
"choices": [
{
"finish_reason": str(choice.finish_reason),
"index": choice.index,
"message": {
"content": choice.message.content,
"role": choice.message.role,
},
}
for choice in response.choices
],
"created": connector_utils.default_chatcompletion_response_created(),
"model": response.model,
"provider_name": "Mistral",
"provider_id": "mistralai",
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
},
}
return plain_response
except (MistralAPIException, MistralConnectionException) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="mistralai", model=model, provider_message=str(error))
Expand Down Expand Up @@ -137,6 +138,7 @@ class MistralAzureConnector(MistralConnector):
def __init__(self, api_key: str, endpoint: str, prompt_template: str = None):
super().__init__(api_key=api_key, prompt_template=prompt_template)
self.client = MistralClient(endpoint=endpoint, api_key=api_key)
self.async_client = MistralAsyncClient(endpoint=endpoint, api_key=api_key)
self.exception_mapping = {
MistralAPIException: errors.PremProviderAPIStatusError,
MistralConnectionException: errors.PremProviderAPIConnectionError,
Expand Down
Loading

0 comments on commit e9fa7ff

Please sign in to comment.