Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add new agents docs embbeddings functionality #305

Merged
merged 7 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion agents-api/agents_api/clients/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
async def embed(
inputs: list[str],
join_inputs=False,
embed_model_name: str = embedding_model_id,
) -> list[list[float]]:
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
Expand All @@ -17,7 +18,7 @@ async def embed(
"normalize": True,
# FIXME: We should control the truncation ourselves and truncate before sending
"truncate": truncate_embed_text,
"model_id": embedding_model_id,
"model_id": embed_model_name,
},
)
resp.raise_for_status()
Expand Down
147 changes: 147 additions & 0 deletions agents-api/agents_api/embed_models_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import tiktoken
import numpy as np
from typing import TypedDict, Any
from dataclasses import dataclass
from tokenizers import Tokenizer
from agents_api.clients.model import openai_client
from agents_api.clients.embed import embed
from agents_api.exceptions import (
ModelNotSupportedError,
PromptTooBigError,
UnknownTokenizerError,
)


def normalize_l2(x):
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
x = np.array(x)
if x.ndim == 1:
norm = np.linalg.norm(x)
if norm == 0:
return x
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
return x / norm
else:
norm = np.linalg.norm(x, 2, axis=1, keepdims=True)
return np.where(norm == 0, x, x / norm)


class EmbeddingInput(TypedDict):
instruction: str | None
text: str


@dataclass
class EmbeddingModel:
embedding_provider: str
embedding_model_name: str
original_embedding_dimensions: int
output_embedding_dimensions: int
context_window: int
tokenizer: Any

@classmethod
def from_model_name(cls, model_name: str):
try:
return _embedding_model_registry[model_name]
except KeyError:
raise ModelNotSupportedError(model_name)

def _token_count(self, text: str) -> int:
tokenize = getattr(self.tokenizer, "tokenize", None)
if tokenize:
return len(tokenize(text))

encode = getattr(self.tokenizer, "encode", None)
if encode:
return len(encode(text))

raise UnknownTokenizerError

def preprocess(self, inputs: list[EmbeddingInput]) -> list[str]:
"""Maybe use this function from embed() to truncate (if needed) or raise an error"""
result: list[str] = []

for i in inputs:
instruction = i.get("instruction", "")
sep = " " if len(instruction) else ""
result.append(f"{instruction}{sep}{i['text']}")

token_count = self._token_count(" ".join(result))
if token_count > self.context_window:
raise PromptTooBigError(token_count, self.context_window)

return result

async def embed(
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
self, inputs: list[EmbeddingInput]
) -> list[np.ndarray | list[float]]:
input = self.preprocess(inputs)
embeddings: list[np.ndarray | list[float]] = []

whiterabbit1983 marked this conversation as resolved.
Show resolved Hide resolved
if self.embedding_provider == "julep":
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
embeddings = await embed(input, embed_model_name=self.embedding_model_name)
elif self.embedding_provider == "openai":
embeddings = (
await openai_client.embeddings.create(
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
input=input, model=self.embedding_model_name
)
.data[0]
.embedding
)

return self.normalize(embeddings)

def normalize(
self, embeddings: list[np.ndarray | list[float]]
) -> list[np.ndarray | list[float]]:
return [
(
e
if len(e) <= self.original_embedding_dimensions
else normalize_l2(e[: self.original_embedding_dimensions])
)
for e in embeddings
]


_embedding_model_registry = {
"text-embedding-3-small": EmbeddingModel(
embedding_provider="openai",
embedding_model_name="text-embedding-3-small",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=tiktoken.encoding_for_model("text-embedding-3-small"),
),
"text-embedding-3-large": EmbeddingModel(
embedding_provider="openai",
embedding_model_name="text-embedding-3-large",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=tiktoken.encoding_for_model("text-embedding-3-large"),
),
"Alibaba-NLP/gte-large-en-v1.5": EmbeddingModel(
embedding_provider="julep",
embedding_model_name="Alibaba-NLP/gte-large-en-v1.5",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=Tokenizer.from_pretrained("Alibaba-NLP/gte-large-en-v1.5"),
),
"BAAI/bge-m3": EmbeddingModel(
embedding_provider="julep",
embedding_model_name="BAAI/bge-m3",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=Tokenizer.from_pretrained("BAAI/bge-m3"),
),
"BAAI/llm-embedder": EmbeddingModel(
embedding_provider="julep",
embedding_model_name="BAAI/llm-embedder",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=Tokenizer.from_pretrained("BAAI/llm-embedder"),
),
}
19 changes: 19 additions & 0 deletions agents-api/agents_api/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
class AgentsBaseException(Exception):
pass


class ModelNotSupportedError(AgentsBaseException):
def __init__(self, model_name):
super().__init__(f"model {model_name} is not supported")


class PromptTooBigError(AgentsBaseException):
def __init__(self, token_count, max_tokens):
super().__init__(
f"prompt is too big, {token_count} tokens provided, exceeds maximum of {max_tokens}"
)


class UnknownTokenizerError(AgentsBaseException):
def __init__(self):
super().__init__("unknown tokenizer")
19 changes: 16 additions & 3 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Annotated
from uuid import uuid4

from agents_api.autogen.openapi_model import ContentItem
from agents_api.model_registry import validate_configuration
from fastapi import APIRouter, HTTPException, status, Depends
import pandas as pd
Expand Down Expand Up @@ -66,6 +67,8 @@
PatchToolRequest,
PatchAgentRequest,
)
from agents_api.env import embedding_model_id
from agents_api.embed_models_registry import EmbeddingModel


class AgentList(BaseModel):
Expand Down Expand Up @@ -302,7 +305,13 @@ async def list_agents(
@router.post("/agents/{agent_id}/docs", tags=["agents"])
async def create_docs(agent_id: UUID4, request: CreateDoc) -> ResourceCreatedResponse:
doc_id = uuid4()
content = [request.content] if isinstance(request.content, str) else request.content
content = [
(c.model_dump() if isinstance(c, ContentItem) else c)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure that ContentItem has a model_dump() method to avoid potential AttributeError. If it does not exist, consider implementing it or handling this scenario appropriately.

for c in (
[request.content] if isinstance(request.content, str) else request.content
)
]

resp: pd.DataFrame = create_docs_query(
owner_type="agent",
owner_id=agent_id,
Expand All @@ -319,9 +328,13 @@ async def create_docs(agent_id: UUID4, request: CreateDoc) -> ResourceCreatedRes
)

indices, snippets = list(zip(*enumerate(content)))
embeddings = await embed(
model = EmbeddingModel.from_model_name(embedding_model_id)
embeddings = await model.embed(
[
snippet_embed_instruction + request.title + "\n\n" + snippet
{
"instruction": snippet_embed_instruction,
"text": request.title + "\n\n" + snippet,
}
for snippet in snippets
]
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# /usr/bin/env python3

MIGRATION_ID = "change_embeddings_dimensions"
CREATED_AT = 1714566760.731964


change_dimensions = {
"up": """
?[
doc_id,
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
] :=
*information_snippets{
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
doc_id,
}

:replace information_snippets {
doc_id: Uuid,
snippet_idx: Int,
=>
title: String,
snippet: String,
embed_instruction: String default 'Encode this passage for retrieval: ',
embedding: <F32; 1024>? default null,
}
""",
"down": """
?[
doc_id,
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
] :=
*information_snippets{
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
doc_id,
}

:replace information_snippets {
doc_id: Uuid,
snippet_idx: Int,
=>
title: String,
snippet: String,
embed_instruction: String default 'Encode this passage for retrieval: ',
embedding: <F32; 768>? default null,
}
""",
}

snippets_hnsw_768_index = dict(
up="""
::hnsw create information_snippets:embedding_space {
fields: [embedding],
filter: !is_null(embedding),
dim: 768,
distance: Cosine,
m: 64,
ef_construction: 256,
extend_candidates: true,
keep_pruned_connections: false,
}
""",
down="""
::hnsw drop information_snippets:embedding_space
""",
)

drop_snippets_hnsw_768_index = {
"up": snippets_hnsw_768_index["down"],
"down": snippets_hnsw_768_index["up"],
}

snippets_hnsw_1024_index = dict(
up="""
::hnsw create information_snippets:embedding_space {
fields: [embedding],
filter: !is_null(embedding),
dim: 1024,
distance: Cosine,
m: 64,
ef_construction: 256,
extend_candidates: true,
keep_pruned_connections: false,
}
""",
down="""
::hnsw drop information_snippets:embedding_space
""",
)

drop_snippets_hnsw_1024_index = {
"up": snippets_hnsw_1024_index["down"],
"down": snippets_hnsw_1024_index["up"],
}


# See: https://docs.cozodb.org/en/latest/vector.html#full-text-search-fts
information_snippets_fts_index = dict(
up="""
::fts create information_snippets:fts {
extractor: concat(title, ' ', snippet),
tokenizer: Simple,
filters: [Lowercase, Stemmer('english'), Stopwords('en')],
}
""",
down="""
::fts drop information_snippets:fts
""",
)

drop_information_snippets_fts_index = {
"up": information_snippets_fts_index["down"],
"down": information_snippets_fts_index["up"],
}


queries_to_run = [
whiterabbit1983 marked this conversation as resolved.
Show resolved Hide resolved
drop_information_snippets_fts_index,
drop_snippets_hnsw_768_index,
change_dimensions,
snippets_hnsw_1024_index,
information_snippets_fts_index,
]


def up(client):
for q in queries_to_run:
client.run(q["up"])


def down(client):
for q in reversed(queries_to_run):
client.run(q["down"])
Loading
Loading