-
Notifications
You must be signed in to change notification settings - Fork 894
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add new agents docs embbeddings functionality
- Loading branch information
1 parent
9c8d1a6
commit a7e7e1d
Showing
6 changed files
with
560 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import tiktoken | ||
import numpy as np | ||
from typing import TypedDict, Any | ||
from dataclasses import dataclass | ||
from transformers import PreTrainedTokenizer | ||
from agents_api.clients.model import openai_client | ||
from agents_api.clients.embed import embed | ||
from agents_api.exceptions import ModelNotSupportedError, PromptTooBigError, UnknownTokenizerError | ||
|
||
|
||
def normalize_l2(x): | ||
x = np.array(x) | ||
if x.ndim == 1: | ||
norm = np.linalg.norm(x) | ||
if norm == 0: | ||
return x | ||
return x / norm | ||
else: | ||
norm = np.linalg.norm(x, 2, axis=1, keepdims=True) | ||
return np.where(norm == 0, x, x / norm) | ||
|
||
|
||
class EmbeddingInput(TypedDict): | ||
instruction: str | None | ||
text: str | ||
|
||
|
||
@dataclass | ||
class EmbeddingModel: | ||
embedding_provider: str | ||
embedding_model_name: str | ||
original_embedding_dimensions: int | ||
output_embedding_dimensions: int | ||
context_window: int | ||
tokenizer: Any | ||
|
||
@classmethod | ||
def from_model_name(cls, model_name: str): | ||
try: | ||
return _embedding_model_registry[model_name] | ||
except KeyError: | ||
raise ModelNotSupportedError(model_name) | ||
|
||
def _token_count(self, text: str) -> int: | ||
tokenize = getattr(self.tokenizer, "tokenize", None) | ||
if tokenize: | ||
return len(tokenize(text)) | ||
|
||
encode = getattr(self.tokenizer, "encode", None) | ||
if encode: | ||
return len(encode(text)) | ||
|
||
raise UnknownTokenizerError | ||
|
||
def preprocess(self, inputs: list[EmbeddingInput]) -> list[str]: | ||
"""Maybe use this function from embed() to truncate (if needed) or raise an error""" | ||
result: list[str] = [] | ||
|
||
for i in inputs: | ||
instruction = i.get('instruction', '') | ||
sep = ' ' if len(instruction) else '' | ||
result.append(f"{instruction}{sep}{i['text']}") | ||
|
||
token_count = self._token_count(" ".join(result)) | ||
if token_count > self.context_window: | ||
raise PromptTooBigError(token_count, self.context_window) | ||
|
||
return result | ||
|
||
async def embed( | ||
self, inputs: list[EmbeddingInput] | ||
) -> list[np.NDArray | list[float]]: | ||
input = self.preprocess(inputs) | ||
embeddings: list[np.NDArray | list[float]] = [] | ||
|
||
if self.embedding_provider == "julep": | ||
embeddings = await embed(input) | ||
elif self.embedding_provider == "openai": | ||
embeddings = ( | ||
await openai_client.embeddings.create( | ||
input=input, model=self.embedding_model_name | ||
) | ||
.data[0] | ||
.embedding | ||
) | ||
|
||
return self.normalize(embeddings) | ||
|
||
def normalize( | ||
self, embeddings: list[np.NDArray | list[float]] | ||
) -> list[np.NDArray | list[float]]: | ||
return [ | ||
( | ||
e | ||
if len(e) <= self.original_embedding_dimensions | ||
else normalize_l2(e[: self.original_embedding_dimensions]) | ||
) | ||
for e in embeddings | ||
] | ||
|
||
|
||
_embedding_model_registry = { | ||
"text-embeddings-3-small": EmbeddingModel( | ||
embedding_provider="openai", | ||
embedding_model_name="text-embeddings-3-small", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=tiktoken.encoding_for_model("text-embeddings-3-small"), | ||
), | ||
"text-embeddings-3-large": EmbeddingModel( | ||
embedding_provider="openai", | ||
embedding_model_name="text-embeddings-3-large", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=tiktoken.encoding_for_model("text-embeddings-3-large"), | ||
), | ||
"Alibaba-NLP/gte-large-en-v1.5": EmbeddingModel( | ||
embedding_provider="julep", | ||
embedding_model_name="Alibaba-NLP/gte-large-en-v1.5", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=PreTrainedTokenizer.from_pretrained("Alibaba-NLP/gte-large-en-v1.5"), | ||
), | ||
"BAAI/bge-m3": EmbeddingModel( | ||
embedding_provider="julep", | ||
embedding_model_name="BAAI/bge-m3", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=PreTrainedTokenizer.from_pretrained("BAAI/bge-m3"), | ||
), | ||
"BAAI/llm-embedder": EmbeddingModel( | ||
embedding_provider="julep", | ||
embedding_model_name="BAAI/llm-embedder", | ||
original_embedding_dimensions=1024, | ||
output_embedding_dimensions=1024, | ||
context_window=8192, | ||
tokenizer=PreTrainedTokenizer.from_pretrained("BAAI/llm-embedder"), | ||
), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
class BaseException(Exception): | ||
pass | ||
|
||
|
||
class ModelNotSupportedError(BaseException): | ||
def __init__(self, model_name): | ||
super().__init__(f"model {model_name} is not supporrted") | ||
|
||
|
||
class PromptTooBigError(BaseException): | ||
def __init__(self, token_count, max_tokens): | ||
super().__init__(f"prompt is too big, {max_tokens} required, but actual length is {token_count}") | ||
|
||
|
||
class UnknownTokenizerError(BaseException): | ||
def __init__(self): | ||
super().__init__("unknown tokenizer") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
agents-api/migrations/migrate_1714566760_change_embeddings_dimensions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# /usr/bin/env python3 | ||
|
||
MIGRATION_ID = "change_embeddings_dimensions" | ||
CREATED_AT = 1714566760.731964 | ||
|
||
|
||
change_dimensions = { | ||
"up": """ | ||
?[ | ||
doc_id, | ||
snippet_idx, | ||
title, | ||
snippet, | ||
embed_instruction, | ||
embedding, | ||
] := | ||
*information_snippets{ | ||
snippet_idx, | ||
title, | ||
snippet, | ||
embed_instruction, | ||
embedding, | ||
additional_info_id: doc_id, | ||
} | ||
:replace information_snippets { | ||
doc_id: Uuid, | ||
snippet_idx: Int, | ||
=> | ||
title: String, | ||
snippet: String, | ||
embed_instruction: String default 'Encode this passage for retrieval: ', | ||
embedding: <F32; 1024>? default null, | ||
} | ||
""", | ||
"down": """ | ||
?[ | ||
doc_id, | ||
snippet_idx, | ||
title, | ||
snippet, | ||
embed_instruction, | ||
embedding, | ||
] := | ||
*information_snippets{ | ||
snippet_idx, | ||
title, | ||
snippet, | ||
embed_instruction, | ||
embedding, | ||
additional_info_id: doc_id, | ||
} | ||
:replace information_snippets { | ||
doc_id: Uuid, | ||
snippet_idx: Int, | ||
=> | ||
title: String, | ||
snippet: String, | ||
embed_instruction: String default 'Encode this passage for retrieval: ', | ||
embedding: <F32; 768>? default null, | ||
} | ||
""", | ||
} | ||
|
||
information_snippets_hnsw_index = dict( | ||
up=""" | ||
::hnsw create information_snippets:embedding_space { | ||
fields: [embedding], | ||
filter: !is_null(embedding), | ||
dim: 1024, | ||
distance: Cosine, | ||
m: 64, | ||
ef_construction: 256, | ||
extend_candidates: false, | ||
keep_pruned_connections: false, | ||
} | ||
""", | ||
down=""" | ||
::hnsw create information_snippets:embedding_space { | ||
fields: [embedding], | ||
filter: !is_null(embedding), | ||
dim: 768, | ||
distance: Cosine, | ||
m: 64, | ||
ef_construction: 256, | ||
extend_candidates: false, | ||
keep_pruned_connections: false, | ||
} | ||
""", | ||
) | ||
|
||
drop_index = { | ||
"up": """ | ||
::hnsw drop information_snippets:embedding_space | ||
""", | ||
"down": """ | ||
::hnsw drop information_snippets:embedding_space | ||
""", | ||
} | ||
|
||
|
||
queries_to_run = [ | ||
drop_index, | ||
change_dimensions, | ||
information_snippets_hnsw_index, | ||
] | ||
|
||
|
||
def up(client): | ||
for q in queries_to_run: | ||
client.run(q["up"]) | ||
|
||
|
||
def down(client): | ||
for q in reversed(queries_to_run): | ||
client.run(q["down"]) |
Oops, something went wrong.