Skip to content

Commit

Permalink
Improve SqliteVectorStore init
Browse files Browse the repository at this point in the history
  • Loading branch information
tarrencev committed Nov 26, 2024
1 parent 00265b1 commit d6a959e
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 44 deletions.
27 changes: 24 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rig-sqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ tokio-rusqlite = { git = "https://github.com/programatik29/tokio-rusqlite", feat
"bundled",
] }
tracing = "0.1"
zerocopy = "0.8.10"
chrono = "0.4"

[dev-dependencies]
Expand Down
16 changes: 8 additions & 8 deletions rig-sqlite/examples/vector_search_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ async fn main() -> Result<(), anyhow::Error> {
// Initialize SQLite connection
let conn = Connection::open("vector_store.db").await?;

// Initialize SQLite vector store
let mut vector_store = SqliteVectorStore::new(conn).await?;

// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let embeddings = EmbeddingsBuilder::new(model.clone())
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build()
.await?;
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build()
.await?;

// Initialize SQLite vector store
let mut vector_store = SqliteVectorStore::new(conn, model.clone()).await?;

// Add embeddings to vector store
match vector_store.add_documents(embeddings).await {
Expand Down
74 changes: 41 additions & 33 deletions rig-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ use rig::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel};
use rig::vector_store::{VectorStore, VectorStoreError, VectorStoreIndex};
use rusqlite::OptionalExtension;
use serde::Deserialize;
use std::marker::PhantomData;
use tokio_rusqlite::Connection;
use tracing::{debug, info};
use zerocopy::IntoBytes;

#[derive(Debug)]
pub enum SqliteError {
Expand All @@ -12,50 +14,52 @@ pub enum SqliteError {
}

#[derive(Clone)]
pub struct SqliteVectorStore {
pub struct SqliteVectorStore<E: EmbeddingModel> {
conn: Connection,
_phantom: PhantomData<E>,
}

impl SqliteVectorStore {
pub async fn new(conn: Connection) -> Result<Self, VectorStoreError> {
debug!("Running initial migrations");
impl<E: EmbeddingModel> SqliteVectorStore<E> {
pub async fn new(conn: Connection, embedding_model: E) -> Result<Self, VectorStoreError> {
// Run migrations or create tables if they don't exist
conn.call(|conn| {
conn.execute_batch(
let dims = embedding_model.ndims();
conn.call(move |conn| {
conn.execute_batch(&format!(
"BEGIN;
-- Document tables
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
doc_id TEXT UNIQUE NOT NULL,
document_id TEXT UNIQUE NOT NULL,
document TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_doc_id ON documents(doc_id);
CREATE VIRTUAL TABLE IF NOT EXISTS embeddings USING vec0(embedding float[1536]);
CREATE INDEX IF NOT EXISTS idx_document_id ON documents(document_id);
CREATE VIRTUAL TABLE IF NOT EXISTS embeddings USING vec0(embedding float[{}]);
COMMIT;",
)
dims
))
.map_err(tokio_rusqlite::Error::from)
})
.await
.map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;

Ok(Self { conn })
Ok(Self {
conn,
_phantom: PhantomData,
})
}

fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
embedding.vec.iter().map(|x| *x as f32).collect()
}

/// Create a new `SqliteVectorIndex` from an existing `SqliteVectorStore`.
pub async fn index<M: EmbeddingModel>(
&self,
model: M,
) -> Result<SqliteVectorIndex<M>, VectorStoreError> {
pub async fn index(&self, model: E) -> Result<SqliteVectorIndex<E>, VectorStoreError> {
Ok(SqliteVectorIndex::new(model, self.clone()))
}
}

impl VectorStore for SqliteVectorStore {
impl<E: EmbeddingModel> VectorStore for SqliteVectorStore<E> {
type Q = String;

async fn add_documents(
Expand All @@ -71,12 +75,12 @@ impl VectorStore for SqliteVectorStore {
debug!("Storing document with id {}", doc.id);
// Store document and get auto-incremented ID
tx.execute(
"INSERT OR REPLACE INTO documents (doc_id, document) VALUES (?1, ?2)",
"INSERT OR REPLACE INTO documents (document_id, document) VALUES (?1, ?2)",
[&doc.id, &doc.document.to_string()],
)
.map_err(tokio_rusqlite::Error::from)?;

let doc_id = tx.last_insert_rowid();
let document_id = tx.last_insert_rowid();

// Store embeddings
let mut stmt = tx
Expand All @@ -90,10 +94,8 @@ impl VectorStore for SqliteVectorStore {
);
for embedding in doc.embeddings {
let vec = Self::serialize_embedding(&embedding);
let blob = rusqlite::types::Value::Blob(
vec.iter().flat_map(|x| x.to_le_bytes()).collect(),
);
stmt.execute(rusqlite::params![doc_id, blob])
let blob = rusqlite::types::Value::Blob(vec.as_bytes().to_vec());
stmt.execute(rusqlite::params![document_id, blob])
.map_err(tokio_rusqlite::Error::from)?;
}
}
Expand All @@ -117,7 +119,7 @@ impl VectorStore for SqliteVectorStore {
.conn
.call(move |conn| {
conn.query_row(
"SELECT document FROM documents WHERE doc_id = ?1",
"SELECT document FROM documents WHERE document_id = ?1",
rusqlite::params![id_clone],
|row| row.get::<_, String>(0),
)
Expand Down Expand Up @@ -153,7 +155,7 @@ impl VectorStore for SqliteVectorStore {
"SELECT e.embedding, d.document
FROM embeddings e
JOIN documents d ON e.rowid = d.id
WHERE d.doc_id = ?1",
WHERE d.document_id = ?1",
)?;

let result = stmt
Expand All @@ -170,7 +172,13 @@ impl VectorStore for SqliteVectorStore {
})?;
let vec = bytes
.chunks(4)
.map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap()) as f64)
.map(|chunk| {
f32::from_le_bytes(
chunk
.try_into()
.expect("Invalid chunk length - must be 4 bytes"),
) as f64
})
.collect();
Ok((
rig::embeddings::Embedding {
Expand Down Expand Up @@ -209,7 +217,7 @@ impl VectorStore for SqliteVectorStore {
.conn
.call(move |conn| {
let mut stmt = conn.prepare(
"SELECT d.doc_id, e.distance
"SELECT d.document_id, e.distance
FROM embeddings e
JOIN documents d ON e.rowid = d.id
WHERE e.embedding MATCH ?1 AND k = ?2
Expand Down Expand Up @@ -240,12 +248,12 @@ impl VectorStore for SqliteVectorStore {
}

pub struct SqliteVectorIndex<E: EmbeddingModel> {
store: SqliteVectorStore,
store: SqliteVectorStore<E>,
embedding_model: E,
}

impl<E: EmbeddingModel> SqliteVectorIndex<E> {
pub fn new(embedding_model: E, store: SqliteVectorStore) -> Self {
pub fn new(embedding_model: E, store: SqliteVectorStore<E>) -> Self {
Self {
store,
embedding_model,
Expand All @@ -261,14 +269,14 @@ impl<E: EmbeddingModel + std::marker::Sync> VectorStoreIndex for SqliteVectorInd
) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
debug!("Finding top {} matches for query", n);
let embedding = self.embedding_model.embed_document(query).await?;
let query_vec = SqliteVectorStore::serialize_embedding(&embedding);
let query_vec = SqliteVectorStore::<E>::serialize_embedding(&embedding);

let rows = self
.store
.conn
.call(move |conn| {
let mut stmt = conn.prepare(
"SELECT d.doc_id, d.document, e.distance
"SELECT d.document_id, d.document, e.distance
FROM embeddings e
JOIN documents d ON e.rowid = d.id
WHERE e.embedding MATCH ?1 AND k = ?2
Expand Down Expand Up @@ -323,17 +331,17 @@ impl<E: EmbeddingModel + std::marker::Sync> VectorStoreIndex for SqliteVectorInd
) -> Result<Vec<(f64, String)>, VectorStoreError> {
debug!("Finding top {} document IDs for query", n);
let embedding = self.embedding_model.embed_document(query).await?;
let query_vec = SqliteVectorStore::serialize_embedding(&embedding);
let query_vec = SqliteVectorStore::<E>::serialize_embedding(&embedding);

let results = self
.store
.conn
.call(move |conn| {
let mut stmt = conn.prepare(
"SELECT d.doc_id, e.distance
"SELECT d.document_id, e.distance
FROM embeddings e
JOIN documents d ON e.rowid = d.id
WHERE e.embedding MATCH ?1 AND k = ?2
WHERE e.embedding MATCH ?1 AND k = ?2
ORDER BY e.distance",
)?;

Expand Down

0 comments on commit d6a959e

Please sign in to comment.