Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exponential bucketing integration #642

Open
wants to merge 10 commits into
base: habana_main
Choose a base branch
from
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@d4f37bb
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@8ce8112
30 changes: 12 additions & 18 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,19 @@ def _generic_padding_fn(self, batch_size, max_seq_len) -> int:
return batch_size * max_seq_len

def _hpu_padding_fn(self, batch_size, max_seq_len):
from vllm_hpu_extension.bucketing import (HPUBucketingGlobalState,
find_bucket)
padded_bs = batch_size
padded_seq = max_seq_len

hpu_bucketing_global_state = HPUBucketingGlobalState()

bs_cfg = hpu_bucketing_global_state.prompt_bs_bucket_cfg
if bs_cfg is not None:
padded_bs = find_bucket(batch_size, bs_cfg)
use_exponential_bucketing = os.environ.get(
'VLLM_EXPONENTIAL_BUCKETING', 'true').lower() == 'true'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall this variable be defined in manual?

if use_exponential_bucketing:
from vllm_hpu_extension.bucketing.exponential import (
HPUExponentialBucketingContext as HPUBucketingContext)
else:
logger.warning(
"prompt_bs_bucket_cfg was not set! Using unpadded batch size.")
seq_cfg = hpu_bucketing_global_state.prompt_seq_bucket_cfg
if seq_cfg is not None:
padded_seq = find_bucket(max_seq_len, seq_cfg)
else:
logger.warning("prompt_seq_bucket_cfg was not set! "
"Using unpadded sequence length.")
from vllm_hpu_extension.bucketing.linear import HPUBucketingContext

hpu_bucketing_context = HPUBucketingContext.get_instance()
padded_bs = hpu_bucketing_context.get_padded_prompt_batch_size(
batch_size)
padded_seq = hpu_bucketing_context.get_padded_prompt_seq_len(
max_seq_len)
return padded_bs * padded_seq

def _padding_fn_selector(self):
Expand Down
22 changes: 16 additions & 6 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import habana_frameworks.torch.internal.bridge_config as bc
import torch
import vllm_hpu_extension.environment as environment
from vllm_hpu_extension.bucketing import HPUBucketingContext
from vllm_hpu_extension.flags import enabled_flags
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.ops import batch2block, block2batch
Expand Down Expand Up @@ -654,10 +653,19 @@ def __init__(
self.profiler_counter_helper = HabanaProfilerCounterHelper()
self.seen_configs: set = set()
self._mem_margin: Optional[int] = None
self.use_exponential_bucketing = os.environ.get(
'VLLM_EXPONENTIAL_BUCKETING', 'true').lower() == 'true'
if self.use_exponential_bucketing:
from vllm_hpu_extension.bucketing.exponential import (
HPUExponentialBucketingContext as HPUBucketingContext)
else:
from vllm_hpu_extension.bucketing.linear import HPUBucketingContext

self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs,
self.max_num_prefill_seqs,
self.block_size,
self.max_num_batched_tokens)
self.max_num_batched_tokens,
self.max_model_len)
self.graphed_buckets: Set[Any] = set()

self._set_gc_threshold()
Expand Down Expand Up @@ -1583,7 +1591,9 @@ def profile_run(self) -> None:
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
max_batch_size = min(self.max_num_seqs,
self.max_num_batched_tokens // max_seq_len)

#NOTE(kzawora): this is nasty - we need to generate buckets for prompt
# ahead of time, without kv cache
self.bucketing_ctx.generate_prompt_buckets()
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
return
Expand Down Expand Up @@ -1802,6 +1812,9 @@ def log_graph_warmup_summary(self, buckets, is_prompt, total_mem):

@torch.inference_mode()
def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
max_blocks = kv_caches[0][0].size(0)
self.bucketing_ctx.generate_decode_buckets(max_blocks)

if profile := os.environ.get('VLLM_PT_PROFILE', None):
phase, bs, seq_len, graph = profile.split('_')
is_prompt = phase == 'prompt'
Expand All @@ -1811,9 +1824,6 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
True)
raise AssertionError("Finished profiling")
max_blocks = kv_caches[0][0].size(0)
self.bucketing_ctx.generate_prompt_buckets()
self.bucketing_ctx.generate_decode_buckets(max_blocks)
if not htorch.utils.internal.is_lazy() and not self.enforce_eager:
multiplier = 3 if os.getenv('VLLM_REGIONAL_COMPILATION',
'true').lower() == 'true' else 1
Expand Down
Loading