Skip to content

Commit

Permalink
fix timeout issue, add cost in LLMResponse
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Dec 15, 2024
1 parent 66d27b9 commit 0f5f3b3
Showing 1 changed file with 113 additions and 44 deletions.
157 changes: 113 additions & 44 deletions defog_utils/utils_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,68 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Any

LLM_COSTS_PER_TOKEN = {
"gpt-4o": {"input_cost_per1k": 0.0025, "output_cost_per1k": 0.01},
"gpt-4o-mini": {"input_cost_per1k": 0.00015, "output_cost_per1k": 0.0006},
"o1-preview": {"input_cost_per1k": 0.015, "output_cost_per1k": 0.06},
"o1-mini": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.012},
"gpt-4-turbo": {"input_cost_per1k": 0.01, "output_cost_per1k": 0.03},
"gpt-3.5-turbo": {"input_cost_per1k": 0.0005, "output_cost_per1k": 0.0015},
"claude-3-5-sonnet": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.015},
"claude-3-5-haiku": {"input_cost_per1k": 0.00025, "output_cost_per1k": 0.00125},
"claude-3-opus": {"input_cost_per1k": 0.015, "output_cost_per1k": 0.075},
"claude-3-sonnet": {"input_cost_per1k": 0.003, "output_cost_per1k": 0.015},
"claude-3-haiku": {"input_cost_per1k": 0.00025, "output_cost_per1k": 0.00125},
"gemini-1.5-pro": {"input_cost_per1k": 0.00125, "output_cost_per1k": 0.005},
"gemini-1.5-flash": {"input_cost_per1k": 0.000075, "output_cost_per1k": 0.0003},
"gemini-1.5-flash-8b": {"input_cost_per1k": 0.0000375, "output_cost_per1k": 0.00015,},
"gemini-2.0-flash": {
# assume same cost as 1.5
"input_cost_per1k": 0.000075,
"output_cost_per1k": 0.0003,
},
}


@dataclass
class LLMResponse:
content: Any
model: str
time: float
input_tokens: int
output_tokens: int
output_tokens_details: Optional[Dict[str, int]] = None
cost: Optional[float] = None

def __post_init__(self):
if self.model in LLM_COSTS_PER_TOKEN:
model_name = self.model
else:
# if there is no exact match (for example, if the model name is "gpt-4o-2024-08-06")
# then try to find the closest match
model_name = None
potential_model_names = []

# first, find all the models that have a matching prefix
for mname in LLM_COSTS_PER_TOKEN.keys():
if mname in self.model:
potential_model_names.append(mname)

if len(potential_model_names) == 1:
model_name = potential_model_names[0]
else:
# if there are multiple potential matches, then find the one with the longest prefix
model_name = max(potential_model_names, key=len)

if model_name:
self.cost = (
self.input_tokens
/ 1000
* LLM_COSTS_PER_TOKEN[model_name]["input_cost_per1k"]
+ self.output_tokens
/ 1000
* LLM_COSTS_PER_TOKEN[model_name]["output_cost_per1k"]
)


def chat_anthropic(
Expand Down Expand Up @@ -48,10 +103,11 @@ def chat_anthropic(
if len(response.content) == 0:
raise Exception("Max tokens reached")
return LLMResponse(
response.content[0].text,
round(time.time() - t, 3),
response.usage.input_tokens,
response.usage.output_tokens,
model=model,
content=response.content[0].text,
time=round(time.time() - t, 3),
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
)


Expand All @@ -74,7 +130,7 @@ async def chat_anthropic_async(
"""
from anthropic import AsyncAnthropic

client_anthropic = AsyncAnthropic(timeout=timeout)
client_anthropic = AsyncAnthropic()
t = time.time()
if len(messages) >= 1 and messages[0].get("role") == "system":
sys_msg = messages[0]["content"]
Expand All @@ -88,16 +144,18 @@ async def chat_anthropic_async(
max_tokens=max_completion_tokens,
temperature=temperature,
stop_sequences=stop,
timeout=timeout
)
if response.stop_reason == "max_tokens":
raise Exception("Max tokens reached")
if len(response.content) == 0:
raise Exception("Max tokens reached")
return LLMResponse(
response.content[0].text,
round(time.time() - t, 3),
response.usage.input_tokens,
response.usage.output_tokens,
model=model,
content=response.content[0].text,
time=round(time.time() - t, 3),
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
)


Expand All @@ -124,7 +182,7 @@ def chat_openai(
sys_msg = messages[0]["content"]
messages = messages[1:]
messages[0]["content"] = sys_msg + messages[0]["content"]

response = client_openai.chat.completions.create(
messages=messages,
model=model,
Expand All @@ -138,7 +196,9 @@ def chat_openai(
max_completion_tokens=max_completion_tokens,
temperature=temperature,
stop=stop,
response_format={"type": "json_object"} if json_mode else response_format,
response_format=(
{"type": "json_object"} if json_mode else response_format
),
seed=seed,
)
else:
Expand All @@ -154,18 +214,19 @@ def chat_openai(
raise Exception("Max tokens reached")
if len(response.choices) == 0:
raise Exception("Max tokens reached")

if response_format and model not in ["o1-mini", "o1-preview", "o1"]:
content = response.choices[0].message.parsed
else:
content = response.choices[0].message.content

return LLMResponse(
content,
round(time.time() - t, 3),
response.usage.prompt_tokens,
response.usage.completion_tokens,
response.usage.completion_tokens_details,
model=model,
content=content,
time=round(time.time() - t, 3),
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
output_tokens_details=response.usage.completion_tokens_details,
)


Expand All @@ -188,20 +249,21 @@ async def chat_openai_async(
"""
from openai import AsyncOpenAI

client_openai = AsyncOpenAI(timeout=timeout)
client_openai = AsyncOpenAI()
t = time.time()
if model.startswith("o1"):
if messages[0].get("role") == "system":
sys_msg = messages[0]["content"]
messages = messages[1:]
messages[0]["content"] = sys_msg + messages[0]["content"]

response = await client_openai.chat.completions.create(
messages=messages,
model=model,
max_completion_tokens=max_completion_tokens,
store=store,
metadata=metadata,
timeout=timeout,
)
else:
if response_format or json_mode:
Expand All @@ -211,7 +273,9 @@ async def chat_openai_async(
max_completion_tokens=max_completion_tokens,
temperature=temperature,
stop=stop,
response_format={"type": "json_object"} if json_mode else response_format,
response_format=(
{"type": "json_object"} if json_mode else response_format
),
seed=seed,
store=store,
metadata=metadata,
Expand All @@ -227,24 +291,25 @@ async def chat_openai_async(
store=store,
metadata=metadata,
)

if response_format and not model.startswith("o1"):
content = response.choices[0].message.parsed
else:
content = response.choices[0].message.content

if response.choices[0].finish_reason == "length":
print("Max tokens reached")
raise Exception("Max tokens reached")
if len(response.choices) == 0:
print("Empty response")
raise Exception("No response")
return LLMResponse(
content,
round(time.time() - t, 3),
response.usage.prompt_tokens,
response.usage.completion_tokens,
response.usage.completion_tokens_details,
model=model,
content=content,
time=round(time.time() - t, 3),
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
output_tokens_details=response.usage.completion_tokens_details,
)


Expand Down Expand Up @@ -280,10 +345,11 @@ def chat_together(
if len(response.choices) == 0:
raise Exception("Max tokens reached")
return LLMResponse(
response.choices[0].message.content,
round(time.time() - t, 3),
response.usage.prompt_tokens,
response.usage.completion_tokens,
model=model,
content=response.choices[0].message.content,
time=round(time.time() - t, 3),
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
)


Expand Down Expand Up @@ -322,10 +388,11 @@ async def chat_together_async(
if len(response.choices) == 0:
raise Exception("Max tokens reached")
return LLMResponse(
response.choices[0].message.content,
round(time.time() - t, 3),
response.usage.prompt_tokens,
response.usage.completion_tokens,
model=model,
content=response.choices[0].message.content,
time=round(time.time() - t, 3),
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
)


Expand Down Expand Up @@ -353,7 +420,7 @@ def chat_gemini(
messages = messages[1:]
else:
system_msg = None

message = "\n".join([i["content"] for i in messages])

generation_config = types.GenerateContentConfig(
Expand All @@ -365,9 +432,9 @@ def chat_gemini(

if response_format:
# use Pydantic classes for response_format
generation_config.response_mime_type = 'application/json'
generation_config.response_mime_type = "application/json"
generation_config.response_schema = response_format

try:
response = client.models.generate_content(
model=model,
Expand All @@ -381,8 +448,9 @@ def chat_gemini(
if response_format:
# convert the content into Pydantic class
content = response_format.parse_raw(content)

return LLMResponse(
model=model,
content=content,
time=round(time.time() - t, 3),
input_tokens=response.usage_metadata.prompt_token_count,
Expand All @@ -401,7 +469,7 @@ async def chat_gemini_async(
seed: int = 0,
store=True,
metadata=None,
timeout=100, # does not have timeout method
timeout=100, # does not have timeout method
) -> LLMResponse:
from google import genai
from google.genai import types
Expand All @@ -415,7 +483,7 @@ async def chat_gemini_async(
messages = messages[1:]
else:
system_msg = None

message = "\n".join([i["content"] for i in messages])

generation_config = types.GenerateContentConfig(
Expand All @@ -427,7 +495,7 @@ async def chat_gemini_async(

if response_format:
# use Pydantic classes for response_format
generation_config.response_mime_type = 'application/json'
generation_config.response_mime_type = "application/json"
generation_config.response_schema = response_format

try:
Expand All @@ -445,8 +513,9 @@ async def chat_gemini_async(
content = response_format.parse_raw(content)

return LLMResponse(
model=model,
content=content,
time=round(time.time() - t, 3),
input_tokens=response.usage_metadata.prompt_token_count,
output_tokens=response.usage_metadata.candidates_token_count,
)
)

0 comments on commit 0f5f3b3

Please sign in to comment.