From 4a778a9fd7791a6dc98a3115542d8b247f5ffd64 Mon Sep 17 00:00:00 2001 From: Tarrence van As Date: Mon, 25 Nov 2024 09:08:25 -0500 Subject: [PATCH] Add sqlite vector store --- Cargo.lock | 155 ++++++++- Cargo.toml | 7 +- rig-sqlite/Cargo.toml | 24 ++ rig-sqlite/LICENSE | 7 + rig-sqlite/README.md | 46 +++ rig-sqlite/examples/vector_search_sqlite.rs | 75 +++++ rig-sqlite/src/lib.rs | 335 ++++++++++++++++++++ 7 files changed, 635 insertions(+), 14 deletions(-) create mode 100644 rig-sqlite/Cargo.toml create mode 100644 rig-sqlite/LICENSE create mode 100644 rig-sqlite/README.md create mode 100644 rig-sqlite/examples/vector_search_sqlite.rs create mode 100644 rig-sqlite/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index aefa6a32..5fa49f12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,7 +28,7 @@ dependencies = [ "getrandom", "once_cell", "version_check", - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -297,7 +297,7 @@ dependencies = [ "memchr", "num", "regex", - "regex-syntax", + "regex-syntax 0.8.5", ] [[package]] @@ -1524,7 +1524,7 @@ dependencies = [ "itertools 0.12.1", "log", "paste", - "regex-syntax", + "regex-syntax 0.8.5", ] [[package]] @@ -1862,6 +1862,18 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastdivide" version = "0.4.1" @@ -2118,8 +2130,8 @@ dependencies = [ "aho-corasick", "bstr", "log", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.8", + "regex-syntax 0.8.5", ] [[package]] @@ -2209,6 +2221,15 @@ dependencies = [ "foldhash", ] +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "heck" version = "0.4.1" @@ -2527,7 +2548,7 @@ dependencies = [ "globset", "log", "memchr", - "regex-automata", + "regex-automata 0.4.8", "same-file", "walkdir", "winapi-util", @@ -3161,6 +3182,17 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -3257,6 +3289,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "matches" version = "0.1.10" @@ -3912,7 +3953,7 @@ version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" dependencies = [ - "zerocopy", + "zerocopy 0.7.35", ] [[package]] @@ -4273,8 +4314,17 @@ checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" dependencies = [ "aho-corasick", "memchr", - "regex-automata", - "regex-syntax", + "regex-automata 0.4.8", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", ] [[package]] @@ -4285,7 +4335,7 @@ checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", - "regex-syntax", + "regex-syntax 0.8.5", ] [[package]] @@ -4294,6 +4344,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -4478,6 +4534,24 @@ dependencies = [ "tokio", ] +[[package]] +name = "rig-sqlite" +version = "0.1.0" +dependencies = [ + "anyhow", + "chrono", + "rig-core", + "rusqlite", + "serde", + "serde_json", + "sqlite-vec", + "tokio", + "tokio-rusqlite", + "tracing", + "tracing-subscriber", + "zerocopy 0.8.10", +] + [[package]] name = "ring" version = "0.17.8" @@ -4503,6 +4577,20 @@ dependencies = [ "byteorder", ] +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags 2.6.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + [[package]] name = "rust-stemmers" version = "1.2.0" @@ -5093,6 +5181,15 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "sqlite-vec" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec77b84fb8dd5f0f8def127226db83b5d1152c5bf367f09af03998b76ba554a" +dependencies = [ + "cc", +] + [[package]] name = "sqlparser" version = "0.47.0" @@ -5355,7 +5452,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d60769b80ad7953d8a7b2c70cdfe722bbcdcac6bccc8ac934c40c034d866fc18" dependencies = [ "byteorder", - "regex-syntax", + "regex-syntax 0.8.5", "utf8-ranges", ] @@ -5570,6 +5667,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rusqlite" +version = "0.6.0" +source = "git+https://github.com/programatik29/tokio-rusqlite#168e9c47fd6c7c9f8032b660f62084e16ade7bac" +dependencies = [ + "crossbeam-channel", + "rusqlite", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -5745,10 +5852,14 @@ version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex", "sharded-slab", "smallvec", "thread_local", + "tracing", "tracing-core", "tracing-log", ] @@ -6352,7 +6463,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", - "zerocopy-derive", + "zerocopy-derive 0.7.35", +] + +[[package]] +name = "zerocopy" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a13a42ed30c63171d820889b2981318736915150575b8d2d6dbee7edd68336ca" +dependencies = [ + "zerocopy-derive 0.8.10", ] [[package]] @@ -6366,6 +6486,17 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "zerocopy-derive" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "593e7c96176495043fcb9e87cf7659f4d18679b5bab6b92bdef359c76a7795dd" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "zeroize" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index c8d75273..bcdf87c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,10 @@ [workspace] resolver = "2" members = [ - "rig-core", "rig-lancedb", - "rig-mongodb", "rig-neo4j", + "rig-core", + "rig-lancedb", + "rig-mongodb", + "rig-neo4j", "rig-qdrant", + "rig-sqlite", ] diff --git a/rig-sqlite/Cargo.toml b/rig-sqlite/Cargo.toml new file mode 100644 index 00000000..e64532fc --- /dev/null +++ b/rig-sqlite/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "rig-sqlite" +version = "0.1.0" +edition = "2021" +description = "SQLite-based vector store implementation for the rig framework" +license = "MIT" + +[dependencies] +rig-core = { path = "../rig-core" } +rusqlite = { version = "0.32", features = ["bundled"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +sqlite-vec = "0.1" +tokio-rusqlite = { git = "https://github.com/programatik29/tokio-rusqlite", features = [ + "bundled", +] } +tracing = "0.1" +zerocopy = "0.8.10" +chrono = "0.4" + +[dev-dependencies] +anyhow = "1.0.86" +tokio = { version = "1.38.0", features = ["macros", "rt-multi-thread"] } +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/rig-sqlite/LICENSE b/rig-sqlite/LICENSE new file mode 100644 index 00000000..878b5fbc --- /dev/null +++ b/rig-sqlite/LICENSE @@ -0,0 +1,7 @@ +Copyright (c) 2024, Playgrounds Analytics Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/rig-sqlite/README.md b/rig-sqlite/README.md new file mode 100644 index 00000000..2ce9935a --- /dev/null +++ b/rig-sqlite/README.md @@ -0,0 +1,46 @@ +
+ + + + Rig logo + + + + + + + SQLite logo + +
+ +

+ +## Rig-SQLite + +This companion crate implements a Rig vector store based on SQLite. + +## Usage + +Add the companion crate to your `Cargo.toml`, along with the rig-core crate: + +```toml +[dependencies] +rig-sqlite = "0.1.3" +rig-core = "0.4.0" +``` + +You can also run `cargo add rig-sqlite rig-core` to add the most recent versions of the dependencies to your project. + +See the [`/examples`](./examples) folder for usage examples. + +## Important Note + +Before using the SQLite vector store, you must initialize the SQLite vector extension. Add this code before creating your connection: + +```rust +use rusqlite::ffi::sqlite3_auto_extension; +use sqlite_vec::sqlite3_vec_init; + +unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); +} +``` diff --git a/rig-sqlite/examples/vector_search_sqlite.rs b/rig-sqlite/examples/vector_search_sqlite.rs new file mode 100644 index 00000000..df4dfa4d --- /dev/null +++ b/rig-sqlite/examples/vector_search_sqlite.rs @@ -0,0 +1,75 @@ +use rig::vector_store::VectorStore; +use rig::{ + embeddings::EmbeddingsBuilder, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::VectorStoreIndex, +}; +use rig_sqlite::SqliteVectorStore; +use rusqlite::ffi::sqlite3_auto_extension; +use sqlite_vec::sqlite3_vec_init; +use std::env; +use tokio_rusqlite::Connection; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive(tracing::Level::DEBUG.into()), + ) + .init(); + + // Initialize OpenAI client + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_client = Client::new(&openai_api_key); + + unsafe { + sqlite3_auto_extension(Some(std::mem::transmute(sqlite3_vec_init as *const ()))); + } + + // Initialize SQLite connection + let conn = Connection::open("vector_store.db").await?; + + // Initialize SQLite vector store + let mut vector_store = SqliteVectorStore::new(conn).await?; + + // Select the embedding model and generate our embeddings + let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + + let embeddings = EmbeddingsBuilder::new(model.clone()) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .build() + .await?; + + // Add embeddings to vector store + match vector_store.add_documents(embeddings).await { + Ok(_) => println!("Documents added successfully"), + Err(e) => println!("Error adding documents: {:?}", e), + } + + // Create a vector index on our vector store + // IMPORTANT: Reuse the same model that was used to generate the embeddings + let index = vector_store.index(model).await?; + + // Query the index + let results = index + .top_n::("What is a linglingdong?", 1) + .await? + .into_iter() + .map(|(score, id, doc)| (score, id, doc)) + .collect::>(); + + println!("Results: {:?}", results); + + let id_results = index + .top_n_ids("What is a linglingdong?", 1) + .await? + .into_iter() + .collect::>(); + + println!("ID results: {:?}", id_results); + + Ok(()) +} diff --git a/rig-sqlite/src/lib.rs b/rig-sqlite/src/lib.rs new file mode 100644 index 00000000..3634e175 --- /dev/null +++ b/rig-sqlite/src/lib.rs @@ -0,0 +1,335 @@ +use rig::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}; +use rig::vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}; +use rusqlite::OptionalExtension; +use serde::Deserialize; +use tokio_rusqlite::Connection; +use tracing::{debug, info}; +use zerocopy::IntoBytes; + +#[derive(Debug)] +pub enum SqliteError { + DatabaseError(Box), + SerializationError(Box), +} + +#[derive(Clone)] +pub struct SqliteVectorStore { + conn: Connection, +} + +impl SqliteVectorStore { + pub async fn new(conn: Connection) -> Result { + debug!("Running initial migrations"); + // Run migrations or create tables if they don't exist + conn.call(|conn| { + conn.execute_batch( + "BEGIN; + -- Document tables + CREATE TABLE IF NOT EXISTS documents ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + doc_id TEXT UNIQUE NOT NULL, + document TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_doc_id ON documents(doc_id); + CREATE VIRTUAL TABLE IF NOT EXISTS embeddings USING vec0(embedding float[1536]); + + COMMIT;", + ) + .map_err(|e| tokio_rusqlite::Error::from(e)) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + Ok(Self { conn }) + } + + fn serialize_embedding(embedding: &Embedding) -> Vec { + embedding.vec.iter().map(|x| *x as f32).collect() + } + + /// Create a new `SqliteVectorIndex` from an existing `SqliteVectorStore`. + pub async fn index( + &self, + model: M, + ) -> Result, VectorStoreError> { + Ok(SqliteVectorIndex::new(model, self.clone())) + } +} + +impl VectorStore for SqliteVectorStore { + type Q = String; + + async fn add_documents( + &mut self, + documents: Vec, + ) -> Result<(), VectorStoreError> { + info!("Adding {} documents to store", documents.len()); + self.conn + .call(|conn| { + let tx = conn + .transaction() + .map_err(|e| tokio_rusqlite::Error::from(e))?; + + for doc in documents { + debug!("Storing document with id {}", doc.id); + // Store document and get auto-incremented ID + tx.execute( + "INSERT OR REPLACE INTO documents (doc_id, document) VALUES (?1, ?2)", + &[&doc.id, &doc.document.to_string()], + ) + .map_err(|e| tokio_rusqlite::Error::from(e))?; + + let doc_id = tx.last_insert_rowid(); + + // Store embeddings + let mut stmt = tx + .prepare("INSERT INTO embeddings (rowid, embedding) VALUES (?1, ?2)") + .map_err(|e| tokio_rusqlite::Error::from(e))?; + + debug!( + "Storing {} embeddings for document {}", + doc.embeddings.len(), + doc.id + ); + for embedding in doc.embeddings { + let vec = Self::serialize_embedding(&embedding); + let blob = rusqlite::types::Value::Blob(vec.as_slice().as_bytes().to_vec()); + stmt.execute(rusqlite::params![doc_id, blob]) + .map_err(|e| tokio_rusqlite::Error::from(e))?; + } + } + + tx.commit().map_err(|e| tokio_rusqlite::Error::from(e))?; + Ok(()) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + Ok(()) + } + + async fn get_document Deserialize<'a>>( + &self, + id: &str, + ) -> Result, VectorStoreError> { + debug!("Fetching document with id {}", id); + let id_clone = id.to_string(); + let doc_str = self + .conn + .call(move |conn| { + conn.query_row( + "SELECT document FROM documents WHERE doc_id = ?1", + rusqlite::params![id_clone], + |row| row.get::<_, String>(0), + ) + .optional() + .map_err(|e| tokio_rusqlite::Error::from(e)) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + match doc_str { + Some(doc_str) => { + let doc: T = serde_json::from_str(&doc_str) + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + Ok(Some(doc)) + } + None => { + debug!("No document found with id {}", id); + Ok(None) + } + } + } + + async fn get_document_embeddings( + &self, + id: &str, + ) -> Result, VectorStoreError> { + debug!("Fetching embeddings for document {}", id); + // First get the document + let doc: Option = self.get_document(&id).await?; + + if let Some(doc) = doc { + let id_clone = id.to_string(); + let embeddings = self + .conn + .call(move |conn| { + let mut stmt = conn.prepare( + "SELECT e.embedding + FROM embeddings e + JOIN documents d ON e.rowid = d.id + WHERE d.doc_id = ?1", + )?; + + let embeddings = stmt + .query_map(rusqlite::params![id_clone], |row| { + let bytes: Vec = row.get(0)?; + let vec = bytes + .chunks(4) + .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()) as f64) + .collect(); + Ok(rig::embeddings::Embedding { + vec, + document: "".to_string(), + }) + })? + .collect::, _>>()?; + Ok(embeddings) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + debug!("Found {} embeddings for document {}", embeddings.len(), id); + Ok(Some(DocumentEmbeddings { + id: id.to_string(), + document: doc, + embeddings, + })) + } else { + debug!("No embeddings found for document {}", id); + Ok(None) + } + } + + async fn get_document_by_query( + &self, + query: Self::Q, + ) -> Result, VectorStoreError> { + debug!("Searching for document matching query"); + let result = self + .conn + .call(move |conn| { + let mut stmt = conn.prepare( + "SELECT d.doc_id, e.distance + FROM embeddings e + JOIN documents d ON e.rowid = d.id + WHERE e.embedding MATCH ?1 AND k = ?2 + ORDER BY e.distance", + )?; + + let result = stmt + .query_row(rusqlite::params![query.as_bytes(), 1], |row| { + Ok((row.get::<_, String>(0)?, row.get::<_, f64>(1)?)) + }) + .optional()?; + Ok(result) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + match result { + Some((id, distance)) => { + debug!("Found matching document {} with distance {}", id, distance); + self.get_document_embeddings(&id).await + } + None => { + debug!("No matching documents found"); + Ok(None) + } + } + } +} + +pub struct SqliteVectorIndex { + store: SqliteVectorStore, + embedding_model: E, +} + +impl SqliteVectorIndex { + pub fn new(embedding_model: E, store: SqliteVectorStore) -> Self { + Self { + store, + embedding_model, + } + } +} + +impl VectorStoreIndex for SqliteVectorIndex { + async fn top_n Deserialize<'a>>( + &self, + query: &str, + n: usize, + ) -> Result, VectorStoreError> { + debug!("Finding top {} matches for query", n); + let embedding = self.embedding_model.embed_document(query).await?; + let query_vec = SqliteVectorStore::serialize_embedding(&embedding); + + let rows = self + .store + .conn + .call(move |conn| { + let mut stmt = conn.prepare( + "SELECT d.doc_id, d.document, e.distance + FROM embeddings e + JOIN documents d ON e.rowid = d.id + WHERE e.embedding MATCH ?1 AND k = ?2 + ORDER BY e.distance", + )?; + + let rows = stmt + .query_map(rusqlite::params![query_vec.as_bytes(), n], |row| { + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, f64>(2)?, + )) + })? + .collect::, _>>()?; + Ok(rows) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + debug!("Found {} potential matches", rows.len()); + let mut top_n = Vec::new(); + for (id, doc_str, distance) in rows { + match serde_json::from_str::(&doc_str) { + Ok(doc) => { + top_n.push((distance, id, doc)); + } + Err(e) => { + debug!("Failed to deserialize document {}: {}", id, e); + continue; + } + } + } + + debug!("Returning {} matches", top_n.len()); + Ok(top_n) + } + + async fn top_n_ids( + &self, + query: &str, + n: usize, + ) -> Result, VectorStoreError> { + debug!("Finding top {} document IDs for query", n); + let embedding = self.embedding_model.embed_document(query).await?; + let query_vec = SqliteVectorStore::serialize_embedding(&embedding); + + let results = self + .store + .conn + .call(move |conn| { + let mut stmt = conn.prepare( + "SELECT d.doc_id, e.distance + FROM embeddings e + JOIN documents d ON e.rowid = d.id + WHERE e.embedding MATCH ?1 AND k = ?2 + ORDER BY e.distance", + )?; + + let results = stmt + .query_map(rusqlite::params![query_vec.as_bytes(), n], |row| { + Ok((row.get::<_, f64>(1)?, row.get::<_, String>(0)?)) + })? + .collect::, _>>()?; + Ok(results) + }) + .await + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + debug!("Found {} matching document IDs", results.len()); + Ok(results) + } +}