Skip to content

Commit

Permalink
Merge pull request #510 from julep-ai/f/vertexai-client
Browse files Browse the repository at this point in the history
Vertex AI client
  • Loading branch information
whiterabbit1983 authored Sep 20, 2024
2 parents 99eef6c + 6b3ab54 commit e2ac3ce
Show file tree
Hide file tree
Showing 61 changed files with 704 additions and 661 deletions.
1 change: 0 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ LITELLM_POSTGRES_PASSWORD=<your_litellm_postgres_password>
LITELLM_MASTER_KEY=<your_litellm_master_key>
LITELLM_SALT_KEY=<your_litellm_salt_key>
LITELLM_REDIS_PASSWORD=<your_litellm_redis_password>
EMBEDDING_SERVICE_BASE=http://text-embeddings-inference-<gpu|cpu> # Use the 'gpu' profile to run on GPU

# Memory Store
# -----------
Expand Down
7 changes: 3 additions & 4 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from beartype import beartype
from temporalio import activity

from ..clients import cozo
from ..clients import embed as embedder
from ..clients import cozo, litellm
from ..env import testing
from ..models.docs.embed_snippets import embed_snippets as embed_snippets_query
from .types import EmbedDocsPayload
Expand All @@ -14,8 +13,8 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await embedder.embed(
[
embeddings = await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
Expand Down
28 changes: 0 additions & 28 deletions agents-api/agents_api/clients/embed.py

This file was deleted.

63 changes: 59 additions & 4 deletions agents-api/agents_api/clients/litellm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
from functools import wraps
from typing import List
from typing import List, Literal

import litellm
from beartype import beartype
from litellm import acompletion as _acompletion
from litellm import get_supported_openai_params
from litellm import (
acompletion as _acompletion,
)
from litellm import (
aembedding as _aembedding,
)
from litellm import (
get_supported_openai_params,
)
from litellm.utils import CustomStreamWrapper, ModelResponse

from ..env import litellm_master_key, litellm_url
from ..env import (
embedding_dimensions,
embedding_model_id,
litellm_master_key,
litellm_url,
)

__all__: List[str] = ["acompletion"]

# TODO: Should check if this is really needed
litellm.drop_params = True


@wraps(_acompletion)
@beartype
async def acompletion(
*, model: str, messages: list[dict], custom_api_key: None | str = None, **kwargs
) -> ModelResponse | CustomStreamWrapper:
if not custom_api_key:
model = f"openai/{model}" # FIXME: This is for litellm

supported_params = get_supported_openai_params(model)
settings = {k: v for k, v in kwargs.items() if k in supported_params}
Expand All @@ -27,3 +45,40 @@ async def acompletion(
base_url=None if custom_api_key else litellm_url,
api_key=custom_api_key or litellm_master_key,
)


@wraps(_aembedding)
@beartype
async def aembedding(
*,
inputs: str | list[str],
model: str = embedding_model_id,
dimensions: int = embedding_dimensions,
join_inputs: bool = False,
custom_api_key: None | str = None,
**settings,
) -> list[list[float]]:
if not custom_api_key:
model = f"openai/{model}" # FIXME: This is for litellm

if isinstance(inputs, str):
input = [inputs]
else:
input = ["\n\n".join(inputs)] if join_inputs else inputs

response = await _aembedding(
model=model,
input=input,
# dimensions=dimensions, # FIXME: litellm doesn't support dimensions correctly
api_base=None if custom_api_key else litellm_url,
api_key=custom_api_key or litellm_master_key,
drop_params=True,
**settings,
)

embedding_list: list[dict[Literal["embedding"], list[float]]] = response.data

# FIXME: Truncation should be handled by litellm
result = [embedding["embedding"][:dimensions] for embedding in embedding_list]

return result
10 changes: 4 additions & 6 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# -----
task_max_parallelism: int = env.int("AGENTS_API_TASK_MAX_PARALLELISM", default=100)


# Debug
# -----
debug: bool = env.bool("AGENTS_API_DEBUG", default=False)
Expand All @@ -51,6 +52,7 @@

api_key_header_name: str = env.str("AGENTS_API_KEY_HEADER_NAME", default="X-Auth-Key")


# Litellm API
# -----------
litellm_url: str = env.str("LITELLM_URL", default="http://0.0.0.0:4000")
Expand All @@ -59,13 +61,11 @@

# Embedding service
# -----------------
embedding_service_base: str = env.str(
"EMBEDDING_SERVICE_BASE", default="http://0.0.0.0:8082"
)
embedding_model_id: str = env.str(
"EMBEDDING_MODEL_ID", default="Alibaba-NLP/gte-large-en-v1.5"
)
truncate_embed_text: bool = env.bool("TRUNCATE_EMBED_TEXT", default=True)

embedding_dimensions: int = env.int("EMBEDDING_DIMENSIONS", default=1024)


# Temporal
Expand All @@ -91,8 +91,6 @@
api_key_header_name=api_key_header_name,
hostname=hostname,
api_prefix=api_prefix,
embedding_service_base=embedding_service_base,
truncate_embed_text=truncate_embed_text,
temporal_worker_url=temporal_worker_url,
temporal_namespace=temporal_namespace,
embedding_model_id=embedding_model_id,
Expand Down
15 changes: 8 additions & 7 deletions agents-api/agents_api/models/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from agents_api.autogen.Chat import ChatInput

from ...autogen.openapi_model import DocReference, History
from ...clients import embed
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ..docs.search_docs_hybrid import search_docs_hybrid
Expand Down Expand Up @@ -61,12 +61,13 @@ async def gather_messages(
return past_messages, []

# Search matching docs
[query_embedding, *_] = await embed.embed(
inputs=[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
],
join_inputs=True,
[query_embedding, *_] = await litellm.aembedding(
inputs="\n\n".join(
[
f"{msg.get('name') or msg['role']}: {msg['content']}"
for msg in new_raw_messages
]
),
)
query_text = new_raw_messages[-1]["content"]

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/create_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_201_CREATED

import agents_api.models as models
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/create_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_201_CREATED

import agents_api.models as models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_201_CREATED

import agents_api.models as models
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/delete_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_202_ACCEPTED

from ...autogen.openapi_model import ResourceDeletedResponse
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/delete_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import ResourceDeletedResponse
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/get_agent_details.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import Agent
from ...dependencies.developer_id import get_developer_id
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/list_agent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import ListResponse, Tool
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/list_agents.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from json import JSONDecodeError
from typing import Annotated, Literal
from uuid import UUID

from fastapi import Depends, HTTPException, status
from uuid import UUID

from ...autogen.openapi_model import Agent, ListResponse
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/patch_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_200_OK

from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/patch_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import (
PatchToolRequest,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/agents/update_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_200_OK

from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/agents/update_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import (
ResourceUpdatedResponse,
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/docs/create_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from uuid import UUID, uuid4

from fastapi import BackgroundTasks, Depends
from uuid import UUID
from starlette.status import HTTP_201_CREATED
from temporalio.client import Client as TemporalClient

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/docs/delete_doc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID
from starlette.status import HTTP_202_ACCEPTED

from ...autogen.openapi_model import ResourceDeletedResponse
Expand Down
7 changes: 3 additions & 4 deletions agents-api/agents_api/routers/docs/embed.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Annotated

from fastapi import Depends
from uuid import UUID

import agents_api.clients.embed as embedder
from fastapi import Depends

from ...autogen.openapi_model import (
EmbedQueryRequest,
EmbedQueryResponse,
)
from ...clients import litellm
from ...dependencies.developer_id import get_developer_id
from .router import router

Expand All @@ -23,6 +22,6 @@ async def embed(
[text_to_embed] if isinstance(text_to_embed, str) else text_to_embed
)

vectors = await embedder.embed(inputs=text_to_embed)
vectors = await litellm.aembedding(inputs=text_to_embed)

return EmbedQueryResponse(vectors=vectors)
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/docs/get_doc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import Doc
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/docs/list_docs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
from json import JSONDecodeError
from typing import Annotated, Literal
from uuid import UUID

from fastapi import Depends, HTTPException, status
from uuid import UUID

from ...autogen.openapi_model import Doc, ListResponse
from ...dependencies.developer_id import get_developer_id
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/docs/search_docs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from typing import Annotated, Any, Dict, List, Optional, Tuple, Union
from uuid import UUID

from fastapi import Depends
from uuid import UUID

from ...autogen.openapi_model import (
DocSearchResponse,
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/routers/jobs/routers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Literal
from uuid import UUID

from fastapi import APIRouter
from uuid import UUID
from temporalio.client import WorkflowExecutionStatus

from agents_api.autogen.openapi_model import JobStatus
Expand Down
Loading

0 comments on commit e2ac3ce

Please sign in to comment.