Skip to content

Commit

Permalink
Add const norm feature
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi committed Jan 22, 2025
1 parent 001aa55 commit 3ad6cf3
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
from vllm.logger import init_logger
import habana_frameworks.torch.core as htcore

logger = init_logger(__name__)

Expand Down Expand Up @@ -55,6 +56,62 @@ def prompt_fsdpa(
attn_weights = attn_weights.transpose(1, 2)
return attn_weights

const_norm = os.environ.get('VLLM_SOFTMAX_CONST_NORM', 'false').lower() == 'true'
const_val = os.environ.get('VLLM_SOFTMAX_CONST_VAL', '10.0')
def pa(attn, value, block_groups, block_mapping, batch2block_matmul_op, block2batch_matmul_op):
#normalization
attn.sub_(const_val)
# end of norm
attn = attn.exp()
sums = attn.sum(dim=-1).unsqueeze(-1)
block_sum = sums
# Sum block's sums that belongs to the same sequeneces
group_sums = ops.block2batch(sums, block_mapping, block2batch_matmul_op)
group_sums = ops.batch2block(group_sums, block_mapping, batch2block_matmul_op)
group_sums.add_(torch.finfo(group_sums.dtype).tiny)
group_sums = torch.maximum(block_sum, group_sums)
attn.div_(group_sums)
attn = ops.matmul_av_op(attn, value)
return attn

def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
block_bias, block_scales, block_groups, scale, matmul_qk_op,
matmul_av_op, batch2block_matmul_op, block2batch_matmul_op,
keys_fetch_func, values_fetch_func):
batch_size = query.size(0)
q_heads = query.size(1)
kv_heads = key_cache.size(2)

query = ops.batch2block(scale * query, block_mapping, batch2block_matmul_op).unsqueeze(-2)
key = keys_fetch_func(key_cache, block_list).transpose(1, 2)
value = values_fetch_func(value_cache, block_list).transpose(1, 2)
block_bias = block_bias.view(key.size(0), 1, 1, -1)
if kv_heads != q_heads:
block_bias = block_bias.unsqueeze(1)
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
key = key.transpose(3, 4)
else:
key = key.transpose(2, 3)

attn = matmul_qk_op(query, key)
if 'fp32_softmax' in enabled_flags():
attn = attn.float()
htcore.mark_step()
attn = attn + block_bias
if const_norm:
attn = pa(attn, value, block_groups, block_mapping, batch2block_matmul_op, block2batch_matmul_op,)
else:
attn = ops.pipelined_pa(attn, value, block_groups, block_mapping, block_scales=block_scales,
batch_size=batch_size, matmul_av_op=matmul_av_op,
batch2block_matmul_op=batch2block_matmul_op, block2batch_matmul_op=block2batch_matmul_op)
attn = ops.block2batch(attn, block_mapping, block2batch_matmul_op)
attn = attn.squeeze(-2)
if kv_heads != q_heads:
attn = attn.flatten(1, 2)
return attn


class HPUAttentionBackend(AttentionBackend):

Expand Down

0 comments on commit 3ad6cf3

Please sign in to comment.