Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge code #1

Merged
merged 17 commits into from
Mar 10, 2024
232 changes: 156 additions & 76 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512

from timber.models.timber_attention.attention1_block_gpu import paged_timber_attention

from timber.models.timber_attention.attention1_block_gpu import (
paged_timber_attention,
timber_attention
)
from vllm.transformers_utils import config as vllm_transformers_config
from timber.utils import get_bench
BENCHMARK_ITERATION = 0

class PagedAttention(nn.Module):
Expand All @@ -44,6 +48,7 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
layer_index: Optional[int] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
Expand All @@ -61,6 +66,8 @@ def __init__(
if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

self.layer_index = layer_index

def forward(
self,
Expand Down Expand Up @@ -106,88 +113,160 @@ def forward(
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)

hip_k = int(os.environ.get('HIP_K', '1024'))

if input_metadata.is_prompt:
# Prompt run.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :, None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
BENCHMARK_PROMPT_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'
backend = os.environ.get('PROMPT_ATTENTION_BACKEND', 'vllm')
is_normal_attention = (key_cache is None) or (value_cache is None) or (input_metadata.block_tables.numel() == 0)
if backend == 'vllm':
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(
query.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1],
)
key = key[:, :, None, :]\
.expand(
key.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1]
)
value = value[:, :, None, :]\
.expand(
value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1]
)
# normal attention
if is_normal_attention:
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)

# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))

# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if BENCHMARK_PROMPT_ATTENTION:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)

if BENCHMARK_PROMPT_ATTENTION:
end.record()
torch.cuda.synchronize()
print(backend, start.elapsed_time(end), output.shape, end='\n')
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))

out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)
elif backend == 'timber':
# timber support MQA/GQA
warnings.warn('prompt attention backend is timber')

TDST, H, HID = query.shape
TSRC, H_KV, _HID = key.shape
assert key.shape[:-1] == value.shape[:-1]
assert HID == _HID

query = query.permute(1, 0, 2)
key = key.permute(1, 0, 2)
value = value.permute(1, 0, 2)

if BENCHMARK_PROMPT_ATTENTION:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

assert input_metadata.attn_bias is None
assert self.alibi_slopes is None

output, _ = timber_attention(
q=query * self.scale,
k=key,
v=value,
attention_mask=None,
mask_k=hip_k,
block_size_q=32,
block_size_k=2,
)
output = out.view_as(query)

output = output.permute(1, 0, 2)
output = output.view(
1,
TDST,
H,
HID,
).contiguous()

if BENCHMARK_PROMPT_ATTENTION:
end.record()
torch.cuda.synchronize()
print(backend, start.elapsed_time(end), output.shape, end='\n')
else:
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)

raise Exception(backend)
else:
# Decoding run.
BENCHMARK_PAGED_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'

# print(f'[{os.getpid()}, {self.layer_index}] query_size: {query.shape}, block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')

if BENCHMARK_PAGED_ATTENTION:
warnings.warn(f'query_size: {query.shape}, block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')
warnings.warn(f'query_size: {query.shape}({query.dtype}), block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
Expand All @@ -203,9 +282,9 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
)
elif backend == 'timber':
warnings.warn('backend is timber')
warnings.warn('paged attention backend is timber')

output, _ = paged_timber_attention(
q=query,
Expand All @@ -216,9 +295,9 @@ def forward(
context_lens=input_metadata.context_lens,
max_context_len=input_metadata.max_context_len,
attention_mask=None,
mask_k=1024,
mask_k=hip_k,
block_size_q=32,
block_size_k=2,
block_size_q=16
)

N_H, _, HID = output.shape
Expand All @@ -243,11 +322,12 @@ def forward(
"alibi_slopes": self.alibi_slopes,
"output": output,
}, 'cache/llama/vllmout.pth')
print('saved cache/llama/vllmout.pth')

if BENCHMARK_PAGED_ATTENTION:
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
print(f'({backend}) {start.elapsed_time(end)}', end='\r')

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
dtype=torch.float,
device=pos_freqs.device
)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,15 @@ def load_weights(self,
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
# print('vllm.load_weight: ignore', weight_name)
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
print('vllm.load_weight: ignore', name)
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -355,6 +359,9 @@ def load_weights(self,
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
print('vllm.load_weight: ignore', name)
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
7 changes: 4 additions & 3 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
}

# NOTE: For benchmarking
FORCE_SIGNLE_LAYER = False
FORCE_SIGNLE_LAYER = 0

def get_config(
model: str,
Expand Down Expand Up @@ -44,7 +44,8 @@ def get_config(
config = config_class.from_pretrained(model, revision=revision)

# NOTE: DEBUG
if FORCE_SIGNLE_LAYER:
config.num_hidden_layers = 1
if FORCE_SIGNLE_LAYER > 0:
assert isinstance(FORCE_SIGNLE_LAYER, int)
config.num_hidden_layers = FORCE_SIGNLE_LAYER

return config
Loading
Loading