diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py b/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py index d001919c976..e6ec52acf8d 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/__init__.py @@ -20,7 +20,7 @@ # only search the first bigdl package and end up finding only one sub-package. from .bigdlllm import * -from .transformersembeddings import TransformersEmbeddings +from .transformersembeddings import TransformersEmbeddings, TransformersBgeEmbeddings __all__ = [ "BigdlNativeEmbeddings", @@ -28,5 +28,6 @@ "BloomEmbeddings", "GptneoxEmbeddings", "StarcoderEmbeddings", - "TransformersEmbeddings" + "TransformersEmbeddings", + "TransformersBgeEmbeddings" ] diff --git a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py index c52a8adf285..9c69f4744c3 100644 --- a/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py +++ b/python/llm/src/bigdl/llm/langchain/embeddings/transformersembeddings.py @@ -45,6 +45,7 @@ # THE SOFTWARE. """Wrapper around BigdlLLM embedding models.""" +import torch from typing import Any, Dict, List, Optional import numpy as np @@ -181,3 +182,14 @@ def embed_query(self, text: str) -> List[float]: text = text.replace("\n", " ") embedding = self.embed(text, **self.encode_kwargs) return embedding.tolist() + +# fit specific encode method for langchain.embeddings.HuggingFaceBgeEmbeddings +# TODO: directly support HuggingFaceBgeEmbeddings +class TransformersBgeEmbeddings(TransformersEmbeddings): + + def embed(self, text: str, **kwargs): + input_ids = self.tokenizer.encode(text, return_tensors="pt", **kwargs) + input_ids = input_ids.to(self.model.device) + embeddings = self.model(input_ids, return_dict=False)[0].cpu() + embeddings = torch.nn.functional.normalize(embeddings[:, 0], p=2, dim=1) + return embeddings[0]