Skip to content

Commit

Permalink
Merge pull request huggingface#5 from kaixuanliu/ipex
Browse files Browse the repository at this point in the history
add hpu flashBert support
  • Loading branch information
yuanwu2017 authored Aug 29, 2024
2 parents 081ab41 + cbc3ee2 commit f7d1e1b
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Dockerfile-intel
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url
RUN cd backends/python/server && \
make install

FROM vault.habana.ai/gaudi-docker/1.16.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest AS hpu
FROM vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest AS hpu
ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

Expand Down
11 changes: 5 additions & 6 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_model(model_path: Path, dtype: Optional[str]) :
raise RuntimeError(f"Unknown dtype {dtype}")

device = get_device()
logger.info(f"backend device: {device}")
config = AutoConfig.from_pretrained(model_path)
if config.model_type == "bert":
config: BertConfig
Expand All @@ -48,14 +49,12 @@ def get_model(model_path: Path, dtype: Optional[str]) :
):
return FlashBert(model_path, device, datatype) # type: ignore
if use_ipex() and device.type in ["cpu", "xpu"]:
import intel_extension_for_pytorch as ipex
return FlashBert(model_path, device, datatype) # type: ignore
if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
adapt_transformers_to_gaudi()
model_handle = DefaultModel(model_path, device, datatype)
model_handle.model = wrap_in_hpu_graph(model_handle.model, disable_tensor_cache=True)
return model_handle
import habana_frameworks.torch.core as htcore
return FlashBert(model_path, device, datatype)

return DefaultModel(model_path, device, datatype)
else:
return DefaultModel(model_path, device, datatype)
39 changes: 33 additions & 6 deletions backends/python/server/text_embeddings_server/models/flash_bert.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,43 @@
import torch

from pathlib import Path
from torch import nn
import torch.nn.functional as F
from typing import Type, List
from safetensors import safe_open
from transformers.activations import ACT2FN
from transformers.models.bert import BertConfig
from opentelemetry import trace


from text_embeddings_server.models import Model
from text_embeddings_server.models.types import FlashBatch, Embedding
from text_embeddings_server.utils.flash_attn import attention
from text_embeddings_server.utils.device import use_ipex

tracer = trace.get_tracer(__name__)

def hpu_add_layer_norm(
add: torch.Tensor,
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
epsilon: float,
add_back: bool
):
if add is not None:
added_tensor = torch.add(add, x, alpha=1.0)
output = F.layer_norm(added_tensor, [x.size(-1)], weight, bias, epsilon)
if add_back:
add.add_(x)
return output
else:
return F.layer_norm(x, [x.size(-1)], weight=weight, bias=bias, eps=epsilon)

class FastLayerNorm:
def __init__(self, prefix, handle, device, dtype, config: BertConfig):
self.weight = handle.get_tensor(f"{prefix}.weight").to(dtype).to(device)
self.bias = handle.get_tensor(f"{prefix}.bias").to(dtype).to(device)
self.variance_epsilon = config.layer_norm_eps
self.device = device
self.use_ipex = use_ipex()

def forward(self, hidden_states, residual=None):
# Flash attention imports
Expand All @@ -48,7 +64,7 @@ def forward(self, hidden_states, residual=None):
)
if res is None:
res = hidden_states
elif use_ipex():
elif self.use_ipex:
import intel_extension_for_pytorch as ipex
normed_hidden_states = ipex.llm.functional.add_layer_norm(
residual,
Expand All @@ -60,7 +76,16 @@ def forward(self, hidden_states, residual=None):
)

res = residual if residual is not None else hidden_states

elif self.device.type == "hpu":
normed_hidden_states = hpu_add_layer_norm(
residual,
hidden_states,
self.weight,
self.bias,
self.variance_epsilon,
residual is not None
)
res = residual if residual is not None else hidden_states
return normed_hidden_states, res


Expand Down Expand Up @@ -242,7 +267,9 @@ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
config = BertConfig.from_pretrained(model_path)
with safe_open(model_path / "model.safetensors", framework="pt") as f:
model = FlashBertModel(f, device, dtype, config)

if device.type == "hpu":
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
model = wrap_in_hpu_graph(model, disable_tensor_cache=False)
self.hidden_size = config.hidden_size

super(FlashBert, self).__init__(model=model, dtype=dtype, device=device)
Expand Down
6 changes: 3 additions & 3 deletions backends/python/server/text_embeddings_server/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def get_major_and_minor_from_version(full_version):
return False
return True

def _is_hpu() -> bool:
def is_hpu() -> bool:
is_hpu_available = True
try:
subprocess.run(["hl-smi"], capture_output=True, check=True)
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
except:
is_hpu_available = False
return is_hpu_available

Expand All @@ -43,7 +43,7 @@ def get_device() :
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
elif _is_hpu():
elif is_hpu():
import habana_frameworks.torch.core as htcore
if hasattr(torch, "hpu") and torch.hpu.is_available(): # type: ignore
device = torch.device("hpu")
Expand Down
78 changes: 72 additions & 6 deletions backends/python/server/text_embeddings_server/utils/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import torch
from text_embeddings_server.utils.device import use_ipex
from text_embeddings_server.utils.device import use_ipex, is_hpu

from loguru import logger

Expand All @@ -10,7 +10,10 @@
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2 = False

if use_ipex():
is_hpu = is_hpu()
use_ipex = use_ipex()

if use_ipex or is_hpu:
HAS_FLASH_ATTN_V2 = True
else:
if not torch.cuda.is_available():
Expand Down Expand Up @@ -54,14 +57,77 @@
HAS_FLASH_ATTN = True


def hpu_attn(q, k, v, out, seqlen_q, seqlen_k, max_seqlen_q, max_seqlen_k, softmax_scale, is_causal=False):
from habana_frameworks.torch.hpex.kernels import FusedSDPA
total_q, num_head, head_size = q.size()
total_k, num_head_k, _ = k.size()
batch_size = seqlen_q.size(0) - 1
seqlen_q_ = seqlen_q.clone()
seqlen_q_[:batch_size] = seqlen_q[1:]
seqlen_q = (seqlen_q_ - seqlen_q)[:batch_size]
seqlen_k_ = seqlen_k.clone()
seqlen_k_[:batch_size] = seqlen_k[1:]
seqlen_k = (seqlen_k_ - seqlen_k)[:batch_size]

pad_q = torch.zeros(
[batch_size, max_seqlen_q, num_head, head_size],
dtype=q.dtype,
device=q.device,
)
pad_k = torch.zeros(
[batch_size, max_seqlen_k, num_head_k, head_size],
dtype=k.dtype,
device=k.device,
)
pad_v = torch.zeros(
[batch_size, max_seqlen_k, num_head_k, head_size],
dtype=v.dtype,
device=v.device,
)
q_mask = torch.arange(0, max_seqlen_q, device=q.device)[None, :].repeat(
batch_size, 1
)
q_mask = q_mask < seqlen_q[:, None].repeat(1, q_mask.size(-1))
k_mask = torch.arange(0, max_seqlen_k, device=k.device)[None, :].repeat(
batch_size, 1
)
k_mask = k_mask < seqlen_k[:, None].repeat(1, k_mask.size(-1))
align_mask_seqlen = max_seqlen_k
attn_mask = torch.empty(
[batch_size, 1, 1, align_mask_seqlen],
dtype=q.dtype,
device=q.device,
).fill_(float("-inf"))
attn_mask[:, :, :, :max_seqlen_k].masked_fill_(k_mask[:, None, None, :], 0)

pad_q[q_mask] = q
pad_k[k_mask] = k
pad_v[k_mask] = v

pad_q = pad_q.permute(0, 2, 1, 3)
pad_k = pad_k.permute(0, 2, 1, 3)
pad_v = pad_v.permute(0, 2, 1, 3)
if is_causal:
attn_mask = None

out_ = FusedSDPA.apply(pad_q, pad_k, pad_v, attn_mask, 0.0, is_causal, softmax_scale)
out_ = out_.permute(0, 2, 1, 3)
out.copy_(out_[q_mask])
return out


def attention(q, k, v, out, cu_seqlens, max_s, softmax_scale, is_causal=False):
if HAS_FLASH_ATTN_V2:
if use_ipex():
if use_ipex:
import intel_extension_for_pytorch as ipex
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens,
max_s, max_s, 0, softmax_scale,
zero_tensors=False, is_causal=False,
return ipex.llm.functional.varlen_attention(q, k, v, out, cu_seqlens, cu_seqlens,
max_s, max_s, 0, softmax_scale,
zero_tensors=False, is_causal=False,
return_softmax=False, gen_=None)
elif is_hpu:
return hpu_attn(q, k, v, out, cu_seqlens, cu_seqlens,
max_s, max_s, softmax_scale, is_causal=False)

else:
return flash_attn_2_cuda.varlen_fwd(
q,
Expand Down

0 comments on commit f7d1e1b

Please sign in to comment.