Skip to content

Commit

Permalink
CPU Pinned embedding Layer (intel#9538)
Browse files Browse the repository at this point in the history
* CPU Pinned embedding
  • Loading branch information
leonardozcm authored Nov 28, 2023
1 parent fe2def0 commit a1a5443
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions python/llm/src/bigdl/llm/transformers/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,40 @@

import torch
from torch import Tensor
from torch.nn import functional as F
from torch.nn import Parameter
from typing import Optional


# To prevent insufficient available memory when moving embedding from XPU back to CPU,
# we can pin the embedding to CPU if `cpu_embedding==True`.
class CPUPinnedParam(Parameter):
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device.type == 'xpu':
if convert_to_format is not None and self.dim() in (4, 5):
return super().to('cpu', dtype,
non_blocking, memory_format=convert_to_format)
return super().to('cpu', dtype, non_blocking)
return super().to(*args, **kwargs)


class LLMEmbedding(torch.nn.Embedding):
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
_freeze: bool = False,
device=None, dtype=None) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx,
max_norm, norm_type, scale_grad_by_freq, sparse,
_weight, device, dtype)
self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze)

def forward(self, x: Tensor):
if self.weight.device != 'cpu':
self.to('cpu')
torch.xpu.empty_cache()
return super().forward(x.to('cpu')).to(x.device)

0 comments on commit a1a5443

Please sign in to comment.