From c269d3552b892f883a683b7bed05b928cf829bbc Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 14:07:40 -0500 Subject: [PATCH 1/3] feat: Improve `InMemoryVectorStore` API --- rig-core/examples/calculator_chatbot.rs | 5 +- rig-core/examples/rag.rs | 8 +- rig-core/examples/rag_dynamic_tools.rs | 8 +- rig-core/examples/vector_search.rs | 15 +- rig-core/examples/vector_search_cohere.rs | 11 +- rig-core/src/pipeline/builder.rs | 261 +++++++++++++++++++ rig-core/src/vector_store/in_memory_store.rs | 253 ++++++++++++++---- 7 files changed, 493 insertions(+), 68 deletions(-) create mode 100644 rig-core/src/pipeline/builder.rs diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 149b1ce4..f1107312 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -251,9 +251,8 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |tool| tool.name.clone())? - .index(embedding_model); + let vector_store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone()); + let index = vector_store.index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 376c37db..fcbbd20f 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -65,9 +65,11 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |definition| definition.id.clone())? - .index(embedding_model); + // Create vector store with the embeddings + let vector_store = InMemoryVectorStore::from_documents(embeddings); + + // Create vector store index + let index = vector_store.index(embedding_model); let rag_agent = openai_client.agent("gpt-4") .preamble(" diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index bc92f7c5..48b7dbeb 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -155,9 +155,11 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |tool| tool.name.clone())? - .index(embedding_model); + // Create vector store with the embeddings + let vector_store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone()); + + // Create vector store index + let index = vector_store.index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index e8cbf894..c0799e60 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -24,9 +24,9 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let embeddings = EmbeddingsBuilder::new(model.clone()) + let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(vec![ WordDefinition { id: "doc0".to_string(), @@ -56,9 +56,14 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |definition| definition.id.clone())? - .index(model); + // Create vector store with the embeddings + let vector_store = InMemoryVectorStore::from_documents_with_id_f( + embeddings, + |doc| doc.id.clone(), + ); + + // Create vector store index + let index = vector_store.index(embedding_model); let results = index .top_n::("I need to buy something in a fictional universe. What type of money can I use for this?", 1) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index aace89fa..628b69dc 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -57,9 +57,14 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |definition| definition.id.clone())? - .index(search_model); + // Create vector store with the embeddings + let vector_store = InMemoryVectorStore::from_documents_with_id_f( + embeddings, + |doc| doc.id.clone() + ); + + // Create vector store index + let index = vector_store.index(search_model); let results = index .top_n::( diff --git a/rig-core/src/pipeline/builder.rs b/rig-core/src/pipeline/builder.rs new file mode 100644 index 00000000..61ce3419 --- /dev/null +++ b/rig-core/src/pipeline/builder.rs @@ -0,0 +1,261 @@ +use std::future::Future; + +use crate::{completion, vector_store}; + +use super::{agent_ops, op}; + +// pub struct PipelineBuilder { +// _error: std::marker::PhantomData, +// } +pub struct PipelineBuilder; + +impl PipelineBuilder { + /// Chain a function to the current pipeline + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// let chain = pipeline::new() + /// .map(|(x, y)| x + y) + /// .map(|z| format!("Result: {z}!")); + /// + /// let result = chain.call((1, 2)).await; + /// assert_eq!(result, "Result: 3!"); + /// ``` + pub fn map(self, f: F) -> op::Map + where + F: Fn(In) -> T + Send + Sync, + In: Send + Sync, + T: Send + Sync, + Self: Sized, + { + op::Map::new(f) + } + + /// Same as `map` but for asynchronous functions + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// let chain = pipeline::new() + /// .then(|email: String| async move { + /// email.split('@').next().unwrap().to_string() + /// }) + /// .then(|username: String| async move { + /// format!("Hello, {}!", username) + /// }); + /// + /// let result = chain.call("bob@gmail.com".to_string()).await; + /// assert_eq!(result, "Hello, bob!"); + /// ``` + pub fn then(self, f: F) -> op::Then + where + F: Fn(In) -> Fut + Send + Sync, + In: Send + Sync, + Fut: Future + Send + Sync, + Fut::Output: Send + Sync, + Self: Sized, + { + op::Then::new(f) + } + + /// Chain an arbitrary operation to the current pipeline. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// struct MyOp; + /// + /// impl Op for MyOp { + /// type Input = i32; + /// type Output = i32; + /// + /// async fn call(&self, input: Self::Input) -> Self::Output { + /// input + 1 + /// } + /// } + /// + /// let chain = pipeline::new() + /// .chain(MyOp); + /// + /// let result = chain.call(1).await; + /// assert_eq!(result, 2); + /// ``` + pub fn chain(self, op: T) -> T + where + T: op::Op, + Self: Sized, + { + op + } + + /// Chain a lookup operation to the current chain. The lookup operation expects the + /// current chain to output a query string. The lookup operation will use the query to + /// retrieve the top `n` documents from the index and return them with the query string. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .lookup(index, 2) + /// .chain(|(query, docs): (_, Vec)| async move { + /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) + /// }); + /// + /// let result = chain.call("What is a flurbo?".to_string()).await; + /// ``` + pub fn lookup( + self, + index: I, + n: usize, + ) -> agent_ops::Lookup + where + I: vector_store::VectorStoreIndex, + T: Send + Sync + for<'a> serde::Deserialize<'a>, + In: Into + Send + Sync, + Self: Sized, + { + agent_ops::Lookup::new(index, n) + } + + /// Chain a prompt operation to the current chain. The prompt operation expects the + /// current chain to output a string. The prompt operation will use the string to prompt + /// the given agent (or any other type that implements the `Prompt` trait) and return + /// the response. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let agent = &openai_client.agent("gpt-4").build(); + /// + /// let chain = chain::new() + /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) + /// .prompt(agent); + /// + /// let result = chain.call("Alice".to_string()).await; + /// ``` + pub fn prompt( + self, + prompt: P, + ) -> impl op::Op> + where + P: completion::Prompt, + In: Into + Send + Sync, + Self: Sized, + { + agent_ops::prompt(prompt) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ChainError { + #[error("Failed to prompt agent: {0}")] + PromptError(#[from] completion::PromptError), + + #[error("Failed to lookup documents: {0}")] + LookupError(#[from] vector_store::VectorStoreError), +} + +// pub fn new() -> PipelineBuilder { +// PipelineBuilder { +// _error: std::marker::PhantomData, +// } +// } + +// pub fn with_error() -> PipelineBuilder { +// PipelineBuilder { +// _error: std::marker::PhantomData, +// } +// } + +pub fn new() -> PipelineBuilder { + PipelineBuilder +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::pipeline::{op::Op, parallel::{parallel, Parallel}}; + use agent_ops::tests::{Foo, MockIndex, MockModel}; + + #[tokio::test] + async fn test_prompt_pipeline() { + let model = MockModel; + + let chain = super::new() + .map(|input| format!("User query: {}", input)) + .prompt(model); + + let result = chain + .call("What is a flurbo?") + .await + .expect("Failed to run chain"); + + assert_eq!(result, "Mock response: User query: What is a flurbo?"); + } + + // #[tokio::test] + // async fn test_lookup_pipeline() { + // let index = MockIndex; + + // let chain = super::new() + // .lookup::<_, _, Foo>(index, 1) + // .map_ok(|docs| format!("Top documents:\n{}", docs[0].foo)); + + // let result = chain + // .try_call("What is a flurbo?") + // .await + // .expect("Failed to run chain"); + + // assert_eq!( + // result, + // "User query: What is a flurbo?\n\nTop documents:\nbar" + // ); + // } + + #[tokio::test] + async fn test_rag_pipeline() { + let index = MockIndex; + + let chain = super::new() + .chain(parallel!(op::passthrough(), agent_ops::lookup::<_, _, Foo>(index, 1),)) + .map(|(query, maybe_docs)| match maybe_docs { + Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo), + Err(err) => format!("Error: {}", err), + }) + .prompt(MockModel); + + let result = chain + .call("What is a flurbo?") + .await + .expect("Failed to run chain"); + + assert_eq!( + result, + "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar" + ); + } + + #[tokio::test] + async fn test_parallel_chain_compile_check() { + let _ = super::new().chain( + Parallel::new( + op::map(|x: i32| x + 1), + Parallel::new( + op::map(|x: i32| x * 3), + Parallel::new( + op::map(|x: i32| format!("{} is the number!", x)), + op::map(|x: i32| x == 1), + ), + ), + ) + .map(|(r1, (r2, (r3, r4)))| (r1, r2, r3, r4)), + ); + } +} diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 931946eb..cb666b99 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -13,7 +13,7 @@ use crate::{ OneOrMany, }; -/// InMemoryVectorStore is a simple in-memory vector store that stores embeddings +/// [InMemoryVectorStore] is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. #[derive(Clone, Default)] pub struct InMemoryVectorStore { @@ -24,8 +24,48 @@ pub struct InMemoryVectorStore { } impl InMemoryVectorStore { - /// Implement vector search on InMemoryVectorStore. - /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. + /// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings. + /// Ids are automatically generated have will have the form `"doc{n}"` where `n` + /// is the index of the document. + pub fn from_documents(documents: impl IntoIterator)>) -> Self { + let mut store = HashMap::new(); + documents.into_iter() + .enumerate() + .for_each(|(i, (doc, embeddings))| { + store.insert(format!("doc{i}"), (doc, embeddings)); + }); + + Self { embeddings: store } + } + + /// Create a new [InMemoryVectorStore] from documents and and their corresponding embeddings with ids. + pub fn from_documents_with_ids(documents: impl IntoIterator)>) -> Self { + let mut store = HashMap::new(); + documents.into_iter() + .for_each(|(i, doc, embeddings)| { + store.insert(i.to_string(), (doc, embeddings)); + }); + + Self { embeddings: store } + } + + /// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings. + /// Document ids are generated using the provided function. + pub fn from_documents_with_id_f( + documents: impl IntoIterator)>, + f: fn(&D) -> String, + ) -> Self { + let mut store = HashMap::new(); + documents.into_iter() + .for_each(|(doc, embeddings)| { + store.insert(f(&doc), (doc, embeddings)); + }); + + Self { embeddings: store } + } + + /// Implement vector search on [InMemoryVectorStore]. + /// To be used by implementations of [VectorStoreIndex::top_n] and [VectorStoreIndex::top_n_ids] methods. fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { // Sort documents by best embedding distance let mut docs = BinaryHeap::new(); @@ -63,32 +103,43 @@ impl InMemoryVectorStore { docs } - /// Add documents to the store. - /// Returns the store with the added documents. + /// Add documents and their corresponding embeddings to the store. + /// Ids are automatically generated have will have the form `"doc{n}"` where `n` + /// is the index of the document. pub fn add_documents( - mut self, - documents: Vec<(String, D, OneOrMany)>, - ) -> Result { - for (id, doc, embeddings) in documents { - self.embeddings.insert(id, (doc, embeddings)); - } + &mut self, + documents: impl IntoIterator)>, + ) -> () { + let current_index = self.embeddings.len(); + documents.into_iter() + .enumerate() + .for_each(|(index, (doc, embeddings))| { + self.embeddings.insert(format!("doc{}", index + current_index), (doc, embeddings)); + }); + } - Ok(self) + /// Add documents and their corresponding embeddings to the store with ids. + pub fn add_documents_with_ids( + &mut self, + documents: impl IntoIterator)>, + ) -> () { + documents.into_iter() + .for_each(|(id, doc, embeddings)| { + self.embeddings.insert(id.to_string(), (doc, embeddings)); + }); } - /// Add documents to the store. Define a function that takes as input the reference of the document and returns its id. - /// Returns the store with the added documents. - pub fn add_documents_with_id( - mut self, + /// Add documents and their corresponding embeddings to the store. + /// Document ids are generated using the provided function. + pub fn add_documents_with_id_f( + &mut self, documents: Vec<(D, OneOrMany)>, - id_f: fn(&D) -> String, - ) -> Result { + f: fn(&D) -> String, + ) -> () { for (doc, embeddings) in documents { - let id = id_f(&doc); + let id = f(&doc); self.embeddings.insert(id, (doc, embeddings)); } - - Ok(self) } /// Get the document by its id and deserialize it into the given type. @@ -210,42 +261,144 @@ impl VectorStoreIndex mod tests { use std::cmp::Reverse; - use crate::{embeddings::embedding::Embedding, OneOrMany}; + use crate::{embeddings::embedding::Embedding, vector_store, OneOrMany}; use super::{InMemoryVectorStore, RankingItem}; #[test] - fn test_single_embedding() { - let index = InMemoryVectorStore::default() - .add_documents(vec![ + fn test_auto_ids() { + let mut vector_store = InMemoryVectorStore::from_documents(vec![ + ( + "glarb-garb", + OneOrMany::one(Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }), + ), + ( + "marble-marble", + OneOrMany::one(Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }), + ), + ( + "flumb-flumb", + OneOrMany::one(Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }), + ), + ]); + + vector_store.add_documents(vec![ + ( + "brotato", + OneOrMany::one(Embedding { + document: "brotato".to_string(), + vec: vec![0.3, 0.7, 0.1], + }), + ), + ( + "ping-pong", + OneOrMany::one(Embedding { + document: "ping-pong".to_string(), + vec: vec![0.7, -0.3, 0.0], + }), + ), + ]); + + let mut store = vector_store.embeddings.into_iter() + .collect::>(); + store.sort_by_key(|(id, _)| id.clone()); + + assert_eq!( + store, + vec![ + ( + "doc0".to_string(), + ( + "glarb-garb", + OneOrMany::one(Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }) + ) + ), ( "doc1".to_string(), - "glarb-garb", - OneOrMany::one(Embedding { - document: "glarb-garb".to_string(), - vec: vec![0.1, 0.1, 0.5], - }), + ( + "marble-marble", + OneOrMany::one(Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }) + ) ), ( "doc2".to_string(), - "marble-marble", - OneOrMany::one(Embedding { - document: "marble-marble".to_string(), - vec: vec![0.7, -0.3, 0.0], - }), + ( + "flumb-flumb", + OneOrMany::one(Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }) + ) ), ( "doc3".to_string(), - "flumb-flumb", - OneOrMany::one(Embedding { - document: "flumb-flumb".to_string(), - vec: vec![0.3, 0.7, 0.1], - }), + ( + "brotato", + OneOrMany::one(Embedding { + document: "brotato".to_string(), + vec: vec![0.3, 0.7, 0.1], + }) + ) ), - ]) - .unwrap(); + ( + "doc4".to_string(), + ( + "ping-pong", + OneOrMany::one(Embedding { + document: "ping-pong".to_string(), + vec: vec![0.7, -0.3, 0.0], + }) + ) + ) + ] + ); + } - let ranking = index.vector_search( + #[test] + fn test_single_embedding() { + let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![ + ( + "doc1", + "glarb-garb", + OneOrMany::one(Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }), + ), + ( + "doc2", + "marble-marble", + OneOrMany::one(Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }), + ), + ( + "doc3", + "flumb-flumb", + OneOrMany::one(Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }), + ), + ]); + + let ranking = vector_store.vector_search( &Embedding { document: "glarby-glarble".to_string(), vec: vec![0.0, 0.1, 0.6], @@ -274,10 +427,9 @@ mod tests { #[test] fn test_multiple_embeddings() { - let index = InMemoryVectorStore::default() - .add_documents(vec![ + let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![ ( - "doc1".to_string(), + "doc1", "glarb-garb", OneOrMany::many(vec![ Embedding { @@ -292,7 +444,7 @@ mod tests { .unwrap(), ), ( - "doc2".to_string(), + "doc2", "marble-marble", OneOrMany::many(vec![ Embedding { @@ -307,7 +459,7 @@ mod tests { .unwrap(), ), ( - "doc3".to_string(), + "doc3", "flumb-flumb", OneOrMany::many(vec![ Embedding { @@ -321,10 +473,9 @@ mod tests { ]) .unwrap(), ), - ]) - .unwrap(); + ]); - let ranking = index.vector_search( + let ranking = vector_store.vector_search( &Embedding { document: "glarby-glarble".to_string(), vec: vec![0.0, 0.1, 0.6], From 564b9a63b311e4642d3497c2bdd338d24b1877a8 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 14:09:19 -0500 Subject: [PATCH 2/3] style: clippy+fmt --- rig-core/examples/calculator_chatbot.rs | 3 +- rig-core/examples/rag_dynamic_tools.rs | 3 +- rig-core/examples/vector_search.rs | 6 +- rig-core/examples/vector_search_cohere.rs | 6 +- rig-core/src/vector_store/in_memory_store.rs | 137 ++++++++++--------- 5 files changed, 77 insertions(+), 78 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index f1107312..97d923f8 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -251,7 +251,8 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let vector_store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone()); + let vector_store = + InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone()); let index = vector_store.index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 48b7dbeb..19ea2064 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -156,7 +156,8 @@ async fn main() -> Result<(), anyhow::Error> { .await?; // Create vector store with the embeddings - let vector_store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone()); + let vector_store = + InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone()); // Create vector store index let index = vector_store.index(embedding_model); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index c0799e60..ca840029 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -57,10 +57,8 @@ async fn main() -> Result<(), anyhow::Error> { .await?; // Create vector store with the embeddings - let vector_store = InMemoryVectorStore::from_documents_with_id_f( - embeddings, - |doc| doc.id.clone(), - ); + let vector_store = + InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone()); // Create vector store index let index = vector_store.index(embedding_model); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 628b69dc..4ee621bd 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -58,10 +58,8 @@ async fn main() -> Result<(), anyhow::Error> { .await?; // Create vector store with the embeddings - let vector_store = InMemoryVectorStore::from_documents_with_id_f( - embeddings, - |doc| doc.id.clone() - ); + let vector_store = + InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone()); // Create vector store index let index = vector_store.index(search_model); diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index cb666b99..32e55bcf 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -25,11 +25,12 @@ pub struct InMemoryVectorStore { impl InMemoryVectorStore { /// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings. - /// Ids are automatically generated have will have the form `"doc{n}"` where `n` + /// Ids are automatically generated have will have the form `"doc{n}"` where `n` /// is the index of the document. pub fn from_documents(documents: impl IntoIterator)>) -> Self { let mut store = HashMap::new(); - documents.into_iter() + documents + .into_iter() .enumerate() .for_each(|(i, (doc, embeddings))| { store.insert(format!("doc{i}"), (doc, embeddings)); @@ -39,12 +40,13 @@ impl InMemoryVectorStore { } /// Create a new [InMemoryVectorStore] from documents and and their corresponding embeddings with ids. - pub fn from_documents_with_ids(documents: impl IntoIterator)>) -> Self { + pub fn from_documents_with_ids( + documents: impl IntoIterator)>, + ) -> Self { let mut store = HashMap::new(); - documents.into_iter() - .for_each(|(i, doc, embeddings)| { - store.insert(i.to_string(), (doc, embeddings)); - }); + documents.into_iter().for_each(|(i, doc, embeddings)| { + store.insert(i.to_string(), (doc, embeddings)); + }); Self { embeddings: store } } @@ -56,10 +58,9 @@ impl InMemoryVectorStore { f: fn(&D) -> String, ) -> Self { let mut store = HashMap::new(); - documents.into_iter() - .for_each(|(doc, embeddings)| { - store.insert(f(&doc), (doc, embeddings)); - }); + documents.into_iter().for_each(|(doc, embeddings)| { + store.insert(f(&doc), (doc, embeddings)); + }); Self { embeddings: store } } @@ -109,12 +110,14 @@ impl InMemoryVectorStore { pub fn add_documents( &mut self, documents: impl IntoIterator)>, - ) -> () { + ) { let current_index = self.embeddings.len(); - documents.into_iter() + documents + .into_iter() .enumerate() .for_each(|(index, (doc, embeddings))| { - self.embeddings.insert(format!("doc{}", index + current_index), (doc, embeddings)); + self.embeddings + .insert(format!("doc{}", index + current_index), (doc, embeddings)); }); } @@ -122,11 +125,10 @@ impl InMemoryVectorStore { pub fn add_documents_with_ids( &mut self, documents: impl IntoIterator)>, - ) -> () { - documents.into_iter() - .for_each(|(id, doc, embeddings)| { - self.embeddings.insert(id.to_string(), (doc, embeddings)); - }); + ) { + documents.into_iter().for_each(|(id, doc, embeddings)| { + self.embeddings.insert(id.to_string(), (doc, embeddings)); + }); } /// Add documents and their corresponding embeddings to the store. @@ -135,7 +137,7 @@ impl InMemoryVectorStore { &mut self, documents: Vec<(D, OneOrMany)>, f: fn(&D) -> String, - ) -> () { + ) { for (doc, embeddings) in documents { let id = f(&doc); self.embeddings.insert(id, (doc, embeddings)); @@ -308,8 +310,7 @@ mod tests { ), ]); - let mut store = vector_store.embeddings.into_iter() - .collect::>(); + let mut store = vector_store.embeddings.into_iter().collect::>(); store.sort_by_key(|(id, _)| id.clone()); assert_eq!( @@ -428,52 +429,52 @@ mod tests { #[test] fn test_multiple_embeddings() { let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![ - ( - "doc1", - "glarb-garb", - OneOrMany::many(vec![ - Embedding { - document: "glarb-garb".to_string(), - vec: vec![0.1, 0.1, 0.5], - }, - Embedding { - document: "don't-choose-me".to_string(), - vec: vec![-0.5, 0.9, 0.1], - }, - ]) - .unwrap(), - ), - ( - "doc2", - "marble-marble", - OneOrMany::many(vec![ - Embedding { - document: "marble-marble".to_string(), - vec: vec![0.7, -0.3, 0.0], - }, - Embedding { - document: "sandwich".to_string(), - vec: vec![0.5, 0.5, -0.7], - }, - ]) - .unwrap(), - ), - ( - "doc3", - "flumb-flumb", - OneOrMany::many(vec![ - Embedding { - document: "flumb-flumb".to_string(), - vec: vec![0.3, 0.7, 0.1], - }, - Embedding { - document: "banana".to_string(), - vec: vec![0.1, -0.5, -0.5], - }, - ]) - .unwrap(), - ), - ]); + ( + "doc1", + "glarb-garb", + OneOrMany::many(vec![ + Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }, + Embedding { + document: "don't-choose-me".to_string(), + vec: vec![-0.5, 0.9, 0.1], + }, + ]) + .unwrap(), + ), + ( + "doc2", + "marble-marble", + OneOrMany::many(vec![ + Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }, + Embedding { + document: "sandwich".to_string(), + vec: vec![0.5, 0.5, -0.7], + }, + ]) + .unwrap(), + ), + ( + "doc3", + "flumb-flumb", + OneOrMany::many(vec![ + Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }, + Embedding { + document: "banana".to_string(), + vec: vec![0.1, -0.5, -0.5], + }, + ]) + .unwrap(), + ), + ]); let ranking = vector_store.vector_search( &Embedding { From 8a88d55da812d5e92e2fdb25d43c5ddfcf03ac4d Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 14:11:29 -0500 Subject: [PATCH 3/3] test: fix test --- rig-core/src/vector_store/in_memory_store.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 32e55bcf..aae0a256 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -263,7 +263,7 @@ impl VectorStoreIndex mod tests { use std::cmp::Reverse; - use crate::{embeddings::embedding::Embedding, vector_store, OneOrMany}; + use crate::{embeddings::embedding::Embedding, OneOrMany}; use super::{InMemoryVectorStore, RankingItem};