From e8c5aad425c6179e4a229c10f6e4d587f82f2323 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 24 May 2024 15:55:54 +0300 Subject: [PATCH] fix: Choose correct way to generate based on model name --- agents-api/agents_api/rec_sum/generate.py | 26 +++++++++++++---------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/agents-api/agents_api/rec_sum/generate.py b/agents-api/agents_api/rec_sum/generate.py index d8a280d36..e05111faa 100644 --- a/agents-api/agents_api/rec_sum/generate.py +++ b/agents-api/agents_api/rec_sum/generate.py @@ -1,21 +1,25 @@ -from openai import AsyncClient from tenacity import retry, stop_after_attempt, wait_fixed - - -client = AsyncClient() +from agents_api.env import model_inference_url, model_api_key +from agents_api.model_registry import JULEP_MODELS +from litellm import acompletion @retry(wait=wait_fixed(2), stop=stop_after_attempt(5)) async def generate( messages: list[dict], - client: AsyncClient = client, model: str = "gpt-4-turbo", - **kwargs + **kwargs, ) -> dict: - result = await client.chat.completions.create( - model=model, messages=messages, **kwargs - ) + base_url, api_key = None, None + if model in JULEP_MODELS: + base_url, api_key = model_inference_url, model_api_key + model = f"openai/{model}" - result = result.choices[0].message.__dict__ + result = await acompletion( + model=model, + messages=messages, + base_url=base_url, + api_key=api_key, + ) - return result + return result.choices[0].message.json()