diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 149b1ce4..97d923f8 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -251,9 +251,9 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |tool| tool.name.clone())? - .index(embedding_model); + let vector_store = + InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone()); + let index = vector_store.index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 376c37db..fcbbd20f 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -65,9 +65,11 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |definition| definition.id.clone())? - .index(embedding_model); + // Create vector store with the embeddings + let vector_store = InMemoryVectorStore::from_documents(embeddings); + + // Create vector store index + let index = vector_store.index(embedding_model); let rag_agent = openai_client.agent("gpt-4") .preamble(" diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index bc92f7c5..19ea2064 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -155,9 +155,12 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |tool| tool.name.clone())? - .index(embedding_model); + // Create vector store with the embeddings + let vector_store = + InMemoryVectorStore::from_documents_with_id_f(embeddings, |tool| tool.name.clone()); + + // Create vector store index + let index = vector_store.index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index e8cbf894..ca840029 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -24,9 +24,9 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let embeddings = EmbeddingsBuilder::new(model.clone()) + let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .documents(vec![ WordDefinition { id: "doc0".to_string(), @@ -56,9 +56,12 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |definition| definition.id.clone())? - .index(model); + // Create vector store with the embeddings + let vector_store = + InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone()); + + // Create vector store index + let index = vector_store.index(embedding_model); let results = index .top_n::("I need to buy something in a fictional universe. What type of money can I use for this?", 1) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index aace89fa..4ee621bd 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -57,9 +57,12 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let index = InMemoryVectorStore::default() - .add_documents_with_id(embeddings, |definition| definition.id.clone())? - .index(search_model); + // Create vector store with the embeddings + let vector_store = + InMemoryVectorStore::from_documents_with_id_f(embeddings, |doc| doc.id.clone()); + + // Create vector store index + let index = vector_store.index(search_model); let results = index .top_n::( diff --git a/rig-core/src/pipeline/builder.rs b/rig-core/src/pipeline/builder.rs new file mode 100644 index 00000000..61ce3419 --- /dev/null +++ b/rig-core/src/pipeline/builder.rs @@ -0,0 +1,261 @@ +use std::future::Future; + +use crate::{completion, vector_store}; + +use super::{agent_ops, op}; + +// pub struct PipelineBuilder { +// _error: std::marker::PhantomData, +// } +pub struct PipelineBuilder; + +impl PipelineBuilder { + /// Chain a function to the current pipeline + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// let chain = pipeline::new() + /// .map(|(x, y)| x + y) + /// .map(|z| format!("Result: {z}!")); + /// + /// let result = chain.call((1, 2)).await; + /// assert_eq!(result, "Result: 3!"); + /// ``` + pub fn map(self, f: F) -> op::Map + where + F: Fn(In) -> T + Send + Sync, + In: Send + Sync, + T: Send + Sync, + Self: Sized, + { + op::Map::new(f) + } + + /// Same as `map` but for asynchronous functions + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// let chain = pipeline::new() + /// .then(|email: String| async move { + /// email.split('@').next().unwrap().to_string() + /// }) + /// .then(|username: String| async move { + /// format!("Hello, {}!", username) + /// }); + /// + /// let result = chain.call("bob@gmail.com".to_string()).await; + /// assert_eq!(result, "Hello, bob!"); + /// ``` + pub fn then(self, f: F) -> op::Then + where + F: Fn(In) -> Fut + Send + Sync, + In: Send + Sync, + Fut: Future + Send + Sync, + Fut::Output: Send + Sync, + Self: Sized, + { + op::Then::new(f) + } + + /// Chain an arbitrary operation to the current pipeline. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// struct MyOp; + /// + /// impl Op for MyOp { + /// type Input = i32; + /// type Output = i32; + /// + /// async fn call(&self, input: Self::Input) -> Self::Output { + /// input + 1 + /// } + /// } + /// + /// let chain = pipeline::new() + /// .chain(MyOp); + /// + /// let result = chain.call(1).await; + /// assert_eq!(result, 2); + /// ``` + pub fn chain(self, op: T) -> T + where + T: op::Op, + Self: Sized, + { + op + } + + /// Chain a lookup operation to the current chain. The lookup operation expects the + /// current chain to output a query string. The lookup operation will use the query to + /// retrieve the top `n` documents from the index and return them with the query string. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .lookup(index, 2) + /// .chain(|(query, docs): (_, Vec)| async move { + /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) + /// }); + /// + /// let result = chain.call("What is a flurbo?".to_string()).await; + /// ``` + pub fn lookup( + self, + index: I, + n: usize, + ) -> agent_ops::Lookup + where + I: vector_store::VectorStoreIndex, + T: Send + Sync + for<'a> serde::Deserialize<'a>, + In: Into + Send + Sync, + Self: Sized, + { + agent_ops::Lookup::new(index, n) + } + + /// Chain a prompt operation to the current chain. The prompt operation expects the + /// current chain to output a string. The prompt operation will use the string to prompt + /// the given agent (or any other type that implements the `Prompt` trait) and return + /// the response. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let agent = &openai_client.agent("gpt-4").build(); + /// + /// let chain = chain::new() + /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) + /// .prompt(agent); + /// + /// let result = chain.call("Alice".to_string()).await; + /// ``` + pub fn prompt( + self, + prompt: P, + ) -> impl op::Op> + where + P: completion::Prompt, + In: Into + Send + Sync, + Self: Sized, + { + agent_ops::prompt(prompt) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ChainError { + #[error("Failed to prompt agent: {0}")] + PromptError(#[from] completion::PromptError), + + #[error("Failed to lookup documents: {0}")] + LookupError(#[from] vector_store::VectorStoreError), +} + +// pub fn new() -> PipelineBuilder { +// PipelineBuilder { +// _error: std::marker::PhantomData, +// } +// } + +// pub fn with_error() -> PipelineBuilder { +// PipelineBuilder { +// _error: std::marker::PhantomData, +// } +// } + +pub fn new() -> PipelineBuilder { + PipelineBuilder +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::pipeline::{op::Op, parallel::{parallel, Parallel}}; + use agent_ops::tests::{Foo, MockIndex, MockModel}; + + #[tokio::test] + async fn test_prompt_pipeline() { + let model = MockModel; + + let chain = super::new() + .map(|input| format!("User query: {}", input)) + .prompt(model); + + let result = chain + .call("What is a flurbo?") + .await + .expect("Failed to run chain"); + + assert_eq!(result, "Mock response: User query: What is a flurbo?"); + } + + // #[tokio::test] + // async fn test_lookup_pipeline() { + // let index = MockIndex; + + // let chain = super::new() + // .lookup::<_, _, Foo>(index, 1) + // .map_ok(|docs| format!("Top documents:\n{}", docs[0].foo)); + + // let result = chain + // .try_call("What is a flurbo?") + // .await + // .expect("Failed to run chain"); + + // assert_eq!( + // result, + // "User query: What is a flurbo?\n\nTop documents:\nbar" + // ); + // } + + #[tokio::test] + async fn test_rag_pipeline() { + let index = MockIndex; + + let chain = super::new() + .chain(parallel!(op::passthrough(), agent_ops::lookup::<_, _, Foo>(index, 1),)) + .map(|(query, maybe_docs)| match maybe_docs { + Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo), + Err(err) => format!("Error: {}", err), + }) + .prompt(MockModel); + + let result = chain + .call("What is a flurbo?") + .await + .expect("Failed to run chain"); + + assert_eq!( + result, + "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar" + ); + } + + #[tokio::test] + async fn test_parallel_chain_compile_check() { + let _ = super::new().chain( + Parallel::new( + op::map(|x: i32| x + 1), + Parallel::new( + op::map(|x: i32| x * 3), + Parallel::new( + op::map(|x: i32| format!("{} is the number!", x)), + op::map(|x: i32| x == 1), + ), + ), + ) + .map(|(r1, (r2, (r3, r4)))| (r1, r2, r3, r4)), + ); + } +} diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 931946eb..aae0a256 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -13,7 +13,7 @@ use crate::{ OneOrMany, }; -/// InMemoryVectorStore is a simple in-memory vector store that stores embeddings +/// [InMemoryVectorStore] is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. #[derive(Clone, Default)] pub struct InMemoryVectorStore { @@ -24,8 +24,49 @@ pub struct InMemoryVectorStore { } impl InMemoryVectorStore { - /// Implement vector search on InMemoryVectorStore. - /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. + /// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings. + /// Ids are automatically generated have will have the form `"doc{n}"` where `n` + /// is the index of the document. + pub fn from_documents(documents: impl IntoIterator)>) -> Self { + let mut store = HashMap::new(); + documents + .into_iter() + .enumerate() + .for_each(|(i, (doc, embeddings))| { + store.insert(format!("doc{i}"), (doc, embeddings)); + }); + + Self { embeddings: store } + } + + /// Create a new [InMemoryVectorStore] from documents and and their corresponding embeddings with ids. + pub fn from_documents_with_ids( + documents: impl IntoIterator)>, + ) -> Self { + let mut store = HashMap::new(); + documents.into_iter().for_each(|(i, doc, embeddings)| { + store.insert(i.to_string(), (doc, embeddings)); + }); + + Self { embeddings: store } + } + + /// Create a new [InMemoryVectorStore] from documents and their corresponding embeddings. + /// Document ids are generated using the provided function. + pub fn from_documents_with_id_f( + documents: impl IntoIterator)>, + f: fn(&D) -> String, + ) -> Self { + let mut store = HashMap::new(); + documents.into_iter().for_each(|(doc, embeddings)| { + store.insert(f(&doc), (doc, embeddings)); + }); + + Self { embeddings: store } + } + + /// Implement vector search on [InMemoryVectorStore]. + /// To be used by implementations of [VectorStoreIndex::top_n] and [VectorStoreIndex::top_n_ids] methods. fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { // Sort documents by best embedding distance let mut docs = BinaryHeap::new(); @@ -63,32 +104,44 @@ impl InMemoryVectorStore { docs } - /// Add documents to the store. - /// Returns the store with the added documents. + /// Add documents and their corresponding embeddings to the store. + /// Ids are automatically generated have will have the form `"doc{n}"` where `n` + /// is the index of the document. pub fn add_documents( - mut self, - documents: Vec<(String, D, OneOrMany)>, - ) -> Result { - for (id, doc, embeddings) in documents { - self.embeddings.insert(id, (doc, embeddings)); - } + &mut self, + documents: impl IntoIterator)>, + ) { + let current_index = self.embeddings.len(); + documents + .into_iter() + .enumerate() + .for_each(|(index, (doc, embeddings))| { + self.embeddings + .insert(format!("doc{}", index + current_index), (doc, embeddings)); + }); + } - Ok(self) + /// Add documents and their corresponding embeddings to the store with ids. + pub fn add_documents_with_ids( + &mut self, + documents: impl IntoIterator)>, + ) { + documents.into_iter().for_each(|(id, doc, embeddings)| { + self.embeddings.insert(id.to_string(), (doc, embeddings)); + }); } - /// Add documents to the store. Define a function that takes as input the reference of the document and returns its id. - /// Returns the store with the added documents. - pub fn add_documents_with_id( - mut self, + /// Add documents and their corresponding embeddings to the store. + /// Document ids are generated using the provided function. + pub fn add_documents_with_id_f( + &mut self, documents: Vec<(D, OneOrMany)>, - id_f: fn(&D) -> String, - ) -> Result { + f: fn(&D) -> String, + ) { for (doc, embeddings) in documents { - let id = id_f(&doc); + let id = f(&doc); self.embeddings.insert(id, (doc, embeddings)); } - - Ok(self) } /// Get the document by its id and deserialize it into the given type. @@ -215,37 +268,138 @@ mod tests { use super::{InMemoryVectorStore, RankingItem}; #[test] - fn test_single_embedding() { - let index = InMemoryVectorStore::default() - .add_documents(vec![ + fn test_auto_ids() { + let mut vector_store = InMemoryVectorStore::from_documents(vec![ + ( + "glarb-garb", + OneOrMany::one(Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }), + ), + ( + "marble-marble", + OneOrMany::one(Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }), + ), + ( + "flumb-flumb", + OneOrMany::one(Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }), + ), + ]); + + vector_store.add_documents(vec![ + ( + "brotato", + OneOrMany::one(Embedding { + document: "brotato".to_string(), + vec: vec![0.3, 0.7, 0.1], + }), + ), + ( + "ping-pong", + OneOrMany::one(Embedding { + document: "ping-pong".to_string(), + vec: vec![0.7, -0.3, 0.0], + }), + ), + ]); + + let mut store = vector_store.embeddings.into_iter().collect::>(); + store.sort_by_key(|(id, _)| id.clone()); + + assert_eq!( + store, + vec![ + ( + "doc0".to_string(), + ( + "glarb-garb", + OneOrMany::one(Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }) + ) + ), ( "doc1".to_string(), - "glarb-garb", - OneOrMany::one(Embedding { - document: "glarb-garb".to_string(), - vec: vec![0.1, 0.1, 0.5], - }), + ( + "marble-marble", + OneOrMany::one(Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }) + ) ), ( "doc2".to_string(), - "marble-marble", - OneOrMany::one(Embedding { - document: "marble-marble".to_string(), - vec: vec![0.7, -0.3, 0.0], - }), + ( + "flumb-flumb", + OneOrMany::one(Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }) + ) ), ( "doc3".to_string(), - "flumb-flumb", - OneOrMany::one(Embedding { - document: "flumb-flumb".to_string(), - vec: vec![0.3, 0.7, 0.1], - }), + ( + "brotato", + OneOrMany::one(Embedding { + document: "brotato".to_string(), + vec: vec![0.3, 0.7, 0.1], + }) + ) ), - ]) - .unwrap(); + ( + "doc4".to_string(), + ( + "ping-pong", + OneOrMany::one(Embedding { + document: "ping-pong".to_string(), + vec: vec![0.7, -0.3, 0.0], + }) + ) + ) + ] + ); + } - let ranking = index.vector_search( + #[test] + fn test_single_embedding() { + let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![ + ( + "doc1", + "glarb-garb", + OneOrMany::one(Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }), + ), + ( + "doc2", + "marble-marble", + OneOrMany::one(Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }), + ), + ( + "doc3", + "flumb-flumb", + OneOrMany::one(Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }), + ), + ]); + + let ranking = vector_store.vector_search( &Embedding { document: "glarby-glarble".to_string(), vec: vec![0.0, 0.1, 0.6], @@ -274,57 +428,55 @@ mod tests { #[test] fn test_multiple_embeddings() { - let index = InMemoryVectorStore::default() - .add_documents(vec![ - ( - "doc1".to_string(), - "glarb-garb", - OneOrMany::many(vec![ - Embedding { - document: "glarb-garb".to_string(), - vec: vec![0.1, 0.1, 0.5], - }, - Embedding { - document: "don't-choose-me".to_string(), - vec: vec![-0.5, 0.9, 0.1], - }, - ]) - .unwrap(), - ), - ( - "doc2".to_string(), - "marble-marble", - OneOrMany::many(vec![ - Embedding { - document: "marble-marble".to_string(), - vec: vec![0.7, -0.3, 0.0], - }, - Embedding { - document: "sandwich".to_string(), - vec: vec![0.5, 0.5, -0.7], - }, - ]) - .unwrap(), - ), - ( - "doc3".to_string(), - "flumb-flumb", - OneOrMany::many(vec![ - Embedding { - document: "flumb-flumb".to_string(), - vec: vec![0.3, 0.7, 0.1], - }, - Embedding { - document: "banana".to_string(), - vec: vec![0.1, -0.5, -0.5], - }, - ]) - .unwrap(), - ), - ]) - .unwrap(); - - let ranking = index.vector_search( + let vector_store = InMemoryVectorStore::from_documents_with_ids(vec![ + ( + "doc1", + "glarb-garb", + OneOrMany::many(vec![ + Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }, + Embedding { + document: "don't-choose-me".to_string(), + vec: vec![-0.5, 0.9, 0.1], + }, + ]) + .unwrap(), + ), + ( + "doc2", + "marble-marble", + OneOrMany::many(vec![ + Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }, + Embedding { + document: "sandwich".to_string(), + vec: vec![0.5, 0.5, -0.7], + }, + ]) + .unwrap(), + ), + ( + "doc3", + "flumb-flumb", + OneOrMany::many(vec![ + Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }, + Embedding { + document: "banana".to_string(), + vec: vec![0.1, -0.5, -0.5], + }, + ]) + .unwrap(), + ), + ]); + + let ranking = vector_store.vector_search( &Embedding { document: "glarby-glarble".to_string(), vec: vec![0.0, 0.1, 0.6],