Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server tests : more pythonic process management; fix bare except: #6146

Merged
merged 9 commits into from
Mar 20, 2024
74 changes: 22 additions & 52 deletions examples/server/tests/features/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
import time
import traceback
from contextlib import closing

import psutil
from subprocess import TimeoutExpired


def before_scenario(context, scenario):
context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
if context.debug:
print("DEBUG=ON\n")
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n")
print("DEBUG=ON")
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
port = 8080
if 'PORT' in os.environ:
port = int(os.environ['PORT'])
Expand All @@ -27,75 +26,46 @@ def after_scenario(context, scenario):
return
if scenario.status == "failed":
if 'GITHUB_ACTIONS' in os.environ:
print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n")
print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n")
if os.path.isfile('llama.log'):
with closing(open('llama.log', 'r')) as f:
for line in f:
print(line)
if not is_server_listening(context.server_fqdn, context.server_port):
print("\x1b[33;101mERROR: Server stopped listening\x1b[0m\n")
print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")

if not pid_exists(context.server_process.pid):
if context.server_process.poll() is not None:
assert False, f"Server not running pid={context.server_process.pid} ..."

server_graceful_shutdown(context)
server_graceful_shutdown(context) # SIGINT

# Wait few for socket to free up
time.sleep(0.05)
try:
context.server_process.wait(0.5)
except TimeoutExpired:
print(f"server still alive after 500ms, force-killing pid={context.server_process.pid} ...")
context.server_process.kill() # SIGKILL
context.server_process.wait()

attempts = 0
while pid_exists(context.server_process.pid) or is_server_listening(context.server_fqdn, context.server_port):
server_kill(context)
while is_server_listening(context.server_fqdn, context.server_port):
time.sleep(0.1)
attempts += 1
if attempts > 5:
server_kill_hard(context)
except:
exc = sys.exception()
print("error in after scenario: \n")
print(exc)
print("*** print_tb: \n")
traceback.print_tb(exc.__traceback__, file=sys.stdout)
except Exception:
print("ignoring error in after_scenario:")
traceback.print_exc(file=sys.stdout)


def server_graceful_shutdown(context):
print(f"shutting down server pid={context.server_process.pid} ...\n")
print(f"shutting down server pid={context.server_process.pid} ...")
if os.name == 'nt':
os.kill(context.server_process.pid, signal.CTRL_C_EVENT)
interrupt = signal.CTRL_C_EVENT
else:
os.kill(context.server_process.pid, signal.SIGINT)


def server_kill(context):
print(f"killing server pid={context.server_process.pid} ...\n")
context.server_process.kill()


def server_kill_hard(context):
pid = context.server_process.pid
path = context.server_path

print(f"Server dangling exits, hard killing force {pid}={path}...\n")
try:
psutil.Process(pid).kill()
except psutil.NoSuchProcess:
return False
return True
interrupt = signal.SIGINT
context.server_process.send_signal(interrupt)


def is_server_listening(server_fqdn, server_port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
result = sock.connect_ex((server_fqdn, server_port))
_is_server_listening = result == 0
if _is_server_listening:
print(f"server is listening on {server_fqdn}:{server_port}...\n")
print(f"server is listening on {server_fqdn}:{server_port}...")
return _is_server_listening


def pid_exists(pid):
try:
psutil.Process(pid)
except psutil.NoSuchProcess:
return False
return True

14 changes: 7 additions & 7 deletions examples/server/tests/features/server.feature
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ Feature: llama.cpp server
And metric llamacpp:tokens_predicted is <n_predicted>

Examples: Prompts
| prompt | n_predict | re_content | n_prompt | n_predicted | truncated |
| I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not |
| Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids)+ | 46 | 64 | not |
| prompt | n_predict | re_content | n_prompt | n_predicted | truncated |
| I believe the meaning of life is | 8 | (read\|going)+ | 18 | 8 | not |
| Write a joke about AI from a very long prompt which will not be truncated | 256 | (princesses\|everyone\|kids\|Anna\|forest)+ | 46 | 64 | not |

Scenario: Completion prompt truncated
Given a prompt:
Expand All @@ -48,7 +48,7 @@ Feature: llama.cpp server
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
"""
And a completion request with no api error
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry
Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl
And the completion is truncated
And 109 prompt tokens are processed

Expand All @@ -65,9 +65,9 @@ Feature: llama.cpp server
And the completion is <truncated> truncated

Examples: Prompts
| model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated |
| llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird)+ | -1 | 64 | enabled | |
| model | system_prompt | user_prompt | max_tokens | re_content | n_prompt | n_predicted | enable_streaming | truncated |
| llama-2 | Book | What is the best book | 8 | (Here\|what)+ | 77 | 8 | disabled | not |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird\|Annabyear)+ | -1 | 64 | enabled | |


Scenario: Tokenize / Detokenize
Expand Down
33 changes: 17 additions & 16 deletions examples/server/tests/features/steps/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def step_server_config(context, server_fqdn, server_port):
def step_download_hf_model(context, hf_file, hf_repo):
context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
if context.debug:
print(f"model file: {context.model_file}\n")
print(f"model file: {context.model_file}")


@step('a model file {model_file}')
Expand Down Expand Up @@ -137,9 +137,12 @@ def step_start_server(context):
if 'GITHUB_ACTIONS' in os.environ:
max_attempts *= 2

addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM)
family, typ, proto, _, sockaddr = addrs[0]

while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
result = sock.connect_ex((context.server_fqdn, context.server_port))
with closing(socket.socket(family, typ, proto)) as sock:
result = sock.connect_ex(sockaddr)
if result == 0:
print("\x1b[33;46mserver started!\x1b[0m")
return
Expand Down Expand Up @@ -209,7 +212,7 @@ async def step_request_completion(context, api_error):
user_api_key=context.user_api_key)
context.tasks_result.append(completion)
if context.debug:
print(f"Completion response: {completion}\n")
print(f"Completion response: {completion}")
if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}"

Expand Down Expand Up @@ -354,7 +357,7 @@ def step_prompt_passkey(context, passkey, i_pos):
prompt += context.prompt_junk_suffix
if context.debug:
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```")
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
context.n_prompts = len(context.prompts)

Expand All @@ -363,7 +366,7 @@ def step_prompt_passkey(context, passkey, i_pos):
@async_run_until_complete
async def step_oai_chat_completions(context, api_error):
if context.debug:
print(f"Submitting OAI compatible completions request...\n")
print(f"Submitting OAI compatible completions request...")
expect_api_error = api_error == 'raised'
completion = await oai_chat_completions(context.prompts.pop(),
context.system_prompt,
Expand Down Expand Up @@ -508,12 +511,12 @@ async def step_all_embeddings_are_the_same(context):
embedding1 = np.array(embeddings[i])
embedding2 = np.array(embeddings[j])
if context.debug:
print(f"embedding1: {embedding1[-8:]}\n")
print(f"embedding2: {embedding2[-8:]}\n")
print(f"embedding1: {embedding1[-8:]}")
print(f"embedding2: {embedding2[-8:]}")
similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
msg = f"Similarity between {i} and {j}: {similarity:.10f}"
if context.debug:
print(f"{msg}\n")
print(f"{msg}")
assert np.isclose(similarity, 1.0, rtol=1e-05, atol=1e-08, equal_nan=False), msg


Expand Down Expand Up @@ -630,7 +633,7 @@ async def step_prometheus_metrics_exported(context):
metrics_raw = await metrics_response.text()
metric_exported = False
if context.debug:
print(f"/metrics answer:\n{metrics_raw}\n")
print(f"/metrics answer:\n{metrics_raw}")
context.metrics = {}
for metric in parser.text_string_to_metric_families(metrics_raw):
match metric.name:
Expand Down Expand Up @@ -932,7 +935,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
last_match = end
highlighted += content[last_match:]
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
print(f"Checking completion response: {highlighted}\n")
print(f"Checking completion response: {highlighted}")
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
if expected_predicted_n and expected_predicted_n > 0:
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
Expand All @@ -942,7 +945,7 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks)
if context.debug:
print(f"Waiting for all {n_tasks} tasks results...\n")
print(f"Waiting for all {n_tasks} tasks results...")
for task_no in range(n_tasks):
context.tasks_result.append(await context.concurrent_tasks.pop())
n_completions = len(context.tasks_result)
Expand All @@ -959,7 +962,7 @@ async def wait_for_health_status(context,
slots_processing=None,
expected_slots=None):
if context.debug:
print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
print(f"Starting checking for health for expected_health_status={expected_health_status}")
interval = 0.5
counter = 0
if 'GITHUB_ACTIONS' in os.environ:
Expand Down Expand Up @@ -1048,8 +1051,6 @@ def start_server_background(context):
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
server_listen_addr = context.server_fqdn
if os.name == 'nt':
server_listen_addr = '0.0.0.0'
server_args = [
'--host', server_listen_addr,
'--port', context.server_port,
Expand Down Expand Up @@ -1088,7 +1089,7 @@ def start_server_background(context):
server_args.append('--verbose')
if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
server_args.extend(['--log-format', "text"])
print(f"starting server with: {context.server_path} {server_args}\n")
print(f"starting server with: {context.server_path} {server_args}")
flags = 0
if 'nt' == os.name:
flags |= subprocess.DETACHED_PROCESS
Expand Down
1 change: 0 additions & 1 deletion examples/server/tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ behave~=1.2.6
huggingface_hub~=0.20.3
numpy~=1.24.4
openai~=0.25.0
psutil~=5.9.8
prometheus-client~=0.20.0
Loading