Skip to content

Commit

Permalink
Merge pull request #2 from DeepAuto-AI/geon-dev
Browse files Browse the repository at this point in the history
Geon dev
  • Loading branch information
daniel-geon-park authored Mar 10, 2024
2 parents d9d746e + eb622cd commit e217585
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 63 deletions.
2 changes: 2 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(
self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state.
self.swapped: Deque[SequenceGroup] = deque()

logger.info(f'scheduler max_seq {self.scheduler_config.max_num_seqs}')

@property
def lora_enabled(self) -> bool:
Expand Down
51 changes: 31 additions & 20 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

self.layer_index = layer_index
assert layer_index is not None, 'layer index should be not none'
self.hip_dense_layers = list(range(int(os.environ.get('HIP_DENSE_LAYERS', '3'))))
self.hip_high_k_layers = {}

def forward(
self,
Expand Down Expand Up @@ -115,13 +118,24 @@ def forward(
)

hip_k = int(os.environ.get('HIP_K', '1024'))

benchmark_prompt_attention = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'
prompt_backend = os.environ.get('PROMPT_ATTENTION_BACKEND', 'timber')

benchmark_paged_attention = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'
paged_backend = os.environ.get('PAGED_ATTENTION_BACKEND', 'timber')

if self.layer_index in self.hip_dense_layers:
prompt_backend = 'vllm'
paged_backend = 'vllm'

if (paged_backend == 'timber' or prompt_backend == 'timber') and (self.layer_index in self.hip_high_k_layers):
hip_k *= self.hip_high_k_layers[self.layer_index]

if input_metadata.is_prompt:
# Prompt run.
BENCHMARK_PROMPT_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'
backend = os.environ.get('PROMPT_ATTENTION_BACKEND', 'vllm')
is_normal_attention = (key_cache is None) or (value_cache is None) or (input_metadata.block_tables.numel() == 0)
if backend == 'vllm':
if prompt_backend == 'vllm':
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
Expand Down Expand Up @@ -176,7 +190,7 @@ def forward(
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))

if BENCHMARK_PROMPT_ATTENTION:
if benchmark_prompt_attention:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
Expand All @@ -193,10 +207,10 @@ def forward(
)
output = out.view_as(query)

if BENCHMARK_PROMPT_ATTENTION:
if benchmark_prompt_attention:
end.record()
torch.cuda.synchronize()
print(backend, start.elapsed_time(end), output.shape, end='\n')
print(prompt_backend, start.elapsed_time(end), output.shape, end='\n')
else:
# prefix-enabled attention
output = torch.empty_like(query)
Expand All @@ -214,8 +228,7 @@ def forward(
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)
elif backend == 'timber':
# timber support MQA/GQA
elif prompt_backend == 'timber':
warnings.warn('prompt attention backend is timber')

TDST, H, HID = query.shape
Expand All @@ -227,12 +240,12 @@ def forward(
key = key.permute(1, 0, 2)
value = value.permute(1, 0, 2)

if BENCHMARK_PROMPT_ATTENTION:
if benchmark_prompt_attention:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

assert input_metadata.attn_bias is None
assert (input_metadata.attn_bias is None) or isinstance(input_metadata.attn_bias, BlockDiagonalCausalMask), f'{input_metadata.attn_bias}'
assert self.alibi_slopes is None

output, _ = timber_attention(
Expand All @@ -253,27 +266,25 @@ def forward(
HID,
).contiguous()

if BENCHMARK_PROMPT_ATTENTION:
if benchmark_prompt_attention:
end.record()
torch.cuda.synchronize()
print(backend, start.elapsed_time(end), output.shape, end='\n')
print(prompt_backend, start.elapsed_time(end), output.shape, end='\n')
else:
raise Exception(backend)
raise Exception(prompt_backend)
else:
# Decoding run.
BENCHMARK_PAGED_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'

# print(f'[{os.getpid()}, {self.layer_index}] query_size: {query.shape}, block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')

if BENCHMARK_PAGED_ATTENTION:
if benchmark_paged_attention:
warnings.warn(f'query_size: {query.shape}({query.dtype}), block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

backend = os.environ.get('PAGED_ATTENTION_BACKEND', 'vllm')
if backend == 'vllm':
if paged_backend == 'vllm':
output = _paged_attention(
query,
key_cache,
Expand All @@ -283,7 +294,7 @@ def forward(
self.scale,
self.alibi_slopes,
)
elif backend == 'timber':
elif paged_backend == 'timber':
warnings.warn('paged attention backend is timber')

output, _ = paged_timber_attention(
Expand Down Expand Up @@ -324,10 +335,10 @@ def forward(
}, 'cache/llama/vllmout.pth')
print('saved cache/llama/vllmout.pth')

if BENCHMARK_PAGED_ATTENTION:
if benchmark_paged_attention:
end.record()
torch.cuda.synchronize()
print(f'({backend}) {start.elapsed_time(end)}', end='\r')
print(f'({paged_backend}) {start.elapsed_time(end)}', end='\r')

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
layer_index: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -141,7 +142,8 @@ def __init__(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads
num_kv_heads=self.num_kv_heads,
layer_index=layer_index,
)

def forward(
Expand All @@ -166,6 +168,7 @@ def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
layer_index: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand All @@ -181,6 +184,7 @@ def __init__(
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
layer_index=layer_index,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
Expand Down Expand Up @@ -243,8 +247,8 @@ def __init__(
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
LlamaDecoderLayer(config, linear_method, layer_index=layer_index)
for layer_index in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
35 changes: 22 additions & 13 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class MixtralAttention(nn.Module):

def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None,
layer_index: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -202,6 +204,7 @@ def __init__(self,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
layer_index=layer_index,
)

def forward(
Expand All @@ -221,11 +224,11 @@ def forward(


class MixtralDecoderLayer(nn.Module):

def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
layer_index: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand All @@ -238,7 +241,9 @@ def __init__(
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method)
linear_method=linear_method,
layer_index=layer_index,
)
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
Expand Down Expand Up @@ -294,8 +299,12 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers)
MixtralDecoderLayer(
config,
linear_method=linear_method,
layer_index=layer_index
)
for layer_index in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
34 changes: 22 additions & 12 deletions vllm/model_executor/models/mixtral_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

class MixtralAttention(nn.Module):

def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None,
layer_index: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -220,6 +223,7 @@ def __init__(self,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
layer_index=layer_index,
)

def forward(
Expand All @@ -239,11 +243,11 @@ def forward(


class MixtralDecoderLayer(nn.Module):

def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
layer_index: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand All @@ -256,7 +260,9 @@ def __init__(
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method)
linear_method=linear_method,
layer_index=layer_index,
)
self.block_sparse_moe = MixtralMoE(config=config,
linear_method=linear_method)
self.input_layernorm = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -309,8 +315,12 @@ def __init__(
config.hidden_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers)
MixtralDecoderLayer(
config,
linear_method=linear_method,
layer_index=layer_index,
)
for layer_index in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
38 changes: 23 additions & 15 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,18 @@ def forward(self, x):

class Qwen2Attention(nn.Module):

def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
use_sliding_window: bool = False,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
use_sliding_window: bool = False,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None,
layer_index: Optional[int] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -135,11 +138,14 @@ def __init__(self,
max_position=max_position,
base=self.rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window)
self.attn = PagedAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
layer_index=layer_index,
)

def forward(
self,
Expand Down Expand Up @@ -178,7 +184,9 @@ def __init__(
rope_theta=rope_theta,
use_sliding_window=use_sliding_window,
linear_method=linear_method,
sliding_window=config.sliding_window)
sliding_window=config.sliding_window,
layer_index=layer_idx
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
Expand Down
Loading

0 comments on commit e217585

Please sign in to comment.