Skip to content

Commit

Permalink
feat: Add Vertex AI embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Sep 19, 2024
1 parent 39f71a3 commit 2831486
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 6 deletions.
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from temporalio import activity

from ..clients import cozo
from ..clients import embed as embedder
from ..clients import vertexai as embedder
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload
Expand Down
18 changes: 18 additions & 0 deletions agents-api/agents_api/clients/vertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import litellm
from litellm import aembedding

from ..env import google_project_id, vertex_location

litellm.vertex_project = google_project_id
litellm.vertex_location = vertex_location


async def embed(
inputs: list[str], dimensions: int = 1024, join_inputs: bool = True
) -> list[list[float]]:
input = ["\n\n".join(inputs)] if join_inputs else inputs
response = await aembedding(
model="vertex_ai/text-embedding-004", input=input, dimensions=dimensions
)

return response.data or []
4 changes: 4 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@
temporal_endpoint: Any = env.str("TEMPORAL_ENDPOINT", default="localhost:7233")
temporal_task_queue: Any = env.str("TEMPORAL_TASK_QUEUE", default="julep-task-queue")

# Google cloud
google_project_id: str = env.str("GOOGLE_PROJECT_ID")
vertex_location: str = env.str("VERTEX_LOCATION", default="us-central1")

# Consolidate environment variables
environment: Dict[str, Any] = dict(
Expand All @@ -97,6 +100,7 @@
temporal_namespace=temporal_namespace,
embedding_model_id=embedding_model_id,
testing=testing,
google_project_id=google_project_id,
)

if debug or testing:
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/models/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from agents_api.autogen.Chat import ChatInput

from ...autogen.openapi_model import DocReference, History
from ...clients import embed
from ...clients import vertexai as embed
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ..docs.search_docs_hybrid import search_docs_hybrid
Expand Down Expand Up @@ -61,7 +61,7 @@ async def gather_messages(
return past_messages, []

# Search matching docs
[query_embedding, *_] = await embed.embed(
query_embedding = await embed.embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/docs/embed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID

import agents_api.clients.embed as embedder
import agents_api.clients.vertexai as embedder

from ...autogen.openapi_model import (
EmbedQueryRequest,
Expand Down
3 changes: 2 additions & 1 deletion agents-api/tests/test_chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from ward import test

from agents_api.autogen.openapi_model import ChatInput, CreateSessionRequest
from agents_api.clients import embed, litellm
from agents_api.clients import litellm
from agents_api.clients import vertexai as embed
from agents_api.common.protocol.sessions import ChatContext
from agents_api.models.chat.gather_messages import gather_messages
from agents_api.models.chat.prepare_chat_context import prepare_chat_context
Expand Down
6 changes: 6 additions & 0 deletions llm-proxy/litellm-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ model_list:
api_base: os.environ/EMBEDDING_SERVICE_BASE
tags: ["free"]

- model_name: text-embedding-004
litellm_params:
model: vertex_ai/text-embedding-004
vertex_project: os.environ/GOOGLE_PROJECT_ID
vertex_location: os.environ/VERTEX_LOCATION


# -*= Free models =*-
# -------------------
Expand Down

0 comments on commit 2831486

Please sign in to comment.