Skip to content

Commit

Permalink
feat: Add new agents docs embbeddings functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed May 2, 2024
1 parent 9c8d1a6 commit 4613316
Show file tree
Hide file tree
Showing 6 changed files with 560 additions and 91 deletions.
143 changes: 143 additions & 0 deletions agents-api/agents_api/embed_models_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import tiktoken
import numpy as np
from typing import TypedDict, Any
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from agents_api.clients.model import openai_client
from agents_api.clients.embed import embed
from agents_api.exceptions import ModelNotSupportedError, PromptTooBigError, UnknownTokenizerError


def normalize_l2(x):
x = np.array(x)
if x.ndim == 1:
norm = np.linalg.norm(x)
if norm == 0:
return x
return x / norm
else:
norm = np.linalg.norm(x, 2, axis=1, keepdims=True)
return np.where(norm == 0, x, x / norm)


class EmbeddingInput(TypedDict):
instruction: str | None
text: str


@dataclass
class EmbeddingModel:
embedding_provider: str
embedding_model_name: str
original_embedding_dimensions: int
output_embedding_dimensions: int
context_window: int
tokenizer: Any

@classmethod
def from_model_name(cls, model_name: str):
try:
return _embedding_model_registry[model_name]
except KeyError:
raise ModelNotSupportedError(model_name)

def _token_count(self, text: str) -> int:
tokenize = getattr(self.tokenizer, "tokenize", None)
if tokenize:
return len(tokenize(text))

encode = getattr(self.tokenizer, "encode", None)
if encode:
return len(encode(text))

raise UnknownTokenizerError

def preprocess(self, inputs: list[EmbeddingInput]) -> list[str]:
"""Maybe use this function from embed() to truncate (if needed) or raise an error"""
result: list[str] = []

for i in inputs:
instruction = i.get('instruction', '')
sep = ' ' if len(instruction) else ''
result.append(f"{instruction}{sep}{i['text']}")

token_count = self._token_count(" ".join(result))
if token_count > self.context_window:
raise PromptTooBigError(token_count, self.context_window)

return result

async def embed(
self, inputs: list[EmbeddingInput]
) -> list[np.NDArray | list[float]]:
input = self.preprocess(inputs)
embeddings: list[np.NDArray | list[float]] = []

if self.embedding_provider == "julep":
embeddings = await embed(input)
elif self.embedding_provider == "openai":
embeddings = (
await openai_client.embeddings.create(
input=input, model=self.embedding_model_name
)
.data[0]
.embedding
)

return self.normalize(embeddings)

def normalize(
self, embeddings: list[np.NDArray | list[float]]
) -> list[np.NDArray | list[float]]:
return [
(
e
if len(e) <= self.original_embedding_dimensions
else normalize_l2(e[: self.original_embedding_dimensions])
)
for e in embeddings
]


_embedding_model_registry = {
"text-embeddings-3-small": EmbeddingModel(
embedding_provider="openai",
embedding_model_name="text-embeddings-3-small",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=tiktoken.encoding_for_model("text-embeddings-3-small"),
),
"text-embeddings-3-large": EmbeddingModel(
embedding_provider="openai",
embedding_model_name="text-embeddings-3-large",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=tiktoken.encoding_for_model("text-embeddings-3-large"),
),
"Alibaba-NLP/gte-large-en-v1.5": EmbeddingModel(
embedding_provider="julep",
embedding_model_name="Alibaba-NLP/gte-large-en-v1.5",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=PreTrainedTokenizer.from_pretrained("Alibaba-NLP/gte-large-en-v1.5"),
),
"BAAI/bge-m3": EmbeddingModel(
embedding_provider="julep",
embedding_model_name="BAAI/bge-m3",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=PreTrainedTokenizer.from_pretrained("BAAI/bge-m3"),
),
"BAAI/llm-embedder": EmbeddingModel(
embedding_provider="julep",
embedding_model_name="BAAI/llm-embedder",
original_embedding_dimensions=1024,
output_embedding_dimensions=1024,
context_window=8192,
tokenizer=PreTrainedTokenizer.from_pretrained("BAAI/llm-embedder"),
),
}
17 changes: 17 additions & 0 deletions agents-api/agents_api/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
class BaseException(Exception):
pass


class ModelNotSupportedError(BaseException):
def __init__(self, model_name):
super().__init__(f"model {model_name} is not supported")


class PromptTooBigError(BaseException):
def __init__(self, token_count, max_tokens):
super().__init__(f"prompt is too big, {max_tokens} required, but actual length is {token_count}")


class UnknownTokenizerError(BaseException):
def __init__(self):
super().__init__("unknown tokenizer")
10 changes: 8 additions & 2 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
PatchToolRequest,
PatchAgentRequest,
)
from agents_api.env import embedding_model_id
from agents_api.embed_models_registry import EmbeddingModel


class AgentList(BaseModel):
Expand Down Expand Up @@ -319,9 +321,13 @@ async def create_docs(agent_id: UUID4, request: CreateDoc) -> ResourceCreatedRes
)

indices, snippets = list(zip(*enumerate(content)))
embeddings = await embed(
model = EmbeddingModel.from_model_name(embedding_model_id)
embeddings = await model.embed(
[
snippet_embed_instruction + request.title + "\n\n" + snippet
{
"instruction": snippet_embed_instruction,
"text": request.title + "\n\n" + snippet,
}
for snippet in snippets
]
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# /usr/bin/env python3

MIGRATION_ID = "change_embeddings_dimensions"
CREATED_AT = 1714566760.731964


change_dimensions = {
"up": """
?[
doc_id,
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
] :=
*information_snippets{
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
additional_info_id: doc_id,
}
:replace information_snippets {
doc_id: Uuid,
snippet_idx: Int,
=>
title: String,
snippet: String,
embed_instruction: String default 'Encode this passage for retrieval: ',
embedding: <F32; 1024>? default null,
}
""",
"down": """
?[
doc_id,
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
] :=
*information_snippets{
snippet_idx,
title,
snippet,
embed_instruction,
embedding,
additional_info_id: doc_id,
}
:replace information_snippets {
doc_id: Uuid,
snippet_idx: Int,
=>
title: String,
snippet: String,
embed_instruction: String default 'Encode this passage for retrieval: ',
embedding: <F32; 768>? default null,
}
""",
}

information_snippets_hnsw_index = dict(
up="""
::hnsw create information_snippets:embedding_space {
fields: [embedding],
filter: !is_null(embedding),
dim: 1024,
distance: Cosine,
m: 64,
ef_construction: 256,
extend_candidates: false,
keep_pruned_connections: false,
}
""",
down="""
::hnsw create information_snippets:embedding_space {
fields: [embedding],
filter: !is_null(embedding),
dim: 768,
distance: Cosine,
m: 64,
ef_construction: 256,
extend_candidates: false,
keep_pruned_connections: false,
}
""",
)

drop_index = {
"up": """
::hnsw drop information_snippets:embedding_space
""",
"down": """
::hnsw drop information_snippets:embedding_space
""",
}


queries_to_run = [
drop_index,
change_dimensions,
information_snippets_hnsw_index,
]


def up(client):
for q in queries_to_run:
client.run(q["up"])


def down(client):
for q in reversed(queries_to_run):
client.run(q["down"])
Loading

0 comments on commit 4613316

Please sign in to comment.