From d6a959ece9c8af754996305773cffe65ed8a6c1e Mon Sep 17 00:00:00 2001 From: Tarrence van As Date: Mon, 25 Nov 2024 20:55:57 -0500 Subject: [PATCH] Improve SqliteVectorStore init --- Cargo.lock | 27 +++++++- rig-sqlite/Cargo.toml | 1 + rig-sqlite/examples/vector_search_sqlite.rs | 16 ++--- rig-sqlite/src/lib.rs | 74 ++++++++++++--------- 4 files changed, 74 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8d318fbc..cd11f61b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,7 +28,7 @@ dependencies = [ "getrandom", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -3953,7 +3953,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -4549,6 +4549,7 @@ dependencies = [ "tokio-rusqlite", "tracing", "tracing-subscriber", + "zerocopy 0.8.11", ] [[package]] @@ -6462,7 +6463,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cce3b5629d87654b53a49002acc2ce64aa5aa7255f5c718374a37ac7fd98c218" +dependencies = [ + "zerocopy-derive 0.8.11", ] [[package]] @@ -6476,6 +6486,17 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74a82c26c3986af2623ec9eb890ff4aa19c006e30a1133dc9bd1830ec1612e20" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/rig-sqlite/Cargo.toml b/rig-sqlite/Cargo.toml index 49d6fde0..e64532fc 100644 --- a/rig-sqlite/Cargo.toml +++ b/rig-sqlite/Cargo.toml @@ -15,6 +15,7 @@ tokio-rusqlite = { git = "https://github.com/programatik29/tokio-rusqlite", feat "bundled", ] } tracing = "0.1" +zerocopy = "0.8.10" chrono = "0.4" [dev-dependencies] diff --git a/rig-sqlite/examples/vector_search_sqlite.rs b/rig-sqlite/examples/vector_search_sqlite.rs index 20fab71b..49c2ba73 100644 --- a/rig-sqlite/examples/vector_search_sqlite.rs +++ b/rig-sqlite/examples/vector_search_sqlite.rs @@ -32,18 +32,18 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize SQLite connection let conn = Connection::open("vector_store.db").await?; - // Initialize SQLite vector store - let mut vector_store = SqliteVectorStore::new(conn).await?; - // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") - .build() - .await?; + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .build() + .await?; + + // Initialize SQLite vector store + let mut vector_store = SqliteVectorStore::new(conn, model.clone()).await?; // Add embeddings to vector store match vector_store.add_documents(embeddings).await { diff --git a/rig-sqlite/src/lib.rs b/rig-sqlite/src/lib.rs index 7ee91312..2467be0a 100644 --- a/rig-sqlite/src/lib.rs +++ b/rig-sqlite/src/lib.rs @@ -2,8 +2,10 @@ use rig::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}; use rig::vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}; use rusqlite::OptionalExtension; use serde::Deserialize; +use std::marker::PhantomData; use tokio_rusqlite::Connection; use tracing::{debug, info}; +use zerocopy::IntoBytes; #[derive(Debug)] pub enum SqliteError { @@ -12,34 +14,39 @@ pub enum SqliteError { } #[derive(Clone)] -pub struct SqliteVectorStore { +pub struct SqliteVectorStore { conn: Connection, + _phantom: PhantomData, } -impl SqliteVectorStore { - pub async fn new(conn: Connection) -> Result { - debug!("Running initial migrations"); +impl SqliteVectorStore { + pub async fn new(conn: Connection, embedding_model: E) -> Result { // Run migrations or create tables if they don't exist - conn.call(|conn| { - conn.execute_batch( + let dims = embedding_model.ndims(); + conn.call(move |conn| { + conn.execute_batch(&format!( "BEGIN; -- Document tables CREATE TABLE IF NOT EXISTS documents ( id INTEGER PRIMARY KEY AUTOINCREMENT, - doc_id TEXT UNIQUE NOT NULL, + document_id TEXT UNIQUE NOT NULL, document TEXT NOT NULL ); - CREATE INDEX IF NOT EXISTS idx_doc_id ON documents(doc_id); - CREATE VIRTUAL TABLE IF NOT EXISTS embeddings USING vec0(embedding float[1536]); + CREATE INDEX IF NOT EXISTS idx_document_id ON documents(document_id); + CREATE VIRTUAL TABLE IF NOT EXISTS embeddings USING vec0(embedding float[{}]); COMMIT;", - ) + dims + )) .map_err(tokio_rusqlite::Error::from) }) .await .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - Ok(Self { conn }) + Ok(Self { + conn, + _phantom: PhantomData, + }) } fn serialize_embedding(embedding: &Embedding) -> Vec { @@ -47,15 +54,12 @@ impl SqliteVectorStore { } /// Create a new `SqliteVectorIndex` from an existing `SqliteVectorStore`. - pub async fn index( - &self, - model: M, - ) -> Result, VectorStoreError> { + pub async fn index(&self, model: E) -> Result, VectorStoreError> { Ok(SqliteVectorIndex::new(model, self.clone())) } } -impl VectorStore for SqliteVectorStore { +impl VectorStore for SqliteVectorStore { type Q = String; async fn add_documents( @@ -71,12 +75,12 @@ impl VectorStore for SqliteVectorStore { debug!("Storing document with id {}", doc.id); // Store document and get auto-incremented ID tx.execute( - "INSERT OR REPLACE INTO documents (doc_id, document) VALUES (?1, ?2)", + "INSERT OR REPLACE INTO documents (document_id, document) VALUES (?1, ?2)", [&doc.id, &doc.document.to_string()], ) .map_err(tokio_rusqlite::Error::from)?; - let doc_id = tx.last_insert_rowid(); + let document_id = tx.last_insert_rowid(); // Store embeddings let mut stmt = tx @@ -90,10 +94,8 @@ impl VectorStore for SqliteVectorStore { ); for embedding in doc.embeddings { let vec = Self::serialize_embedding(&embedding); - let blob = rusqlite::types::Value::Blob( - vec.iter().flat_map(|x| x.to_le_bytes()).collect(), - ); - stmt.execute(rusqlite::params![doc_id, blob]) + let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec()); + stmt.execute(rusqlite::params![document_id, blob]) .map_err(tokio_rusqlite::Error::from)?; } } @@ -117,7 +119,7 @@ impl VectorStore for SqliteVectorStore { .conn .call(move |conn| { conn.query_row( - "SELECT document FROM documents WHERE doc_id = ?1", + "SELECT document FROM documents WHERE document_id = ?1", rusqlite::params![id_clone], |row| row.get::<_, String>(0), ) @@ -153,7 +155,7 @@ impl VectorStore for SqliteVectorStore { "SELECT e.embedding, d.document FROM embeddings e JOIN documents d ON e.rowid = d.id - WHERE d.doc_id = ?1", + WHERE d.document_id = ?1", )?; let result = stmt @@ -170,7 +172,13 @@ impl VectorStore for SqliteVectorStore { })?; let vec = bytes .chunks(4) - .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()) as f64) + .map(|chunk| { + f32::from_le_bytes( + chunk + .try_into() + .expect("Invalid chunk length - must be 4 bytes"), + ) as f64 + }) .collect(); Ok(( rig::embeddings::Embedding { @@ -209,7 +217,7 @@ impl VectorStore for SqliteVectorStore { .conn .call(move |conn| { let mut stmt = conn.prepare( - "SELECT d.doc_id, e.distance + "SELECT d.document_id, e.distance FROM embeddings e JOIN documents d ON e.rowid = d.id WHERE e.embedding MATCH ?1 AND k = ?2 @@ -240,12 +248,12 @@ impl VectorStore for SqliteVectorStore { } pub struct SqliteVectorIndex { - store: SqliteVectorStore, + store: SqliteVectorStore, embedding_model: E, } impl SqliteVectorIndex { - pub fn new(embedding_model: E, store: SqliteVectorStore) -> Self { + pub fn new(embedding_model: E, store: SqliteVectorStore) -> Self { Self { store, embedding_model, @@ -261,14 +269,14 @@ impl VectorStoreIndex for SqliteVectorInd ) -> Result, VectorStoreError> { debug!("Finding top {} matches for query", n); let embedding = self.embedding_model.embed_document(query).await?; - let query_vec = SqliteVectorStore::serialize_embedding(&embedding); + let query_vec = SqliteVectorStore::::serialize_embedding(&embedding); let rows = self .store .conn .call(move |conn| { let mut stmt = conn.prepare( - "SELECT d.doc_id, d.document, e.distance + "SELECT d.document_id, d.document, e.distance FROM embeddings e JOIN documents d ON e.rowid = d.id WHERE e.embedding MATCH ?1 AND k = ?2 @@ -323,17 +331,17 @@ impl VectorStoreIndex for SqliteVectorInd ) -> Result, VectorStoreError> { debug!("Finding top {} document IDs for query", n); let embedding = self.embedding_model.embed_document(query).await?; - let query_vec = SqliteVectorStore::serialize_embedding(&embedding); + let query_vec = SqliteVectorStore::::serialize_embedding(&embedding); let results = self .store .conn .call(move |conn| { let mut stmt = conn.prepare( - "SELECT d.doc_id, e.distance + "SELECT d.document_id, e.distance FROM embeddings e JOIN documents d ON e.rowid = d.id - WHERE e.embedding MATCH ?1 AND k = ?2 + WHERE e.embedding MATCH ?1 AND k = ?2 ORDER BY e.distance", )?;