Skip to content

Commit

Permalink
clean and re-enable logging
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacRe committed Oct 1, 2024
1 parent 163b319 commit 17da8eb
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
7 changes: 6 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ def main(args: argparse.Namespace):
print(f"Max decoding batch: {max_decoding_batch}")
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
BENCHMARKER.summarize()
if args.latency_breakdown:
BENCHMARKER.summarize()

# Output JSON results if specified
if args.output_json:
Expand Down Expand Up @@ -749,6 +750,10 @@ def main(args: argparse.Namespace):
'--real-text',
action='store_true',
)
parser.add_argument(
'--latency-breakdown',
action='store_true',
)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
11 changes: 10 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ def __init__(
# avoiding frequent compression-preemption cycles.
self.lock_prefill = False

self.logged_decoding = False

@property
def next_cache_id(self):
return (self.cache_id + 1) % self.num_cache_iters
Expand Down Expand Up @@ -1129,7 +1131,14 @@ def _schedule_default(self) -> SchedulerOutputs:
preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out))
self.max_decoding_batch = max(self.max_decoding_batch, len(running_scheduled.decode_seq_groups))
# print(f"{len(self.running)}/{len(self.waiting)} (running/waiting) - {len(prefills.seq_groups)} prefill, {len(running_scheduled.decode_seq_groups)} decode")

if len(prefills.seq_groups) > 0:
self.logged_decoding = False

if not self.logged_decoding:
print(f"{len(self.running)}/{len(self.waiting)} (running/waiting) - {len(prefills.seq_groups)} prefill, {len(running_scheduled.decode_seq_groups)} decode")
if len(running_scheduled.decode_seq_groups) > 0:
self.logged_decoding = True

# There should be no prefill from running queue because this policy
# doesn't allow chunked prefills.
Expand Down
2 changes: 1 addition & 1 deletion vllm/kvcompress/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def _schedule_compression(self, seqs: List[Sequence], sampling_params: List[Samp
self.total_evicted_kvs.get(seq.seq_id, 0) + (torch.clamp(freed_block_count[seq.seq_id] * self.block_size - empty_slots, min=0)).sum().item()
)
seq_evicted_kvs = self.total_evicted_kvs[seq.seq_id]
# print(f'Seq {seq.seq_id} evicted {seq_evicted_kvs} KVs (~{seq_evicted_kvs / self.config.num_kv_heads / self.config.num_layers} tokens) so far')
print(f'Seq {seq.seq_id} evicted {seq_evicted_kvs} KVs (~{seq_evicted_kvs / self.config.num_kv_heads / self.config.num_layers} tokens) so far')

CHECKPOINTER.checkpoint('schedule_compression__cache_moves_indices', self.cache_move_indices)
CHECKPOINTER.checkpoint('schedule_compression__cache_moves_count', cache_moves_count)
Expand Down

0 comments on commit 17da8eb

Please sign in to comment.