diff --git a/python/llm/src/ipex_llm/transformers/embedding.py b/python/llm/src/ipex_llm/transformers/embedding.py index 0bc1553db27..2a8a23fb11c 100644 --- a/python/llm/src/ipex_llm/transformers/embedding.py +++ b/python/llm/src/ipex_llm/transformers/embedding.py @@ -15,6 +15,7 @@ # +import numpy import torch from torch import Tensor from torch.nn import functional as F @@ -68,14 +69,56 @@ def __init__(self, _freeze: bool = False, device=None, dtype=None) -> None: super().__init__(num_embeddings, embedding_dim, padding_idx, - max_norm, norm_type, scale_grad_by_freq, sparse, - _weight, device, dtype) + max_norm, norm_type, scale_grad_by_freq, + sparse, _weight, _freeze, device, dtype) self.weight = CPUPinnedParam(self.weight.data, requires_grad=not _freeze) def forward(self, x: Tensor): return super().forward(x.to('cpu')).to(x.device) +class DiskEmbedding(torch.nn.Embedding): + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2., + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + _freeze: bool = False, + device=None, dtype=None) -> None: + super().__init__(num_embeddings, embedding_dim, padding_idx, + max_norm, norm_type, scale_grad_by_freq, + sparse, _weight, _freeze, device, dtype) + self.filename = "embeddings.bin" + self.weight.data.flatten().half().numpy().tofile(self.filename) + dummy_weight = torch.empty(0, 0, dtype=self.weight.dtype, device=self.weight.device) + self.weight = torch.nn.Parameter(dummy_weight, requires_grad=False) + + def forward(self, input_ids: Tensor): + ids = input_ids.cpu().flatten() + + embeds = [] + with open(self.filename, 'rb') as f: + for idx in ids: + f.seek(idx * self.embedding_dim * 2) + buffer = f.read(self.embedding_dim * 2) + embeds.append(torch.frombuffer(buffer, dtype=torch.half)) + embeds = torch.stack(embeds).to(device=input_ids.device, dtype=self.weight.dtype) + return embeds.view(*input_ids.size(), self.embedding_dim) + + def restore(self): + with open(self.filename, 'rb') as f: + buffer = f.read() + embeds = torch.frombuffer(buffer, dtype=torch.half).clone() + embeds = embeds.view(self.num_embeddings, self.embedding_dim).to( + device=self.weight.device, dtype=self.weight.dtype + ) + self.weight = torch.nn.Parameter(embeds, requires_grad=False) + + class LowBitEmbedding(torch.nn.Embedding): def __init__(self, num_embeddings: int,