Skip to content

Commit

Permalink
update tests, actually make the async tests work
Browse files Browse the repository at this point in the history
  • Loading branch information
rishsriv committed Jan 18, 2025
1 parent 7f42eda commit 7e48c91
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 109 deletions.
92 changes: 1 addition & 91 deletions tests/test_utils_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,94 +244,4 @@ def test_chat_gemini_sql(self):
self.assertIsInstance(response, LLMResponse)
self.check_sql(response.content)
self.assertEqual(response.input_tokens, 87)
self.assertTrue(response.output_tokens < 10)

def test_chat_json_anthropic(self):
response = chat_anthropic(
messages_json,
model="claude-3-haiku-20240307",
max_completion_tokens=100,
seed=0,
json_mode=True,
)
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

@pytest.mark.asyncio
async def test_chat_json_anthropic_async(self):
response = await chat_anthropic_async(
"claude-3-haiku-20240307",
messages_json,
max_completion_tokens=100,
seed=0,
json_mode=True,
)
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

def test_chat_json_openai(self):
response = chat_openai(
messages_json, model="gpt-4o-mini", seed=0, json_mode=True
)
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

@pytest.mark.asyncio
async def test_chat_json_openai_async(self):
response = await chat_openai_async(
"gpt-4o-mini", messages_json, seed=0, json_mode=True
)
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

def test_chat_json_together(self):
response = chat_together(
messages_json,
model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
seed=0,
json_mode=True,
)
print(response)
self.assertIsInstance(response, LLMResponse)
raw_output = response.content
resp_dict = json.loads(raw_output)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)

@pytest.mark.asyncio
async def test_chat_json_together_async(self):
response = await chat_together_async(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
messages_json,
seed=0,
json_mode=True,
)
print(response)
self.assertIsInstance(response, LLMResponse)
resp_dict = json.loads(response.content)
self.check_sql(resp_dict["sql"])
self.assertIsInstance(resp_dict["reasoning"], str)
self.assertIsInstance(response.input_tokens, int)
self.assertIsInstance(response.output_tokens, int)
self.assertTrue(response.output_tokens < 10)
76 changes: 58 additions & 18 deletions tests/test_utils_multi_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
chat_together_async,
)

from pydantic import BaseModel, Field

messages_sql = [
{
"role": "system",
Expand Down Expand Up @@ -46,10 +48,35 @@
"select count(order_id) as total_orders from orders",
]

class ResponseFormat (BaseModel):
reasoning: str
sql: str

messages_sql_structured = [
{
"role": "system",
"content": "Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases.",
},
{
"role": "user",
"content": f"""Question: What is the total number of orders?
Schema:
```sql
CREATE TABLE orders (
order_id int,
customer_id int,
employee_id int,
order_date date
);
```
""",
},
]


class TestChatClients(unittest.TestCase):
class TestChatClients(unittest.IsolatedAsyncioTestCase):
def check_sql(self, sql: str):
self.assertIn(sql.strip(";\n").lower(), acceptable_sql)
self.assertIn(sql.replace("```sql", "").replace("```", "").strip(";\n").lower(), acceptable_sql)

def test_map_model_to_chat_fn(self):
self.assertEqual(
Expand Down Expand Up @@ -110,7 +137,6 @@ def test_simple_chat(self):
max_completion_tokens=20,
temperature=0.0,
stop=[";"],
json_mode=False,
seed=0,
)
self.assertIsInstance(responses, dict)
Expand Down Expand Up @@ -139,7 +165,6 @@ def test_sql_chat(self):
max_completion_tokens=20,
temperature=0.0,
stop=[";"],
json_mode=False,
seed=0,
)
self.assertIsInstance(responses, dict)
Expand All @@ -159,7 +184,9 @@ async def test_simple_chat_async(self):
"claude-3-haiku-20240307",
"gpt-4o-mini",
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"o1-mini",
# "o1-mini", --o1-mini seems to be having issues, and o3-mini will be out soon anyway. so leaving out for now
"o1",
"gemini-2.0-flash-exp"
]
messages = [
{"role": "user", "content": "Return a greeting in not more than 2 words\n"}
Expand All @@ -168,45 +195,58 @@ async def test_simple_chat_async(self):
response = await chat_async(
model,
messages,
max_completion_tokens=20,
max_completion_tokens=4000,
temperature=0.0,
stop=[";"],
json_mode=False,
seed=0,
)
print(model, response)
self.assertIsInstance(response, LLMResponse)
self.assertIsInstance(response.content, str)
self.assertIsInstance(response.time, float)
self.assertLess(
response.input_tokens, 50
) # higher as default system prompt is added in together's API when none provided
self.assertLess(response.output_tokens, 20)

@pytest.mark.asyncio
async def test_sql_chat_async(self):
models = [
"claude-3-haiku-20240307",
"gpt-4o-mini",
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"o1-mini",
# "o1-mini", --o1-mini seems to be having issues, and o3-mini will be out soon anyway. so leaving out for now
"o1",
"gemini-2.0-flash-exp"
]
for model in models:
response = await chat_async(
model,
messages_sql,
max_completion_tokens=20,
max_completion_tokens=4000,
temperature=0.0,
stop=[";"],
json_mode=False,
seed=0,
)
print(model, response)
self.assertIsInstance(response, LLMResponse)
self.check_sql(response.content)
self.assertIsInstance(response.time, float)
self.assertLess(response.input_tokens, 110)
self.assertLess(response.output_tokens, 20)

@pytest.mark.asyncio
async def test_sql_chat_structured_async(self):
models = [
"gpt-4o",
"o1",
"gemini-2.0-flash-exp",
]
for model in models:
response = await chat_async(
model,
messages_sql_structured,
max_completion_tokens=4000,
temperature=0.0,
stop=[";"],
seed=0,
response_format=ResponseFormat,
)
print(model, response)
self.check_sql(response.content.sql)
self.assertIsInstance(response.content.reasoning, str)


if __name__ == "__main__":
Expand Down

0 comments on commit 7e48c91

Please sign in to comment.