Skip to content

Commit

Permalink
Migrate to new embeddings api
Browse files Browse the repository at this point in the history
  • Loading branch information
tarrencev committed Nov 30, 2024
1 parent 7ff68e2 commit 3a4d90b
Show file tree
Hide file tree
Showing 3 changed files with 391 additions and 257 deletions.
22 changes: 11 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

71 changes: 56 additions & 15 deletions rig-sqlite/examples/vector_search_sqlite.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,47 @@
use rig::vector_store::VectorStore;
use rig::{
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::VectorStoreIndex,
Embed,
};
use rig_sqlite::SqliteVectorStore;
use rig_sqlite::{Column, ColumnValue, SqliteVectorStore, SqliteVectorStoreTable};
use rusqlite::ffi::sqlite3_auto_extension;
use serde::Deserialize;
use sqlite_vec::sqlite3_vec_init;
use std::env;
use tokio_rusqlite::Connection;

#[derive(Embed, Clone, Debug, Deserialize)]
struct Document {
id: String,
#[embed]
content: String,
}

impl SqliteVectorStoreTable for Document {
fn name() -> &'static str {
"documents"
}

fn schema() -> Vec<Column> {
vec![
Column::new("id", "TEXT PRIMARY KEY"),
Column::new("content", "TEXT"),
]
}

fn id(&self) -> String {
self.id.clone()
}

fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
vec![
("id", Box::new(self.id.clone())),
("content", Box::new(self.content.clone())),
]
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
Expand All @@ -35,29 +67,38 @@ async fn main() -> Result<(), anyhow::Error> {
// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let documents = vec![
Document {
id: "doc0".to_string(),
content: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(),
},
Document {
id: "doc1".to_string(),
content: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
},
Document {
id: "doc2".to_string(),
content: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
},
];

let embeddings = EmbeddingsBuilder::new(model.clone())
.simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build()
.await?;
.documents(documents)?
.build()
.await?;

// Initialize SQLite vector store
let mut vector_store = SqliteVectorStore::new(conn, &model).await?;
let vector_store = SqliteVectorStore::new(conn, &model).await?;

// Add embeddings to vector store
match vector_store.add_documents(embeddings).await {
Ok(_) => println!("Documents added successfully"),
Err(e) => println!("Error adding documents: {:?}", e),
}
vector_store.add_rows(embeddings).await?;

// Create a vector index on our vector store
// IMPORTANT: Reuse the same model that was used to generate the embeddings
let index = vector_store.index(model).await?;
let index = vector_store.index(model);

// Query the index
let results = index
.top_n::<String>("What is a linglingdong?", 1)
.top_n::<Document>("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc))
Expand Down
Loading

0 comments on commit 3a4d90b

Please sign in to comment.