diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index 3222ff9e7..0fad43d92 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -2,7 +2,7 @@ from temporalio import activity from ..clients import cozo -from ..clients import embed as embedder +from ..clients import vertexai as embedder from ..env import testing from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query from .types import EmbedDocsPayload diff --git a/agents-api/agents_api/clients/vertexai.py b/agents-api/agents_api/clients/vertexai.py new file mode 100644 index 000000000..eaaee0182 --- /dev/null +++ b/agents-api/agents_api/clients/vertexai.py @@ -0,0 +1,18 @@ +import litellm +from litellm import aembedding + +from ..env import google_project_id, vertex_location + +litellm.vertex_project = google_project_id +litellm.vertex_location = vertex_location + + +async def embed( + inputs: list[str], dimensions: int = 1024, join_inputs: bool = True +) -> list[list[float]]: + input = ["\n\n".join(inputs)] if join_inputs else inputs + response = await aembedding( + model="vertex_ai/text-embedding-004", input=input, dimensions=dimensions + ) + + return response.data or [] diff --git a/agents-api/agents_api/env.py b/agents-api/agents_api/env.py index 85d33f0e6..b5253788f 100644 --- a/agents-api/agents_api/env.py +++ b/agents-api/agents_api/env.py @@ -77,6 +77,9 @@ temporal_endpoint: Any = env.str("TEMPORAL_ENDPOINT", default="localhost:7233") temporal_task_queue: Any = env.str("TEMPORAL_TASK_QUEUE", default="julep-task-queue") +# Google cloud +google_project_id: str = env.str("GOOGLE_PROJECT_ID") +vertex_location: str = env.str("VERTEX_LOCATION", default="us-central1") # Consolidate environment variables environment: Dict[str, Any] = dict( @@ -97,6 +100,7 @@ temporal_namespace=temporal_namespace, embedding_model_id=embedding_model_id, testing=testing, + google_project_id=google_project_id, ) if debug or testing: diff --git a/agents-api/agents_api/models/chat/gather_messages.py b/agents-api/agents_api/models/chat/gather_messages.py index f8e08632d..f1b64d05d 100644 --- a/agents-api/agents_api/models/chat/gather_messages.py +++ b/agents-api/agents_api/models/chat/gather_messages.py @@ -9,7 +9,7 @@ from agents_api.autogen.Chat import ChatInput from ...autogen.openapi_model import DocReference, History -from ...clients import embed +from ...clients import vertexai as embed from ...common.protocol.developers import Developer from ...common.protocol.sessions import ChatContext from ..docs.search_docs_hybrid import search_docs_hybrid @@ -61,7 +61,7 @@ async def gather_messages( return past_messages, [] # Search matching docs - [query_embedding, *_] = await embed.embed( + query_embedding = await embed.embed( inputs=[ f"{msg.get('name') or msg['role']}: {msg['content']}" for msg in new_raw_messages diff --git a/agents-api/agents_api/routers/docs/embed.py b/agents-api/agents_api/routers/docs/embed.py index 2c6b7b641..932fcbf6c 100644 --- a/agents-api/agents_api/routers/docs/embed.py +++ b/agents-api/agents_api/routers/docs/embed.py @@ -1,9 +1,9 @@ from typing import Annotated +from uuid import UUID from fastapi import Depends -from uuid import UUID -import agents_api.clients.embed as embedder +import agents_api.clients.vertexai as embedder from ...autogen.openapi_model import ( EmbedQueryRequest, diff --git a/agents-api/tests/test_chat_routes.py b/agents-api/tests/test_chat_routes.py index 416849c69..2c71db801 100644 --- a/agents-api/tests/test_chat_routes.py +++ b/agents-api/tests/test_chat_routes.py @@ -3,7 +3,8 @@ from ward import test from agents_api.autogen.openapi_model import ChatInput, CreateSessionRequest -from agents_api.clients import embed, litellm +from agents_api.clients import litellm +from agents_api.clients import vertexai as embed from agents_api.common.protocol.sessions import ChatContext from agents_api.models.chat.gather_messages import gather_messages from agents_api.models.chat.prepare_chat_context import prepare_chat_context diff --git a/llm-proxy/litellm-config.yaml b/llm-proxy/litellm-config.yaml index e91087461..0e2cf0254 100644 --- a/llm-proxy/litellm-config.yaml +++ b/llm-proxy/litellm-config.yaml @@ -103,6 +103,12 @@ model_list: api_base: os.environ/EMBEDDING_SERVICE_BASE tags: ["free"] +- model_name: text-embedding-004 + litellm_params: + model: vertex_ai/text-embedding-004 + vertex_project: os.environ/GOOGLE_PROJECT_ID + vertex_location: os.environ/VERTEX_LOCATION + # -*= Free models =*- # -------------------