Skip to content

Commit

Permalink
Refactor & re-enable HPU RoPE for Gaudi1 (#129)
Browse files Browse the repository at this point in the history
* Re-enable FusedRoPE for Gaudi1

* add fallback impl of rope
  • Loading branch information
kzawora-intel authored Jul 29, 2024
1 parent f7dc554 commit 19993b7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 77 deletions.
102 changes: 27 additions & 75 deletions vllm/hpu/rotary_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,80 +5,25 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

import habana_frameworks.torch.utils.experimental as htexp
import torch
import torch.nn as nn
from vllm.utils import is_hpu
from vllm.logger import init_logger

logger = init_logger(__name__)

def get_device_type():
return htexp._get_device_type()


def is_gaudi1():
return get_device_type() == htexp.synDeviceType.synDeviceGaudi


def is_gaudi2():
return get_device_type() == htexp.synDeviceType.synDeviceGaudi2


def is_gaudi3():
return get_device_type() == htexp.synDeviceType.synDeviceGaudi3


# TODO: remove this workaround when FusedRoPE properly works on Gaudi
if not is_gaudi1() and (is_gaudi2() or is_gaudi3()):
if is_hpu():
try:
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingHelperV1 as FusedRoPE)
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None
logger.warning("Could not import HPU FusedRoPE kernel. "
"vLLM will use forward_native implementation of RoPE.")
FusedRoPE = None
else:
FusedRoPE = None


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and
key tensors. For example, this can be used to pass offsetted
position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to
unsqueeze cos[position_ids] and sin[position_ids] so that they can
be properly broadcasted to the dimensions of q and k. For example,
note that cos[position_ids] and sin[position_ids] have the shape
[batch_size, seq_len, head_dim]. Then, if q and k have the shape
[batch_size, heads, seq_len, head_dim], then setting
unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids]
broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set
unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated
using the Rotary Position Embedding.
"""
cos = cos[position_ids] #.unsqueeze(unsqueeze_dim)
sin = sin[position_ids] #.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class HpuRotaryEmbedding(nn.Module):

def __init__(self,
Expand All @@ -87,7 +32,8 @@ def __init__(self,
max_position_embeddings=2048,
base=10000,
is_neox_style=None,
device='hpu'):
device='hpu',
RoPEFallback=None):
super().__init__()

self.head_size = head_size
Expand All @@ -102,6 +48,14 @@ def __init__(self,
self._set_cos_sin_cache(seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
if FusedRoPE is None:
assert RoPEFallback is not None, "HPU FusedRoPE kernel could not be imported, and fallback RoPE implementation was not provided!"
self.fallback_impl = RoPEFallback(head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype=torch.get_default_dtype())

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
Expand All @@ -122,6 +76,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):

def forward(self, positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor):
if FusedRoPE is None:
return self.fallback_impl(positions, query, key)
if query.dim() == 2:
query = query.unsqueeze(0)
if key.dim() == 2:
Expand All @@ -141,19 +97,15 @@ def forward(self, positions: torch.Tensor, query: torch.Tensor,
self.head_size))
key = key.reshape((key.shape[0], key.shape[1],
key.shape[2] // self.head_size, self.head_size))
if query.device.type == "hpu" and FusedRoPE:
if len(positions[0]) == 1:
cos = self.cos_cached[positions].unsqueeze(2).to(
dtype=query.dtype)
sin = self.sin_cached[positions].unsqueeze(2).to(
dtype=query.dtype)
else:
cos = cos[positions].unsqueeze(2)
sin = sin[positions].unsqueeze(2)
query, key = FusedRoPE.apply(query, cos, sin,
0), FusedRoPE.apply(key, cos, sin, 0)

if len(positions[0]) == 1:
cos = self.cos_cached[positions].unsqueeze(2).to(dtype=query.dtype)
sin = self.sin_cached[positions].unsqueeze(2).to(dtype=query.dtype)
else:
query, key = apply_rotary_pos_emb(query, key, cos, sin, positions)
cos = cos[positions].unsqueeze(2)
sin = sin[positions].unsqueeze(2)
query, key = FusedRoPE.apply(query, cos, sin,
0), FusedRoPE.apply(key, cos, sin, 0)
return query.reshape(
(query.shape[0], query.shape[1],
query.shape[2] * query.shape[3])), key.reshape(
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,8 +765,12 @@ def get_rope(
return _ROPE_DICT[key]
if rope_scaling is None:
if is_hpu():
rotary_emb = HpuRotaryEmbedding(head_size, rotary_dim,
max_position, base, is_neox_style)
rotary_emb = HpuRotaryEmbedding(head_size,
rotary_dim,
max_position,
base,
is_neox_style,
RoPEFallback=RotaryEmbedding)
else:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position,
base, is_neox_style, dtype)
Expand Down

0 comments on commit 19993b7

Please sign in to comment.