From 26b66c54964a2c1c03ecb77b5a9d0f7e53bb1b55 Mon Sep 17 00:00:00 2001 From: Pierrick HYMBERT Date: Thu, 22 Feb 2024 23:38:47 +0100 Subject: [PATCH] server: tests: Fix some random behavior where the wait for busy status is missing --- examples/server/tests/features/steps/steps.py | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 1e27ce274d2f4..71327728ae3df 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -86,17 +86,17 @@ def step_start_server(context): async def step_wait_for_the_server_to_be_started(context, expecting_status): match expecting_status: case 'healthy': - await wait_for_health_status(context.base_url, 200, 'ok') + await wait_for_health_status(context, context.base_url, 200, 'ok') case 'ready' | 'idle': - await wait_for_health_status(context.base_url, 200, 'ok', + await wait_for_health_status(context, context.base_url, 200, 'ok', params={'fail_on_no_slot': 0, 'include_slots': 0}, slots_idle=context.n_slots, slots_processing=0, expected_slots=[{'id': slot_id, 'state': 0} for slot_id in range(context.n_slots)]) case 'busy': - await wait_for_health_status(context.base_url, 503, + await wait_for_health_status(context, context.base_url, 503, 'no slot available', params={'fail_on_no_slot': 0, 'include_slots': 0}, slots_idle=0, @@ -269,17 +269,24 @@ async def step_oai_chat_completions(context): if hasattr(context, 'user_api_key') else None) +@async_run_until_complete +@step(u'all prompts are predicted') +async def step_impl(context): + await all_prompts_are_predicted(context) + + @step(u'all prompts are predicted with {n_predict} tokens') @async_run_until_complete async def step_all_prompts_are_predicted(context, n_predict): - n_completion_tasks = len(context.concurrent_completion_tasks) - print(f"Waiting for all {n_completion_tasks} completion responses...") - for task_no in range(n_completion_tasks): - context.completions.append(await context.concurrent_completion_tasks.pop()) - n_completions = len(context.completions) + expected_predicted_n = int(n_predict) + await all_prompts_are_predicted(context, expected_predicted_n) + + +async def all_prompts_are_predicted(context, expected_predicted_n): + n_completions = await gather_concurrent_completions_tasks(context) assert n_completions > 0 for i in range(n_completions): - assert_n_tokens_predicted(context.completions.pop(), expected_predicted_n=int(n_predict)) + assert_n_tokens_predicted(context.completions.pop(), expected_predicted_n=expected_predicted_n) @step(u'embeddings are computed for') @@ -448,7 +455,6 @@ async def oai_chat_completions(user_prompt, completion_response['timings']['predicted_n'] += 1 print(f"XXXXXXXXXXXXXXXXXcompletion_response: {completion_response}") else: - print(f"raw completion response: {response}") if expect_api_error is None or not expect_api_error: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -512,7 +518,17 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re f' ```\n{content}\n``` do not match /{re_content}/') -async def wait_for_health_status(base_url, +async def gather_concurrent_completions_tasks(context): + n_completion_tasks = len(context.concurrent_completion_tasks) + print(f"Waiting for all {n_completion_tasks} completion responses...") + for task_no in range(n_completion_tasks): + context.completions.append(await context.concurrent_completion_tasks.pop()) + n_completions = len(context.completions) + return n_completions + + +async def wait_for_health_status(context, + base_url, expected_http_status_code, expected_health_status, params=None, @@ -545,8 +561,17 @@ async def wait_for_health_status(base_url, assert_slots_status(health['slots'], expected_slots) return await asyncio.sleep(interval) + counter += interval if counter >= timeout: + # Sometimes health requests are triggered after completions are predicted + if expected_http_status_code == 503: + if len(context.completions) == 0: + print("\x1b[5;37;43mWARNING: forcing concurrent completions task, busy health check missed") + n_completions = await gather_concurrent_completions_tasks(context) + if n_completions > 0: + return + assert False, 'timeout exceeded' @@ -572,6 +597,8 @@ def start_server_background(context): if 'LLAMA_SERVER_BIN_PATH' in os.environ: context.server_path = os.environ['LLAMA_SERVER_BIN_PATH'] server_args = [ + '--host', context.server_fqdn, + '--port', context.server_port, '--model', context.model_file ] if context.server_continuous_batching: