Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZachZimm committed Sep 11, 2024
1 parent 918ec65 commit 3358783
Showing 1 changed file with 67 additions and 6 deletions.
73 changes: 67 additions & 6 deletions tests/test_fastmlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ChatCompletionResponse,
ChatMessage,
ModelProvider,
Usage,
app,
handle_function_calls,
)
Expand Down Expand Up @@ -53,7 +54,11 @@ async def get_available_models(self):

# Mock functions
def mock_generate(*args, **kwargs):
return "generated response"
return "generated response", {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
}


def mock_vlm_stream_generate(*args, **kwargs):
Expand Down Expand Up @@ -95,6 +100,11 @@ def test_chat_completion_vlm(client):

assert response.status_code == 200
assert "generated response" in response.json()["choices"][0]["message"]["content"]
assert "usage" in response.json()
usage = response.json()["usage"]
assert "prompt_tokens" in usage
assert "completion_tokens" in usage
assert "total_tokens" in usage


def test_chat_completion_lm(client):
Expand All @@ -107,6 +117,11 @@ def test_chat_completion_lm(client):

assert response.status_code == 200
assert "generated response" in response.json()["choices"][0]["message"]["content"]
assert "usage" in response.json()
usage = response.json()["usage"]
assert "prompt_tokens" in usage
assert "completion_tokens" in usage
assert "total_tokens" in usage


@pytest.mark.asyncio
Expand Down Expand Up @@ -141,6 +156,11 @@ async def test_vlm_streaming(client):
assert "delta" in data["choices"][0]
assert "role" in data["choices"][0]["delta"]
assert "content" in data["choices"][0]["delta"]
if "usage" in data:
usage = data["usage"]
assert "prompt_tokens" in usage
assert "completion_tokens" in usage
assert "total_tokens" in usage

assert chunks[-2] == "data: [DONE]"

Expand Down Expand Up @@ -177,6 +197,11 @@ async def test_lm_streaming(client):
assert "delta" in data["choices"][0]
assert "role" in data["choices"][0]["delta"]
assert "content" in data["choices"][0]["delta"]
if "usage" in data:
usage = data["usage"]
assert "prompt_tokens" in usage
assert "completion_tokens" in usage
assert "total_tokens" in usage

assert chunks[-2] == "data: [DONE]"

Expand Down Expand Up @@ -240,8 +265,13 @@ def test_handle_function_calls_json_format():
"""
request = MagicMock()
request.model = "test_model"
token_info = MagicMock()
token_info.prompt_tokens = 10
token_info.completion_tokens = 20
token_info.total_tokens = 30
token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30)

response = handle_function_calls(output, request)
response = handle_function_calls(output, request, token_info)

assert isinstance(response, ChatCompletionResponse)
assert len(response.tool_calls) == 1
Expand All @@ -252,6 +282,11 @@ def test_handle_function_calls_json_format():
}
assert "Here's the weather forecast:" in response.choices[0]["message"]["content"]
assert '{"tool_calls":' not in response.choices[0]["message"]["content"]
assert response.usage
usage = response.usage
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30


def test_handle_function_calls_xml_format_old():
Expand All @@ -262,15 +297,22 @@ def test_handle_function_calls_xml_format_old():
"""
request = MagicMock()
request.model = "test_model"
token_info = MagicMock()
token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30)

response = handle_function_calls(output, request)
response = handle_function_calls(output, request, token_info)

assert isinstance(response, ChatCompletionResponse)
assert len(response.tool_calls) == 1
assert response.tool_calls[0].function.name == "get_stock_price"
assert json.loads(response.tool_calls[0].function.arguments) == {"symbol": "AAPL"}
assert "Let me check that for you." in response.choices[0]["message"]["content"]
assert "<function_calls>" not in response.choices[0]["message"]["content"]
assert response.usage
usage = response.usage
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30


def test_handle_function_calls_xml_format_new():
Expand All @@ -285,8 +327,9 @@ def test_handle_function_calls_xml_format_new():
"""
request = MagicMock()
request.model = "test_model"
token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30)

response = handle_function_calls(output, request)
response = handle_function_calls(output, request, token_info)

assert isinstance(response, ChatCompletionResponse)
assert len(response.tool_calls) == 1
Expand All @@ -300,6 +343,11 @@ def test_handle_function_calls_xml_format_new():
in response.choices[0]["message"]["content"]
)
assert "<function_calls>" not in response.choices[0]["message"]["content"]
assert response.usage
usage = response.usage
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30


def test_handle_function_calls_functools_format():
Expand All @@ -308,8 +356,9 @@ def test_handle_function_calls_functools_format():
"""
request = MagicMock()
request.model = "test_model"
token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30)

response = handle_function_calls(output, request)
response = handle_function_calls(output, request, token_info)

assert isinstance(response, ChatCompletionResponse)
assert response.tool_calls is not None
Expand All @@ -321,6 +370,11 @@ def test_handle_function_calls_functools_format():
}
assert "Here are the results:" in response.choices[0]["message"]["content"]
assert "functools[" not in response.choices[0]["message"]["content"]
assert response.usage
usage = response.usage
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30


# Add a new test for multiple function calls in functools format
Expand All @@ -330,7 +384,9 @@ def test_handle_function_calls_multiple_functools():
"""
request = MagicMock()
request.model = "test_model"
response = handle_function_calls(output, request)
token_info = Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30)

response = handle_function_calls(output, request, token_info)
assert isinstance(response, ChatCompletionResponse)
assert response.tool_calls is not None
assert len(response.tool_calls) == 2
Expand All @@ -342,6 +398,11 @@ def test_handle_function_calls_multiple_functools():
assert json.loads(response.tool_calls[1].function.arguments) == {"timezone": "EST"}
assert "Here are the results:" in response.choices[0]["message"]["content"]
assert "functools[" not in response.choices[0]["message"]["content"]
assert response.usage
usage = response.usage
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30


if __name__ == "__main__":
Expand Down

0 comments on commit 3358783

Please sign in to comment.