Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hack #9

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 119 additions & 33 deletions llm_bench/load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
import orjson
import threading

from typing import Iterator, Dict
import urllib.parse
import itertools
import requests

try:
import locust_plugins
except ImportError:
Expand Down Expand Up @@ -274,7 +279,7 @@ def format_payload(self, prompt, max_tokens, images):
data["logprobs"] = self.parsed_options.logprobs
return data

def parse_output_json(self, data, prompt):
def parse_output_json(self, data, prompt=None):
usage = data.get("usage", None)

assert len(data["choices"]) == 1, f"Too many choices {len(data['choices'])}"
Expand Down Expand Up @@ -436,6 +441,38 @@ def parse_output_json(self, data, prompt):
usage_tokens=None,
prompt_usage_tokens=None,
)

class State:
def __init__(self, opts):
self._requests = itertools.cycle(self._read_requests(opts.requests_dir))
self._opts = opts
# if opts.response_file:
# self._response_file = open(opts.response_file, "w")
# else:
# self._response_file = None

def next_request(self):
return next(self._requests)

def _read_requests(self, requests_dir: str) -> Iterator[Dict]:
if os.path.isdir(requests_dir):
files = [
os.path.join(requests_dir, file) for file in os.listdir(requests_dir)
]
else:
files = [requests_dir]
for file in files:
with open(file, "r") as f:
for line in f:
data = json.loads(line)
yield data

state: State = None
def get_state(opts) -> State:
global state
if state is None:
state = State(opts)
return state


class TgiProvider(BaseProvider):
Expand Down Expand Up @@ -579,6 +616,8 @@ def _guess_provider(self):
)

def _on_start(self):
self._global_state = get_state(self.environment.parsed_options)

self.client.headers["Content-Type"] = "application/json"
if self.environment.parsed_options.api_key:
self.client.headers["Authorization"] = (
Expand Down Expand Up @@ -659,42 +698,81 @@ def _on_start(self):
time.sleep(random.random())

self.first_done = False


def _get_input(self):
def _maybe_randomize(prompt):
if not self.environment.parsed_options.prompt_randomize:
return prompt

# single letters are single tokens
num_random_tokens = (len(prompt) - len(PROMPT_SUFFIX)) // len(
PROMPT_PREFIX_TOKEN
)
return (
" ".join(
chr(ord("a") + random.randint(0, 25))
for _ in range(num_random_tokens)
)
+ " "
+ prompt[-len(PROMPT_SUFFIX) :]
)
opts = self.environment.parsed_options
data = self._global_state.next_request()

if isinstance(self.input, str):
return _maybe_randomize(self.input), None
else:
item = self.input[random.randint(0, len(self.input) - 1)]
assert "prompt" in item
return _maybe_randomize(item["prompt"]), item.get("images", None)
body = data["body"]
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
body["model"] = self.model
for header, value in data.get("headers", {}).items():
header = header.lower()
if header.startswith("x-fireworks") or header == "x-request-id":
headers[header] = value
# if opts.model_name is not None:
# body["model"] = opts.model_name

if opts.api_key is not None:
headers["Authorization"] = "Bearer " + opts.api_key
# if not opts.prompt_cache:
# body["prompt_cache_max_len"] = 0
if opts.logprobs is not None:
body["logprobs"] = opts.logprobs
if opts.temperature is not None:
body["temperature"] = opts.temperature
return body

# url = urllib.parse.urlparse(data["url"])
# with self.client.post(
# url.path,
# data=json.dumps(body),
# stream=True,
# catch_response=True,
# headers=headers,
# ) as response:
# self._global_state.process_response(response)

# def _maybe_randomize(prompt):
# if not self.environment.parsed_options.prompt_randomize:
# return prompt

# # single letters are single tokens
# num_random_tokens = (len(prompt) - len(PROMPT_SUFFIX)) // len(
# PROMPT_PREFIX_TOKEN
# )
# return (
# " ".join(
# chr(ord("a") + random.randint(0, 25))
# for _ in range(num_random_tokens)
# )
# + " "
# + prompt[-len(PROMPT_SUFFIX) :]
# )

# if isinstance(self.input, str):
# return _maybe_randomize(self.input), None
# else:
# item = self.input[random.randint(0, len(self.input) - 1)]
# assert "prompt" in item
# return _maybe_randomize(item["prompt"]), item.get("images", None)

@task
def generate_text(self):
max_tokens = self.max_tokens_sampler.sample()
prompt, images = self._get_input()
data = self.provider_formatter.format_payload(prompt, max_tokens, images)
# max_tokens = self.max_tokens_sampler.sample()
max_tokens = 0
# prompt, images = self._get_input()
# data = self.provider_formatter.format_payload(prompt, max_tokens, images)
body = self._get_input()
t_start = time.perf_counter()

with self.client.post(
self.provider_formatter.get_url(),
data=json.dumps(data),
data=json.dumps(body),
stream=True,
catch_response=True,
) as response:
Expand Down Expand Up @@ -732,7 +810,7 @@ def generate_text(self):
done = True
continue
data = orjson.loads(chunk)
out = self.provider_formatter.parse_output_json(data, prompt)
out = self.provider_formatter.parse_output_json(data)
if out.usage_tokens:
total_usage_tokens = (
total_usage_tokens or 0
Expand All @@ -746,6 +824,7 @@ def generate_text(self):
total_logprob_tokens or 0
) + out.logprob_tokens
except Exception as e:

print(f"Failed to parse response: {chunk} with error {repr(e)}")
response.failure(e)
return
Expand Down Expand Up @@ -791,10 +870,10 @@ def generate_text(self):
add_custom_metric("time_to_first_token", dur_first_token * 1000)
add_custom_metric("total_latency", dur_total * 1000)
if num_tokens:
if num_tokens != max_tokens:
print(
f"WARNING: wrong number of tokens: {num_tokens}, expected {max_tokens}"
)
# if num_tokens != max_tokens:
# print(
# f"WARNING: wrong number of tokens: {num_tokens}, expected {max_tokens}"
# )
add_custom_metric("num_tokens", num_tokens)
add_custom_metric(
"latency_per_token", dur_generation / num_tokens * 1000, num_tokens
Expand Down Expand Up @@ -966,6 +1045,13 @@ def init_parser(parser):
default=0,
help="Maximum length of the prompt cache to use. Defaults to 0 (no caching).",
)
parser.add_argument(
"-d",
"--requests-dir",
type=str,
required=True,
help="A single jsonl file or a directory with several files",
)


@events.quitting.add_listener
Expand Down