Skip to content

Commit

Permalink
add test_cache_vs_nocache_prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Nov 26, 2024
1 parent 217c9e4 commit 52c2625
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions examples/server/tests/unit/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_consistent_result_same_seed(n_slots: int):
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
Expand All @@ -80,7 +80,7 @@ def test_different_result_different_seed(n_slots: int):
"prompt": "I believe the meaning of life is",
"seed": seed,
"temperature": 1.0,
"cache_prompt": False,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] != last_res.body["content"]
Expand All @@ -99,13 +99,32 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": temperature,
"cache_prompt": False,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res


@pytest.mark.skip(reason="This test fails on linux, need to be fixed")
def test_cache_vs_nocache_prompt():
global server
server.start()
res_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": True,
})
res_no_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res_cache.body["content"] == res_no_cache.body["content"]


def test_completion_with_tokens_input():
global server
server.temperature = 0.0
Expand Down

0 comments on commit 52c2625

Please sign in to comment.