From 38887bac766c701ebc66db6463f91a6ff4494ef8 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 24 Sep 2024 20:51:58 +0300 Subject: [PATCH 1/2] feat: VertexAI EF --- chromadbx/embeddings/google.py | 55 ++++++++++++++++++++++++++++++++++ test/embeddings/test_google.py | 40 +++++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 chromadbx/embeddings/google.py create mode 100644 test/embeddings/test_google.py diff --git a/chromadbx/embeddings/google.py b/chromadbx/embeddings/google.py new file mode 100644 index 0000000..54fdd78 --- /dev/null +++ b/chromadbx/embeddings/google.py @@ -0,0 +1,55 @@ +import logging +from typing import Optional, cast + +import numpy as np +import numpy.typing as npt +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings + +logger = logging.getLogger(__name__) + + +class GoogleVertexAiEmbeddings(EmbeddingFunction[Documents]): + def __init__( + self, + model_name: str = "text-embedding-004", + *, + project_id: Optional[str] = None, + location: Optional[str] = "us-central1", + dimensions: Optional[int] = 256, + task_type: Optional[str] = "RETRIEVAL_DOCUMENT", + ) -> None: + """ + Initialize the OnnxRuntimeEmbeddings. + + :param model_name: The name of the model to use. Defaults to "text-embedding-004". + :param project_id: The project ID to use. Defaults to None. + :param location: The location to use. Defaults to None. + :param dimensions: The number of dimensions to use. Defaults to None. + :param task_type: The task type to use. Defaults to "RETRIEVAL_DOCUMENT". https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types + + """ + try: + import vertexai + from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel + if project_id and location: + vertexai.init(project=project_id, location=location) + elif project_id and not location: + vertexai.init(project=project_id) + elif not project_id and location: + vertexai.init(location=location) + else: + vertexai.init() + self._model = TextEmbeddingModel.from_pretrained(model_name) + except ImportError: + raise ValueError( + "The vertexai python package is not installed. Please install it with `pip install vertexai`" + ) + self._dimensions = dimensions + self._task_type = task_type + + def __call__(self, input: Documents) -> Embeddings: + from vertexai.language_models import TextEmbeddingInput + inputs = [TextEmbeddingInput(text, self._task_type) for text in input] + kwargs = dict(output_dimensionality=self._dimensions) if self._dimensions else {} + embeddings = [embedding.values for embedding in self._model.get_embeddings(inputs,**kwargs)] + return cast(Embeddings, embeddings) diff --git a/test/embeddings/test_google.py b/test/embeddings/test_google.py new file mode 100644 index 0000000..2ab0b1e --- /dev/null +++ b/test/embeddings/test_google.py @@ -0,0 +1,40 @@ +from chromadbx.embeddings.google import GoogleVertexAiEmbeddings + + +def test_embed() -> None: + ef = GoogleVertexAiEmbeddings() + embeddings = ef(["hello world", "goodbye world"]) + assert len(embeddings) == 2 + assert len(embeddings[0]) == 256 + assert len(embeddings[1]) == 256 + assert embeddings[0] != embeddings[1] + +def test_with_model() -> None: + ef = GoogleVertexAiEmbeddings( + model_name="text-multilingual-embedding-002", + ) + embeddings = ef(["hello world", "goodbye world"]) + assert len(embeddings) == 2 + assert len(embeddings[0]) == 256 + assert len(embeddings[1]) == 256 + assert embeddings[0] != embeddings[1] + +def test_dimensions() -> None: + ef = GoogleVertexAiEmbeddings( + model_name="text-multilingual-embedding-002", + dimensions=768, + ) + embeddings = ef(["hello world", "goodbye world"]) + assert len(embeddings) == 2 + assert len(embeddings[0]) == 768 + assert len(embeddings[1]) == 768 + assert embeddings[0] != embeddings[1] + +def test_task_type() -> None: + ef = GoogleVertexAiEmbeddings( + task_type="RETRIEVAL_QUERY", + ) + embeddings = ef(["hello world", "goodbye world"]) + assert len(embeddings) == 2 + assert len(embeddings[0]) == 256 + assert len(embeddings[1]) == 256 \ No newline at end of file From b16af7f68562363ce5a55da77931fcea8351dd3b Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Tue, 24 Sep 2024 21:12:03 +0300 Subject: [PATCH 2/2] feat: Added additional auth methods + test --- chromadbx/embeddings/google.py | 45 ++++++++++++++++++++-------------- docs/embeddings.md | 32 ++++++++++++++++++++++++ test/embeddings/test_google.py | 24 +++++++++++++++++- 3 files changed, 82 insertions(+), 19 deletions(-) diff --git a/chromadbx/embeddings/google.py b/chromadbx/embeddings/google.py index 54fdd78..05ec02c 100644 --- a/chromadbx/embeddings/google.py +++ b/chromadbx/embeddings/google.py @@ -1,12 +1,7 @@ -import logging from typing import Optional, cast -import numpy as np -import numpy.typing as npt from chromadb.api.types import Documents, EmbeddingFunction, Embeddings -logger = logging.getLogger(__name__) - class GoogleVertexAiEmbeddings(EmbeddingFunction[Documents]): def __init__( @@ -14,9 +9,13 @@ def __init__( model_name: str = "text-embedding-004", *, project_id: Optional[str] = None, - location: Optional[str] = "us-central1", + location: Optional[str] = None, dimensions: Optional[int] = 256, task_type: Optional[str] = "RETRIEVAL_DOCUMENT", + credentials: Optional[any] = None, + api_key: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_transport: Optional[str] = None, ) -> None: """ Initialize the OnnxRuntimeEmbeddings. @@ -26,19 +25,23 @@ def __init__( :param location: The location to use. Defaults to None. :param dimensions: The number of dimensions to use. Defaults to None. :param task_type: The task type to use. Defaults to "RETRIEVAL_DOCUMENT". https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types - + :param credentials: The credentials to use. Defaults to None. + :param api_key: The API key to use. Defaults to None. + :param api_endpoint: The API endpoint to use. Defaults to None. + :param api_transport: The API transport to use. Defaults to None. """ try: import vertexai - from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel - if project_id and location: - vertexai.init(project=project_id, location=location) - elif project_id and not location: - vertexai.init(project=project_id) - elif not project_id and location: - vertexai.init(location=location) - else: - vertexai.init() + from vertexai.language_models import TextEmbeddingModel + + vertexai.init( + project=project_id, + location=location, + credentials=credentials, + api_key=api_key, + api_endpoint=api_endpoint, + api_transport=api_transport, + ) self._model = TextEmbeddingModel.from_pretrained(model_name) except ImportError: raise ValueError( @@ -49,7 +52,13 @@ def __init__( def __call__(self, input: Documents) -> Embeddings: from vertexai.language_models import TextEmbeddingInput + inputs = [TextEmbeddingInput(text, self._task_type) for text in input] - kwargs = dict(output_dimensionality=self._dimensions) if self._dimensions else {} - embeddings = [embedding.values for embedding in self._model.get_embeddings(inputs,**kwargs)] + kwargs = ( + dict(output_dimensionality=self._dimensions) if self._dimensions else {} + ) + embeddings = [ + embedding.values + for embedding in self._model.get_embeddings(inputs, **kwargs) + ] return cast(Embeddings, embeddings) diff --git a/docs/embeddings.md b/docs/embeddings.md index 4f57c1b..d7d12eb 100644 --- a/docs/embeddings.md +++ b/docs/embeddings.md @@ -95,3 +95,35 @@ col = client.get_or_create_collection("test", embedding_function=ef) col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"]) ``` + +## Google Vertex AI + +A convenient way to run Google Vertex AI models to generate embeddings. + +Google Vertex AI uses variety of authentication methods. The most secure is either service account key file or Google Application Default Credentials. + +```py +import chromadb +from chromadbx.embeddings.google import GoogleVertexAiEmbeddings + +ef = GoogleVertexAiEmbeddings() + +client = chromadb.Client() + +col = client.get_or_create_collection("test", embedding_function=ef) + +col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"]) +``` + +### Auth with service account key file + +```py +import chromadb +from chromadbx.embeddings.google import GoogleVertexAiEmbeddings +from google.oauth2 import service_account + +credentials = service_account.Credentials.from_service_account_file("path/to/service-account-key.json") +ef = GoogleVertexAiEmbeddings(credentials=credentials) + +ef(["hello world", "goodbye world"]) +``` diff --git a/test/embeddings/test_google.py b/test/embeddings/test_google.py index 2ab0b1e..92c3237 100644 --- a/test/embeddings/test_google.py +++ b/test/embeddings/test_google.py @@ -1,5 +1,9 @@ +import os +import pytest from chromadbx.embeddings.google import GoogleVertexAiEmbeddings +vai = pytest.importorskip("vertexai", reason="vertexai not installed") + def test_embed() -> None: ef = GoogleVertexAiEmbeddings() @@ -9,6 +13,7 @@ def test_embed() -> None: assert len(embeddings[1]) == 256 assert embeddings[0] != embeddings[1] + def test_with_model() -> None: ef = GoogleVertexAiEmbeddings( model_name="text-multilingual-embedding-002", @@ -19,6 +24,7 @@ def test_with_model() -> None: assert len(embeddings[1]) == 256 assert embeddings[0] != embeddings[1] + def test_dimensions() -> None: ef = GoogleVertexAiEmbeddings( model_name="text-multilingual-embedding-002", @@ -30,6 +36,7 @@ def test_dimensions() -> None: assert len(embeddings[1]) == 768 assert embeddings[0] != embeddings[1] + def test_task_type() -> None: ef = GoogleVertexAiEmbeddings( task_type="RETRIEVAL_QUERY", @@ -37,4 +44,19 @@ def test_task_type() -> None: embeddings = ef(["hello world", "goodbye world"]) assert len(embeddings) == 2 assert len(embeddings[0]) == 256 - assert len(embeddings[1]) == 256 \ No newline at end of file + assert len(embeddings[1]) == 256 + + +def test_credentials() -> None: + file_path = "genai-sa-key.json" + if not os.path.exists(file_path): + pytest.skip(f"File {file_path} does not exist") + from google.oauth2 import service_account + + credentials = service_account.Credentials.from_service_account_file(file_path) + ef = GoogleVertexAiEmbeddings(credentials=credentials) + embeddings = ef(["hello world", "goodbye world"]) + assert len(embeddings) == 2 + assert len(embeddings[0]) == 256 + assert len(embeddings[1]) == 256 + assert embeddings[0] != embeddings[1]