Skip to content
This repository has been archived by the owner on Aug 12, 2024. It is now read-only.

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
albemordo committed Mar 19, 2024
1 parent a421060 commit ed8fb9a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 39 deletions.
21 changes: 17 additions & 4 deletions prem_utils/connectors/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class CohereConnector(BaseConnector):
def __init__(self, api_key: str, prompt_template: str = None):
super().__init__(prompt_template=prompt_template)
self.client = cohere.Client(api_key)
self.async_client = cohere.AsyncClient(api_key)
self.exception_mapping = {
CohereAPIError: errors.PremProviderAPIErrror,
CohereConnectionError: errors.PremProviderAPIConnectionError,
Expand Down Expand Up @@ -49,7 +50,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -66,19 +67,31 @@ def chat_completion(
):
chat_history, message = self.preprocess_messages(messages)
try:
if stream:
return await self.async_client.chat(
chat_history=chat_history,
max_tokens=max_tokens,
message=message,
model=model,
p=top_p,
temperature=temperature,
stream=stream,
)

response = self.client.chat(
chat_history=chat_history,
max_tokens=max_tokens,
message=message,
model="command",
model=model,
p=top_p,
temperature=temperature,
stream=stream,
)

except (CohereAPIError, CohereConnectionError) as error:
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="cohere", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"choices": [
{
Expand Down
38 changes: 22 additions & 16 deletions prem_utils/connectors/groq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from groq import Groq
from groq import AsyncGroq, Groq
from groq._exceptions import (
APIConnectionError,
APIError,
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(self, api_key: str, prompt_template: str = None):
APIResponseValidationError: errors.PremProviderAPIResponseValidationError,
}
self.client = Groq(api_key=api_key)
self.async_client = AsyncGroq(api_key=api_key)

def parse_chunk(self, chunk):
return {
Expand All @@ -62,7 +63,7 @@ def parse_chunk(self, chunk):
],
}

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -83,19 +84,26 @@ def chat_completion(
if "groq" in model:
model = model.replace("groq/", "", 1)

request_data = dict(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
logprobs=log_probs,
logit_bias=logit_bias,
top_p=top_p,
)

try:
response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
if stream:
return await self.async_client.chat.completions.create(**request_data)

response = self.client.chat.completions.create(**request_data)
except (
APIConnectionError,
APIError,
Expand All @@ -114,8 +122,6 @@ def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"id": response.id,
"choices": [
Expand Down
32 changes: 15 additions & 17 deletions prem_utils/connectors/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,8 @@ async def chat_completion(
messages = self.apply_prompt_template(messages)

try:
response = (
self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
if not stream
else await self.async_client.chat.completions.create(
if stream:
return await self.async_client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
Expand All @@ -112,7 +99,20 @@ async def chat_completion(
temperature=temperature,
top_p=top_p,
)

response = self.client.chat.completions.create(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)

except (
NotFoundError,
APIResponseValidationError,
Expand All @@ -130,8 +130,6 @@ async def chat_completion(
custom_exception = self.exception_mapping.get(type(error), errors.PremProviderError)
raise custom_exception(error, provider="openai", model=model, provider_message=str(error))

if stream:
return response
plain_response = {
"id": response.id,
"choices": [
Expand Down
4 changes: 2 additions & 2 deletions prem_utils/connectors/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class PerplexityAIConnector(OpenAIConnector):
def __init__(self, api_key: str, base_url: str = "https://api.perplexity.ai", prompt_template: str = None) -> None:
super().__init__(prompt_template=prompt_template, base_url=base_url, api_key=api_key)

def chat_completion(
async def chat_completion(
self,
model: str,
messages: list[dict[str]],
Expand All @@ -29,7 +29,7 @@ def chat_completion(
if "perplexity" in model:
model = model.replace("perplexity/", "", 1)

return super().chat_completion(
return await super().chat_completion(
model=model,
messages=messages,
max_tokens=max_tokens,
Expand Down

0 comments on commit ed8fb9a

Please sign in to comment.