diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e00d68cd..cf3409ff 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -4,7 +4,7 @@ name: Lint & Test on: pull_request: branches: - - main + - "**" push: branches: - main diff --git a/Cargo.lock b/Cargo.lock index 4b71fbe4..76f72265 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3305,9 +3305,9 @@ checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" [[package]] name = "openssl" -version = "0.10.64" +version = "0.10.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" dependencies = [ "bitflags 2.5.0", "cfg-if", @@ -3337,9 +3337,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.102" +version = "0.9.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c597637d56fbc83893a35eb0dd04b2b8e7a50c91e64e9493e398b5df4fb45fa2" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" dependencies = [ "cc", "libc", diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 04d26dc3..fb168a08 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -2,10 +2,10 @@ use anyhow::Result; use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -251,9 +251,20 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let mut store = InMemoryVectorStore::default(); - store.add_documents(embeddings).await?; - let index = store.index(embedding_model); + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 3abd8ee9..b4dee8a5 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -2,9 +2,9 @@ use std::env; use rig::{ completion::Prompt, - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, }; #[tokio::main] @@ -15,9 +15,6 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Create vector store, compute embeddings and load them in the store - let mut vector_store = InMemoryVectorStore::default(); - let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") @@ -25,10 +22,20 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - vector_store.add_documents(embeddings).await?; - - // Create vector store index - let index = vector_store.index(embedding_model); + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(embedding_model); let rag_agent = openai_client.agent("gpt-4") .preamble(" diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 6e45730b..cdf6b65e 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -1,10 +1,10 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, - embeddings::EmbeddingsBuilder, + embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -150,9 +150,6 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Create vector store, compute tool embeddings and load them in the store - let mut vector_store = InMemoryVectorStore::default(); - let toolset = ToolSet::builder() .dynamic_tool(Add) .dynamic_tool(Subtract) @@ -163,10 +160,20 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - vector_store.add_documents(embeddings).await?; - - // Create vector store index - let index = vector_store.index(embedding_model); + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index e0d68bef..45110606 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -3,7 +3,7 @@ use std::env; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, }; #[tokio::main] @@ -21,13 +21,26 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(model); let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index a49ac231..1e0180d3 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -3,7 +3,7 @@ use std::env; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::cohere::{Client, EMBED_ENGLISH_V3}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, }; #[tokio::main] @@ -15,8 +15,6 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let mut vector_store = InMemoryVectorStore::default(); - let embeddings = EmbeddingsBuilder::new(document_model) .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") @@ -24,15 +22,26 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - vector_store.add_documents(embeddings).await?; - - let index = vector_store.index(search_model); + let index = InMemoryVectorStore::default() + .add_documents( + embeddings + .into_iter() + .map( + |DocumentEmbeddings { + id, + document, + embeddings, + }| { (id, document, embeddings) }, + ) + .collect(), + )? + .index(search_model); let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 0e61b588..ec497ac4 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -7,28 +7,29 @@ use std::{ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; -use super::{VectorStore, VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder}; +use super::{VectorStoreError, VectorStoreIndex}; +use crate::embeddings::{Embedding, EmbeddingModel}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. -#[derive(Clone, Default, Deserialize, Serialize)] -pub struct InMemoryVectorStore { - /// The embeddings are stored in a HashMap with the document ID as the key. - embeddings: HashMap, +#[derive(Clone, Default)] +pub struct InMemoryVectorStore { + /// The embeddings are stored in a HashMap. + /// Hashmap key is the document id. + /// Hashmap value is a tuple of the serializable document and its corresponding embeddings. + embeddings: HashMap)>, } -impl InMemoryVectorStore { +impl InMemoryVectorStore { /// Implement vector search on InMemoryVectorStore. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. - fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { + fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { // Sort documents by best embedding distance - let mut docs: EmbeddingRanking = BinaryHeap::new(); + let mut docs = BinaryHeap::new(); - for (id, doc_embeddings) in self.embeddings.iter() { + for (id, (doc, embeddings)) in self.embeddings.iter() { // Get the best context for the document given the prompt - if let Some((distance, embed_doc)) = doc_embeddings - .embeddings + if let Some((distance, embed_doc)) = embeddings .iter() .map(|embedding| { ( @@ -38,12 +39,7 @@ impl InMemoryVectorStore { }) .max_by(|a, b| a.0.cmp(&b.0)) { - docs.push(Reverse(RankingItem( - distance, - id, - doc_embeddings, - embed_doc, - ))); + docs.push(Reverse(RankingItem(distance, id, doc, embed_doc))); }; // If the heap size exceeds n, pop the least old element. @@ -63,77 +59,57 @@ impl InMemoryVectorStore { docs } -} - -/// RankingItem(distance, document_id, document, embed_doc) -#[derive(Eq, PartialEq)] -struct RankingItem<'a>( - OrderedFloat, - &'a String, - &'a DocumentEmbeddings, - &'a String, -); - -impl Ord for RankingItem<'_> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.0.cmp(&other.0) - } -} -impl PartialOrd for RankingItem<'_> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -type EmbeddingRanking<'a> = BinaryHeap>>; - -impl VectorStore for InMemoryVectorStore { - type Q = (); - - async fn add_documents( - &mut self, - documents: Vec, - ) -> Result<(), VectorStoreError> { - for doc in documents { - self.embeddings.insert(doc.id.clone(), doc); + /// Add documents to the store. + /// Returns the store with the added documents. + pub fn add_documents( + mut self, + documents: Vec<(String, D, Vec)>, + ) -> Result { + for (id, doc, embeddings) in documents { + self.embeddings.insert(id, (doc, embeddings)); } - Ok(()) + Ok(self) } - async fn get_document Deserialize<'a>>( + /// Get the document by its id and deserialize it into the given type. + pub fn get_document Deserialize<'a>>( &self, id: &str, ) -> Result, VectorStoreError> { Ok(self .embeddings .get(id) - .map(|document| serde_json::from_value(document.document.clone())) + .map(|(doc, _)| serde_json::from_str(&serde_json::to_string(doc)?)) .transpose()?) } +} - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result, VectorStoreError> { - Ok(self.embeddings.get(id).cloned()) +/// RankingItem(distance, document_id, serializable document, embeddings document) +#[derive(Eq, PartialEq)] +struct RankingItem<'a, D: Serialize>(OrderedFloat, &'a String, &'a D, &'a String); + +impl Ord for RankingItem<'_, D> { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) } +} - async fn get_document_by_query( - &self, - _query: Self::Q, - ) -> Result, VectorStoreError> { - Ok(None) +impl PartialOrd for RankingItem<'_, D> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } -impl InMemoryVectorStore { - pub fn index(self, model: M) -> InMemoryVectorIndex { +type EmbeddingRanking<'a, D> = BinaryHeap>>; + +impl InMemoryVectorStore { + pub fn index(self, model: M) -> InMemoryVectorIndex { InMemoryVectorIndex::new(model, self) } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator))> { self.embeddings.iter() } @@ -144,54 +120,19 @@ impl InMemoryVectorStore { pub fn is_empty(&self) -> bool { self.embeddings.is_empty() } - - /// Uitilty method to create an InMemoryVectorStore from a list of embeddings. - pub async fn from_embeddings( - embeddings: Vec, - ) -> Result { - let mut store = Self::default(); - store.add_documents(embeddings).await?; - Ok(store) - } - - /// Create an InMemoryVectorStore from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - pub async fn from_documents( - embedding_model: M, - documents: &[(String, T)], - ) -> Result { - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - let store = Self::from_embeddings(embeddings).await?; - Ok(store) - } } -pub struct InMemoryVectorIndex { +pub struct InMemoryVectorIndex { model: M, - pub store: InMemoryVectorStore, + pub store: InMemoryVectorStore, } -impl InMemoryVectorIndex { - pub fn new(model: M, store: InMemoryVectorStore) -> Self { +impl InMemoryVectorIndex { + pub fn new(model: M, store: InMemoryVectorStore) -> Self { Self { model, store } } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator))> { self.store.iter() } @@ -202,49 +143,11 @@ impl InMemoryVectorIndex { pub fn is_empty(&self) -> bool { self.store.is_empty() } - - /// Create an InMemoryVectorIndex from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - /// The InMemoryVectorIndex is then created from the store and the provided query model. - pub async fn from_documents( - embedding_model: M, - query_model: M, - documents: &[(String, T)], - ) -> Result { - let mut store = InMemoryVectorStore::default(); - - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - store.add_documents(embeddings).await?; - Ok(store.index(query_model)) - } - - /// Utility method to create an InMemoryVectorIndex from a list of embeddings - /// and an embedding model. - pub async fn from_embeddings( - query_model: M, - embeddings: Vec, - ) -> Result { - let store = InMemoryVectorStore::from_embeddings(embeddings).await?; - Ok(store.index(query_model)) - } } -impl VectorStoreIndex for InMemoryVectorIndex { +impl VectorStoreIndex + for InMemoryVectorIndex +{ async fn top_n Deserialize<'a>>( &self, query: &str, @@ -256,12 +159,11 @@ impl VectorStoreIndex for InMemoryVectorI // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| { - let doc_value = serde_json::to_value(doc).map_err(VectorStoreError::JsonError)?; + .map(|Reverse(RankingItem(distance, id, doc, _))| { Ok(( distance.0, - doc.id.clone(), - serde_json::from_value(doc_value).map_err(VectorStoreError::JsonError)?, + id.clone(), + serde_json::from_str(&serde_json::to_string(doc)?)?, )) }) .collect::, _>>() @@ -278,7 +180,150 @@ impl VectorStoreIndex for InMemoryVectorI // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone()))) + .map(|Reverse(RankingItem(distance, id, _, _))| Ok((distance.0, id.clone()))) .collect::, _>>() } } + +#[cfg(test)] +mod tests { + use std::cmp::Reverse; + + use crate::embeddings::Embedding; + + use super::{InMemoryVectorStore, RankingItem}; + + #[test] + fn test_single_embedding() { + let index = InMemoryVectorStore::default() + .add_documents(vec![ + ( + "doc1".to_string(), + "glarb-garb", + vec![Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }], + ), + ( + "doc2".to_string(), + "marble-marble", + vec![Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }], + ), + ( + "doc3".to_string(), + "flumb-flumb", + vec![Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }], + ), + ]) + .unwrap(); + + let ranking = index.vector_search( + &Embedding { + document: "glarby-glarble".to_string(), + vec: vec![0.0, 0.1, 0.6], + }, + 1, + ); + + assert_eq!( + ranking + .into_iter() + .map(|Reverse(RankingItem(distance, id, doc, _))| { + ( + distance.0, + id.clone(), + serde_json::from_str(&serde_json::to_string(doc).unwrap()).unwrap(), + ) + }) + .collect::>(), + vec![( + 0.034444444444444444, + "doc1".to_string(), + "glarb-garb".to_string() + )] + ) + } + + #[test] + fn test_multiple_embeddings() { + let index = InMemoryVectorStore::default() + .add_documents(vec![ + ( + "doc1".to_string(), + "glarb-garb", + vec![ + Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }, + Embedding { + document: "don't-choose-me".to_string(), + vec: vec![-0.5, 0.9, 0.1], + }, + ], + ), + ( + "doc2".to_string(), + "marble-marble", + vec![ + Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }, + Embedding { + document: "sandwich".to_string(), + vec: vec![0.5, 0.5, -0.7], + }, + ], + ), + ( + "doc3".to_string(), + "flumb-flumb", + vec![ + Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }, + Embedding { + document: "banana".to_string(), + vec: vec![0.1, -0.5, -0.5], + }, + ], + ), + ]) + .unwrap(); + + let ranking = index.vector_search( + &Embedding { + document: "glarby-glarble".to_string(), + vec: vec![0.0, 0.1, 0.6], + }, + 1, + ); + + assert_eq!( + ranking + .into_iter() + .map(|Reverse(RankingItem(distance, id, doc, _))| { + ( + distance.0, + id.clone(), + serde_json::from_str(&serde_json::to_string(doc).unwrap()).unwrap(), + ) + }) + .collect::>(), + vec![( + 0.034444444444444444, + "doc1".to_string(), + "glarb-garb".to_string() + )] + ) + } +} diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index b07d348a..396b5514 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -2,7 +2,7 @@ use futures::future::BoxFuture; use serde::Deserialize; use serde_json::Value; -use crate::embeddings::{DocumentEmbeddings, EmbeddingError}; +use crate::embeddings::EmbeddingError; pub mod in_memory_store; @@ -19,36 +19,6 @@ pub enum VectorStoreError { DatastoreError(#[from] Box), } -/// Trait for vector stores -pub trait VectorStore: Send + Sync { - /// Query type for the vector store - type Q; - - /// Add a list of documents to the vector store - fn add_documents( - &mut self, - documents: Vec, - ) -> impl std::future::Future> + Send; - - /// Get the embeddings of a document by its id - fn get_document_embeddings( - &self, - id: &str, - ) -> impl std::future::Future, VectorStoreError>> + Send; - - /// Get the document by its id and deserialize it into the given type - fn get_document Deserialize<'a>>( - &self, - id: &str, - ) -> impl std::future::Future, VectorStoreError>> + Send; - - /// Get the document by a query and deserialize it into the given type - fn get_document_by_query( - &self, - query: Self::Q, - ) -> impl std::future::Future, VectorStoreError>> + Send; -} - /// Trait for vector store indexes pub trait VectorStoreIndex: Send + Sync { /// Get the top n documents based on the distance to the given query. diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 3d062de3..0d31aaa2 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -4,7 +4,7 @@ use std::env; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndex}, + vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; @@ -29,8 +29,6 @@ async fn main() -> Result<(), anyhow::Error> { .database("knowledgebase") .collection("context"); - let mut vector_store = MongoDbVectorStore::new(collection); - // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); @@ -41,15 +39,15 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Add embeddings to vector store - match vector_store.add_documents(embeddings).await { + match collection.insert_many(embeddings, None).await { Ok(_) => println!("Documents added successfully"), Err(e) => println!("Error adding documents: {:?}", e), } // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "vector_index", SearchParams::default()); + let index = + MongoDbVectorStore::new(collection).index(model, "vector_index", SearchParams::default()); // Query the index let results = index diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 43869989..30dd9e95 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -3,7 +3,7 @@ use mongodb::bson::{self, doc}; use rig::{ embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, - vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, + vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; @@ -16,67 +16,6 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -impl VectorStore for MongoDbVectorStore { - type Q = mongodb::bson::Document; - - async fn add_documents( - &mut self, - documents: Vec, - ) -> Result<(), VectorStoreError> { - self.collection - .insert_many(documents, None) - .await - .map_err(mongodb_to_rig_error)?; - Ok(()) - } - - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result, VectorStoreError> { - self.collection - .find_one(doc! { "_id": id }, None) - .await - .map_err(mongodb_to_rig_error) - } - - async fn get_document serde::Deserialize<'a>>( - &self, - id: &str, - ) -> Result, VectorStoreError> { - Ok(self - .collection - .clone_with_type::() - .aggregate( - [ - doc! {"$match": { "_id": id}}, - doc! {"$project": { "document": 1 }}, - doc! {"$replaceRoot": { "newRoot": "$document" }}, - ], - None, - ) - .await - .map_err(mongodb_to_rig_error)? - .with_type::() - .next() - .await - .transpose() - .map_err(mongodb_to_rig_error)? - .map(|doc| serde_json::from_str(&doc)) - .transpose()?) - } - - async fn get_document_by_query( - &self, - query: Self::Q, - ) -> Result, VectorStoreError> { - self.collection - .find_one(query, None) - .await - .map_err(mongodb_to_rig_error) - } -} - impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. pub fn new(collection: mongodb::Collection) -> Self {