Skip to content

Commit

Permalink
Fix logging of stop reason for streaming requests
Browse files Browse the repository at this point in the history
Also don't set/return seed or other (random) sampling params when in greedy mode.

Signed-off-by: Nick Hill <[email protected]>
  • Loading branch information
njhill committed Jun 13, 2024
1 parent 79b7364 commit 670ec70
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions vllm/entrypoints/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ async def GenerateStream(

resp_options = request.params.response

first = True
first_response = None
last_response = None
last_output_length = 0
last_token_count = 0
time_limit_reached = False
Expand All @@ -258,13 +258,13 @@ async def GenerateStream(
#TODO handle cancellation
async for result in result_generator:
last_engine_response = result
if first:
if first_response is None:
service_metrics.observe_queue_time(result)
first_response = self._convert_input_details(
result, resp_options, sampling_params,
GenerationResponse())
last_response = first_response
yield first_response
first = False

output = result.outputs[0]

Expand All @@ -273,23 +273,31 @@ async def GenerateStream(
time_limit_reached = True

# Convert output text and token_ids to deltas
yield self._convert_output(output, resp_options, max_is_tok_limit,
time_limit_reached, last_output_length,
last_token_count)
if time_limit_reached:
break
last_response = self._convert_output(output, resp_options,
max_is_tok_limit, time_limit_reached, last_output_length,
last_token_count)
yield last_response

last_output_length = len(output.text)
last_token_count = len(output.token_ids)
# Save full output for logging
full_output = output.text

if time_limit_reached:
break

# Edit up the first_response for logging purposes only
if first_response is None:
# We didn't output anything!
return

# Log and record metrics
assert last_response is not None
first_response.text = full_output
first_response.generated_token_count = last_token_count
first_response.stop_reason = last_response.stop_reason
first_response.stop_sequence = last_response.stop_sequence
first_response.generated_token_count = (
last_response.generated_token_count)
logs.log_response(request=request, response=first_response,
start_time=start_time,
engine_metrics=last_engine_response.metrics
Expand Down Expand Up @@ -427,18 +435,24 @@ async def _validate_and_convert_params(
deadline = time.time(
) + time_limit_millis / 1000.0 if time_limit_millis > 0 else None

random_sampling_params: Dict[str, Any]
if greedy:
random_sampling_params = {"temperature": 0.0}
else:
random_sampling_params = {
"temperature": with_default(sampling.temperature, 1.0),
"top_k": with_default(sampling.top_k, -1),
"top_p": with_default(sampling.top_p, 1.0),
"seed": sampling.seed if sampling.HasField("seed") else None,
}

try:
sampling_params = SamplingParams(
logprobs=logprobs,
prompt_logprobs=logprobs
if resp_options.input_tokens else None,
max_tokens=max_new_tokens,
min_tokens=min_new_tokens,
temperature=with_default(sampling.temperature, 1.0)
if not greedy else 0.0,
top_k=with_default(sampling.top_k, -1),
top_p=with_default(sampling.top_p, 1.0),
seed=sampling.seed if sampling.HasField("seed") else None,
repetition_penalty=with_default(
decoding.repetition_penalty, 1.0),
logits_processors=logits_processors,
Expand All @@ -447,6 +461,7 @@ async def _validate_and_convert_params(
if stopping.HasField("include_stop_sequence") else
self.default_include_stop_seqs,
skip_special_tokens=self.skip_special_tokens,
**random_sampling_params
)
except ValueError as vllm_validation_error:
# There may be validation cases caught by vLLM that are not covered
Expand Down

0 comments on commit 670ec70

Please sign in to comment.