Skip to content

Commit

Permalink
Adds timeout retry backoff
Browse files Browse the repository at this point in the history
  • Loading branch information
kgrofelnik committed Apr 16, 2024
1 parent 7bb96a3 commit e073228
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 16 deletions.
31 changes: 21 additions & 10 deletions oracles/src/domain/llm/generate_response_use_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import List
from typing import Optional

import groq
from groq import AsyncGroq
import openai
from openai import AsyncOpenAI
from openai import RateLimitError
from openai import APIError
from openai.types.chat import ChatCompletion
from groq.types.chat import ChatCompletion as GroqChatCompletion

Expand All @@ -16,7 +16,9 @@
from src.entities import PromptType


@backoff.on_exception(backoff.expo, RateLimitError)
@backoff.on_exception(
backoff.expo, (openai.RateLimitError, openai.APITimeoutError), max_tries=3
)
async def _generate(model: str, messages: List[dict]) -> Optional[str]:
client = AsyncOpenAI(
api_key=settings.OPEN_AI_API_KEY,
Expand All @@ -28,7 +30,9 @@ async def _generate(model: str, messages: List[dict]) -> Optional[str]:
return chat_completion.choices[0].message.content


@backoff.on_exception(backoff.expo, RateLimitError)
@backoff.on_exception(
backoff.expo, (openai.RateLimitError, openai.APITimeoutError), max_tries=3
)
async def _generate_openai_with_params(chat: Chat) -> Optional[ChatCompletion]:
client = AsyncOpenAI(
api_key=settings.OPEN_AI_API_KEY,
Expand All @@ -49,11 +53,16 @@ async def _generate_openai_with_params(chat: Chat) -> Optional[ChatCompletion]:
tool_choice=chat.config.tool_choice,
user=chat.config.user,
)
assert chat_completion.choices[0].message.content or chat_completion.choices[0].message.tool_calls
assert (
chat_completion.choices[0].message.content
or chat_completion.choices[0].message.tool_calls
)
return chat_completion


@backoff.on_exception(backoff.expo, RateLimitError)
@backoff.on_exception(
backoff.expo, (groq.RateLimitError, groq.APITimeoutError), max_tries=3
)
async def _generate_groq_with_params(chat: Chat) -> Optional[GroqChatCompletion]:
client = AsyncGroq(
api_key=settings.GROQ_API_KEY,
Expand All @@ -72,16 +81,18 @@ async def _generate_groq_with_params(chat: Chat) -> Optional[GroqChatCompletion]
top_p=chat.config.top_p,
user=chat.config.user,
)
assert chat_completion.choices[0].message.content or chat_completion.choices[0].message.tool_calls
assert (
chat_completion.choices[0].message.content
or chat_completion.choices[0].message.tool_calls
)
return chat_completion


async def execute(model: str, chat: Chat) -> LLMResult:
try:
if not chat.config or chat.prompt_type == PromptType.DEFAULT:
chat.prompt_type = PromptType.DEFAULT
response = await _generate(
model=model, messages=chat.messages)
response = await _generate(model=model, messages=chat.messages)
elif chat.prompt_type == PromptType.OPENAI:
response = await _generate_openai_with_params(chat)
elif chat.prompt_type == PromptType.GROQ:
Expand All @@ -92,7 +103,7 @@ async def execute(model: str, chat: Chat) -> LLMResult:
chat_completion=response,
error="",
)
except APIError as api_error:
except openai.APIError as api_error:
print(f"OpenAI API error: {api_error}", flush=True)
return LLMResult(
chat_completion=None,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import backoff
import settings
from typing import Optional
import openai
from openai import AsyncOpenAI
from openai import RateLimitError
from openai import APIError
from src.domain.tools.image_generation.entities import ImageGenerationResult


@backoff.on_exception(backoff.expo, RateLimitError)
@backoff.on_exception(
backoff.expo, (openai.RateLimitError, openai.APITimeoutError), max_tries=3
)
async def _generate_image(prompt: str) -> Optional[ImageGenerationResult]:
client = AsyncOpenAI(
api_key=settings.OPEN_AI_API_KEY,
Expand All @@ -29,7 +30,7 @@ async def execute(prompt: str) -> Optional[ImageGenerationResult]:
url=response.data[0].url,
error="",
)
except APIError as api_error:
except openai.APIError as api_error:
print(f"OpenAI API error: {api_error}", flush=True)
return ImageGenerationResult(
url="",
Expand Down
6 changes: 4 additions & 2 deletions oracles/src/repositories/knowledge_base_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import settings
import numpy as np
from io import BytesIO
import openai
from openai import AsyncOpenAI
from openai import RateLimitError
from collections import OrderedDict
from typing import List, Any, Tuple, Dict
from src.domain.knowledge_base.entities import Document
Expand Down Expand Up @@ -92,7 +92,9 @@ async def exists(self, name: str) -> bool:
except KeyError:
return False

@backoff.on_exception(backoff.expo, RateLimitError, max_tries=3)
@backoff.on_exception(
backoff.expo, (openai.RateLimitError, openai.APITimeoutError), max_tries=3
)
async def _create_embedding(self, texts: List[str]) -> List[float]:
response = await self.openai_client.embeddings.create(
input=texts, model="text-embedding-3-small"
Expand Down

0 comments on commit e073228

Please sign in to comment.