Skip to content

Commit

Permalink
fix: Fix user doc creation
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed May 4, 2024
1 parent d762fa1 commit f6443b1
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 7 deletions.
5 changes: 3 additions & 2 deletions agents-api/agents_api/clients/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
async def embed(
inputs: list[str],
join_inputs=False,
embed_model_name: str = embedding_model_id,
embedding_service_url: str = embedding_service_url,
embedding_model_name: str = embedding_model_id,
) -> list[list[float]]:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
Expand All @@ -18,7 +19,7 @@ async def embed(
"normalize": True,
# FIXME: We should control the truncation ourselves and truncate before sending
"truncate": truncate_embed_text,
"model_id": embed_model_name,
"model_id": embedding_model_name,
},
)
resp.raise_for_status()
Expand Down
14 changes: 13 additions & 1 deletion agents-api/agents_api/embed_models_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PromptTooBigError,
UnknownTokenizerError,
)
from agents_api.env import docs_embedding_service_url


def normalize_l2(x):
Expand All @@ -31,6 +32,7 @@ class EmbeddingInput(TypedDict):

@dataclass
class EmbeddingModel:
embedding_service_url: str | None
embedding_provider: str
embedding_model_name: str
original_embedding_dimensions: int
Expand Down Expand Up @@ -78,7 +80,12 @@ async def embed(
embeddings: list[np.ndarray | list[float]] = []

if self.embedding_provider == "julep":
embeddings = await embed(input, embed_model_name=self.embedding_model_name)
embeddings = await embed(
input,
embedding_service_url=self.embedding_service_url
or docs_embedding_service_url,
embedding_model_name=self.embedding_model_name,
)
elif self.embedding_provider == "openai":
embeddings = (
await openai_client.embeddings.create(
Expand All @@ -105,6 +112,7 @@ def normalize(

_embedding_model_registry = {
"text-embedding-3-small": EmbeddingModel(
embedding_service_url=None,
embedding_provider="openai",
embedding_model_name="text-embedding-3-small",
original_embedding_dimensions=1024,
Expand All @@ -113,6 +121,7 @@ def normalize(
tokenizer=tiktoken.encoding_for_model("text-embedding-3-small"),
),
"text-embedding-3-large": EmbeddingModel(
embedding_service_url=None,
embedding_provider="openai",
embedding_model_name="text-embedding-3-large",
original_embedding_dimensions=1024,
Expand All @@ -121,6 +130,7 @@ def normalize(
tokenizer=tiktoken.encoding_for_model("text-embedding-3-large"),
),
"Alibaba-NLP/gte-large-en-v1.5": EmbeddingModel(
embedding_service_url=docs_embedding_service_url,
embedding_provider="julep",
embedding_model_name="Alibaba-NLP/gte-large-en-v1.5",
original_embedding_dimensions=1024,
Expand All @@ -129,6 +139,7 @@ def normalize(
tokenizer=Tokenizer.from_pretrained("Alibaba-NLP/gte-large-en-v1.5"),
),
"BAAI/bge-m3": EmbeddingModel(
embedding_service_url=docs_embedding_service_url,
embedding_provider="julep",
embedding_model_name="BAAI/bge-m3",
original_embedding_dimensions=1024,
Expand All @@ -137,6 +148,7 @@ def normalize(
tokenizer=Tokenizer.from_pretrained("BAAI/bge-m3"),
),
"BAAI/llm-embedder": EmbeddingModel(
embedding_service_url=docs_embedding_service_url,
embedding_provider="julep",
embedding_model_name="BAAI/llm-embedder",
original_embedding_dimensions=1024,
Expand Down
9 changes: 9 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,16 @@
"EMBEDDING_SERVICE_URL", default="http://0.0.0.0:8082/embed"
)

docs_embedding_service_url: str = env.str(
"DOCS_EMBEDDING_SERVICE_URL", default="http://0.0.0.0:8083/embed"
)

embedding_model_id: str = env.str(
"EMBEDDING_MODEL_ID", default="BAAI/bge-large-en-v1.5"
)

docs_embedding_model_id: str = env.str("DOCS_EMBEDDING_MODEL_ID", default="BAAI/bge-m3")

truncate_embed_text: bool = env.bool("TRUNCATE_EMBED_TEXT", default=False)

# Temporal
Expand Down Expand Up @@ -86,6 +92,9 @@
temporal_worker_url=temporal_worker_url,
temporal_namespace=temporal_namespace,
openai_api_key=openai_api_key,
docs_embedding_model_id=docs_embedding_model_id,
docs_embedding_service_url=docs_embedding_service_url,
embedding_model_id=embedding_model_id,
)

if openai_api_key == "":
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
PatchToolRequest,
PatchAgentRequest,
)
from agents_api.env import embedding_model_id
from agents_api.env import docs_embedding_model_id
from agents_api.embed_models_registry import EmbeddingModel


Expand Down Expand Up @@ -328,7 +328,7 @@ async def create_docs(agent_id: UUID4, request: CreateDoc) -> ResourceCreatedRes
)

indices, snippets = list(zip(*enumerate(content)))
model = EmbeddingModel.from_model_name(embedding_model_id)
model = EmbeddingModel.from_model_name(docs_embedding_model_id)
embeddings = await model.embed(
[
{
Expand Down
13 changes: 11 additions & 2 deletions agents-api/agents_api/routers/users/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Annotated
from uuid import uuid4

from agents_api.autogen.openapi_model import ContentItem
from fastapi import APIRouter, HTTPException, status, Depends
import pandas as pd
from pycozo.client import QueryException
Expand Down Expand Up @@ -46,6 +47,7 @@
Doc,
PatchUserRequest,
)
from agents_api.env import docs_embedding_model_id, docs_embedding_service_url


class UserList(BaseModel):
Expand Down Expand Up @@ -238,7 +240,12 @@ async def list_users(
@router.post("/users/{user_id}/docs", tags=["users"])
async def create_docs(user_id: UUID4, request: CreateDoc) -> ResourceCreatedResponse:
doc_id = uuid4()
content = [request.content] if isinstance(request.content, str) else request.content
content = [
(c.model_dump() if isinstance(c, ContentItem) else c)
for c in (
[request.content] if isinstance(request.content, str) else request.content
)
]
resp: pd.DataFrame = create_docs_query(
owner_type="user",
owner_id=user_id,
Expand All @@ -259,7 +266,9 @@ async def create_docs(user_id: UUID4, request: CreateDoc) -> ResourceCreatedResp
[
snippet_embed_instruction + request.title + "\n\n" + snippet
for snippet in snippets
]
],
embedding_service_url=docs_embedding_service_url,
embedding_model_name=docs_embedding_model_id,
)

embed_docs_snippets_query(
Expand Down
18 changes: 18 additions & 0 deletions agents-api/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ services:
count: all
capabilities: [gpu]

docs-text-embeddings-inference:
container_name: docs-text-embeddings-inference
environment:
- DTYPE=float16
- MODEL_ID=BAAI/bge-m3

image: ghcr.io/huggingface/text-embeddings-inference:1.0
ports:
- "8083:80"
shm_size: "2gb"
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]

temporal:
image: julepai/temporal:dev
container_name: temporal
Expand Down

0 comments on commit f6443b1

Please sign in to comment.