Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use openai's batch processing to create large volumes of embeddings #280

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fff2086
feat: create embedding batches using OpenAI's batch api
kolaente Dec 5, 2024
9335e66
feat: process batch embeddings submitted to openai
kolaente Dec 5, 2024
70d7b7d
fix: only open temp file for writing
kolaente Dec 5, 2024
a10b0fd
chore: move table creation to separate function
kolaente Dec 5, 2024
b1be455
chore: use OpenAI's batch type
kolaente Dec 5, 2024
b0e6a08
feat: generate full chunk id earlier
kolaente Dec 5, 2024
9c3e017
fix: correctly use embeddings endpoint
kolaente Dec 5, 2024
866e041
fix: properly convert time
kolaente Dec 5, 2024
2e57a35
feat: insert all chunks into the db after batch creation
kolaente Dec 5, 2024
5ef0491
fix: correctly process batches
kolaente Dec 5, 2024
fa10108
fix: return documents
kolaente Dec 5, 2024
783a62f
fix: use configured embeddings model
kolaente Dec 9, 2024
1be82ac
chore: rename write embeddings function
kolaente Dec 9, 2024
f7f6d13
chore: adjust function comment
kolaente Dec 9, 2024
d49c352
feat: create batch embedding tables in extension
kolaente Dec 9, 2024
cdad4bc
feat: move all queries to cached properties
kolaente Dec 9, 2024
2b17edc
Merge branch 'main' into feature/openai-batch-processing
kolaente Dec 18, 2024
ba6e179
fix: move batch embedding changes to openai embedder
kolaente Dec 18, 2024
9b4f3c3
fix: lint issues
kolaente Dec 18, 2024
9af30c1
fix: move batch embedding tables creation to embedding functions
kolaente Dec 19, 2024
cf86931
chore: rename text to chunk to match store table
kolaente Dec 19, 2024
916f7b2
feat: add total_attempts and next_attempt_after to openai batch table
kolaente Dec 19, 2024
ef4c382
feat: make fetching queries concurrently safe
kolaente Dec 19, 2024
f96906d
feat: make handling async embeddings more abstract
kolaente Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions projects/extension/sql/idempotent/008-embedding.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,31 @@ create or replace function ai.embedding_openai
, dimensions pg_catalog.int4
, chat_user pg_catalog.text default null
, api_key_name pg_catalog.text default 'OPENAI_API_KEY'
, use_batch_api pg_catalog.bool default false
, embedding_batch_schema pg_catalog.name default null
, embedding_batch_table pg_catalog.name default null
, embedding_batch_chunks_table pg_catalog.name default null
) returns pg_catalog.jsonb
as $func$
declare
_vectorizer_id pg_catalog.int4;
begin
_vectorizer_id = pg_catalog.nextval('ai.vectorizer_id_seq'::pg_catalog.regclass);
embedding_batch_schema = coalesce(embedding_batch_schema, 'ai');
embedding_batch_table = coalesce(embedding_batch_table, pg_catalog.concat('_vectorizer_embedding_batches_', _vectorizer_id));
embedding_batch_chunks_table = coalesce(embedding_batch_chunks_table, pg_catalog.concat('_vectorizer_embedding_batch_chunks_', _vectorizer_id));

select json_object
( 'implementation': 'openai'
, 'config_type': 'embedding'
, 'model': model
, 'dimensions': dimensions
, 'user': chat_user
, 'api_key_name': api_key_name
, 'use_batch_api': use_batch_api
, 'embedding_batch_schema': embedding_batch_schema
, 'embedding_batch_table': embedding_batch_table
, 'embedding_batch_chunks_table': embedding_batch_chunks_table
absent on null
)
$func$ language sql immutable security invoker
Expand Down Expand Up @@ -81,6 +97,9 @@ as $func$
declare
_config_type pg_catalog.text;
_implementation pg_catalog.text;
_embedding_batch_schema pg_catalog.text;
_embedding_batch_table pg_catalog.text;
_embedding_batch_chunks_table pg_catalog.text;
begin
if pg_catalog.jsonb_typeof(config) operator(pg_catalog.!=) 'object' then
raise exception 'embedding config is not a jsonb object';
Expand All @@ -93,6 +112,19 @@ begin
_implementation = config operator(pg_catalog.->>) 'implementation';
case _implementation
when 'openai' then
-- make sure embedding batch table name is available
select (config operator (pg_catalog.->> 'embedding_batch_schema'))::text into _embedding_batch_schema;
select (config operator (pg_catalog.->> 'embedding_batch_table'))::text into _embedding_batch_table;
select (config operator (pg_catalog.->> 'embedding_batch_chunks_table'))::text into _embedding_batch_chunks_table;
if pg_catalog.to_regclass(pg_catalog.format('%I.%I', _embedding_batch_schema, _embedding_batch_table)) is not null then
raise exception 'an object named %.% already exists. specify an alternate embedding_batch_table explicitly', queue_schema, queue_table;
end if;

-- make sure embedding batch chunks table name is available
if pg_catalog.to_regclass(pg_catalog.format('%I.%I', _embedding_batch_schema, _embedding_batch_chunks_table)) is not null then
raise exception 'an object named %.% already exists. specify an alternate embedding_batch_chunks_table explicitly', queue_schema, queue_table;
end if;

-- ok
when 'ollama' then
-- ok
Expand Down
14 changes: 12 additions & 2 deletions projects/extension/sql/idempotent/013-vectorizer-api.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


-------------------------------------------------------------------------------
-- execute_vectorizer
create or replace function ai.execute_vectorizer(vectorizer_id pg_catalog.int4) returns void
Expand Down Expand Up @@ -44,6 +42,7 @@ declare
_vectorizer_id pg_catalog.int4;
_sql pg_catalog.text;
_job_id pg_catalog.int8;
_implementation pg_catalog.text;
begin
-- make sure all the roles listed in grant_to exist
if grant_to is not null then
Expand Down Expand Up @@ -225,6 +224,17 @@ begin
scheduling = pg_catalog.jsonb_insert(scheduling, array['job_id'], pg_catalog.to_jsonb(_job_id));
end if;

-- create batch embedding tables
select (embedding operator (pg_catalog.->> 'implementation'))::text into _implementation;
if _implementation = 'openai' then
perform ai._vectorizer_create_embedding_batches_table
(embedding_batch_schema
, embedding_batch_table
, embedding_batch_chunks_table
, grant_to
);
end if;

insert into ai.vectorizer
( id
, source_schema
Expand Down
98 changes: 98 additions & 0 deletions projects/extension/sql/idempotent/016-openai-batch-api.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
-------------------------------------------------------------------------------
-- _vectorizer_create_queue_table
create or replace function ai._vectorizer_create_embedding_batches_table
( embedding_batch_schema name
, embedding_batch_table name
, embedding_batch_chunks_table name
, grant_to name[]
) returns void as
$func$
declare
_sql text;
begin
-- create the batches table
select pg_catalog.format
( $sql$create table %I.%I(
external_batch_id VARCHAR(255) PRIMARY KEY,
input_file_id VARCHAR(255) NOT NULL,
output_file_id VARCHAR(255),
status VARCHAR(255) NOT NULL,
errors JSONB,
created_at TIMESTAMP(0) NOT NULL DEFAULT NOW(),
expires_at TIMESTAMP(0),
completed_at TIMESTAMP(0),
failed_at TIMESTAMP(0),
next_attempt_after TIMESTAMPTZ,
total_attempts BIGINT NOT NULL DEFAULT 0
))$sql$
, embedding_batch_schema
, embedding_batch_table
) into strict _sql
;
execute _sql;

-- create the index
select pg_catalog.format
( $sql$create index on %I.%I (status)$sql$
, embedding_batch_schema, embedding_batch_table
) into strict _sql
;
execute _sql;

-- create the batch chunks table
select pg_catalog.format
( $sql$create table %I.%I(
id VARCHAR(255) PRIMARY KEY,
embedding_batch_id VARCHAR(255) REFERENCES %I.%I (external_batch_id) ON DELETE CASCADE,
chunk TEXT
))$sql$
, embedding_batch_schema
, embedding_batch_chunks_table
, embedding_batch_schema
, embedding_batch_table
) into strict _sql
;
execute _sql;

if grant_to is not null then
-- grant usage on queue schema to grant_to roles
select pg_catalog.format
( $sql$grant usage on schema %I to %s$sql$
, embedding_batch_schema
, (
select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ')
from pg_catalog.unnest(grant_to) x
)
) into strict _sql;
execute _sql;

-- grant select, update, delete on batches table to grant_to roles
select pg_catalog.format
( $sql$grant select, insert, update, delete on %I.%I to %s$sql$
, embedding_batch_schema
, embedding_batch_table
, (
select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ')
from pg_catalog.unnest(grant_to) x
)
) into strict _sql;
execute _sql;

-- grant select, update, delete on batch chunks table to grant_to roles
select pg_catalog.format
( $sql$grant select, insert, update, delete on %I.%I to %s$sql$
, embedding_batch_schema
, embedding_batch_chunks_table
, (
select pg_catalog.string_agg(pg_catalog.quote_ident(x), ', ')
from pg_catalog.unnest(grant_to) x
)
) into strict _sql;
execute _sql;
end if;
end;
$func$
language plpgsql volatile security invoker
set search_path to pg_catalog, pg_temp
;

147 changes: 146 additions & 1 deletion projects/pgai/pgai/vectorizer/embedders/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import re
import tempfile
from collections.abc import Sequence
from functools import cached_property
from typing import Any, Literal
Expand All @@ -8,6 +10,7 @@
from openai import resources
from pydantic import BaseModel
from typing_extensions import override
from psycopg import AsyncConnection

from ..embeddings import (
ApiKeyMixin,
Expand All @@ -21,6 +24,7 @@
Usage,
logger,
)
from ..vectorizer import AsyncBatch

TOKEN_CONTEXT_LENGTH_ERROR = "chunk exceeds model context length"

Expand All @@ -39,12 +43,20 @@ class OpenAI(ApiKeyMixin, BaseModel, Embedder):
model (str): The name of the OpenAI model used for embeddings.
dimensions (int | None): Optional dimensions for the embeddings.
user (str | None): Optional user identifier for OpenAI API usage.
use_batch (bool): Whether to use OpenAI Batch API.
embedding_batch_schema (str | None): The schema where the embedding batches are stored.
embedding_batch_table (str | None): The table where the embedding batches are stored.
embedding_batch_chunks_table (str | None): The table where the embedding batch chunks are stored.
"""

implementation: Literal["openai"]
model: str
dimensions: int | None = None
user: str | None = None
use_batch: bool = False
embedding_batch_schema: str | None = None
embedding_batch_table: str | None = None
embedding_batch_chunks_table: str | None = None

@cached_property
def _openai_dimensions(self) -> int | openai.NotGiven:
Expand All @@ -58,9 +70,13 @@ def _openai_dimensions(self) -> int | openai.NotGiven:
def _openai_user(self) -> str | openai.NotGiven:
return self.user if self.user is not None else openai.NOT_GIVEN

@cached_property
def _client(self) -> resources.Client:
return openai.AsyncOpenAI(api_key=self._api_key, max_retries=3)

@cached_property
def _embedder(self) -> resources.AsyncEmbeddings:
return openai.AsyncOpenAI(api_key=self._api_key, max_retries=3).embeddings
return self._client.embeddings

@override
def _max_chunks_per_batch(self) -> int:
Expand Down Expand Up @@ -129,6 +145,55 @@ async def embed(
model_token_length, encoded_documents
)

async def create_and_submit_embedding_batch(
self,
documents: list[dict[str, Any]],
) -> AsyncBatch:
"""
Creates a batch of embeddings using OpenAI's embeddings API as outlined in
https://platform.openai.com/docs/guides/batch/batch-api?lang=python

Args:
documents (list[str]): A list of document chunks to be embedded.

Returns:

"""

with tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl", mode="w") as temp_file:
for document in documents:
entry = {
"custom_id": document["unique_full_chunk_id"],
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": self.model,
"input": document["chunk"],
},
}
temp_file.write(json.dumps(entry) + "\n")

temp_file.close()

with open(temp_file.name, "rb") as file:
batch_input_file = self._client.files.create(
file=file,
purpose="batch",
)

openai_batch = self._client.batches.create(
input_file_id=batch_input_file.id,
endpoint="/v1/embeddings",
completion_window="24h",
)

batch = AsyncBatch()
batch.external_batch_id = openai_batch.id
batch.input_file_id = openai_batch.input_file_id
batch.status = openai_batch.status

return batch

async def _filter_by_length_and_embed(
self, model_token_length: int, encoded_documents: list[list[int]]
) -> Sequence[EmbeddingVector | ChunkEmbeddingError]:
Expand Down Expand Up @@ -200,3 +265,83 @@ async def _encode(self, documents: list[str]) -> list[list[int]]:
@cached_property
def _encoder(self) -> tiktoken.Encoding:
return tiktoken.encoding_for_model(self.model)

def is_api_async(self) -> bool:
return self.use_batch

async def fetch_async_embedding_status(self, batch: AsyncBatch) -> AsyncBatch:
openai_batch = self._client.batches.retrieve(batch.external_batch_id)

batch.status = openai_batch.status
batch.completed_at = openai_batch.completed_at
batch.failed_at = openai_batch.failed_at
batch.errors = openai_batch.errors

return batch

async def process_async_embedding(
self,
conn: AsyncConnection,
batch: AsyncBatch,
):
"""
Writes embeddings from an OpenAI batch embedding to the database.

- Deletes existing embeddings for the items.
- Loads created embeddings from the batch.
- Writes created embeddings to the database.
- Logs any non-fatal errors encountered during embedding.

Args:
conn (AsyncConnection): The database connection.
batch: The batch as stored in the queue table.
"""
openai_batch = self._client.batches.retrieve(batch.external_batch_id)
batch_file = self._client.files.content(openai_batch.output_file_id)

batch_data = batch_file.text.strip().split("\n")
num_records = 0
all_items = []
all_records: list[EmbeddingRecord] = []

async with conn.cursor() as cursor:
await cursor.execute(
self.queries.fetch_chunks_for_batch_id_query
(batch.id,)
)
embedding_batch_chunks = {row[0]: row[1] for row in await cursor.fetchall()}

for line in batch_data:
json_line = json.loads(line)
if "custom_id" in json_line and "response" in json_line:

custom_id = json_line["custom_id"]
pk_names, document_id, chunk_seq = custom_id.split(":::")
embedding_data = json_line["response"]["body"]["data"][0]["embedding"]

resolved_id = document_id.split(",")
resolved_pk = pk_names.split(",")
item = {pk: id_value
for pk, id_value in zip(resolved_pk, resolved_id, strict=False)}
item[self.vectorizer.config.chunking.chunk_column] = embedding_batch_chunks[custom_id]

all_items.append(item)
all_records.append([
resolved_id
+ [chunk_seq, embedding_batch_chunks[custom_id]]
+ [np.array(embedding_data)]])

await self._delete_embeddings(conn, all_items)
for records in all_records:
await self._copy_embeddings(conn, records)

return num_records


async def finalize_async_embedding(
self,
batch: AsyncBatch,
):
openai_batch = self._client.batches.retrieve(batch.external_batch_id)
await self._client.files.delete(openai_batch.input_file_id)
await self._client.files.delete(openai_batch.output_file_id)
Loading
Loading