From 7223fb8cb3340a7e4d26af30f5cb41c15483fbf0 Mon Sep 17 00:00:00 2001 From: yan ma Date: Sun, 26 Jan 2025 17:43:01 +0800 Subject: [PATCH] revert --- vllm/model_executor/models/mllama.py | 8 +- vllm/worker/hpu_enc_dec_model_runner.py | 211 ++++++++++++++++-------- 2 files changed, 143 insertions(+), 76 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index f15318fbd98ff..0127668c731bd 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -988,10 +988,12 @@ def forward( # to 2D tensor to align with public vllm input_tokens shape. But this # will face the graph building failure issue, still need to investigate. assert len(residual.shape) == 3 - if len(hidden_states.shape)==2: - hidden_states = hidden_states.view(residual.size(0), residual.size(1), residual.size(2)) + if len(hidden_states.shape) == 2: + hidden_states = hidden_states.view(residual.size(0), + residual.size(1), + residual.size(2)) full_text_row_masked_out_mask = full_text_row_masked_out_mask.view( - hidden_states.size(0), -1, 1) + hidden_states.size(0), -1, 1) hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh( ) * hidden_states diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 6001178e1ea4a..66ede172cf694 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -8,6 +8,7 @@ import habana_frameworks.torch as htorch import torch +from PIL import Image from vllm_hpu_extension.ops import batch2block, block2batch from vllm.attention import AttentionMetadata @@ -20,7 +21,7 @@ from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceData, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import is_fake_hpu +from vllm.utils import is_fake_hpu, is_list_of from vllm.worker.hpu_model_runner import (HpuModelAdapter, HPUModelRunnerBase, ModelInputForHPUWithSamplingMetadata, setup_profiler, subtuple) @@ -357,85 +358,149 @@ def _prepare_encoder_model_input_tensors( return attn_metadata + @torch.inference_mode() def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - # Workaround to avoid unexpeced OOM failure during profile run - max_num_seqs = int(self.scheduler_config.max_num_seqs/2) - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - logger.info("Starting profile run for multi-modal models.") - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - decoder_dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=False) - encoder_dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry, - is_encoder_data=True) - - # Having more tokens is over-conservative but otherwise fine - assert len( - decoder_dummy_data.seq_data.prompt_token_ids - ) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(decoder_dummy_data.seq_data.prompt_token_ids)}" - ) - - assert decoder_dummy_data.multi_modal_data is None or \ - encoder_dummy_data.multi_modal_data is None, ( - "Multi-modal data can't be provided in both encoder and decoder" - ) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: decoder_dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - encoder_seq_data=encoder_dummy_data.seq_data, - cross_block_table=None, - multi_modal_data=decoder_dummy_data.multi_modal_data - or encoder_dummy_data.multi_modal_data, - multi_modal_placeholders=decoder_dummy_data. - multi_modal_placeholders - or encoder_dummy_data.multi_modal_placeholders) - seqs.append(seq) - - # Run the model with the dummy inputs. num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). kv_caches = [ - torch.tensor([], dtype=torch.float32, device=self.device) + torch.tensor([], dtype=torch.bfloat16, device=self.device) for _ in range(num_layers) ] - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - self.execute_model(model_input, kv_caches, intermediate_tensors) + max_batch_size = self.max_num_prefill_seqs + _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() + max_seq_len = min(self.max_num_batched_tokens // max_batch_size, + max_seq_len) + + self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, + False) + return + + def warmup_scenario(self, + batch_size, + seq_len, + is_prompt, + kv_caches, + is_pt_profiler_run=False, + is_lora_profile_run=False, + temperature=0) -> None: + use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) + scenario_name = ("warmup_" + f"{'prompt' if is_prompt else 'decode'}_" + f"bs{batch_size}_" + f"seq{seq_len}_" + f"graphs{'T' if use_graphs else 'F'}") + self.profiler.start('internal', scenario_name) + times = 3 if use_graphs or is_pt_profiler_run else 1 + if is_prompt: + seqs = [ + self.create_dummy_seq_group_metadata(i, seq_len, is_prompt) + for i in range(batch_size) + ] + else: + # FIXME: seq_len is actually number of blocks + blocks = [seq_len // batch_size for _ in range(batch_size)] + blocks[0] += seq_len % batch_size + seqs = [ + self.create_dummy_seq_group_metadata(i, + b * self.block_size - 1, + is_prompt) + for i, b in enumerate(blocks) + ] torch.hpu.synchronize() + profiler = None + if is_pt_profiler_run and self.is_driver_worker: + profiler = setup_profiler() + profiler.start() + for _ in range(times): + inputs = self.prepare_model_input(seqs) + is_single_step = \ + self.vllm_config.scheduler_config.num_scheduler_steps == 1 + if is_prompt or is_single_step: + self.execute_model(inputs, kv_caches, warmup_mode=True) + else: # decode with multi-step + inputs = dataclasses.replace(inputs, + is_first_multi_step=True, + is_last_step=False) + self.execute_model(inputs, + kv_caches, + warmup_mode=True, + num_steps=2, + seqs=seqs) + inputs = dataclasses.replace(inputs, + is_first_multi_step=False, + is_last_step=True) + self.execute_model(inputs, + kv_caches, + warmup_mode=True, + num_steps=2, + seqs=seqs) + torch.hpu.synchronize() + if profiler: + profiler.step() + if profiler: + profiler.stop() + self.profiler.end() gc.collect() - return + + def create_dummy_seq_group_metadata(self, + group_id, + seq_len, + is_prompt, + lora_request=None, + temperature=0): + sampling_params = SamplingParams(temperature=temperature) + num_blocks = math.ceil(seq_len / self.block_size) + cross_block_table: Optional[List[int]] = None + encoder_dummy_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) + mm_counts = self.mm_registry.get_mm_limits_per_prompt( + self.model_config) + num_images = mm_counts["image"] + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) * num_images + seq_len = max(seq_len, 1) + if is_prompt: + input_len = seq_len + output_len = 0 + block_tables = None + cross_block_table = None + else: + input_len = seq_len - 1 + output_len = 1 + block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} + # limit cross blocks to the number of available blocks + num_cross_blocks = min(self.bucketing_ctx.num_hpu_blocks, + max_mm_tokens) // self.block_size + cross_block_table = [_PAD_BLOCK_ID] * num_cross_blocks + prompt_token_ids = [0] * input_len + if is_prompt: + image_data = encoder_dummy_data.multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_data = [image_data] + assert is_list_of(image_data, Image.Image) + text_prompt_len = input_len - 2 - len(image_data) + prompt_token_ids = [128000] + [128256] * len(image_data) + [ + 128000 + ] + [0] * text_prompt_len + output_token_ids = [1] * output_len + prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 + seq_data = SequenceData(prompt_token_ids_array) + seq_data.output_token_ids = output_token_ids + + return SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=is_prompt, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + encoder_seq_data=encoder_dummy_data.seq_data, + multi_modal_data=encoder_dummy_data.multi_modal_data, + multi_modal_placeholders=encoder_dummy_data. + multi_modal_placeholders, + cross_block_table=cross_block_table) def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # NOTE(kzawora): To anyone working on this in the future: