Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: VertexAI EF #49

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions chromadbx/embeddings/google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Optional, cast

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings


class GoogleVertexAiEmbeddings(EmbeddingFunction[Documents]):
def __init__(
self,
model_name: str = "text-embedding-004",
*,
project_id: Optional[str] = None,
location: Optional[str] = None,
dimensions: Optional[int] = 256,
task_type: Optional[str] = "RETRIEVAL_DOCUMENT",
credentials: Optional[any] = None,
api_key: Optional[str] = None,
api_endpoint: Optional[str] = None,
api_transport: Optional[str] = None,
) -> None:
"""
Initialize the OnnxRuntimeEmbeddings.

:param model_name: The name of the model to use. Defaults to "text-embedding-004".
:param project_id: The project ID to use. Defaults to None.
:param location: The location to use. Defaults to None.
:param dimensions: The number of dimensions to use. Defaults to None.
:param task_type: The task type to use. Defaults to "RETRIEVAL_DOCUMENT". https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/task-types
:param credentials: The credentials to use. Defaults to None.
:param api_key: The API key to use. Defaults to None.
:param api_endpoint: The API endpoint to use. Defaults to None.
:param api_transport: The API transport to use. Defaults to None.
"""
try:
import vertexai
from vertexai.language_models import TextEmbeddingModel

vertexai.init(
project=project_id,
location=location,
credentials=credentials,
api_key=api_key,
api_endpoint=api_endpoint,
api_transport=api_transport,
)
self._model = TextEmbeddingModel.from_pretrained(model_name)
except ImportError:
raise ValueError(
"The vertexai python package is not installed. Please install it with `pip install vertexai`"
)
self._dimensions = dimensions
self._task_type = task_type

def __call__(self, input: Documents) -> Embeddings:
from vertexai.language_models import TextEmbeddingInput

inputs = [TextEmbeddingInput(text, self._task_type) for text in input]
kwargs = (
dict(output_dimensionality=self._dimensions) if self._dimensions else {}
)
embeddings = [
embedding.values
for embedding in self._model.get_embeddings(inputs, **kwargs)
]
return cast(Embeddings, embeddings)
32 changes: 32 additions & 0 deletions docs/embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,35 @@ col = client.get_or_create_collection("test", embedding_function=ef)

col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"])
```

## Google Vertex AI

A convenient way to run Google Vertex AI models to generate embeddings.

Google Vertex AI uses variety of authentication methods. The most secure is either service account key file or Google Application Default Credentials.

```py
import chromadb
from chromadbx.embeddings.google import GoogleVertexAiEmbeddings

ef = GoogleVertexAiEmbeddings()

client = chromadb.Client()

col = client.get_or_create_collection("test", embedding_function=ef)

col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"])
```

### Auth with service account key file

```py
import chromadb
from chromadbx.embeddings.google import GoogleVertexAiEmbeddings
from google.oauth2 import service_account

credentials = service_account.Credentials.from_service_account_file("path/to/service-account-key.json")
ef = GoogleVertexAiEmbeddings(credentials=credentials)

ef(["hello world", "goodbye world"])
```
62 changes: 62 additions & 0 deletions test/embeddings/test_google.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import pytest
from chromadbx.embeddings.google import GoogleVertexAiEmbeddings

vai = pytest.importorskip("vertexai", reason="vertexai not installed")


def test_embed() -> None:
ef = GoogleVertexAiEmbeddings()
embeddings = ef(["hello world", "goodbye world"])
assert len(embeddings) == 2
assert len(embeddings[0]) == 256
assert len(embeddings[1]) == 256
assert embeddings[0] != embeddings[1]


def test_with_model() -> None:
ef = GoogleVertexAiEmbeddings(
model_name="text-multilingual-embedding-002",
)
embeddings = ef(["hello world", "goodbye world"])
assert len(embeddings) == 2
assert len(embeddings[0]) == 256
assert len(embeddings[1]) == 256
assert embeddings[0] != embeddings[1]


def test_dimensions() -> None:
ef = GoogleVertexAiEmbeddings(
model_name="text-multilingual-embedding-002",
dimensions=768,
)
embeddings = ef(["hello world", "goodbye world"])
assert len(embeddings) == 2
assert len(embeddings[0]) == 768
assert len(embeddings[1]) == 768
assert embeddings[0] != embeddings[1]


def test_task_type() -> None:
ef = GoogleVertexAiEmbeddings(
task_type="RETRIEVAL_QUERY",
)
embeddings = ef(["hello world", "goodbye world"])
assert len(embeddings) == 2
assert len(embeddings[0]) == 256
assert len(embeddings[1]) == 256


def test_credentials() -> None:
file_path = "genai-sa-key.json"
if not os.path.exists(file_path):
pytest.skip(f"File {file_path} does not exist")
from google.oauth2 import service_account

credentials = service_account.Credentials.from_service_account_file(file_path)
ef = GoogleVertexAiEmbeddings(credentials=credentials)
embeddings = ef(["hello world", "goodbye world"])
assert len(embeddings) == 2
assert len(embeddings[0]) == 256
assert len(embeddings[1]) == 256
assert embeddings[0] != embeddings[1]
Loading