diff --git a/oracles/src/domain/llm/generate_response_use_case.py b/oracles/src/domain/llm/generate_response_use_case.py index a6c9c39..60b1ae6 100644 --- a/oracles/src/domain/llm/generate_response_use_case.py +++ b/oracles/src/domain/llm/generate_response_use_case.py @@ -2,10 +2,10 @@ from typing import List from typing import Optional +import groq from groq import AsyncGroq +import openai from openai import AsyncOpenAI -from openai import RateLimitError -from openai import APIError from openai.types.chat import ChatCompletion from groq.types.chat import ChatCompletion as GroqChatCompletion @@ -16,7 +16,9 @@ from src.entities import PromptType -@backoff.on_exception(backoff.expo, RateLimitError) +@backoff.on_exception( + backoff.expo, (openai.RateLimitError, openai.APITimeoutError), max_tries=3 +) async def _generate(model: str, messages: List[dict]) -> Optional[str]: client = AsyncOpenAI( api_key=settings.OPEN_AI_API_KEY, @@ -28,7 +30,9 @@ async def _generate(model: str, messages: List[dict]) -> Optional[str]: return chat_completion.choices[0].message.content -@backoff.on_exception(backoff.expo, RateLimitError) +@backoff.on_exception( + backoff.expo, (openai.RateLimitError, openai.APITimeoutError), max_tries=3 +) async def _generate_openai_with_params(chat: Chat) -> Optional[ChatCompletion]: client = AsyncOpenAI( api_key=settings.OPEN_AI_API_KEY, @@ -49,11 +53,16 @@ async def _generate_openai_with_params(chat: Chat) -> Optional[ChatCompletion]: tool_choice=chat.config.tool_choice, user=chat.config.user, ) - assert chat_completion.choices[0].message.content or chat_completion.choices[0].message.tool_calls + assert ( + chat_completion.choices[0].message.content + or chat_completion.choices[0].message.tool_calls + ) return chat_completion -@backoff.on_exception(backoff.expo, RateLimitError) +@backoff.on_exception( + backoff.expo, (groq.RateLimitError, groq.APITimeoutError), max_tries=3 +) async def _generate_groq_with_params(chat: Chat) -> Optional[GroqChatCompletion]: client = AsyncGroq( api_key=settings.GROQ_API_KEY, @@ -72,7 +81,10 @@ async def _generate_groq_with_params(chat: Chat) -> Optional[GroqChatCompletion] top_p=chat.config.top_p, user=chat.config.user, ) - assert chat_completion.choices[0].message.content or chat_completion.choices[0].message.tool_calls + assert ( + chat_completion.choices[0].message.content + or chat_completion.choices[0].message.tool_calls + ) return chat_completion @@ -80,8 +92,7 @@ async def execute(model: str, chat: Chat) -> LLMResult: try: if not chat.config or chat.prompt_type == PromptType.DEFAULT: chat.prompt_type = PromptType.DEFAULT - response = await _generate( - model=model, messages=chat.messages) + response = await _generate(model=model, messages=chat.messages) elif chat.prompt_type == PromptType.OPENAI: response = await _generate_openai_with_params(chat) elif chat.prompt_type == PromptType.GROQ: @@ -92,7 +103,7 @@ async def execute(model: str, chat: Chat) -> LLMResult: chat_completion=response, error="", ) - except APIError as api_error: + except openai.APIError as api_error: print(f"OpenAI API error: {api_error}", flush=True) return LLMResult( chat_completion=None, diff --git a/oracles/src/domain/tools/image_generation/generate_image_use_case.py b/oracles/src/domain/tools/image_generation/generate_image_use_case.py index 08df43b..3ed3bf1 100644 --- a/oracles/src/domain/tools/image_generation/generate_image_use_case.py +++ b/oracles/src/domain/tools/image_generation/generate_image_use_case.py @@ -1,13 +1,14 @@ import backoff import settings from typing import Optional +import openai from openai import AsyncOpenAI -from openai import RateLimitError -from openai import APIError from src.domain.tools.image_generation.entities import ImageGenerationResult -@backoff.on_exception(backoff.expo, RateLimitError) +@backoff.on_exception( + backoff.expo, (openai.RateLimitError, openai.APITimeoutError), max_tries=3 +) async def _generate_image(prompt: str) -> Optional[ImageGenerationResult]: client = AsyncOpenAI( api_key=settings.OPEN_AI_API_KEY, @@ -29,7 +30,7 @@ async def execute(prompt: str) -> Optional[ImageGenerationResult]: url=response.data[0].url, error="", ) - except APIError as api_error: + except openai.APIError as api_error: print(f"OpenAI API error: {api_error}", flush=True) return ImageGenerationResult( url="", diff --git a/oracles/src/repositories/knowledge_base_repository.py b/oracles/src/repositories/knowledge_base_repository.py index 53c20a7..c4125e8 100644 --- a/oracles/src/repositories/knowledge_base_repository.py +++ b/oracles/src/repositories/knowledge_base_repository.py @@ -5,8 +5,8 @@ import settings import numpy as np from io import BytesIO +import openai from openai import AsyncOpenAI -from openai import RateLimitError from collections import OrderedDict from typing import List, Any, Tuple, Dict from src.domain.knowledge_base.entities import Document @@ -92,7 +92,9 @@ async def exists(self, name: str) -> bool: except KeyError: return False - @backoff.on_exception(backoff.expo, RateLimitError, max_tries=3) + @backoff.on_exception( + backoff.expo, (openai.RateLimitError, openai.APITimeoutError), max_tries=3 + ) async def _create_embedding(self, texts: List[str]) -> List[float]: response = await self.openai_client.embeddings.create( input=texts, model="text-embedding-3-small"