From 3a4d90b18c971e970caefafa49a6a6c1310e1deb Mon Sep 17 00:00:00 2001 From: Tarrence van As Date: Fri, 29 Nov 2024 21:33:01 -0500 Subject: [PATCH] Migrate to new embeddings api --- Cargo.lock | 22 +- rig-sqlite/examples/vector_search_sqlite.rs | 71 ++- rig-sqlite/src/lib.rs | 555 ++++++++++++-------- 3 files changed, 391 insertions(+), 257 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9c4ec4bb..a2852332 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6800,6 +6800,17 @@ dependencies = [ "syn 2.0.89", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74a82c26c3986af2623ec9eb890ff4aa19c006e30a1133dc9bd1830ec1612e20" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "zerofrom" version = "0.1.5" @@ -6821,17 +6832,6 @@ dependencies = [ "synstructure", ] -[[package]] -name = "zerocopy-derive" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74a82c26c3986af2623ec9eb890ff4aa19c006e30a1133dc9bd1830ec1612e20" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.79", -] - [[package]] name = "zeroize" version = "1.8.1" diff --git a/rig-sqlite/examples/vector_search_sqlite.rs b/rig-sqlite/examples/vector_search_sqlite.rs index 85c68efc..90241fb3 100644 --- a/rig-sqlite/examples/vector_search_sqlite.rs +++ b/rig-sqlite/examples/vector_search_sqlite.rs @@ -1,15 +1,47 @@ -use rig::vector_store::VectorStore; use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, + Embed, }; -use rig_sqlite::SqliteVectorStore; +use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable}; use rusqlite::ffi::sqlite3_auto_extension; +use serde::Deserialize; use sqlite_vec::sqlite3_vec_init; use std::env; use tokio_rusqlite::Connection; +#[derive(Embed, Clone, Debug, Deserialize)] +struct Document { + id: String, + #[embed] + content: String, +} + +impl SqliteVectorStoreTable for Document { + fn name() -> &'static str { + "documents" + } + + fn schema() -> Vec { + vec![ + Column::new("id", "TEXT PRIMARY KEY"), + Column::new("content", "TEXT"), + ] + } + + fn id(&self) -> String { + self.id.clone() + } + + fn column_values(&self) -> Vec<(&'static str, Box)> { + vec![ + ("id", Box::new(self.id.clone())), + ("content", Box::new(self.content.clone())), + ] + } +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::fmt() @@ -35,29 +67,38 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let documents = vec![ + Document { + id: "doc0".to_string(), + content: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + }, + Document { + id: "doc1".to_string(), + content: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + }, + Document { + id: "doc2".to_string(), + content: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + }, + ]; + let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") - .build() - .await?; + .documents(documents)? + .build() + .await?; // Initialize SQLite vector store - let mut vector_store = SqliteVectorStore::new(conn, &model).await?; + let vector_store = SqliteVectorStore::new(conn, &model).await?; // Add embeddings to vector store - match vector_store.add_documents(embeddings).await { - Ok(_) => println!("Documents added successfully"), - Err(e) => println!("Error adding documents: {:?}", e), - } + vector_store.add_rows(embeddings).await?; // Create a vector index on our vector store - // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model).await?; + let index = vector_store.index(model); // Query the index let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, id, doc)| (score, id, doc)) diff --git a/rig-sqlite/src/lib.rs b/rig-sqlite/src/lib.rs index 6f777400..e59af602 100644 --- a/rig-sqlite/src/lib.rs +++ b/rig-sqlite/src/lib.rs @@ -1,6 +1,6 @@ -use rig::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}; -use rig::vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}; -use rusqlite::OptionalExtension; +use rig::embeddings::{Embedding, EmbeddingModel}; +use rig::vector_store::{VectorStoreError, VectorStoreIndex}; +use rig::OneOrMany; use serde::Deserialize; use std::marker::PhantomData; use tokio_rusqlite::Connection; @@ -11,34 +11,141 @@ use zerocopy::IntoBytes; pub enum SqliteError { DatabaseError(Box), SerializationError(Box), + InvalidColumnType(String), +} + +pub trait ColumnValue: Send + Sync { + fn to_sql_string(&self) -> String; + fn column_type(&self) -> &'static str; +} + +pub struct Column { + name: &'static str, + col_type: &'static str, + indexed: bool, +} + +impl Column { + pub fn new(name: &'static str, col_type: &'static str) -> Self { + Self { + name, + col_type, + indexed: false, + } + } + + pub fn indexed(mut self) -> Self { + self.indexed = true; + self + } +} + +/// Example of a document type that can be used with SqliteVectorStore +/// ```rust +/// use rig::Embed; +/// use serde::Deserialize; +/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable}; +/// +/// #[derive(Embed, Clone, Debug, Deserialize)] +/// struct Document { +/// id: String, +/// #[embed] +/// content: String, +/// } +/// +/// impl SqliteVectorStoreTable for Document { +/// fn name() -> &'static str { +/// "documents" +/// } +/// +/// fn schema() -> Vec { +/// vec![ +/// Column::new("id", "TEXT PRIMARY KEY"), +/// Column::new("content", "TEXT"), +/// ] +/// } +/// +/// fn id(&self) -> String { +/// self.id.clone() +/// } +/// +/// fn column_values(&self) -> Vec<(&'static str, Box)> { +/// vec![ +/// ("id", Box::new(self.id.clone())), +/// ("content", Box::new(self.content.clone())), +/// ] +/// } +/// } +/// ``` +pub trait SqliteVectorStoreTable: Send + Sync + Clone { + fn name() -> &'static str; + fn schema() -> Vec; + fn id(&self) -> String; + fn column_values(&self) -> Vec<(&'static str, Box)>; } #[derive(Clone)] -pub struct SqliteVectorStore { +pub struct SqliteVectorStore { conn: Connection, - _phantom: PhantomData, + _phantom: PhantomData<(E, T)>, } -impl SqliteVectorStore { +impl SqliteVectorStore { pub async fn new(conn: Connection, embedding_model: &E) -> Result { - // Run migrations or create tables if they don't exist let dims = embedding_model.ndims(); + let table_name = T::name(); + let schema = T::schema(); + + // Build the table schema + let mut create_table = format!("CREATE TABLE IF NOT EXISTS {} (", table_name); + + // Add columns + let mut first = true; + for column in &schema { + if !first { + create_table.push_str(","); + } + create_table.push_str(&format!("\n {} {}", column.name, column.col_type)); + first = false; + } + + create_table.push_str("\n)"); + + // Build index creation statements + let mut create_indexes = vec![format!( + "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)", + table_name, table_name + )]; + + // Add indexes for marked columns + for column in schema { + if column.indexed { + create_indexes.push(format!( + "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})", + table_name, column.name, table_name, column.name + )); + } + } + conn.call(move |conn| { + conn.execute_batch("BEGIN")?; + + // Create document table + conn.execute_batch(&create_table)?; + + // Create indexes + for index_stmt in create_indexes { + conn.execute_batch(&index_stmt)?; + } + + // Create embeddings table conn.execute_batch(&format!( - "BEGIN; - -- Document tables - CREATE TABLE IF NOT EXISTS documents ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - document_id TEXT UNIQUE NOT NULL, - document TEXT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_document_id ON documents(document_id); - CREATE VIRTUAL TABLE IF NOT EXISTS embeddings USING vec0(embedding float[{}]); - - COMMIT;", - dims - )) - .map_err(tokio_rusqlite::Error::from) + "CREATE VIRTUAL TABLE IF NOT EXISTS {}_embeddings USING vec0(embedding float[{}])", + table_name, dims + ))?; + + conn.execute_batch("COMMIT")?; + Ok(()) }) .await .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; @@ -49,211 +156,166 @@ impl SqliteVectorStore { }) } - fn serialize_embedding(embedding: &Embedding) -> Vec { - embedding.vec.iter().map(|x| *x as f32).collect() + pub fn index(self, model: E) -> SqliteVectorIndex { + SqliteVectorIndex::new(model, self) } - /// Create a new `SqliteVectorIndex` from an existing `SqliteVectorStore`. - pub async fn index(&self, model: E) -> Result, VectorStoreError> { - Ok(SqliteVectorIndex::new(model, self.clone())) - } -} - -impl VectorStore for SqliteVectorStore { - type Q = String; - - async fn add_documents( - &mut self, - documents: Vec, - ) -> Result<(), VectorStoreError> { - info!("Adding {} documents to store", documents.len()); - self.conn - .call(|conn| { - let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?; - - for doc in documents { - debug!("Storing document with id {}", doc.id); - // Store document and get auto-incremented ID - tx.execute( - "INSERT OR REPLACE INTO documents (document_id, document) VALUES (?1, ?2)", - [&doc.id, &doc.document.to_string()], - ) - .map_err(tokio_rusqlite::Error::from)?; - - let document_id = tx.last_insert_rowid(); - - // Store embeddings - let mut stmt = tx - .prepare("INSERT INTO embeddings (rowid, embedding) VALUES (?1, ?2)") - .map_err(tokio_rusqlite::Error::from)?; - - debug!( - "Storing {} embeddings for document {}", - doc.embeddings.len(), - doc.id - ); - for embedding in doc.embeddings { - let vec = Self::serialize_embedding(&embedding); - let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec()); - stmt.execute(rusqlite::params![document_id, blob]) - .map_err(tokio_rusqlite::Error::from)?; - } - } - - tx.commit().map_err(tokio_rusqlite::Error::from)?; - Ok(()) - }) - .await - .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - - Ok(()) - } - - async fn get_document Deserialize<'a>>( + pub fn add_rows_with_txn( &self, - id: &str, - ) -> Result, VectorStoreError> { - debug!("Fetching document with id {}", id); - let id_clone = id.to_string(); - let doc_str = self - .conn - .call(move |conn| { - conn.query_row( - "SELECT document FROM documents WHERE document_id = ?1", - rusqlite::params![id_clone], - |row| row.get::<_, String>(0), - ) - .optional() - .map_err(tokio_rusqlite::Error::from) - }) - .await - .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - - match doc_str { - Some(doc_str) => { - let doc: T = serde_json::from_str(&doc_str) - .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - Ok(Some(doc)) - } - None => { - debug!("No document found with id {}", id); - Ok(None) + txn: &rusqlite::Transaction<'_>, + documents: Vec<(T, OneOrMany)>, + ) -> Result { + info!("Adding {} documents to store", documents.len()); + let table_name = T::name(); + let mut last_id = 0; + + for (doc, embeddings) in &documents { + debug!("Storing document with id {}", doc.id()); + + let values = doc.column_values(); + let columns = values.iter().map(|(col, _)| *col).collect::>(); + let placeholders = (1..=values.len()) + .map(|i| format!("?{}", i)) + .collect::>(); + + txn.execute( + &format!( + "INSERT OR REPLACE INTO {} ({}) VALUES ({})", + table_name, + columns.join(", "), + placeholders.join(", ") + ), + rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())), + )?; + last_id = txn.last_insert_rowid(); + + let mut stmt = txn.prepare(&format!( + "INSERT INTO {}_embeddings (rowid, embedding) VALUES (?1, ?2)", + table_name + ))?; + debug!( + "Storing {} embeddings for document {}", + embeddings.len(), + doc.id() + ); + for embedding in embeddings.iter() { + let vec = serialize_embedding(&embedding); + let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec()); + stmt.execute(rusqlite::params![last_id, blob])?; } } - } - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result, VectorStoreError> { - debug!("Fetching embeddings for document {}", id); - let id_clone = id.to_string(); - let result = self - .conn - .call(move |conn| { - let mut stmt = conn.prepare( - "SELECT e.embedding, d.document - FROM embeddings e - JOIN documents d ON e.rowid = d.id - WHERE d.document_id = ?1", - )?; - - let result = stmt - .query_map(rusqlite::params![id_clone], |row| { - let bytes: Vec = row.get(0)?; - let doc_str: String = row.get(1)?; - let doc: serde_json::Value = - serde_json::from_str(&doc_str).map_err(|e| { - rusqlite::Error::FromSqlConversionFailure( - 0, - rusqlite::types::Type::Text, - Box::new(e), - ) - })?; - let vec = bytes - .chunks(4) - .map(|chunk| { - f32::from_le_bytes( - chunk - .try_into() - .expect("Invalid chunk length - must be 4 bytes"), - ) as f64 - }) - .collect(); - Ok(( - rig::embeddings::Embedding { - vec, - document: "".to_string(), - }, - doc, - )) - })? - .collect::, _>>()?; - Ok(result) - }) - .await - .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - - if let Some((_, doc)) = result.first() { - let embeddings: Vec = result.iter().map(|(e, _)| e.clone()).collect(); - debug!("Found {} embeddings for document {}", embeddings.len(), id); - Ok(Some(DocumentEmbeddings { - id: id.to_string(), - document: doc.clone(), - embeddings, - })) - } else { - debug!("No embeddings found for document {}", id); - Ok(None) - } + Ok(last_id) } - async fn get_document_by_query( + pub async fn add_rows( &self, - query: Self::Q, - ) -> Result, VectorStoreError> { - debug!("Searching for document matching query"); - let result = self - .conn - .call(move |conn| { - let mut stmt = conn.prepare( - "SELECT d.document_id, e.distance - FROM embeddings e - JOIN documents d ON e.rowid = d.id - WHERE e.embedding MATCH ?1 AND k = ?2 - ORDER BY e.distance", - )?; + documents: Vec<(T, OneOrMany)>, + ) -> Result { + let documents = documents.clone(); + let this = self.clone(); - let result = stmt - .query_row(rusqlite::params![query.as_bytes(), 1], |row| { - Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?)) - }) - .optional()?; + self.conn + .call(move |conn| { + let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?; + let result = this.add_rows_with_txn(&tx, documents)?; + tx.commit().map_err(tokio_rusqlite::Error::from)?; Ok(result) }) .await - .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - - match result { - Some((id, distance)) => { - debug!("Found matching document {} with distance {}", id, distance); - self.get_document_embeddings(&id).await - } - None => { - debug!("No matching documents found"); - Ok(None) - } - } + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e))) } } -pub struct SqliteVectorIndex { - store: SqliteVectorStore, +/// SQLite vector store implementation for Rig. +/// +/// This crate provides a SQLite-based vector store implementation that can be used with Rig. +/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities. +/// +/// # Example +/// ```rust +/// use rig::{ +/// embeddings::EmbeddingsBuilder, +/// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +/// vector_store::VectorStoreIndex, +/// Embed, +/// }; +/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable}; +/// use serde::Deserialize; +/// use tokio_rusqlite::Connection; +/// +/// #[derive(Embed, Clone, Debug, Deserialize)] +/// struct Document { +/// id: String, +/// #[embed] +/// content: String, +/// } +/// +/// impl SqliteVectorStoreTable for Document { +/// fn name() -> &'static str { +/// "documents" +/// } +/// +/// fn schema() -> Vec { +/// vec![ +/// Column::new("id", "TEXT PRIMARY KEY"), +/// Column::new("content", "TEXT"), +/// ] +/// } +/// +/// fn id(&self) -> String { +/// self.id.clone() +/// } +/// +/// fn column_values(&self) -> Vec<(&'static str, Box)> { +/// vec![ +/// ("id", Box::new(self.id.clone())), +/// ("content", Box::new(self.content.clone())), +/// ] +/// } +/// } +/// +/// let conn = Connection::open("vector_store.db").await?; +/// let openai_client = Client::new("YOUR_API_KEY"); +/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); +/// +/// // Initialize vector store +/// let vector_store = SqliteVectorStore::new(conn, &model).await?; +/// +/// // Create documents +/// let documents = vec![ +/// Document { +/// id: "doc1".to_string(), +/// content: "Example document 1".to_string(), +/// }, +/// Document { +/// id: "doc2".to_string(), +/// content: "Example document 2".to_string(), +/// }, +/// ]; +/// +/// // Generate embeddings +/// let embeddings = EmbeddingsBuilder::new(model.clone()) +/// .documents(documents)? +/// .build() +/// .await?; +/// +/// // Add to vector store +/// vector_store.add_rows(embeddings).await?; +/// +/// // Create index and search +/// let index = vector_store.index(model); +/// let results = index +/// .top_n::("Example query", 2) +/// .await?; +/// ``` +pub struct SqliteVectorIndex { + store: SqliteVectorStore, embedding_model: E, } -impl SqliteVectorIndex { - pub fn new(embedding_model: E, store: SqliteVectorStore) -> Self { +impl SqliteVectorIndex { + pub fn new(embedding_model: E, store: SqliteVectorStore) -> Self { Self { store, embedding_model, @@ -261,35 +323,50 @@ impl SqliteVectorIndex { } } -impl VectorStoreIndex for SqliteVectorIndex { - async fn top_n Deserialize<'a>>( +impl VectorStoreIndex + for SqliteVectorIndex +{ + async fn top_n Deserialize<'a>>( &self, query: &str, n: usize, - ) -> Result, VectorStoreError> { + ) -> Result, VectorStoreError> { debug!("Finding top {} matches for query", n); - let embedding = self.embedding_model.embed_document(query).await?; - let query_vec = SqliteVectorStore::::serialize_embedding(&embedding); + let embedding = self.embedding_model.embed_text(query).await?; + let query_vec: Vec = serialize_embedding(&embedding); + let table_name = T::name(); + + // Get all column names from SqliteVectorStoreTable + let columns = T::schema(); + let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect(); let rows = self .store .conn .call(move |conn| { - let mut stmt = conn.prepare( - "SELECT d.document_id, d.document, e.distance - FROM embeddings e - JOIN documents d ON e.rowid = d.id + // Build SELECT statement with all columns + let select_cols = column_names.join(", "); + let mut stmt = conn.prepare(&format!( + "SELECT d.{}, e.distance + FROM {}_embeddings e + JOIN {} d ON e.rowid = d.id WHERE e.embedding MATCH ?1 AND k = ?2 ORDER BY e.distance", - )?; + select_cols, table_name, table_name + ))?; let rows = stmt .query_map(rusqlite::params![query_vec.as_bytes().to_vec(), n], |row| { - Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, f64>(2)?, - )) + // Create a map of column names to values + let mut map = serde_json::Map::new(); + for (i, col_name) in column_names.iter().enumerate() { + let value: String = row.get(i)?; + map.insert(col_name.to_string(), serde_json::Value::String(value)); + } + let distance: f64 = row.get(column_names.len())?; + let id: String = row.get(0)?; // Assuming id is always first column + + Ok((id, serde_json::Value::Object(map), distance)) })? .collect::, _>>()?; Ok(rows) @@ -299,8 +376,8 @@ impl VectorStoreIndex for SqliteVectorInd debug!("Found {} potential matches", rows.len()); let mut top_n = Vec::new(); - for (id, doc_str, distance) in rows { - match serde_json::from_str::(&doc_str) { + for (id, doc_value, distance) in rows { + match serde_json::from_value::(doc_value) { Ok(doc) => { top_n.push((distance, id, doc)); } @@ -321,20 +398,22 @@ impl VectorStoreIndex for SqliteVectorInd n: usize, ) -> Result, VectorStoreError> { debug!("Finding top {} document IDs for query", n); - let embedding = self.embedding_model.embed_document(query).await?; - let query_vec = SqliteVectorStore::::serialize_embedding(&embedding); + let embedding = self.embedding_model.embed_text(query).await?; + let query_vec = serialize_embedding(&embedding); + let table_name = T::name(); let results = self .store .conn .call(move |conn| { - let mut stmt = conn.prepare( - "SELECT d.document_id, e.distance - FROM embeddings e - JOIN documents d ON e.rowid = d.id + let mut stmt = conn.prepare(&format!( + "SELECT d.id, e.distance + FROM {0}_embeddings e + JOIN {0} d ON e.rowid = d.id WHERE e.embedding MATCH ?1 AND k = ?2 ORDER BY e.distance", - )?; + table_name + ))?; let results = stmt .query_map( @@ -357,3 +436,17 @@ impl VectorStoreIndex for SqliteVectorInd Ok(results) } } + +fn serialize_embedding(embedding: &Embedding) -> Vec { + embedding.vec.iter().map(|x| *x as f32).collect() +} + +impl ColumnValue for String { + fn to_sql_string(&self) -> String { + self.clone() + } + + fn column_type(&self) -> &'static str { + "TEXT" + } +}