From 53787965b0a22c3598bede0d29e94dfa3b579b6a Mon Sep 17 00:00:00 2001 From: Tarrence van As Date: Tue, 3 Dec 2024 14:58:56 -0500 Subject: [PATCH] feat: Add support for Sqlite vector store (#122) * Add sqlite vector store * Improve SqliteVectorStore init * Pass embedding model by ref * Migrate to new embeddings api * Add test * test: disable doctests as they don't work w/ async * chore: fix lock file --------- Co-authored-by: 0xMochan --- Cargo.lock | 159 +++++- Cargo.toml | 3 +- rig-sqlite/Cargo.toml | 27 + rig-sqlite/LICENSE | 7 + rig-sqlite/README.md | 46 ++ rig-sqlite/examples/vector_search_sqlite.rs | 118 +++++ rig-sqlite/src/lib.rs | 557 ++++++++++++++++++++ 7 files changed, 902 insertions(+), 15 deletions(-) create mode 100644 rig-sqlite/Cargo.toml create mode 100644 rig-sqlite/LICENSE create mode 100644 rig-sqlite/README.md create mode 100644 rig-sqlite/examples/vector_search_sqlite.rs create mode 100644 rig-sqlite/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 2b14046e..fa37b771 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,7 +28,7 @@ dependencies = [ "getrandom", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -297,7 +297,7 @@ dependencies = [ "memchr", "num", "regex", - "regex-syntax", + "regex-syntax 0.8.5", ] [[package]] @@ -1555,7 +1555,7 @@ dependencies = [ "itertools 0.12.1", "log", "paste", - "regex-syntax", + "regex-syntax 0.8.5", ] [[package]] @@ -1926,6 +1926,18 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastdivide" version = "0.4.2" @@ -2194,8 +2206,8 @@ dependencies = [ "aho-corasick", "bstr", "log", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -2285,6 +2297,15 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.4.1" @@ -2815,7 +2836,7 @@ dependencies = [ "globset", "log", "memchr", - "regex-automata", + "regex-automata 0.4.9", "same-file", "walkdir", "winapi-util", @@ -3457,6 +3478,17 @@ dependencies = [ "redox_syscall 0.5.7", ] +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -3559,6 +3591,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "matchit" version = "0.7.3" @@ -4071,7 +4112,7 @@ checksum = "914a1c2265c98e2446911282c6ac86d8524f495792c38c5bd884f80499c7538a" dependencies = [ "parse-display-derive", "regex", - "regex-syntax", + "regex-syntax 0.8.5", ] [[package]] @@ -4083,7 +4124,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "regex-syntax", + "regex-syntax 0.8.5", "structmeta", "syn 2.0.90", ] @@ -4244,7 +4285,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -4618,8 +4659,17 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -4630,7 +4680,7 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.5", ] [[package]] @@ -4639,6 +4689,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -4839,6 +4895,24 @@ dependencies = [ "tokio", ] +[[package]] +name = "rig-sqlite" +version = "0.1.0" +dependencies = [ + "anyhow", + "chrono", + "rig-core", + "rusqlite", + "serde", + "serde_json", + "sqlite-vec", + "tokio", + "tokio-rusqlite", + "tracing", + "tracing-subscriber", + "zerocopy 0.8.11", +] + [[package]] name = "ring" version = "0.17.8" @@ -4864,6 +4938,20 @@ dependencies = [ "byteorder", ] +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags 2.6.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rust-stemmers" version = "1.2.0" @@ -5434,6 +5522,15 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "sqlite-vec" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec77b84fb8dd5f0f8def127226db83b5d1152c5bf367f09af03998b76ba554a" +dependencies = [ + "cc", +] + [[package]] name = "sqlparser" version = "0.47.0" @@ -5724,7 +5821,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d60769b80ad7953d8a7b2c70cdfe722bbcdcac6bccc8ac934c40c034d866fc18" dependencies = [ "byteorder", - "regex-syntax", + "regex-syntax 0.8.5", "utf8-ranges", ] @@ -5998,6 +6095,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rusqlite" +version = "0.6.0" +source = "git+https://github.com/programatik29/tokio-rusqlite#168e9c47fd6c7c9f8032b660f62084e16ade7bac" +dependencies = [ + "crossbeam-channel", + "rusqlite", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -6201,10 +6308,14 @@ version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex", "sharded-slab", "smallvec", "thread_local", + "tracing", "tracing-core", "tracing-log", ] @@ -6837,7 +6948,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cce3b5629d87654b53a49002acc2ce64aa5aa7255f5c718374a37ac7fd98c218" +dependencies = [ + "zerocopy-derive 0.8.11", ] [[package]] @@ -6851,6 +6971,17 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74a82c26c3986af2623ec9eb890ff4aa19c006e30a1133dc9bd1830ec1612e20" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "zerofrom" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index 2f6d642c..df48d1b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,5 +3,6 @@ resolver = "2" members = [ "rig-core", "rig-lancedb", "rig-mongodb", "rig-neo4j", - "rig-qdrant", "rig-core/rig-core-derive" + "rig-qdrant", "rig-core/rig-core-derive", + "rig-sqlite" ] diff --git a/rig-sqlite/Cargo.toml b/rig-sqlite/Cargo.toml new file mode 100644 index 00000000..c61d895d --- /dev/null +++ b/rig-sqlite/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "rig-sqlite" +version = "0.1.0" +edition = "2021" +description = "SQLite-based vector store implementation for the rig framework" +license = "MIT" + +[lib] +doctest = false + +[dependencies] +rig-core = { path = "../rig-core", features = ["derive"] } +rusqlite = { version = "0.32", features = ["bundled"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +sqlite-vec = "0.1" +tokio-rusqlite = { git = "https://github.com/programatik29/tokio-rusqlite", features = [ + "bundled", +] } +tracing = "0.1" +zerocopy = "0.8.10" +chrono = "0.4" + +[dev-dependencies] +anyhow = "1.0.86" +tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/rig-sqlite/LICENSE b/rig-sqlite/LICENSE new file mode 100644 index 00000000..878b5fbc --- /dev/null +++ b/rig-sqlite/LICENSE @@ -0,0 +1,7 @@ +Copyright (c) 2024, Playgrounds Analytics Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/rig-sqlite/README.md b/rig-sqlite/README.md new file mode 100644 index 00000000..2b6fc4b3 --- /dev/null +++ b/rig-sqlite/README.md @@ -0,0 +1,46 @@ +
+ + + + Rig logo + + + + + + + SQLite logo + +
+ +

+ +## Rig-SQLite + +This companion crate implements a Rig vector store based on SQLite. + +## Usage + +Add the companion crate to your `Cargo.toml`, along with the rig-core crate: + +```toml +[dependencies] +rig-sqlite = "0.1.3" +rig-core = "0.4.0" +``` + +You can also run `cargo add rig-sqlite rig-core` to add the most recent versions of the dependencies to your project. + +See the [`/examples`](./examples) folder for usage examples. + +## Important Note + +Before using the SQLite vector store, you must [initialize the SQLite vector extension](https://alexgarcia.xyz/sqlite-vec/rust.html). Add this code before creating your connection: + +```rust +use rusqlite::ffi::sqlite3_auto_extension; +use sqlite_vec::sqlite3_vec_init; + +unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); +} +``` diff --git a/rig-sqlite/examples/vector_search_sqlite.rs b/rig-sqlite/examples/vector_search_sqlite.rs new file mode 100644 index 00000000..90241fb3 --- /dev/null +++ b/rig-sqlite/examples/vector_search_sqlite.rs @@ -0,0 +1,118 @@ +use rig::{ + embeddings::EmbeddingsBuilder, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::VectorStoreIndex, + Embed, +}; +use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable}; +use rusqlite::ffi::sqlite3_auto_extension; +use serde::Deserialize; +use sqlite_vec::sqlite3_vec_init; +use std::env; +use tokio_rusqlite::Connection; + +#[derive(Embed, Clone, Debug, Deserialize)] +struct Document { + id: String, + #[embed] + content: String, +} + +impl SqliteVectorStoreTable for Document { + fn name() -> &'static str { + "documents" + } + + fn schema() -> Vec { + vec![ + Column::new("id", "TEXT PRIMARY KEY"), + Column::new("content", "TEXT"), + ] + } + + fn id(&self) -> String { + self.id.clone() + } + + fn column_values(&self) -> Vec<(&'static str, Box)> { + vec![ + ("id", Box::new(self.id.clone())), + ("content", Box::new(self.content.clone())), + ] + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive(tracing::Level::DEBUG.into()), + ) + .init(); + + // Initialize OpenAI client + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_client = Client::new(&openai_api_key); + + // Initialize the `sqlite-vec`extension + // See: https://alexgarcia.xyz/sqlite-vec/rust.html + unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); + } + + // Initialize SQLite connection + let conn = Connection::open("vector_store.db").await?; + + // Select the embedding model and generate our embeddings + let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + + let documents = vec![ + Document { + id: "doc0".to_string(), + content: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + }, + Document { + id: "doc1".to_string(), + content: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + }, + Document { + id: "doc2".to_string(), + content: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + }, + ]; + + let embeddings = EmbeddingsBuilder::new(model.clone()) + .documents(documents)? + .build() + .await?; + + // Initialize SQLite vector store + let vector_store = SqliteVectorStore::new(conn, &model).await?; + + // Add embeddings to vector store + vector_store.add_rows(embeddings).await?; + + // Create a vector index on our vector store + let index = vector_store.index(model); + + // Query the index + let results = index + .top_n::("What is a linglingdong?", 1) + .await? + .into_iter() + .map(|(score, id, doc)| (score, id, doc)) + .collect::>(); + + println!("Results: {:?}", results); + + let id_results = index + .top_n_ids("What is a linglingdong?", 1) + .await? + .into_iter() + .collect::>(); + + println!("ID results: {:?}", id_results); + + Ok(()) +} diff --git a/rig-sqlite/src/lib.rs b/rig-sqlite/src/lib.rs new file mode 100644 index 00000000..e3525247 --- /dev/null +++ b/rig-sqlite/src/lib.rs @@ -0,0 +1,557 @@ +use rig::embeddings::{Embedding, EmbeddingModel}; +use rig::vector_store::{VectorStoreError, VectorStoreIndex}; +use rig::OneOrMany; +use serde::Deserialize; +use std::marker::PhantomData; +use tokio_rusqlite::Connection; +use tracing::{debug, info}; +use zerocopy::IntoBytes; + +#[derive(Debug)] +pub enum SqliteError { + DatabaseError(Box), + SerializationError(Box), + InvalidColumnType(String), +} + +pub trait ColumnValue: Send + Sync { + fn to_sql_string(&self) -> String; + fn column_type(&self) -> &'static str; +} + +pub struct Column { + name: &'static str, + col_type: &'static str, + indexed: bool, +} + +impl Column { + pub fn new(name: &'static str, col_type: &'static str) -> Self { + Self { + name, + col_type, + indexed: false, + } + } + + pub fn indexed(mut self) -> Self { + self.indexed = true; + self + } +} + +/// Example of a document type that can be used with SqliteVectorStore +/// ```rust +/// use rig::Embed; +/// use serde::Deserialize; +/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStoreTable}; +/// +/// #[derive(Embed, Clone, Debug, Deserialize)] +/// struct Document { +/// id: String, +/// #[embed] +/// content: String, +/// } +/// +/// impl SqliteVectorStoreTable for Document { +/// fn name() -> &'static str { +/// "documents" +/// } +/// +/// fn schema() -> Vec { +/// vec![ +/// Column::new("id", "TEXT PRIMARY KEY"), +/// Column::new("content", "TEXT"), +/// ] +/// } +/// +/// fn id(&self) -> String { +/// self.id.clone() +/// } +/// +/// fn column_values(&self) -> Vec<(&'static str, Box)> { +/// vec![ +/// ("id", Box::new(self.id.clone())), +/// ("content", Box::new(self.content.clone())), +/// ] +/// } +/// } +/// ``` +pub trait SqliteVectorStoreTable: Send + Sync + Clone { + fn name() -> &'static str; + fn schema() -> Vec; + fn id(&self) -> String; + fn column_values(&self) -> Vec<(&'static str, Box)>; +} + +#[derive(Clone)] +pub struct SqliteVectorStore { + conn: Connection, + _phantom: PhantomData<(E, T)>, +} + +impl SqliteVectorStore { + pub async fn new(conn: Connection, embedding_model: &E) -> Result { + let dims = embedding_model.ndims(); + let table_name = T::name(); + let schema = T::schema(); + + // Build the table schema + let mut create_table = format!("CREATE TABLE IF NOT EXISTS {} (", table_name); + + // Add columns + let mut first = true; + for column in &schema { + if !first { + create_table.push(','); + } + create_table.push_str(&format!("\n {} {}", column.name, column.col_type)); + first = false; + } + + create_table.push_str("\n)"); + + // Build index creation statements + let mut create_indexes = vec![format!( + "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)", + table_name, table_name + )]; + + // Add indexes for marked columns + for column in schema { + if column.indexed { + create_indexes.push(format!( + "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})", + table_name, column.name, table_name, column.name + )); + } + } + + conn.call(move |conn| { + conn.execute_batch("BEGIN")?; + + // Create document table + conn.execute_batch(&create_table)?; + + // Create indexes + for index_stmt in create_indexes { + conn.execute_batch(&index_stmt)?; + } + + // Create embeddings table + conn.execute_batch(&format!( + "CREATE VIRTUAL TABLE IF NOT EXISTS {}_embeddings USING vec0(embedding float[{}])", + table_name, dims + ))?; + + conn.execute_batch("COMMIT")?; + Ok(()) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + Ok(Self { + conn, + _phantom: PhantomData, + }) + } + + pub fn index(self, model: E) -> SqliteVectorIndex { + SqliteVectorIndex::new(model, self) + } + + pub fn add_rows_with_txn( + &self, + txn: &rusqlite::Transaction<'_>, + documents: Vec<(T, OneOrMany)>, + ) -> Result { + info!("Adding {} documents to store", documents.len()); + let table_name = T::name(); + let mut last_id = 0; + + for (doc, embeddings) in &documents { + debug!("Storing document with id {}", doc.id()); + + let values = doc.column_values(); + let columns = values.iter().map(|(col, _)| *col).collect::>(); + + let placeholders = (1..=values.len()) + .map(|i| format!("?{}", i)) + .collect::>(); + + let insert_sql = format!( + "INSERT OR REPLACE INTO {} ({}) VALUES ({})", + table_name, + columns.join(", "), + placeholders.join(", ") + ); + + txn.execute( + &insert_sql, + rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_string())), + )?; + last_id = txn.last_insert_rowid(); + + let embeddings_sql = format!( + "INSERT INTO {}_embeddings (rowid, embedding) VALUES (?1, ?2)", + table_name + ); + + let mut stmt = txn.prepare(&embeddings_sql)?; + for (i, embedding) in embeddings.iter().enumerate() { + let vec = serialize_embedding(embedding); + debug!( + "Storing embedding {} of {} (size: {} bytes)", + i + 1, + embeddings.len(), + vec.len() * 4 + ); + let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec()); + stmt.execute(rusqlite::params![last_id, blob])?; + } + } + + Ok(last_id) + } + + pub async fn add_rows( + &self, + documents: Vec<(T, OneOrMany)>, + ) -> Result { + let documents = documents.clone(); + let this = self.clone(); + + self.conn + .call(move |conn| { + let tx = conn.transaction().map_err(tokio_rusqlite::Error::from)?; + let result = this.add_rows_with_txn(&tx, documents)?; + tx.commit().map_err(tokio_rusqlite::Error::from)?; + Ok(result) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e))) + } +} + +/// SQLite vector store implementation for Rig. +/// +/// This crate provides a SQLite-based vector store implementation that can be used with Rig. +/// It uses the `sqlite-vec` extension to enable vector similarity search capabilities. +/// +/// # Example +/// ```rust +/// use rig::{ +/// embeddings::EmbeddingsBuilder, +/// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +/// vector_store::VectorStoreIndex, +/// Embed, +/// }; +/// use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable}; +/// use serde::Deserialize; +/// use tokio_rusqlite::Connection; +/// +/// #[derive(Embed, Clone, Debug, Deserialize)] +/// struct Document { +/// id: String, +/// #[embed] +/// content: String, +/// } +/// +/// impl SqliteVectorStoreTable for Document { +/// fn name() -> &'static str { +/// "documents" +/// } +/// +/// fn schema() -> Vec { +/// vec![ +/// Column::new("id", "TEXT PRIMARY KEY"), +/// Column::new("content", "TEXT"), +/// ] +/// } +/// +/// fn id(&self) -> String { +/// self.id.clone() +/// } +/// +/// fn column_values(&self) -> Vec<(&'static str, Box)> { +/// vec![ +/// ("id", Box::new(self.id.clone())), +/// ("content", Box::new(self.content.clone())), +/// ] +/// } +/// } +/// +/// let conn = Connection::open("vector_store.db").await?; +/// let openai_client = Client::new("YOUR_API_KEY"); +/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); +/// +/// // Initialize vector store +/// let vector_store = SqliteVectorStore::new(conn, &model).await?; +/// +/// // Create documents +/// let documents = vec![ +/// Document { +/// id: "doc1".to_string(), +/// content: "Example document 1".to_string(), +/// }, +/// Document { +/// id: "doc2".to_string(), +/// content: "Example document 2".to_string(), +/// }, +/// ]; +/// +/// // Generate embeddings +/// let embeddings = EmbeddingsBuilder::new(model.clone()) +/// .documents(documents)? +/// .build() +/// .await?; +/// +/// // Add to vector store +/// vector_store.add_rows(embeddings).await?; +/// +/// // Create index and search +/// let index = vector_store.index(model); +/// let results = index +/// .top_n::("Example query", 2) +/// .await?; +/// ``` +pub struct SqliteVectorIndex { + store: SqliteVectorStore, + embedding_model: E, +} + +impl SqliteVectorIndex { + pub fn new(embedding_model: E, store: SqliteVectorStore) -> Self { + Self { + store, + embedding_model, + } + } +} + +impl VectorStoreIndex + for SqliteVectorIndex +{ + async fn top_n Deserialize<'a>>( + &self, + query: &str, + n: usize, + ) -> Result, VectorStoreError> { + debug!("Finding top {} matches for query", n); + let embedding = self.embedding_model.embed_text(query).await?; + let query_vec: Vec = serialize_embedding(&embedding); + let table_name = T::name(); + + // Get all column names from SqliteVectorStoreTable + let columns = T::schema(); + let column_names: Vec<&str> = columns.iter().map(|column| column.name).collect(); + + let rows = self + .store + .conn + .call(move |conn| { + // Build SELECT statement with all columns + let select_cols = column_names.join(", "); + let mut stmt = conn.prepare(&format!( + "SELECT d.{}, e.distance + FROM {}_embeddings e + JOIN {} d ON e.rowid = d.rowid + WHERE e.embedding MATCH ?1 AND k = ?2 + ORDER BY e.distance", + select_cols, table_name, table_name + ))?; + + let rows = stmt + .query_map(rusqlite::params![query_vec.as_bytes().to_vec(), n], |row| { + // Create a map of column names to values + let mut map = serde_json::Map::new(); + for (i, col_name) in column_names.iter().enumerate() { + let value: String = row.get(i)?; + map.insert(col_name.to_string(), serde_json::Value::String(value)); + } + let distance: f64 = row.get(column_names.len())?; + let id: String = row.get(0)?; // Assuming id is always first column + + Ok((id, serde_json::Value::Object(map), distance)) + })? + .collect::, _>>()?; + Ok(rows) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + debug!("Found {} potential matches", rows.len()); + let mut top_n = Vec::new(); + for (id, doc_value, distance) in rows { + match serde_json::from_value::(doc_value) { + Ok(doc) => { + top_n.push((distance, id, doc)); + } + Err(e) => { + debug!("Failed to deserialize document {}: {}", id, e); + continue; + } + } + } + + debug!("Returning {} matches", top_n.len()); + Ok(top_n) + } + + async fn top_n_ids( + &self, + query: &str, + n: usize, + ) -> Result, VectorStoreError> { + debug!("Finding top {} document IDs for query", n); + let embedding = self.embedding_model.embed_text(query).await?; + let query_vec = serialize_embedding(&embedding); + let table_name = T::name(); + + let results = self + .store + .conn + .call(move |conn| { + let mut stmt = conn.prepare(&format!( + "SELECT d.id, e.distance + FROM {0}_embeddings e + JOIN {0} d ON e.rowid = d.rowid + WHERE e.embedding MATCH ?1 AND k = ?2 + ORDER BY e.distance", + table_name + ))?; + + let results = stmt + .query_map( + rusqlite::params![ + query_vec + .iter() + .flat_map(|x| x.to_le_bytes()) + .collect::>(), + n + ], + |row| Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)), + )? + .collect::, _>>()?; + Ok(results) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + debug!("Found {} matching document IDs", results.len()); + Ok(results) + } +} + +fn serialize_embedding(embedding: &Embedding) -> Vec { + embedding.vec.iter().map(|x| *x as f32).collect() +} + +impl ColumnValue for String { + fn to_sql_string(&self) -> String { + self.clone() + } + + fn column_type(&self) -> &'static str { + "TEXT" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable}; + use rig::{ + embeddings::EmbeddingsBuilder, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + Embed, + }; + use rusqlite::ffi::sqlite3_auto_extension; + use sqlite_vec::sqlite3_vec_init; + use tokio_rusqlite::Connection; + + #[derive(Embed, Clone, Debug, Deserialize)] + struct TestDocument { + id: String, + #[embed] + content: String, + } + + impl SqliteVectorStoreTable for TestDocument { + fn name() -> &'static str { + "test_documents" + } + + fn schema() -> Vec { + vec![ + Column::new("id", "TEXT PRIMARY KEY"), + Column::new("content", "TEXT"), + ] + } + + fn id(&self) -> String { + self.id.clone() + } + + fn column_values(&self) -> Vec<(&'static str, Box)> { + vec![ + ("id", Box::new(self.id.clone())), + ("content", Box::new(self.content.clone())), + ] + } + } + + #[tokio::test] + async fn test_vector_search() -> Result<(), anyhow::Error> { + // Initialize the sqlite-vec extension + unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); + } + + // Initialize in-memory SQLite connection + let conn = Connection::open(":memory:").await?; + + // Initialize OpenAI client + let openai_api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_client = Client::new(&openai_api_key); + let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + + let documents = vec![ + TestDocument { + id: "doc0".to_string(), + content: "The quick brown fox jumps over the lazy dog".to_string(), + }, + TestDocument { + id: "doc1".to_string(), + content: "The lazy dog sleeps while the quick brown fox runs".to_string(), + }, + ]; + + let embeddings = EmbeddingsBuilder::new(model.clone()) + .documents(documents)? + .build() + .await?; + + // Initialize SQLite vector store + let vector_store = SqliteVectorStore::new(conn, &model).await?; + + // Add embeddings to vector store + vector_store.add_rows(embeddings).await?; + + // Create vector index + let index = vector_store.index(model); + + // Query the index + let results = index + .top_n::("The quick brown fox jumps over the lazy dog", 1) + .await?; + assert_eq!(results.len(), 1); + + let id_results = index + .top_n_ids("The quick brown fox jumps over the lazy dog", 1) + .await?; + assert_eq!(id_results.len(), 1); + + Ok(()) + } +}