From a1a544337e4414b902ef162857922c352f5e9fd7 Mon Sep 17 00:00:00 2001 From: Zhao Changmin Date: Tue, 28 Nov 2023 09:46:31 +0800 Subject: [PATCH] CPU Pinned embedding Layer (#9538) * CPU Pinned embedding --- .../src/bigdl/llm/transformers/embedding.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/embedding.py b/python/llm/src/bigdl/llm/transformers/embedding.py index aa99e2d7f83..a6fc5589a0d 100644 --- a/python/llm/src/bigdl/llm/transformers/embedding.py +++ b/python/llm/src/bigdl/llm/transformers/embedding.py @@ -17,11 +17,40 @@ import torch from torch import Tensor +from torch.nn import functional as F +from torch.nn import Parameter +from typing import Optional + + +# To prevent insufficient available memory when moving embedding from XPU back to CPU, +# we can pin the embedding to CPU if `cpu_embedding==True`. +class CPUPinnedParam(Parameter): + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + if device.type == 'xpu': + if convert_to_format is not None and self.dim() in (4, 5): + return super().to('cpu', dtype, + non_blocking, memory_format=convert_to_format) + return super().to('cpu', dtype, non_blocking) + return super().to(*args, **kwargs) class LLMEmbedding(torch.nn.Embedding): + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2., + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + _freeze: bool = False, + device=None, dtype=None) -> None: + super().__init__(num_embeddings, embedding_dim, padding_idx, + max_norm, norm_type, scale_grad_by_freq, sparse, + _weight, device, dtype) + self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze) + def forward(self, x: Tensor): - if self.weight.device != 'cpu': - self.to('cpu') - torch.xpu.empty_cache() return super().forward(x.to('cpu')).to(x.device)