Skip to content

Commit

Permalink
add predicted outputs to chat_openai_async and chat_async
Browse files Browse the repository at this point in the history
  • Loading branch information
wendy-aw committed Jan 24, 2025
1 parent c5efdda commit 0a28132
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
10 changes: 8 additions & 2 deletions defog_utils/utils_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def chat_openai(
"""
Returns the response from the OpenAI API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
We use max_completion_tokens here, instead of using max_tokens. This is to support o1 models.
Note this function also supports DeepSeek models as it uses the same API. Simply use the base URL "https://api.deepseek.com"
"""
from openai import OpenAI

Expand Down Expand Up @@ -246,10 +247,12 @@ async def chat_openai_async(
timeout=100,
base_url: str = "https://api.openai.com/v1/",
api_key: str = os.environ.get("OPENAI_API_KEY", ""),
prediction: Dict[str,str] = None,
) -> LLMResponse:
"""
Returns the response from the OpenAI API, the time taken to generate the response, the number of input tokens used, and the number of output tokens used.
We use max_completion_tokens here, instead of using max_tokens. This is to support o1 models.
Note this function also supports DeepSeek models as it uses the same API. Simply use the base URL "https://api.deepseek.com"
"""
from openai import AsyncOpenAI

Expand Down Expand Up @@ -278,6 +281,11 @@ async def chat_openai_async(
"timeout": timeout,
"response_format": response_format,
}

if model in ["gpt-4o", "gpt-4o-mini"] and prediction:
request_params["prediction"] = prediction
del request_params["max_completion_tokens"]
del request_params["response_format"] # completion with prediction output does not support max_completion_tokens and response_format

if model in ["o1-mini", "o1-preview", "o1", "deepseek-chat", "deepseek-reasoner"]:
del request_params["temperature"]
Expand All @@ -287,8 +295,6 @@ async def chat_openai_async(

if "response_format" in request_params and request_params["response_format"]:
del request_params["stop"] # cannot have stop when using response_format, as that often leads to invalid JSON

if "response_format" in request_params and request_params["response_format"]:
response = await client_openai.beta.chat.completions.parse(**request_params)
content = response.choices[0].message.parsed
else:
Expand Down
41 changes: 28 additions & 13 deletions defog_utils/utils_multi_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ async def chat_async(
store=True,
metadata=None,
timeout=100, # in seconds
backup_model=None
backup_model=None,
prediction=None
) -> LLMResponse:
"""
Returns the response from the LLM API for a single model that is passed in.
Expand All @@ -89,18 +90,32 @@ async def chat_async(
model = backup_model
llm_function = map_model_to_chat_fn_async(model)
if not model.startswith("deepseek"):
return await llm_function(
model=model,
messages=messages,
max_completion_tokens=max_completion_tokens,
temperature=temperature,
stop=stop,
response_format=response_format,
seed=seed,
store=store,
metadata=metadata,
timeout=timeout,
)
if prediction and "gpt-4o" in model:
# predicted output completion does not support response_format and max_completion_tokens
return await llm_function(
model=model,
messages=messages,
temperature=temperature,
stop=stop,
seed=seed,
store=store,
metadata=metadata,
timeout=timeout,
prediction=prediction
)
else:
return await llm_function(
model=model,
messages=messages,
max_completion_tokens=max_completion_tokens,
temperature=temperature,
stop=stop,
response_format=response_format,
seed=seed,
store=store,
metadata=metadata,
timeout=timeout,
)
else:
if not os.getenv("DEEPSEEK_API_KEY"):
raise Exception("DEEPSEEK_API_KEY is not set")
Expand Down

0 comments on commit 0a28132

Please sign in to comment.