Skip to content

Commit

Permalink
server: tests: Fix some random behavior where the wait for busy statu…
Browse files Browse the repository at this point in the history
…s is missing
  • Loading branch information
phymbert committed Feb 22, 2024
1 parent aa591ef commit 26b66c5
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions examples/server/tests/features/steps/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,17 @@ def step_start_server(context):
async def step_wait_for_the_server_to_be_started(context, expecting_status):
match expecting_status:
case 'healthy':
await wait_for_health_status(context.base_url, 200, 'ok')
await wait_for_health_status(context, context.base_url, 200, 'ok')

case 'ready' | 'idle':
await wait_for_health_status(context.base_url, 200, 'ok',
await wait_for_health_status(context, context.base_url, 200, 'ok',
params={'fail_on_no_slot': 0, 'include_slots': 0},
slots_idle=context.n_slots,
slots_processing=0,
expected_slots=[{'id': slot_id, 'state': 0}
for slot_id in range(context.n_slots)])
case 'busy':
await wait_for_health_status(context.base_url, 503,
await wait_for_health_status(context, context.base_url, 503,
'no slot available',
params={'fail_on_no_slot': 0, 'include_slots': 0},
slots_idle=0,
Expand Down Expand Up @@ -269,17 +269,24 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'user_api_key') else None)


@async_run_until_complete
@step(u'all prompts are predicted')
async def step_impl(context):
await all_prompts_are_predicted(context)


@step(u'all prompts are predicted with {n_predict} tokens')
@async_run_until_complete
async def step_all_prompts_are_predicted(context, n_predict):
n_completion_tasks = len(context.concurrent_completion_tasks)
print(f"Waiting for all {n_completion_tasks} completion responses...")
for task_no in range(n_completion_tasks):
context.completions.append(await context.concurrent_completion_tasks.pop())
n_completions = len(context.completions)
expected_predicted_n = int(n_predict)
await all_prompts_are_predicted(context, expected_predicted_n)


async def all_prompts_are_predicted(context, expected_predicted_n):
n_completions = await gather_concurrent_completions_tasks(context)
assert n_completions > 0
for i in range(n_completions):
assert_n_tokens_predicted(context.completions.pop(), expected_predicted_n=int(n_predict))
assert_n_tokens_predicted(context.completions.pop(), expected_predicted_n=expected_predicted_n)


@step(u'embeddings are computed for')
Expand Down Expand Up @@ -448,7 +455,6 @@ async def oai_chat_completions(user_prompt,
completion_response['timings']['predicted_n'] += 1
print(f"XXXXXXXXXXXXXXXXXcompletion_response: {completion_response}")
else:
print(f"raw completion response: {response}")
if expect_api_error is None or not expect_api_error:
assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin
Expand Down Expand Up @@ -512,7 +518,17 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
f' ```\n{content}\n``` do not match /{re_content}/')


async def wait_for_health_status(base_url,
async def gather_concurrent_completions_tasks(context):
n_completion_tasks = len(context.concurrent_completion_tasks)
print(f"Waiting for all {n_completion_tasks} completion responses...")
for task_no in range(n_completion_tasks):
context.completions.append(await context.concurrent_completion_tasks.pop())
n_completions = len(context.completions)
return n_completions


async def wait_for_health_status(context,
base_url,
expected_http_status_code,
expected_health_status,
params=None,
Expand Down Expand Up @@ -545,8 +561,17 @@ async def wait_for_health_status(base_url,
assert_slots_status(health['slots'], expected_slots)
return
await asyncio.sleep(interval)

counter += interval
if counter >= timeout:
# Sometimes health requests are triggered after completions are predicted
if expected_http_status_code == 503:
if len(context.completions) == 0:
print("\x1b[5;37;43mWARNING: forcing concurrent completions task, busy health check missed")
n_completions = await gather_concurrent_completions_tasks(context)
if n_completions > 0:
return

assert False, 'timeout exceeded'


Expand All @@ -572,6 +597,8 @@ def start_server_background(context):
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
server_args = [
'--host', context.server_fqdn,
'--port', context.server_port,
'--model', context.model_file
]
if context.server_continuous_batching:
Expand Down

0 comments on commit 26b66c5

Please sign in to comment.