From e28bf0a4ec444f8671b9e7b056c0d7859adcc32c Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 15 Nov 2024 10:30:50 -0500 Subject: [PATCH 01/26] feat(chain): Initial prototype for agentic chain feature --- rig-core/examples/chain.rs | 61 +++++++ rig-core/src/agent.rs | 6 + rig-core/src/chain.rs | 320 +++++++++++++++++++++++++++++++++++++ rig-core/src/lib.rs | 1 + 4 files changed, 388 insertions(+) create mode 100644 rig-core/examples/chain.rs create mode 100644 rig-core/src/chain.rs diff --git a/rig-core/examples/chain.rs b/rig-core/examples/chain.rs new file mode 100644 index 00000000..24516bf4 --- /dev/null +++ b/rig-core/examples/chain.rs @@ -0,0 +1,61 @@ +use std::env; + +use rig::{ + chain::{self, Chain}, + embeddings::EmbeddingsBuilder, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, +}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create OpenAI client + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_client = Client::new(&openai_api_key); + + let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + + // Create vector store, compute embeddings and load them in the store + let mut vector_store = InMemoryVectorStore::default(); + + let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) + .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .build() + .await?; + + vector_store.add_documents(embeddings).await?; + + // Create vector store index + let index = vector_store.index(embedding_model); + + let agent = openai_client.agent("gpt-4") + .preamble(" + You are a dictionary assistant here to assist the user in understanding the meaning of words. + ") + .build(); + + let chain = chain::new() + // Retrieve top document from the index and return it with the prompt + .lookup(index, 2) + // Format the prompt with the context documents + .map(|(query, docs): (_, Vec)| { + format!( + "User question: {}\n\nAdditional word definitions:\n{}", + query, + docs.join("\n") + ) + }) + // Prompt the agent + .prompt(&agent); + + // Prompt the agent and print the response + let response = chain + .call("What does \"glarb-glarb\" mean?".to_string()) + .await; + + println!("{}", response); + + Ok(()) +} diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 2cbe5fbd..5cbac1cc 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -251,6 +251,12 @@ impl Prompt for Agent { } } +impl Prompt for &Agent { + async fn prompt(&self, prompt: &str) -> Result { + self.chat(prompt, vec![]).await + } +} + impl Chat for Agent { async fn chat(&self, prompt: &str, chat_history: Vec) -> Result { match self.completion(prompt, chat_history).await?.send().await? { diff --git a/rig-core/src/chain.rs b/rig-core/src/chain.rs new file mode 100644 index 00000000..3a5a0913 --- /dev/null +++ b/rig-core/src/chain.rs @@ -0,0 +1,320 @@ +use std::{future::Future, marker::PhantomData}; + +use crate::{completion, vector_store}; + +pub trait Chain: Send + Sync { + type Input: Send; + type Output; + + fn call(self, input: Self::Input) -> impl std::future::Future + Send; + + /// Chain a function to the output of the current chain + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .chain(|email: String| async move { + /// email.split('@').next().unwrap().to_string() + /// }) + /// .chain(|username: String| async move { + /// format!("Hello, {}!", username) + /// }); + /// + /// let result = chain.call("bob@gmail.com".to_string()).await; + /// assert_eq!(result, "Hello, bob!"); + /// ``` + fn chain(self, f: F) -> Chained + where + F: Fn(Self::Output) -> Fut, + Fut: Future, + Self: Sized, + { + Chained::new(self, f) + } + + // fn try_chain(self, f: F) -> TryChained + // where + // F: Fn(Self::Output) -> Fut, + // Fut: TryFuture, + // Self: Sized + // { + // TryChained::new(self, f) + // } + + /// Same as `chain` but for synchronous functions + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .map(|(x, y)| x + y) + /// .map(|z| format!("Result: {z}!")); + /// + /// let result = chain.call((1, 2)).await; + /// assert_eq!(result, "Result: 3!"); + /// ``` + fn map(self, f: F) -> Map + where + F: Fn(Self::Output) -> T, + Self: Sized, + { + Map::new(self, f) + } + + /// Chain a lookup operation to the current chain. The lookup operation expects the + /// current chain to output a query string. The lookup operation will use the query to + /// retrieve the top `n` documents from the index and return them with the query string. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .lookup(index, 2) + /// .chain(|(query, docs): (_, Vec)| async move { + /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) + /// }); + /// + /// let result = chain.call("What is a flurbo?".to_string()).await; + /// ``` + fn lookup(self, index: I, n: usize) -> Lookup + where + Self::Output: Into, + Self: Sized, + { + Lookup::new(self, index, n) + } + + /// Chain a prompt operation to the current chain. The prompt operation expects the + /// current chain to output a string. The prompt operation will use the string to prompt + /// the given agent (or any other type that implements the `Prompt` trait) and return + /// the response. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let agent = &openai_client.agent("gpt-4").build(); + /// + /// let chain = chain::new() + /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) + /// .prompt(agent); + /// + /// let result = chain.call("Alice".to_string()).await; + /// ``` + fn prompt

(self, prompt: P) -> Prompt + where + P: completion::Prompt, + Self::Output: Into, + Self: Sized, + { + Prompt::new(self, prompt) + } +} + +// pub struct Root { +// f: F, +// _phantom: PhantomData, +// } + +// impl Root +// where +// F: FnOnce(A) -> Fut, +// Fut: Future, +// { +// pub fn new(f: F) -> Self { +// Self { +// f, +// _phantom: PhantomData, +// } +// } +// } + +// impl Chain for Root +// where +// F: FnOnce(A) -> Fut + Send + Sync, +// Fut: Future + Send + Sync, +// { +// type Input = A; +// type Output = Fut::Output; + +// async fn call(self, input: Self::Input) -> Self::Output { +// (self.f)(input).await +// } +// } + +pub struct Empty(PhantomData); + +impl Chain for Empty { + type Input = T; + type Output = T; + + async fn call(self, value: Self::Input) -> Self::Output { + value + } +} + +pub struct Chained { + chain: Ch, + f: F, +} + +impl Chained { + fn new(chain: Ch, f: F) -> Self { + Self { chain, f } + } +} + +impl Chain for Chained +where + Ch: Chain, + F: Fn(Ch::Output) -> Fut + Send + Sync, + Fut: Future + Send, +{ + type Input = Ch::Input; + type Output = Fut::Output; + + async fn call(self, input: Self::Input) -> Self::Output { + let output = self.chain.call(input).await; + (self.f)(output).await + } +} + +// pub struct TryChained { +// chain: Ch, +// f: F, +// } + +// impl TryChained { +// pub fn new(chain: Ch, f: F) -> Self { +// Self { +// chain, +// f, +// } +// } +// } + +// impl Chain for TryChained +// where +// Ch: Chain, +// F: Fn(Ch::Output) -> Fut + Send + Sync, +// Fut: Future>, +// { +// type Input = Ch::Input; +// type Output = Fut::Output; + +// async fn call(self, input: Self::Input) -> Self::Output { +// let output = self.chain.call(input).await; +// (self.f)(output).await +// } +// } + +pub struct Map { + chain: Ch, + f: F, +} + +impl Map { + fn new(chain: Ch, f: F) -> Self { + Self { chain, f } + } +} + +impl Chain for Map +where + Ch: Chain, + F: Fn(Ch::Output) -> T + Send + Sync, + T: Send, +{ + type Input = Ch::Input; + type Output = T; + + async fn call(self, input: Self::Input) -> Self::Output { + let output = self.chain.call(input).await; + (self.f)(output) + } +} + +pub struct Lookup { + chain: Ch, + index: I, + n: usize, + _t: PhantomData, +} + +impl Lookup +where + Ch: Chain, +{ + pub fn new(chain: Ch, index: I, n: usize) -> Self { + Self { + chain, + index, + n, + _t: PhantomData, + } + } +} + +impl Chain for Lookup +where + I: vector_store::VectorStoreIndex, + Ch: Chain, + Ch::Output: Into, + T: Send + Sync + for<'a> serde::Deserialize<'a>, +{ + type Input = Ch::Input; + type Output = (String, Vec); + + async fn call(self, input: Self::Input) -> Self::Output { + let query = self.chain.call(input).await.into(); + + let docs = self + .index + .top_n::(&query, self.n) + .await + .expect("Failed to get top n documents") + .into_iter() + .map(|(_, _, doc)| doc) + .collect(); + + (query, docs) + } +} + +pub struct Prompt { + chain: Ch, + prompt: P, +} + +impl Prompt { + pub fn new(chain: Ch, prompt: P) -> Self { + Self { chain, prompt } + } +} + +impl Chain for Prompt +where + Ch: Chain, + Ch::Output: Into, + P: completion::Prompt, +{ + type Input = Ch::Input; + type Output = String; + + async fn call(self, input: Self::Input) -> Self::Output { + let output = self.chain.call(input).await.into(); + + self.prompt + .prompt(&output) + .await + .expect("Failed to prompt agent") + } +} + +pub fn new() -> Empty { + Empty(PhantomData) +} diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 9da3abfc..368c9a5b 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -79,6 +79,7 @@ //! implement the [VectorStoreIndex](crate::vector_store::VectorStoreIndex) trait. pub mod agent; +pub mod chain; pub mod cli_chatbot; pub mod completion; pub mod embeddings; From f3c81b361f859c55c164b9e2134ad70d2cc5ed4a Mon Sep 17 00:00:00 2001 From: Christophe Date: Tue, 19 Nov 2024 11:03:07 -0500 Subject: [PATCH 02/26] feat: Add chain error handling --- rig-core/examples/chain.rs | 10 +- rig-core/src/{ => chain}/chain.rs | 44 ++-- rig-core/src/chain/mod.rs | 24 ++ rig-core/src/chain/try_chain.rs | 382 ++++++++++++++++++++++++++++++ 4 files changed, 432 insertions(+), 28 deletions(-) rename rig-core/src/{ => chain}/chain.rs (88%) create mode 100644 rig-core/src/chain/mod.rs create mode 100644 rig-core/src/chain/try_chain.rs diff --git a/rig-core/examples/chain.rs b/rig-core/examples/chain.rs index 24516bf4..0d2441ec 100644 --- a/rig-core/examples/chain.rs +++ b/rig-core/examples/chain.rs @@ -1,7 +1,7 @@ use std::env; use rig::{ - chain::{self, Chain}, + chain::{self, TryChain}, embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, @@ -40,7 +40,7 @@ async fn main() -> Result<(), anyhow::Error> { // Retrieve top document from the index and return it with the prompt .lookup(index, 2) // Format the prompt with the context documents - .map(|(query, docs): (_, Vec)| { + .map_ok(|(query, docs): (_, Vec)| { format!( "User question: {}\n\nAdditional word definitions:\n{}", query, @@ -51,11 +51,9 @@ async fn main() -> Result<(), anyhow::Error> { .prompt(&agent); // Prompt the agent and print the response - let response = chain - .call("What does \"glarb-glarb\" mean?".to_string()) - .await; + let response = chain.try_call("What does \"glarb-glarb\" mean?").await?; - println!("{}", response); + println!("{:?}", response); Ok(()) } diff --git a/rig-core/src/chain.rs b/rig-core/src/chain/chain.rs similarity index 88% rename from rig-core/src/chain.rs rename to rig-core/src/chain/chain.rs index 3a5a0913..dd77c6fa 100644 --- a/rig-core/src/chain.rs +++ b/rig-core/src/chain/chain.rs @@ -1,12 +1,13 @@ -use std::{future::Future, marker::PhantomData}; +use futures::Future; +use std::marker::PhantomData; use crate::{completion, vector_store}; pub trait Chain: Send + Sync { type Input: Send; - type Output; + type Output: Send; - fn call(self, input: Self::Input) -> impl std::future::Future + Send; + fn call(&self, input: Self::Input) -> impl std::future::Future + Send; /// Chain a function to the output of the current chain /// @@ -148,11 +149,17 @@ pub trait Chain: Send + Sync { pub struct Empty(PhantomData); +impl Default for Empty { + fn default() -> Self { + Self(PhantomData) + } +} + impl Chain for Empty { type Input = T; type Output = T; - async fn call(self, value: Self::Input) -> Self::Output { + async fn call(&self, value: Self::Input) -> Self::Output { value } } @@ -173,11 +180,12 @@ where Ch: Chain, F: Fn(Ch::Output) -> Fut + Send + Sync, Fut: Future + Send, + Fut::Output: Send, { type Input = Ch::Input; type Output = Fut::Output; - async fn call(self, input: Self::Input) -> Self::Output { + async fn call(&self, input: Self::Input) -> Self::Output { let output = self.chain.call(input).await; (self.f)(output).await } @@ -218,7 +226,7 @@ pub struct Map { } impl Map { - fn new(chain: Ch, f: F) -> Self { + pub(crate) fn new(chain: Ch, f: F) -> Self { Self { chain, f } } } @@ -232,7 +240,7 @@ where type Input = Ch::Input; type Output = T; - async fn call(self, input: Self::Input) -> Self::Output { + async fn call(&self, input: Self::Input) -> Self::Output { let output = self.chain.call(input).await; (self.f)(output) } @@ -267,21 +275,20 @@ where T: Send + Sync + for<'a> serde::Deserialize<'a>, { type Input = Ch::Input; - type Output = (String, Vec); + type Output = Result<(String, Vec), vector_store::VectorStoreError>; - async fn call(self, input: Self::Input) -> Self::Output { + async fn call(&self, input: Self::Input) -> Self::Output { let query = self.chain.call(input).await.into(); let docs = self .index .top_n::(&query, self.n) - .await - .expect("Failed to get top n documents") + .await? .into_iter() .map(|(_, _, doc)| doc) .collect(); - (query, docs) + Ok((query, docs)) } } @@ -303,18 +310,11 @@ where P: completion::Prompt, { type Input = Ch::Input; - type Output = String; + type Output = Result; - async fn call(self, input: Self::Input) -> Self::Output { + async fn call(&self, input: Self::Input) -> Self::Output { let output = self.chain.call(input).await.into(); - self.prompt - .prompt(&output) - .await - .expect("Failed to prompt agent") + self.prompt.prompt(&output).await } } - -pub fn new() -> Empty { - Empty(PhantomData) -} diff --git a/rig-core/src/chain/mod.rs b/rig-core/src/chain/mod.rs new file mode 100644 index 00000000..4f72da77 --- /dev/null +++ b/rig-core/src/chain/mod.rs @@ -0,0 +1,24 @@ +pub mod chain; +pub mod try_chain; + +pub use chain::Chain; +pub use try_chain::{Empty, TryChain}; + +use crate::{completion, vector_store}; + +#[derive(Debug, thiserror::Error)] +pub enum ChainError { + #[error("Failed to prompt agent: {0}")] + PromptError(#[from] completion::PromptError), + + #[error("Failed to lookup documents: {0}")] + LookupError(#[from] vector_store::VectorStoreError), +} + +pub fn new() -> Empty { + Empty::default() +} + +pub fn with_error() -> Empty { + Empty::default() +} diff --git a/rig-core/src/chain/try_chain.rs b/rig-core/src/chain/try_chain.rs new file mode 100644 index 00000000..050c82c9 --- /dev/null +++ b/rig-core/src/chain/try_chain.rs @@ -0,0 +1,382 @@ +// use std::{marker::PhantomData}; + +use std::marker::PhantomData; + +use futures::{Future, FutureExt, TryFutureExt}; + +use crate::{completion, vector_store}; + +use super::Chain; + +pub trait TryChain: Send + Sync { + type Input: Send; + type Output: Send; + type Error; + + fn try_call( + &self, + input: Self::Input, + ) -> impl Future> + Send; + + /// Chain a function to the output of the current chain + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .chain(|email: String| async move { + /// email.split('@').next().unwrap().to_string() + /// }) + /// .chain(|username: String| async move { + /// format!("Hello, {}!", username) + /// }); + /// + /// let result = chain.call("bob@gmail.com".to_string()).await; + /// assert_eq!(result, "Hello, bob!"); + /// ``` + fn chain_ok(self, f: F) -> ChainOk + where + F: Fn(Self::Output) -> Fut, + Fut: Future, + Self: Sized, + { + ChainOk::new(self, f) + } + + /// Same as `chain` but for synchronous functions + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .map(|(x, y)| x + y) + /// .map(|z| format!("Result: {z}!")); + /// + /// let result = chain.call((1, 2)).await; + /// assert_eq!(result, "Result: 3!"); + /// ``` + fn map_ok(self, f: F) -> MapOk + where + F: Fn(Self::Output) -> T, + Self: Sized, + { + MapOk::new(self, f) + } + + fn map_err(self, f: F) -> MapErr + where + F: Fn(Self::Error) -> E, + Self: Sized, + { + MapErr::new(self, f) + } + + /// Chain a lookup operation to the current chain. The lookup operation expects the + /// current chain to output a query string. The lookup operation will use the query to + /// retrieve the top `n` documents from the index and return them with the query string. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .lookup(index, 2) + /// .chain(|(query, docs): (_, Vec)| async move { + /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) + /// }); + /// + /// let result = chain.call("What is a flurbo?".to_string()).await; + /// ``` + fn lookup(self, index: I, n: usize) -> LookupOk + where + I: vector_store::VectorStoreIndex, + Self::Output: Into, + Self::Error: From, + Self: Sized, + { + LookupOk::new(self, index, n) + } + + /// Chain a prompt operation to the current chain. The prompt operation expects the + /// current chain to output a string. The prompt operation will use the string to prompt + /// the given agent (or any other type that implements the `Prompt` trait) and return + /// the response. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let agent = &openai_client.agent("gpt-4").build(); + /// + /// let chain = chain::new() + /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) + /// .prompt(agent); + /// + /// let result = chain.call("Alice".to_string()).await; + /// ``` + fn prompt

(self, prompt: P) -> PromptOk + where + P: completion::Prompt, + Self::Output: Into, + Self::Error: From, + Self: Sized, + { + PromptOk::new(self, prompt) + } +} + +// #[derive(Debug, thiserror::Error)] +// pub enum ChainError { +// #[error("Failed to prompt agent: {0}")] +// PromptError(#[from] completion::PromptError), + +// #[error("Failed to lookup documents: {0}")] +// LookupError(#[from] vector_store::VectorStoreError), + +// #[error("Failed to chain operation: {0}")] +// ChainError(#[from] Box), +// } + +impl TryChain for Ch +where + Ch: Chain>, + In: Send, + Out: Send, +{ + type Input = In; + type Output = Out; + type Error = E; + + async fn try_call(&self, input: Self::Input) -> Result { + self.call(input).await + } +} + +// pub struct Empty(PhantomData); + +// impl Chain for Empty { +// type Input = T; +// type Output = T; + +// async fn call(&self, value: Self::Input) -> Self::Output { +// value +// } +// } + +// pub struct TryChained { +// chain: Ch, +// f: F, +// } + +// impl TryChained { +// fn new(chain: Ch, f: F) -> Self { +// Self { chain, f } +// } +// } + +// impl Chain for TryChained +// where +// Ch: Chain, +// F: Fn(Ch::Output) -> Fut + Send + Sync, +// Fut: Future + Send, +// { +// type Input = Ch::Input; +// type Output = Fut::Output; + +// async fn call(&self, input: Self::Input) -> Self::Output { +// let output = self.chain.call(input).await; +// (self.f)(output).await +// } +// } + +pub struct ChainOk { + chain: Ch, + f: F, +} + +impl ChainOk { + pub fn new(chain: Ch, f: F) -> Self { + Self { chain, f } + } +} + +impl TryChain for ChainOk +where + Ch: TryChain, + F: Fn(Ch::Output) -> Fut + Send + Sync, + Fut: Future + Send, + Fut::Output: Send, +{ + type Input = Ch::Input; + type Output = Fut::Output; + type Error = Ch::Error; + + async fn try_call(&self, input: Self::Input) -> Result { + self.chain + .try_call(input) + .and_then(|value| (self.f)(value).map(Ok)) + .await + } +} + +pub struct MapOk { + chain: Ch, + f: F, +} + +impl MapOk { + fn new(chain: Ch, f: F) -> Self { + Self { chain, f } + } +} + +impl TryChain for MapOk +where + Ch: TryChain, + F: Fn(Ch::Output) -> T + Send + Sync, + T: Send, +{ + type Input = Ch::Input; + type Output = T; + type Error = Ch::Error; + + async fn try_call(&self, input: Self::Input) -> Result { + self.chain + .try_call(input) + .map_ok(|value| (self.f)(value)) + .await + } +} + +pub struct MapErr { + chain: Ch, + f: F, +} + +impl MapErr { + fn new(chain: Ch, f: F) -> Self { + Self { chain, f } + } +} + +impl TryChain for MapErr +where + Ch: TryChain, + F: Fn(Ch::Error) -> E + Send + Sync, + E: Send, +{ + type Input = Ch::Input; + type Output = Ch::Output; + type Error = E; + + async fn try_call(&self, input: Self::Input) -> Result { + self.chain + .try_call(input) + .map_err(|error| (self.f)(error)) + .await + } +} + +pub struct LookupOk { + chain: Ch, + index: I, + n: usize, + _t: PhantomData, +} + +impl LookupOk { + pub fn new(chain: Ch, index: I, n: usize) -> Self { + Self { + chain, + index, + n, + _t: PhantomData, + } + } +} + +impl TryChain for LookupOk +where + I: vector_store::VectorStoreIndex, + Ch: TryChain, + Ch::Output: Into, + Ch::Error: From, + T: Send + Sync + for<'a> serde::Deserialize<'a>, +{ + type Input = Ch::Input; + type Output = (String, Vec); + type Error = Ch::Error; + + async fn try_call(&self, input: Self::Input) -> Result { + self.chain + .try_call(input) + .and_then(|query| async { + let query: String = query.into(); + + let docs = self + .index + .top_n::(&query, self.n) + .await? + .into_iter() + .map(|(_, _, doc)| doc) + .collect(); + + Ok((query, docs)) + }) + .await + } +} + +pub struct PromptOk { + chain: Ch, + prompt: P, +} + +impl PromptOk { + pub fn new(chain: Ch, prompt: P) -> Self { + Self { chain, prompt } + } +} + +impl TryChain for PromptOk +where + Ch: TryChain, + Ch::Output: Into, + Ch::Error: From, + P: completion::Prompt, +{ + type Input = Ch::Input; + type Output = String; + type Error = Ch::Error; + + async fn try_call(&self, input: Self::Input) -> Result { + self.chain + .try_call(input) + .and_then(|prompt| async { + let prompt: String = prompt.into(); + + Ok(self.prompt.prompt(&prompt).await?) + }) + .await + } +} + +pub struct Empty(PhantomData<(T, E)>); + +impl Default for Empty { + fn default() -> Self { + Self(PhantomData) + } +} + +impl TryChain for Empty { + type Input = T; + type Output = T; + type Error = E; + + async fn try_call(&self, value: Self::Input) -> Result { + Ok(value) + } +} From 1a59d779c3d77deef673629cd5700928bb1e10c9 Mon Sep 17 00:00:00 2001 From: Christophe Date: Thu, 21 Nov 2024 17:11:31 -0500 Subject: [PATCH 03/26] feat: Add parallel ops and rename `chain` to `pipeline` --- rig-core/examples/chain.rs | 83 ++++-- rig-core/src/chain/chain.rs | 320 ----------------------- rig-core/src/chain/mod.rs | 24 -- rig-core/src/chain/try_chain.rs | 382 ---------------------------- rig-core/src/lib.rs | 2 +- rig-core/src/pipeline/agent_ops.rs | 161 ++++++++++++ rig-core/src/pipeline/mod.rs | 264 +++++++++++++++++++ rig-core/src/pipeline/op.rs | 374 +++++++++++++++++++++++++++ rig-core/src/pipeline/parallel.rs | 263 +++++++++++++++++++ rig-core/src/pipeline/try_op.rs | 391 +++++++++++++++++++++++++++++ 10 files changed, 1523 insertions(+), 741 deletions(-) delete mode 100644 rig-core/src/chain/chain.rs delete mode 100644 rig-core/src/chain/mod.rs delete mode 100644 rig-core/src/chain/try_chain.rs create mode 100644 rig-core/src/pipeline/agent_ops.rs create mode 100644 rig-core/src/pipeline/mod.rs create mode 100644 rig-core/src/pipeline/op.rs create mode 100644 rig-core/src/pipeline/parallel.rs create mode 100644 rig-core/src/pipeline/try_op.rs diff --git a/rig-core/examples/chain.rs b/rig-core/examples/chain.rs index 0d2441ec..ed7e6c2f 100644 --- a/rig-core/examples/chain.rs +++ b/rig-core/examples/chain.rs @@ -1,8 +1,9 @@ use std::env; use rig::{ - chain::{self, TryChain}, embeddings::EmbeddingsBuilder, + parallel, + pipeline::{self, agent_ops::lookup, passthrough, Op}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; @@ -36,24 +37,78 @@ async fn main() -> Result<(), anyhow::Error> { ") .build(); - let chain = chain::new() - // Retrieve top document from the index and return it with the prompt - .lookup(index, 2) - // Format the prompt with the context documents - .map_ok(|(query, docs): (_, Vec)| { - format!( - "User question: {}\n\nAdditional word definitions:\n{}", - query, - docs.join("\n") - ) + let chain = pipeline::new() + // Chain a parallel operation to the current chain. The parallel operation will + // perform a lookup operation to retrieve additional context from the user prompt + // while simultaneously applying a passthrough operation. The latter will allow + // us to forward the initial prompt to the next operation in the chain. + .chain(parallel!( + passthrough(), + lookup::<_, _, String>(index, 1), // Required to specify document type + )) + // Chain a "map" operation to the current chain, which will combine the user + // prompt with the retrieved context documents to create the final prompt. + // If an error occurs during the lookup operation, we will log the error and + // simply return the initial prompt. + .map(|(prompt, maybe_docs)| match maybe_docs { + Ok(docs) => format!( + "Non standard word definitions:\n{}\n\n{}", + docs.join("\n"), + prompt, + ), + Err(err) => { + println!("Error: {}! Prompting without additional context", err); + format!("{prompt}") + } }) - // Prompt the agent - .prompt(&agent); + // Chain a "prompt" operation which will prompt out agent with the final prompt + .prompt(agent); // Prompt the agent and print the response - let response = chain.try_call("What does \"glarb-glarb\" mean?").await?; + let response = chain.call("What does \"glarb-glarb\" mean?").await?; println!("{:?}", response); Ok(()) } + +// trait Foo { +// fn foo(&self); +// } + +// impl Foo<(T,)> for F +// where +// F: Fn(T) -> Out, +// { +// fn foo(&self) { +// todo!() +// } +// } + +// impl Foo<(T1, T2)> for F +// where +// F: Fn(T1, T2) -> Out, +// { +// fn foo(&self) { +// todo!() +// } +// } + +// impl Foo<(T1, T2, T3)> for F +// where +// F: Fn(T1, T2, T3) -> Out, +// { +// fn foo(&self) { +// todo!() +// } +// } + +// impl Foo<((Fut, T,),)> for F +// where +// F: Fn(T) -> Fut, +// Fut: Future, +// { +// fn foo(&self) { +// todo!() +// } +// } diff --git a/rig-core/src/chain/chain.rs b/rig-core/src/chain/chain.rs deleted file mode 100644 index dd77c6fa..00000000 --- a/rig-core/src/chain/chain.rs +++ /dev/null @@ -1,320 +0,0 @@ -use futures::Future; -use std::marker::PhantomData; - -use crate::{completion, vector_store}; - -pub trait Chain: Send + Sync { - type Input: Send; - type Output: Send; - - fn call(&self, input: Self::Input) -> impl std::future::Future + Send; - - /// Chain a function to the output of the current chain - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let chain = chain::new() - /// .chain(|email: String| async move { - /// email.split('@').next().unwrap().to_string() - /// }) - /// .chain(|username: String| async move { - /// format!("Hello, {}!", username) - /// }); - /// - /// let result = chain.call("bob@gmail.com".to_string()).await; - /// assert_eq!(result, "Hello, bob!"); - /// ``` - fn chain(self, f: F) -> Chained - where - F: Fn(Self::Output) -> Fut, - Fut: Future, - Self: Sized, - { - Chained::new(self, f) - } - - // fn try_chain(self, f: F) -> TryChained - // where - // F: Fn(Self::Output) -> Fut, - // Fut: TryFuture, - // Self: Sized - // { - // TryChained::new(self, f) - // } - - /// Same as `chain` but for synchronous functions - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let chain = chain::new() - /// .map(|(x, y)| x + y) - /// .map(|z| format!("Result: {z}!")); - /// - /// let result = chain.call((1, 2)).await; - /// assert_eq!(result, "Result: 3!"); - /// ``` - fn map(self, f: F) -> Map - where - F: Fn(Self::Output) -> T, - Self: Sized, - { - Map::new(self, f) - } - - /// Chain a lookup operation to the current chain. The lookup operation expects the - /// current chain to output a query string. The lookup operation will use the query to - /// retrieve the top `n` documents from the index and return them with the query string. - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let chain = chain::new() - /// .lookup(index, 2) - /// .chain(|(query, docs): (_, Vec)| async move { - /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) - /// }); - /// - /// let result = chain.call("What is a flurbo?".to_string()).await; - /// ``` - fn lookup(self, index: I, n: usize) -> Lookup - where - Self::Output: Into, - Self: Sized, - { - Lookup::new(self, index, n) - } - - /// Chain a prompt operation to the current chain. The prompt operation expects the - /// current chain to output a string. The prompt operation will use the string to prompt - /// the given agent (or any other type that implements the `Prompt` trait) and return - /// the response. - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let agent = &openai_client.agent("gpt-4").build(); - /// - /// let chain = chain::new() - /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) - /// .prompt(agent); - /// - /// let result = chain.call("Alice".to_string()).await; - /// ``` - fn prompt

(self, prompt: P) -> Prompt - where - P: completion::Prompt, - Self::Output: Into, - Self: Sized, - { - Prompt::new(self, prompt) - } -} - -// pub struct Root { -// f: F, -// _phantom: PhantomData, -// } - -// impl Root -// where -// F: FnOnce(A) -> Fut, -// Fut: Future, -// { -// pub fn new(f: F) -> Self { -// Self { -// f, -// _phantom: PhantomData, -// } -// } -// } - -// impl Chain for Root -// where -// F: FnOnce(A) -> Fut + Send + Sync, -// Fut: Future + Send + Sync, -// { -// type Input = A; -// type Output = Fut::Output; - -// async fn call(self, input: Self::Input) -> Self::Output { -// (self.f)(input).await -// } -// } - -pub struct Empty(PhantomData); - -impl Default for Empty { - fn default() -> Self { - Self(PhantomData) - } -} - -impl Chain for Empty { - type Input = T; - type Output = T; - - async fn call(&self, value: Self::Input) -> Self::Output { - value - } -} - -pub struct Chained { - chain: Ch, - f: F, -} - -impl Chained { - fn new(chain: Ch, f: F) -> Self { - Self { chain, f } - } -} - -impl Chain for Chained -where - Ch: Chain, - F: Fn(Ch::Output) -> Fut + Send + Sync, - Fut: Future + Send, - Fut::Output: Send, -{ - type Input = Ch::Input; - type Output = Fut::Output; - - async fn call(&self, input: Self::Input) -> Self::Output { - let output = self.chain.call(input).await; - (self.f)(output).await - } -} - -// pub struct TryChained { -// chain: Ch, -// f: F, -// } - -// impl TryChained { -// pub fn new(chain: Ch, f: F) -> Self { -// Self { -// chain, -// f, -// } -// } -// } - -// impl Chain for TryChained -// where -// Ch: Chain, -// F: Fn(Ch::Output) -> Fut + Send + Sync, -// Fut: Future>, -// { -// type Input = Ch::Input; -// type Output = Fut::Output; - -// async fn call(self, input: Self::Input) -> Self::Output { -// let output = self.chain.call(input).await; -// (self.f)(output).await -// } -// } - -pub struct Map { - chain: Ch, - f: F, -} - -impl Map { - pub(crate) fn new(chain: Ch, f: F) -> Self { - Self { chain, f } - } -} - -impl Chain for Map -where - Ch: Chain, - F: Fn(Ch::Output) -> T + Send + Sync, - T: Send, -{ - type Input = Ch::Input; - type Output = T; - - async fn call(&self, input: Self::Input) -> Self::Output { - let output = self.chain.call(input).await; - (self.f)(output) - } -} - -pub struct Lookup { - chain: Ch, - index: I, - n: usize, - _t: PhantomData, -} - -impl Lookup -where - Ch: Chain, -{ - pub fn new(chain: Ch, index: I, n: usize) -> Self { - Self { - chain, - index, - n, - _t: PhantomData, - } - } -} - -impl Chain for Lookup -where - I: vector_store::VectorStoreIndex, - Ch: Chain, - Ch::Output: Into, - T: Send + Sync + for<'a> serde::Deserialize<'a>, -{ - type Input = Ch::Input; - type Output = Result<(String, Vec), vector_store::VectorStoreError>; - - async fn call(&self, input: Self::Input) -> Self::Output { - let query = self.chain.call(input).await.into(); - - let docs = self - .index - .top_n::(&query, self.n) - .await? - .into_iter() - .map(|(_, _, doc)| doc) - .collect(); - - Ok((query, docs)) - } -} - -pub struct Prompt { - chain: Ch, - prompt: P, -} - -impl Prompt { - pub fn new(chain: Ch, prompt: P) -> Self { - Self { chain, prompt } - } -} - -impl Chain for Prompt -where - Ch: Chain, - Ch::Output: Into, - P: completion::Prompt, -{ - type Input = Ch::Input; - type Output = Result; - - async fn call(&self, input: Self::Input) -> Self::Output { - let output = self.chain.call(input).await.into(); - - self.prompt.prompt(&output).await - } -} diff --git a/rig-core/src/chain/mod.rs b/rig-core/src/chain/mod.rs deleted file mode 100644 index 4f72da77..00000000 --- a/rig-core/src/chain/mod.rs +++ /dev/null @@ -1,24 +0,0 @@ -pub mod chain; -pub mod try_chain; - -pub use chain::Chain; -pub use try_chain::{Empty, TryChain}; - -use crate::{completion, vector_store}; - -#[derive(Debug, thiserror::Error)] -pub enum ChainError { - #[error("Failed to prompt agent: {0}")] - PromptError(#[from] completion::PromptError), - - #[error("Failed to lookup documents: {0}")] - LookupError(#[from] vector_store::VectorStoreError), -} - -pub fn new() -> Empty { - Empty::default() -} - -pub fn with_error() -> Empty { - Empty::default() -} diff --git a/rig-core/src/chain/try_chain.rs b/rig-core/src/chain/try_chain.rs deleted file mode 100644 index 050c82c9..00000000 --- a/rig-core/src/chain/try_chain.rs +++ /dev/null @@ -1,382 +0,0 @@ -// use std::{marker::PhantomData}; - -use std::marker::PhantomData; - -use futures::{Future, FutureExt, TryFutureExt}; - -use crate::{completion, vector_store}; - -use super::Chain; - -pub trait TryChain: Send + Sync { - type Input: Send; - type Output: Send; - type Error; - - fn try_call( - &self, - input: Self::Input, - ) -> impl Future> + Send; - - /// Chain a function to the output of the current chain - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let chain = chain::new() - /// .chain(|email: String| async move { - /// email.split('@').next().unwrap().to_string() - /// }) - /// .chain(|username: String| async move { - /// format!("Hello, {}!", username) - /// }); - /// - /// let result = chain.call("bob@gmail.com".to_string()).await; - /// assert_eq!(result, "Hello, bob!"); - /// ``` - fn chain_ok(self, f: F) -> ChainOk - where - F: Fn(Self::Output) -> Fut, - Fut: Future, - Self: Sized, - { - ChainOk::new(self, f) - } - - /// Same as `chain` but for synchronous functions - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let chain = chain::new() - /// .map(|(x, y)| x + y) - /// .map(|z| format!("Result: {z}!")); - /// - /// let result = chain.call((1, 2)).await; - /// assert_eq!(result, "Result: 3!"); - /// ``` - fn map_ok(self, f: F) -> MapOk - where - F: Fn(Self::Output) -> T, - Self: Sized, - { - MapOk::new(self, f) - } - - fn map_err(self, f: F) -> MapErr - where - F: Fn(Self::Error) -> E, - Self: Sized, - { - MapErr::new(self, f) - } - - /// Chain a lookup operation to the current chain. The lookup operation expects the - /// current chain to output a query string. The lookup operation will use the query to - /// retrieve the top `n` documents from the index and return them with the query string. - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let chain = chain::new() - /// .lookup(index, 2) - /// .chain(|(query, docs): (_, Vec)| async move { - /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) - /// }); - /// - /// let result = chain.call("What is a flurbo?".to_string()).await; - /// ``` - fn lookup(self, index: I, n: usize) -> LookupOk - where - I: vector_store::VectorStoreIndex, - Self::Output: Into, - Self::Error: From, - Self: Sized, - { - LookupOk::new(self, index, n) - } - - /// Chain a prompt operation to the current chain. The prompt operation expects the - /// current chain to output a string. The prompt operation will use the string to prompt - /// the given agent (or any other type that implements the `Prompt` trait) and return - /// the response. - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let agent = &openai_client.agent("gpt-4").build(); - /// - /// let chain = chain::new() - /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) - /// .prompt(agent); - /// - /// let result = chain.call("Alice".to_string()).await; - /// ``` - fn prompt

(self, prompt: P) -> PromptOk - where - P: completion::Prompt, - Self::Output: Into, - Self::Error: From, - Self: Sized, - { - PromptOk::new(self, prompt) - } -} - -// #[derive(Debug, thiserror::Error)] -// pub enum ChainError { -// #[error("Failed to prompt agent: {0}")] -// PromptError(#[from] completion::PromptError), - -// #[error("Failed to lookup documents: {0}")] -// LookupError(#[from] vector_store::VectorStoreError), - -// #[error("Failed to chain operation: {0}")] -// ChainError(#[from] Box), -// } - -impl TryChain for Ch -where - Ch: Chain>, - In: Send, - Out: Send, -{ - type Input = In; - type Output = Out; - type Error = E; - - async fn try_call(&self, input: Self::Input) -> Result { - self.call(input).await - } -} - -// pub struct Empty(PhantomData); - -// impl Chain for Empty { -// type Input = T; -// type Output = T; - -// async fn call(&self, value: Self::Input) -> Self::Output { -// value -// } -// } - -// pub struct TryChained { -// chain: Ch, -// f: F, -// } - -// impl TryChained { -// fn new(chain: Ch, f: F) -> Self { -// Self { chain, f } -// } -// } - -// impl Chain for TryChained -// where -// Ch: Chain, -// F: Fn(Ch::Output) -> Fut + Send + Sync, -// Fut: Future + Send, -// { -// type Input = Ch::Input; -// type Output = Fut::Output; - -// async fn call(&self, input: Self::Input) -> Self::Output { -// let output = self.chain.call(input).await; -// (self.f)(output).await -// } -// } - -pub struct ChainOk { - chain: Ch, - f: F, -} - -impl ChainOk { - pub fn new(chain: Ch, f: F) -> Self { - Self { chain, f } - } -} - -impl TryChain for ChainOk -where - Ch: TryChain, - F: Fn(Ch::Output) -> Fut + Send + Sync, - Fut: Future + Send, - Fut::Output: Send, -{ - type Input = Ch::Input; - type Output = Fut::Output; - type Error = Ch::Error; - - async fn try_call(&self, input: Self::Input) -> Result { - self.chain - .try_call(input) - .and_then(|value| (self.f)(value).map(Ok)) - .await - } -} - -pub struct MapOk { - chain: Ch, - f: F, -} - -impl MapOk { - fn new(chain: Ch, f: F) -> Self { - Self { chain, f } - } -} - -impl TryChain for MapOk -where - Ch: TryChain, - F: Fn(Ch::Output) -> T + Send + Sync, - T: Send, -{ - type Input = Ch::Input; - type Output = T; - type Error = Ch::Error; - - async fn try_call(&self, input: Self::Input) -> Result { - self.chain - .try_call(input) - .map_ok(|value| (self.f)(value)) - .await - } -} - -pub struct MapErr { - chain: Ch, - f: F, -} - -impl MapErr { - fn new(chain: Ch, f: F) -> Self { - Self { chain, f } - } -} - -impl TryChain for MapErr -where - Ch: TryChain, - F: Fn(Ch::Error) -> E + Send + Sync, - E: Send, -{ - type Input = Ch::Input; - type Output = Ch::Output; - type Error = E; - - async fn try_call(&self, input: Self::Input) -> Result { - self.chain - .try_call(input) - .map_err(|error| (self.f)(error)) - .await - } -} - -pub struct LookupOk { - chain: Ch, - index: I, - n: usize, - _t: PhantomData, -} - -impl LookupOk { - pub fn new(chain: Ch, index: I, n: usize) -> Self { - Self { - chain, - index, - n, - _t: PhantomData, - } - } -} - -impl TryChain for LookupOk -where - I: vector_store::VectorStoreIndex, - Ch: TryChain, - Ch::Output: Into, - Ch::Error: From, - T: Send + Sync + for<'a> serde::Deserialize<'a>, -{ - type Input = Ch::Input; - type Output = (String, Vec); - type Error = Ch::Error; - - async fn try_call(&self, input: Self::Input) -> Result { - self.chain - .try_call(input) - .and_then(|query| async { - let query: String = query.into(); - - let docs = self - .index - .top_n::(&query, self.n) - .await? - .into_iter() - .map(|(_, _, doc)| doc) - .collect(); - - Ok((query, docs)) - }) - .await - } -} - -pub struct PromptOk { - chain: Ch, - prompt: P, -} - -impl PromptOk { - pub fn new(chain: Ch, prompt: P) -> Self { - Self { chain, prompt } - } -} - -impl TryChain for PromptOk -where - Ch: TryChain, - Ch::Output: Into, - Ch::Error: From, - P: completion::Prompt, -{ - type Input = Ch::Input; - type Output = String; - type Error = Ch::Error; - - async fn try_call(&self, input: Self::Input) -> Result { - self.chain - .try_call(input) - .and_then(|prompt| async { - let prompt: String = prompt.into(); - - Ok(self.prompt.prompt(&prompt).await?) - }) - .await - } -} - -pub struct Empty(PhantomData<(T, E)>); - -impl Default for Empty { - fn default() -> Self { - Self(PhantomData) - } -} - -impl TryChain for Empty { - type Input = T; - type Output = T; - type Error = E; - - async fn try_call(&self, value: Self::Input) -> Result { - Ok(value) - } -} diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 368c9a5b..25601238 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -79,13 +79,13 @@ //! implement the [VectorStoreIndex](crate::vector_store::VectorStoreIndex) trait. pub mod agent; -pub mod chain; pub mod cli_chatbot; pub mod completion; pub mod embeddings; pub mod extractor; pub(crate) mod json_utils; pub mod loaders; +pub mod pipeline; pub mod providers; pub mod tool; pub mod vector_store; diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs new file mode 100644 index 00000000..755bd74f --- /dev/null +++ b/rig-core/src/pipeline/agent_ops.rs @@ -0,0 +1,161 @@ +use crate::{completion, vector_store}; + +use super::Op; + +pub struct Lookup { + index: I, + n: usize, + _in: std::marker::PhantomData, + _t: std::marker::PhantomData, +} + +impl Lookup +where + I: vector_store::VectorStoreIndex, +{ + pub fn new(index: I, n: usize) -> Self { + Self { + index, + n, + _in: std::marker::PhantomData, + _t: std::marker::PhantomData, + } + } +} + +impl Op for Lookup +where + I: vector_store::VectorStoreIndex, + In: Into + Send + Sync, + T: Send + Sync + for<'a> serde::Deserialize<'a>, +{ + type Input = In; + type Output = Result, vector_store::VectorStoreError>; + + async fn call(&self, input: Self::Input) -> Self::Output { + let query: String = input.into(); + + let docs = self + .index + .top_n::(&query, self.n) + .await? + .into_iter() + .map(|(_, _, doc)| doc) + .collect(); + + Ok(docs) + } +} + +pub fn lookup(index: I, n: usize) -> Lookup +where + I: vector_store::VectorStoreIndex, + In: Into + Send + Sync, + T: Send + Sync + for<'a> serde::Deserialize<'a>, +{ + Lookup::new(index, n) +} + +pub struct Prompt { + prompt: P, + _in: std::marker::PhantomData, +} + +impl Prompt { + pub fn new(prompt: P) -> Self { + Self { + prompt, + _in: std::marker::PhantomData, + } + } +} + +impl Op for Prompt +where + P: completion::Prompt, + In: Into + Send + Sync, +{ + type Input = In; + type Output = Result; + + async fn call(&self, input: Self::Input) -> Self::Output { + let prompt: String = input.into(); + self.prompt.prompt(&prompt).await + } +} + +pub fn prompt(prompt: P) -> Prompt +where + P: completion::Prompt, + In: Into + Send + Sync, +{ + Prompt::new(prompt) +} + +#[cfg(test)] +pub mod tests { + use super::*; + use completion::{Prompt, PromptError}; + use vector_store::{VectorStoreError, VectorStoreIndex}; + + pub struct MockModel; + + impl Prompt for MockModel { + async fn prompt(&self, prompt: &str) -> Result { + Ok(format!("Mock response: {}", prompt)) + } + } + + pub struct MockIndex; + + impl VectorStoreIndex for MockIndex { + async fn top_n serde::Deserialize<'a> + std::marker::Send>( + &self, + _query: &str, + _n: usize, + ) -> Result, VectorStoreError> { + let doc = serde_json::from_value(serde_json::json!({ + "foo": "bar", + })) + .unwrap(); + + Ok(vec![(1.0, "doc1".to_string(), doc)]) + } + + async fn top_n_ids( + &self, + _query: &str, + _n: usize, + ) -> Result, VectorStoreError> { + Ok(vec![(1.0, "doc1".to_string())]) + } + } + + #[derive(Debug, serde::Deserialize, PartialEq)] + pub struct Foo { + pub foo: String, + } + + #[tokio::test] + async fn test_lookup() { + let index = MockIndex; + let lookup = lookup::(index, 1); + + let result = lookup.call("query".to_string()).await.unwrap(); + assert_eq!( + result, + vec![Foo { + foo: "bar".to_string() + }] + ); + } + + #[tokio::test] + async fn test_prompt() { + let model = MockModel; + let prompt = prompt::(model); + + let result = prompt.call("hello".to_string()).await.unwrap(); + assert_eq!(result, "Mock response: hello"); + } +} diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs new file mode 100644 index 00000000..2cefeb9d --- /dev/null +++ b/rig-core/src/pipeline/mod.rs @@ -0,0 +1,264 @@ +pub mod agent_ops; +pub mod op; +pub mod try_op; +#[macro_use] +pub mod parallel; + +use std::future::Future; + +use agent_ops::{lookup, prompt as prompt_op}; +pub use op::{map, passthrough, then, Op}; +pub use try_op::TryOp; + +use crate::{completion, vector_store}; + +pub struct PipelineBuilder { + _error: std::marker::PhantomData, +} + +impl PipelineBuilder { + /// Chain a function to the current pipeline + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// let chain = pipeline::new() + /// .map(|(x, y)| x + y) + /// .map(|z| format!("Result: {z}!")); + /// + /// let result = chain.call((1, 2)).await; + /// assert_eq!(result, "Result: 3!"); + /// ``` + pub fn map(self, f: F) -> impl Op + where + F: Fn(In) -> T + Send + Sync, + In: Send + Sync, + T: Send + Sync, + Self: Sized, + { + map(f) + } + + /// Same as `map` but for asynchronous functions + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// let chain = pipeline::new() + /// .then(|email: String| async move { + /// email.split('@').next().unwrap().to_string() + /// }) + /// .then(|username: String| async move { + /// format!("Hello, {}!", username) + /// }); + /// + /// let result = chain.call("bob@gmail.com".to_string()).await; + /// assert_eq!(result, "Hello, bob!"); + /// ``` + pub fn then(self, f: F) -> impl Op + where + F: Fn(In) -> Fut + Send + Sync, + In: Send + Sync, + Fut: Future + Send + Sync, + Fut::Output: Send + Sync, + Self: Sized, + { + then(f) + } + + /// Chain an arbitrary operation to the current pipeline. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// struct MyOp; + /// + /// impl Op for MyOp { + /// type Input = i32; + /// type Output = i32; + /// + /// async fn call(&self, input: Self::Input) -> Self::Output { + /// input + 1 + /// } + /// } + /// + /// let chain = pipeline::new() + /// .chain(MyOp); + /// + /// let result = chain.call(1).await; + /// assert_eq!(result, 2); + /// ``` + pub fn chain(self, op: T) -> impl Op + where + T: Op, + Self: Sized, + { + op + } + + /// Chain a lookup operation to the current chain. The lookup operation expects the + /// current chain to output a query string. The lookup operation will use the query to + /// retrieve the top `n` documents from the index and return them with the query string. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .lookup(index, 2) + /// .chain(|(query, docs): (_, Vec)| async move { + /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) + /// }); + /// + /// let result = chain.call("What is a flurbo?".to_string()).await; + /// ``` + pub fn lookup( + self, + index: I, + n: usize, + ) -> impl Op, vector_store::VectorStoreError>> + where + I: vector_store::VectorStoreIndex, + T: Send + Sync + for<'a> serde::Deserialize<'a>, + In: Into + Send + Sync, + // E: From + Send + Sync, + Self: Sized, + { + lookup(index, n) + } + + /// Chain a prompt operation to the current chain. The prompt operation expects the + /// current chain to output a string. The prompt operation will use the string to prompt + /// the given agent (or any other type that implements the `Prompt` trait) and return + /// the response. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let agent = &openai_client.agent("gpt-4").build(); + /// + /// let chain = chain::new() + /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) + /// .prompt(agent); + /// + /// let result = chain.call("Alice".to_string()).await; + /// ``` + pub fn prompt( + self, + prompt: P, + ) -> impl Op> + where + P: completion::Prompt, + In: Into + Send + Sync, + // E: From + Send + Sync, + Self: Sized, + { + prompt_op(prompt) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ChainError { + #[error("Failed to prompt agent: {0}")] + PromptError(#[from] completion::PromptError), + + #[error("Failed to lookup documents: {0}")] + LookupError(#[from] vector_store::VectorStoreError), +} + +pub fn new() -> PipelineBuilder { + PipelineBuilder { + _error: std::marker::PhantomData, + } +} + +pub fn with_error() -> PipelineBuilder { + PipelineBuilder { + _error: std::marker::PhantomData, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use agent_ops::tests::{Foo, MockIndex, MockModel}; + use parallel::parallel; + + #[tokio::test] + async fn test_prompt_pipeline() { + let model = MockModel; + + let chain = super::new() + .map(|input| format!("User query: {}", input)) + .prompt(model); + + let result = chain + .call("What is a flurbo?") + .await + .expect("Failed to run chain"); + + assert_eq!(result, "Mock response: User query: What is a flurbo?"); + } + + #[tokio::test] + async fn test_prompt_pipeline_error() { + let model = MockModel; + + let chain = super::with_error::<()>() + .map(|input| format!("User query: {}", input)) + .prompt(model); + + let result = chain + .try_call("What is a flurbo?") + .await + .expect("Failed to run chain"); + + assert_eq!(result, "Mock response: User query: What is a flurbo?"); + } + + #[tokio::test] + async fn test_lookup_pipeline() { + let index = MockIndex; + + let chain = super::new() + .lookup::<_, _, Foo>(index, 1) + .map_ok(|docs| format!("Top documents:\n{}", docs[0].foo)); + + let result = chain + .try_call("What is a flurbo?") + .await + .expect("Failed to run chain"); + + assert_eq!( + result, + "User query: What is a flurbo?\n\nTop documents:\nbar" + ); + } + + #[tokio::test] + async fn test_rag_pipeline() { + let index = MockIndex; + + let chain = super::new() + .chain(parallel!(passthrough(), lookup::<_, _, Foo>(index, 1),)) + .map(|(query, maybe_docs)| match maybe_docs { + Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo), + Err(err) => format!("Error: {}", err), + }) + .prompt(MockModel); + + let result = chain + .call("What is a flurbo?") + .await + .expect("Failed to run chain"); + + assert_eq!( + result, + "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar" + ); + } +} diff --git a/rig-core/src/pipeline/op.rs b/rig-core/src/pipeline/op.rs new file mode 100644 index 00000000..7208bc27 --- /dev/null +++ b/rig-core/src/pipeline/op.rs @@ -0,0 +1,374 @@ +use std::future::Future; + +#[allow(unused_imports)] // Needed since this is used in a macro rule +use futures::join; + +// ================================================================ +// Core Op trait +// ================================================================ +pub trait Op: Send + Sync { + type Input: Send + Sync; + type Output: Send + Sync; + + fn call(&self, input: Self::Input) -> impl Future + Send; + + /// Chain a function to the current pipeline + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// let chain = pipeline::new() + /// .map(|(x, y)| x + y) + /// .map(|z| format!("Result: {z}!")); + /// + /// let result = chain.call((1, 2)).await; + /// assert_eq!(result, "Result: 3!"); + /// ``` + fn map(self, f: F) -> impl Op + where + F: Fn(Self::Output) -> T + Send + Sync, + T: Send + Sync, + Self: Sized, + { + Sequential::new(self, map(f)) + } + + /// Same as `map` but for asynchronous functions + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// let chain = pipeline::new() + /// .then(|email: String| async move { + /// email.split('@').next().unwrap().to_string() + /// }) + /// .then(|username: String| async move { + /// format!("Hello, {}!", username) + /// }); + /// + /// let result = chain.call("bob@gmail.com".to_string()).await; + /// assert_eq!(result, "Hello, bob!"); + /// ``` + fn then(self, f: F) -> impl Op + where + F: Fn(Self::Output) -> Fut + Send + Sync, + Fut: Future + Send + Sync, + Fut::Output: Send + Sync, + Self: Sized, + { + Sequential::new(self, then(f)) + } + + /// Chain an arbitrary operation to the current pipeline. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// struct MyOp; + /// + /// impl Op for MyOp { + /// type Input = i32; + /// type Output = i32; + /// + /// async fn call(&self, input: Self::Input) -> Self::Output { + /// input + 1 + /// } + /// } + /// + /// let chain = pipeline::new() + /// .chain(MyOp); + /// + /// let result = chain.call(1).await; + /// assert_eq!(result, 2); + /// ``` + fn chain(self, op: T) -> impl Op + where + T: Op, + Self: Sized, + { + Sequential::new(self, op) + } + + /// Chain a lookup operation to the current chain. The lookup operation expects the + /// current chain to output a query string. The lookup operation will use the query to + /// retrieve the top `n` documents from the index and return them with the query string. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let chain = chain::new() + /// .lookup(index, 2) + /// .chain(|(query, docs): (_, Vec)| async move { + /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) + /// }); + /// + /// let result = chain.call("What is a flurbo?".to_string()).await; + /// ``` + fn lookup( + self, + index: I, + n: usize, + ) -> impl Op, vector_store::VectorStoreError>> + where + I: vector_store::VectorStoreIndex, + T: Send + Sync + for<'a> serde::Deserialize<'a>, + Self::Output: Into, + Self: Sized, + { + Sequential::new(self, Lookup::new(index, n)) + } + + /// Chain a prompt operation to the current chain. The prompt operation expects the + /// current chain to output a string. The prompt operation will use the string to prompt + /// the given agent (or any other type that implements the `Prompt` trait) and return + /// the response. + /// + /// # Example + /// ```rust + /// use rig::chain::{self, Chain}; + /// + /// let agent = &openai_client.agent("gpt-4").build(); + /// + /// let chain = chain::new() + /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) + /// .prompt(agent); + /// + /// let result = chain.call("Alice".to_string()).await; + /// ``` + fn prompt

( + self, + prompt: P, + ) -> impl Op> + where + P: completion::Prompt, + Self::Output: Into, + Self: Sized, + { + Sequential::new(self, Prompt::new(prompt)) + } +} + +impl Op for &T { + type Input = T::Input; + type Output = T::Output; + + #[inline] + async fn call(&self, input: Self::Input) -> Self::Output { + (*self).call(input).await + } +} + +// ================================================================ +// Op combinators +// ================================================================ +pub struct Sequential { + prev: Op1, + op: Op2, +} + +impl Sequential { + pub fn new(prev: Op1, op: Op2) -> Self { + Self { prev, op } + } +} + +impl Op for Sequential +where + Op1: Op, + Op2: Op, +{ + type Input = Op1::Input; + type Output = Op2::Output; + + #[inline] + async fn call(&self, input: Self::Input) -> Self::Output { + let prev = self.prev.call(input).await; + self.op.call(prev).await + } +} + +use crate::{completion, vector_store}; + +use super::agent_ops::{Lookup, Prompt}; + +// ================================================================ +// Core Op implementations +// ================================================================ +pub struct Map { + f: F, + _t: std::marker::PhantomData, +} + +impl Map { + pub fn new(f: F) -> Self { + Self { + f, + _t: std::marker::PhantomData, + } + } +} + +impl Op for Map +where + F: Fn(T) -> Out + Send + Sync, + T: Send + Sync, + Out: Send + Sync, +{ + type Input = T; + type Output = Out; + + #[inline] + async fn call(&self, input: Self::Input) -> Self::Output { + (self.f)(input) + } +} + +pub fn map(f: F) -> impl Op +where + F: Fn(T) -> Out + Send + Sync, + T: Send + Sync, + Out: Send + Sync, +{ + Map::new(f) +} + +pub fn passthrough() -> impl Op +where + T: Send + Sync, +{ + Map::new(|x| x) +} + +pub struct Then { + f: F, + _t: std::marker::PhantomData, +} + +impl Then { + fn new(f: F) -> Self { + Self { + f, + _t: std::marker::PhantomData, + } + } +} + +impl Op for Then +where + F: Fn(T) -> Fut + Send + Sync, + T: Send + Sync, + Fut: Future + Send, + Fut::Output: Send + Sync, +{ + type Input = T; + type Output = Fut::Output; + + #[inline] + async fn call(&self, input: Self::Input) -> Self::Output { + (self.f)(input).await + } +} + +pub fn then(f: F) -> impl Op +where + F: Fn(T) -> Fut + Send + Sync, + T: Send + Sync, + Fut: Future + Send, + Fut::Output: Send + Sync, +{ + Then::new(f) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_sequential_constructor() { + let op1 = map(|x: i32| x + 1); + let op2 = map(|x: i32| x * 2); + let op3 = map(|x: i32| x * 3); + + let pipeline = Sequential::new(Sequential::new(op1, op2), op3); + + let result = pipeline.call(1).await; + assert_eq!(result, 12); + } + + #[tokio::test] + async fn test_sequential_chain() { + let pipeline = map(|x: i32| x + 1) + .map(|x| x * 2) + .then(|x| async move { x * 3 }); + + let result = pipeline.call(1).await; + assert_eq!(result, 12); + } + + // #[tokio::test] + // async fn test_flatten() { + // let op = Parallel::new( + // Parallel::new( + // map(|x: i32| x + 1), + // map(|x: i32| x * 2), + // ), + // map(|x: i32| x * 3), + // ); + + // let pipeline = flatten::<_, (_, _, _)>(op); + + // let result = pipeline.call(1).await; + // assert_eq!(result, (2, 2, 3)); + // } + + // #[tokio::test] + // async fn test_parallel_macro() { + // let op1 = map(|x: i32| x + 1); + // let op2 = map(|x: i32| x * 3); + // let op3 = map(|x: i32| format!("{} is the number!", x)); + // let op4 = map(|x: i32| x - 1); + + // let pipeline = parallel!(op1, op2, op3, op4); + + // let result = pipeline.call(1).await; + // assert_eq!(result, (2, 3, "1 is the number!".to_string(), 0)); + // } + + // #[tokio::test] + // async fn test_parallel_join() { + // let op3 = map(|x: i32| format!("{} is the number!", x)); + + // let pipeline = Sequential::new( + // map(|x: i32| x + 1), + // then(|x| { + // // let op1 = map(|x: i32| x * 2); + // // let op2 = map(|x: i32| x * 3); + // let op3 = &op3; + + // async move { + // join!( + // (&map(|x: i32| x * 2)).call(x), + // { + // let op = map(|x: i32| x * 3); + // op.call(x) + // }, + // op3.call(x), + // ) + // }}), + // ); + + // let result = pipeline.call(1).await; + // assert_eq!(result, (2, 3, "1 is the number!".to_string())); + // } + + // #[test] + // fn test_flatten() { + // let x = (1, (2, (3, 4))); + // let result = flatten!(0, 1, 1, 1, 1); + // assert_eq!(result, (1, 2, 3, 4)); + // } +} diff --git a/rig-core/src/pipeline/parallel.rs b/rig-core/src/pipeline/parallel.rs new file mode 100644 index 00000000..e02c5efb --- /dev/null +++ b/rig-core/src/pipeline/parallel.rs @@ -0,0 +1,263 @@ +use futures::join; + +use super::Op; + +pub struct Parallel { + op1: Op1, + op2: Op2, +} + +impl Parallel { + pub fn new(op1: Op1, op2: Op2) -> Self { + Self { op1, op2 } + } +} + +impl Op for Parallel +where + Op1: Op, + Op1::Input: Clone, + Op2: Op, +{ + type Input = Op1::Input; + type Output = (Op1::Output, Op2::Output); + + #[inline] + async fn call(&self, input: Self::Input) -> Self::Output { + join!(self.op1.call(input.clone()), self.op2.call(input)) + } +} + +// See https://doc.rust-lang.org/src/core/future/join.rs.html#48 +#[macro_export] +macro_rules! parallel_internal { + // Last recursive step + ( + // Accumulate a token for each future that has been expanded: "_ _ _". + current_position: [ + $($underscores:tt)* + ] + // Accumulate values and their positions in the tuple: `_0th () _1st ( _ ) …`. + values_and_positions: [ + $($acc:tt)* + ] + // Munch one value. + munching: [ + $current:tt + ] + ) => ( + $crate::parallel_internal! { + current_position: [ + $($underscores)* + _ + ] + values_and_positions: [ + $($acc)* + $current ( $($underscores)* + ) + ] + munching: [] + } + ); + + // Recursion step: map each value with its "position" (underscore count). + ( + // Accumulate a token for each future that has been expanded: "_ _ _". + current_position: [ + $($underscores:tt)* + ] + // Accumulate values and their positions in the tuple: `_0th () _1st ( _ ) …`. + values_and_positions: [ + $($acc:tt)* + ] + // Munch one value. + munching: [ + $current:tt + $($rest:tt)+ + ] + ) => ( + $crate::parallel_internal! { + current_position: [ + $($underscores)* + _ + ] + values_and_positions: [ + $($acc)* + $current ( $($underscores)* ) + ] + munching: [ + $($rest)* + ] + } + ); + + // End of recursion: flatten the values. + ( + current_position: [ + $($max:tt)* + ] + values_and_positions: [ + $( + $val:tt ( $($pos:tt)* ) + )* + ] + munching: [] + ) => ({ + $crate::parallel_op!($($val),*) + .map(|output| { + ($( + { + let $crate::tuple_pattern!(x $($pos)*) = output; + x + } + ),+) + }) + }) +} + +#[macro_export] +macro_rules! parallel_op { + ($op1:tt, $op2:tt) => { + $crate::pipeline::parallel::Parallel::new($op1, $op2) + }; + ($op1:tt $(, $ops:tt)*) => { + $crate::pipeline::parallel::Parallel::new( + $op1, + $crate::parallel_op!($($ops),*) + ) + }; +} + +#[macro_export] +macro_rules! tuple_pattern { + ($id:ident +) => { + $id + }; + ($id:ident) => { + ($id, ..) + }; + ($id:ident _ $($symbols:tt)*) => { + (_, $crate::tuple_pattern!($id $($symbols)*)) + }; +} + +#[macro_export] +macro_rules! parallel { + ($($es:expr),+ $(,)?) => { + $crate::parallel_internal! { + current_position: [] + values_and_positions: [] + munching: [ + $($es)+ + ] + } + }; +} + +pub use parallel; +pub use parallel_internal; + +#[cfg(test)] +mod tests { + use super::*; + use crate::pipeline::{ + self, + op::{map, Sequential}, + passthrough, then, + }; + + #[tokio::test] + async fn test_parallel() { + let op1 = map(|x: i32| x + 1); + let op2 = map(|x: i32| x * 3); + let pipeline = Parallel::new(op1, op2); + + let result = pipeline.call(1).await; + assert_eq!(result, (2, 3)); + } + + #[tokio::test] + async fn test_parallel_nested() { + let op1 = map(|x: i32| x + 1); + let op2 = map(|x: i32| x * 3); + let op3 = map(|x: i32| format!("{} is the number!", x)); + let op4 = map(|x: i32| x - 1); + + let pipeline = Parallel::new(Parallel::new(Parallel::new(op1, op2), op3), op4); + + let result = pipeline.call(1).await; + assert_eq!(result, (((2, 3), "1 is the number!".to_string()), 0)); + } + + #[tokio::test] + async fn test_parallel_nested_rev() { + let op1 = map(|x: i32| x + 1); + let op2 = map(|x: i32| x * 3); + let op3 = map(|x: i32| format!("{} is the number!", x)); + let op4 = map(|x: i32| x == 1); + + let pipeline = Parallel::new(op1, Parallel::new(op2, Parallel::new(op3, op4))); + + let result = pipeline.call(1).await; + assert_eq!(result, (2, (3, ("1 is the number!".to_string(), true)))); + } + + #[tokio::test] + async fn test_sequential_and_parallel() { + let op1 = map(|x: i32| x + 1); + let op2 = map(|x: i32| x * 2); + let op3 = map(|x: i32| x * 3); + let op4 = map(|(x, y): (i32, i32)| x + y); + + let pipeline = Sequential::new(Sequential::new(op1, Parallel::new(op2, op3)), op4); + + let result = pipeline.call(1).await; + assert_eq!(result, 10); + } + + #[tokio::test] + async fn test_parallel_chain_compile_check() { + let _ = pipeline::new().chain( + Parallel::new( + map(|x: i32| x + 1), + Parallel::new( + map(|x: i32| x * 3), + Parallel::new( + map(|x: i32| format!("{} is the number!", x)), + map(|x: i32| x == 1), + ), + ), + ) + .map(|(r1, (r2, (r3, r4)))| (r1, r2, r3, r4)), + ); + } + + #[tokio::test] + async fn test_parallel_pass_through() { + let pipeline = then(|x| { + let op = Parallel::new(Parallel::new(passthrough(), passthrough()), passthrough()); + + async move { + let ((r1, r2), r3) = op.call(x).await; + (r1, r2, r3) + } + }); + + let result = pipeline.call(1).await; + assert_eq!(result, (1, 1, 1)); + } + + #[tokio::test] + async fn test_parallel_macro() { + let op2 = map(|x: i32| x * 2); + + let pipeline = parallel!( + passthrough(), + op2, + map(|x: i32| format!("{} is the number!", x)), + map(|x: i32| x == 1) + ); + + let result = pipeline.call(1).await; + assert_eq!(result, (2, 3, "1 is the number!".to_string(), false)); + } +} diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs new file mode 100644 index 00000000..d77e491b --- /dev/null +++ b/rig-core/src/pipeline/try_op.rs @@ -0,0 +1,391 @@ +use std::future::Future; + +#[allow(unused_imports)] // Needed since this is used in a macro rule +use futures::try_join; + +use super::op::{map, then}; + +// ================================================================ +// Core TryOp trait +// ================================================================ +pub trait TryOp: Send + Sync { + type Input: Send + Sync; + type Output: Send + Sync; + type Error: Send + Sync; + + fn try_call( + &self, + input: Self::Input, + ) -> impl Future> + Send; + + fn map_ok(self, f: F) -> impl TryOp + where + F: Fn(Self::Output) -> T + Send + Sync, + T: Send + Sync, + Self: Sized, + { + MapOk::new(self, map(f)) + } + + fn map_err( + self, + f: F, + ) -> impl TryOp + where + F: Fn(Self::Error) -> E + Send + Sync, + E: Send + Sync, + Self: Sized, + { + MapErr::new(self, map(f)) + } + + fn and_then( + self, + f: F, + ) -> impl TryOp + where + F: Fn(Self::Output) -> Fut + Send + Sync, + Fut: Future> + Send + Sync, + T: Send + Sync, + Self: Sized, + { + AndThen::new(self, then(f)) + } + + fn or_else( + self, + f: F, + ) -> impl TryOp + where + F: Fn(Self::Error) -> Fut + Send + Sync, + Fut: Future> + Send + Sync, + E: Send + Sync, + Self: Sized, + { + OrElse::new(self, then(f)) + } +} + +impl TryOp for Op +where + Op: super::Op>, + T: Send + Sync, + E: Send + Sync, +{ + type Input = Op::Input; + type Output = T; + type Error = E; + + async fn try_call(&self, input: Self::Input) -> Result { + self.call(input).await + } +} + +// ================================================================ +// TryOp combinators +// ================================================================ +pub struct MapOk { + prev: Op1, + op: Op2, +} + +impl MapOk { + pub fn new(prev: Op1, op: Op2) -> Self { + Self { prev, op } + } +} + +// Result -> Result +impl TryOp for MapOk +where + Op1: TryOp, + Op2: super::Op, +{ + type Input = Op1::Input; + type Output = Op2::Output; + type Error = Op1::Error; + + #[inline] + async fn try_call(&self, input: Self::Input) -> Result { + match self.prev.try_call(input).await { + Ok(output) => Ok(self.op.call(output).await), + Err(err) => Err(err), + } + } +} + +pub struct MapErr { + prev: Op1, + op: Op2, +} + +impl MapErr { + pub fn new(prev: Op1, op: Op2) -> Self { + Self { prev, op } + } +} + +// Result -> Result +impl TryOp for MapErr +where + Op1: TryOp, + Op2: super::Op, +{ + type Input = Op1::Input; + type Output = Op1::Output; + type Error = Op2::Output; + + #[inline] + async fn try_call(&self, input: Self::Input) -> Result { + match self.prev.try_call(input).await { + Ok(output) => Ok(output), + Err(err) => Err(self.op.call(err).await), + } + } +} + +pub struct AndThen { + prev: Op1, + op: Op2, +} + +impl AndThen { + pub fn new(prev: Op1, op: Op2) -> Self { + Self { prev, op } + } +} + +impl TryOp for AndThen +where + Op1: TryOp, + Op2: TryOp, +{ + type Input = Op1::Input; + type Output = Op2::Output; + type Error = Op1::Error; + + #[inline] + async fn try_call(&self, input: Self::Input) -> Result { + let output = self.prev.try_call(input).await?; + self.op.try_call(output).await + } +} + +pub struct OrElse { + prev: Op1, + op: Op2, +} + +impl OrElse { + pub fn new(prev: Op1, op: Op2) -> Self { + Self { prev, op } + } +} + +impl TryOp for OrElse +where + Op1: TryOp, + Op2: TryOp, +{ + type Input = Op1::Input; + type Output = Op1::Output; + type Error = Op2::Error; + + #[inline] + async fn try_call(&self, input: Self::Input) -> Result { + match self.prev.try_call(input).await { + Ok(output) => Ok(output), + Err(err) => self.op.try_call(err).await, + } + } +} + +// TODO: Implement TryParallel +// pub struct TryParallel { +// op1: Op1, +// op2: Op2, +// } + +// impl TryParallel { +// pub fn new(op1: Op1, op2: Op2) -> Self { +// Self { op1, op2 } +// } +// } + +// impl TryOp for TryParallel +// where +// Op1: TryOp, +// Op2: TryOp, +// { +// type Input = Op1::Input; +// type Output = (Op1::Output, Op2::Output); +// type Error = Op1::Error; + +// #[inline] +// async fn try_call(&self, input: Self::Input) -> Result { +// let (output1, output2) = tokio::join!(self.op1.try_call(input.clone()), self.op2.try_call(input)); +// Ok((output1?, output2?)) +// } +// } +#[macro_export] +macro_rules! try_parallel { + ($($ops:ident),+) => { + then(|input: i32| { + let ($($ops),+) = ($(&$ops),+); + async move { + try_join!($($ops.try_call(input.clone())),+) + } + }) + }; +} + +pub use try_parallel; + +#[cfg(test)] +mod tests { + use super::*; + use crate::pipeline::op::{map, then}; + + #[tokio::test] + async fn test_try_op() { + let op = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); + let result = op.try_call(1).await.unwrap(); + assert_eq!(result, 2); + } + + #[tokio::test] + async fn test_map_ok_constructor() { + let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); + let op2 = then(|x: i32| async move { x * 2 }); + let op3 = map(|x: i32| x - 1); + + let pipeline = MapOk::new(MapOk::new(op1, op2), op3); + + let result = pipeline.try_call(2).await.unwrap(); + assert_eq!(result, 3); + } + + #[tokio::test] + async fn test_map_ok_chain() { + let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) + .map_ok(|x| x * 2) + .map_ok(|x| x - 1); + + let result = pipeline.try_call(2).await.unwrap(); + assert_eq!(result, 3); + } + + #[tokio::test] + async fn test_map_err_constructor() { + let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); + let op2 = then(|err: &str| async move { format!("Error: {}", err) }); + let op3 = map(|err: String| err.len()); + + let pipeline = MapErr::new(MapErr::new(op1, op2), op3); + + let result = pipeline.try_call(1).await; + assert_eq!(result, Err(15)); + } + + #[tokio::test] + async fn test_map_err_chain() { + let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) + .map_err(|err| format!("Error: {}", err)) + .map_err(|err| err.len()); + + let result = pipeline.try_call(1).await; + assert_eq!(result, Err(15)); + } + + #[tokio::test] + async fn test_and_then_constructor() { + let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); + let op2 = then(|x: i32| async move { Ok(x * 2) }); + let op3 = map(|x: i32| Ok(x - 1)); + + let pipeline = AndThen::new(AndThen::new(op1, op2), op3); + + let result = pipeline.try_call(2).await.unwrap(); + assert_eq!(result, 3); + } + + #[tokio::test] + async fn test_and_then_chain() { + let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) + .and_then(|x| async move { Ok(x * 2) }) + .and_then(|x| async move { Ok(x - 1) }); + + let result = pipeline.try_call(2).await.unwrap(); + assert_eq!(result, 3); + } + + #[tokio::test] + async fn test_or_else_constructor() { + let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); + let op2 = then(|err: &str| async move { Err(format!("Error: {}", err)) }); + let op3 = map(|err: String| Ok::(err.len() as i32)); + + let pipeline = OrElse::new(OrElse::new(op1, op2), op3); + + let result = pipeline.try_call(1).await.unwrap(); + assert_eq!(result, 15); + } + + #[tokio::test] + async fn test_or_else_chain() { + let pipeline = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) + .or_else(|err| async move { Err(format!("Error: {}", err)) }) + .or_else(|err| async move { Ok::(err.len() as i32) }); + + let result = pipeline.try_call(1).await.unwrap(); + assert_eq!(result, 15); + } + + #[tokio::test] + async fn test_try_parallel_ok() { + let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); + let op2 = map(|x: i32| { + if x % 3 == 0 { + Ok(x) + } else { + Err("x is not divisible by 3") + } + }); + let op3 = map(|x: i32| { + if x % 5 == 0 { + Ok(x) + } else { + Err("x is not divisible by 5") + } + }); + + let pipeline = try_parallel!(op1, op2, op3); + + let result = pipeline.try_call(30).await.unwrap(); + assert_eq!(result, (30, 30, 30)); + } + + #[tokio::test] + async fn test_try_parallel_err() { + let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); + let op2 = map(|x: i32| { + if x % 3 == 0 { + Ok(x) + } else { + Err("x is not divisible by 3") + } + }); + let op3 = map(|x: i32| { + if x % 5 == 0 { + Ok(x) + } else { + Err("x is not divisible by 5") + } + }); + + let pipeline = try_parallel!(op1, op2, op3); + + let result = pipeline.try_call(31).await; + assert_eq!(result, Err("x is odd")); + } +} From f46828d8fed25dae3a631458051acc9f4efe9865 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 14:17:29 -0500 Subject: [PATCH 04/26] docs: Update example --- rig-core/examples/chain.rs | 56 +++++--------------------------------- 1 file changed, 7 insertions(+), 49 deletions(-) diff --git a/rig-core/examples/chain.rs b/rig-core/examples/chain.rs index ed7e6c2f..1fa89e4c 100644 --- a/rig-core/examples/chain.rs +++ b/rig-core/examples/chain.rs @@ -5,7 +5,7 @@ use rig::{ parallel, pipeline::{self, agent_ops::lookup, passthrough, Op}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, }; #[tokio::main] @@ -16,17 +16,16 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Create vector store, compute embeddings and load them in the store - let mut vector_store = InMemoryVectorStore::default(); - + // Create embeddings for our documents let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .document("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")? + .document("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")? + .document("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")? .build() .await?; - vector_store.add_documents(embeddings).await?; + // Create vector store with the embeddings + let vector_store = InMemoryVectorStore::from_documents(embeddings); // Create vector store index let index = vector_store.index(embedding_model); @@ -71,44 +70,3 @@ async fn main() -> Result<(), anyhow::Error> { Ok(()) } - -// trait Foo { -// fn foo(&self); -// } - -// impl Foo<(T,)> for F -// where -// F: Fn(T) -> Out, -// { -// fn foo(&self) { -// todo!() -// } -// } - -// impl Foo<(T1, T2)> for F -// where -// F: Fn(T1, T2) -> Out, -// { -// fn foo(&self) { -// todo!() -// } -// } - -// impl Foo<(T1, T2, T3)> for F -// where -// F: Fn(T1, T2, T3) -> Out, -// { -// fn foo(&self) { -// todo!() -// } -// } - -// impl Foo<((Fut, T,),)> for F -// where -// F: Fn(T) -> Fut, -// Fut: Future, -// { -// fn foo(&self) { -// todo!() -// } -// } From 110554bc62f9fc8ba8dcc7ca310485ec77df9d72 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 14:27:34 -0500 Subject: [PATCH 05/26] feat: Add extraction pipeline op --- rig-core/src/pipeline/agent_ops.rs | 45 +++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index 755bd74f..1494119f 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -1,4 +1,4 @@ -use crate::{completion, vector_store}; +use crate::{completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, vector_store}; use super::Op; @@ -92,6 +92,49 @@ where Prompt::new(prompt) } +pub struct Extract +where + M: CompletionModel, + T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, +{ + extractor: Extractor, + _in: std::marker::PhantomData, +} + +impl Extract +where + M: CompletionModel, + T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, +{ + pub fn new(extractor: Extractor) -> Self { + Self { extractor, _in: std::marker::PhantomData } + } +} + +impl Op for Extract +where + M: CompletionModel, + T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + In: Into + Send + Sync, +{ + type Input = In; + type Output = Result; + + async fn call(&self, input: Self::Input) -> Self::Output { + self.extractor.extract(&input.into()).await + } +} + +pub fn extract(extractor: Extractor) -> Extract +where + M: CompletionModel, + T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + In: Into + Send + Sync, +{ + Extract::new(extractor) +} + + #[cfg(test)] pub mod tests { use super::*; From d49e11edac1631c86ffb301717e8131463fa8629 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 14:40:57 -0500 Subject: [PATCH 06/26] docs: Add extraction pipeline example --- rig-core/examples/multi_extract.rs | 73 ++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 rig-core/examples/multi_extract.rs diff --git a/rig-core/examples/multi_extract.rs b/rig-core/examples/multi_extract.rs new file mode 100644 index 00000000..a88fc283 --- /dev/null +++ b/rig-core/examples/multi_extract.rs @@ -0,0 +1,73 @@ +use rig::{parallel, pipeline::{self, agent_ops, Op}, providers::openai}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +/// A record containing extracted names +pub struct Names { + /// The names extracted from the text + pub names: Vec, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +/// A record containing extracted topics +pub struct Topics { + /// The topics extracted from the text + pub topics: Vec, +} + +#[derive(Debug, Deserialize, JsonSchema, Serialize)] +/// A record containing extracted sentiment +pub struct Sentiment { + /// The sentiment of the text (-1 being negative, 1 being positive) + pub sentiment: f64, + /// The confidence of the sentiment + pub confidence: f64, +} + + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let openai = openai::Client::from_env(); + + let names_extractor = openai + .extractor::("gpt-4") + .preamble("Extract names (e.g.: of people, places) from the given text.") + .build(); + + let topics_extractor = openai + .extractor::("gpt-4") + .preamble("Extract topics from the given text.") + .build(); + + let sentiment_extractor = openai + .extractor::("gpt-4") + .preamble("Extract sentiment (and how confident you are of the sentiment) from the given text.") + .build(); + + let chain = pipeline::new() + .chain(parallel!( + agent_ops::extract(names_extractor), + agent_ops::extract(topics_extractor), + agent_ops::extract(sentiment_extractor), + )) + .map(|(names, topics, sentiment)| { + match (names, topics, sentiment) { + (Ok(names), Ok(topics), Ok(sentiment)) => { + format!( + "Extracted names: {names}\nExtracted topics: {topics}\nExtracted sentiment: {sentiment}", + names = names.names.join(", "), + topics = topics.topics.join(", "), + sentiment = sentiment.sentiment, + ) + } + _ => "Failed to extract names, topics, or sentiment".to_string(), + } + }); + + let response = chain.call("Screw you Putin!").await; + + println!("Text analysis:\n{response}"); + + Ok(()) +} \ No newline at end of file From fbfc2bced80e03d85e840c911a7096d7a4b23f74 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 15:18:40 -0500 Subject: [PATCH 07/26] feat: Add `try_parallel!` pipeline op macro --- rig-core/examples/multi_extract.rs | 25 ++--- rig-core/src/pipeline/parallel.rs | 170 ++++++++++++++++++++++++++++- rig-core/src/pipeline/try_op.rs | 127 ++++++++++----------- 3 files changed, 238 insertions(+), 84 deletions(-) diff --git a/rig-core/examples/multi_extract.rs b/rig-core/examples/multi_extract.rs index a88fc283..ccd7bc89 100644 --- a/rig-core/examples/multi_extract.rs +++ b/rig-core/examples/multi_extract.rs @@ -1,4 +1,4 @@ -use rig::{parallel, pipeline::{self, agent_ops, Op}, providers::openai}; +use rig::{pipeline::{self, agent_ops, TryOp}, providers::openai, try_parallel}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -46,26 +46,21 @@ async fn main() -> anyhow::Result<()> { .build(); let chain = pipeline::new() - .chain(parallel!( + .chain(try_parallel!( agent_ops::extract(names_extractor), agent_ops::extract(topics_extractor), agent_ops::extract(sentiment_extractor), )) - .map(|(names, topics, sentiment)| { - match (names, topics, sentiment) { - (Ok(names), Ok(topics), Ok(sentiment)) => { - format!( - "Extracted names: {names}\nExtracted topics: {topics}\nExtracted sentiment: {sentiment}", - names = names.names.join(", "), - topics = topics.topics.join(", "), - sentiment = sentiment.sentiment, - ) - } - _ => "Failed to extract names, topics, or sentiment".to_string(), - } + .map_ok(|(names, topics, sentiment)| { + format!( + "Extracted names: {names}\nExtracted topics: {topics}\nExtracted sentiment: {sentiment}", + names = names.names.join(", "), + topics = topics.topics.join(", "), + sentiment = sentiment.sentiment, + ) }); - let response = chain.call("Screw you Putin!").await; + let response = chain.try_call("Screw you Putin!").await?; println!("Text analysis:\n{response}"); diff --git a/rig-core/src/pipeline/parallel.rs b/rig-core/src/pipeline/parallel.rs index e02c5efb..f790d72e 100644 --- a/rig-core/src/pipeline/parallel.rs +++ b/rig-core/src/pipeline/parallel.rs @@ -1,6 +1,6 @@ -use futures::join; +use futures::{join, try_join}; -use super::Op; +use super::{Op, TryOp}; pub struct Parallel { op1: Op1, @@ -28,6 +28,22 @@ where } } +impl TryOp for Parallel +where + Op1: TryOp, + Op1::Input: Clone, + Op2: TryOp, +{ + type Input = Op1::Input; + type Output = (Op1::Output, Op2::Output); + type Error = Op1::Error; + + #[inline] + async fn try_call(&self, input: Self::Input) -> Result { + try_join!(self.op1.try_call(input.clone()), self.op2.try_call(input)) + } +} + // See https://doc.rust-lang.org/src/core/future/join.rs.html#48 #[macro_export] macro_rules! parallel_internal { @@ -153,6 +169,106 @@ macro_rules! parallel { }; } +// See https://doc.rust-lang.org/src/core/future/join.rs.html#48 +#[macro_export] +macro_rules! try_parallel_internal { + // Last recursive step + ( + // Accumulate a token for each future that has been expanded: "_ _ _". + current_position: [ + $($underscores:tt)* + ] + // Accumulate values and their positions in the tuple: `_0th () _1st ( _ ) …`. + values_and_positions: [ + $($acc:tt)* + ] + // Munch one value. + munching: [ + $current:tt + ] + ) => ( + $crate::try_parallel_internal! { + current_position: [ + $($underscores)* + _ + ] + values_and_positions: [ + $($acc)* + $current ( $($underscores)* + ) + ] + munching: [] + } + ); + + // Recursion step: map each value with its "position" (underscore count). + ( + // Accumulate a token for each future that has been expanded: "_ _ _". + current_position: [ + $($underscores:tt)* + ] + // Accumulate values and their positions in the tuple: `_0th () _1st ( _ ) …`. + values_and_positions: [ + $($acc:tt)* + ] + // Munch one value. + munching: [ + $current:tt + $($rest:tt)+ + ] + ) => ( + $crate::try_parallel_internal! { + current_position: [ + $($underscores)* + _ + ] + values_and_positions: [ + $($acc)* + $current ( $($underscores)* ) + ] + munching: [ + $($rest)* + ] + } + ); + + // End of recursion: flatten the values. + ( + current_position: [ + $($max:tt)* + ] + values_and_positions: [ + $( + $val:tt ( $($pos:tt)* ) + )* + ] + munching: [] + ) => ({ + $crate::parallel_op!($($val),*) + .map_ok(|output| { + ($( + { + let $crate::tuple_pattern!(x $($pos)*) = output; + x + } + ),+) + }) + }) +} + + +#[macro_export] +macro_rules! try_parallel { + ($($es:expr),+ $(,)?) => { + $crate::try_parallel_internal! { + current_position: [] + values_and_positions: [] + munching: [ + $($es)+ + ] + } + }; +} + pub use parallel; pub use parallel_internal; @@ -260,4 +376,54 @@ mod tests { let result = pipeline.call(1).await; assert_eq!(result, (2, 3, "1 is the number!".to_string(), false)); } + + #[tokio::test] + async fn test_try_parallel_chain_compile_check() { + let chain = pipeline::new().chain( + Parallel::new( + map(|x: i32| Ok::<_, String>(x + 1)), + Parallel::new( + map(|x: i32| Ok::<_, String>(x * 3)), + Parallel::new( + map(|x: i32| Err::(format!("{} is the number!", x))), + map(|x: i32| Ok::<_, String>(x == 1)), + ), + ), + ) + .map_ok(|(r1, (r2, (r3, r4)))| (r1, r2, r3, r4)), + ); + + let response = chain.call(1).await; + assert_eq!(response, Err("1 is the number!".to_string())); + } + + #[tokio::test] + async fn test_try_parallel_macro_ok() { + let op2 = map(|x: i32| Ok::<_, String>(x * 2)); + + let pipeline = try_parallel!( + map(|x: i32| Ok::<_, String>(x)), + op2, + map(|x: i32| Ok::<_, String>(format!("{} is the number!", x))), + map(|x: i32| Ok::<_, String>(x == 1)) + ); + + let result = pipeline.try_call(1).await; + assert_eq!(result, Ok((2, 3, "1 is the number!".to_string(), false))); + } + + #[tokio::test] + async fn test_try_parallel_macro_err() { + let op2 = map(|x: i32| Ok::<_, String>(x * 2)); + + let pipeline = try_parallel!( + map(|x: i32| Ok::<_, String>(x)), + op2, + map(|x: i32| Err::(format!("{} is the number!", x))), + map(|x: i32| Ok::<_, String>(x == 1)) + ); + + let result = pipeline.try_call(1).await; + assert_eq!(result, Err("1 is the number!".to_string())); + } } diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index d77e491b..61876ba4 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -3,7 +3,7 @@ use std::future::Future; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::try_join; -use super::op::{map, then}; +use super::op::{self, map, then}; // ================================================================ // Core TryOp trait @@ -18,7 +18,7 @@ pub trait TryOp: Send + Sync { input: Self::Input, ) -> impl Future> + Send; - fn map_ok(self, f: F) -> impl TryOp + fn map_ok(self, f: F) -> impl op::Op> where F: Fn(Self::Output) -> T + Send + Sync, T: Send + Sync, @@ -64,6 +64,14 @@ pub trait TryOp: Send + Sync { { OrElse::new(self, then(f)) } + + fn chain_ok(self, op: T) -> impl TryOp + where + T: op::Op, + Self: Sized, + { + TrySequential::new(self, op) + } } impl TryOp for Op @@ -96,17 +104,34 @@ impl MapOk { } // Result -> Result -impl TryOp for MapOk +// impl TryOp for MapOk +// where +// Op1: TryOp, +// Op2: super::Op, +// { +// type Input = Op1::Input; +// type Output = Op2::Output; +// type Error = Op1::Error; + +// #[inline] +// async fn try_call(&self, input: Self::Input) -> Result { +// match self.prev.try_call(input).await { +// Ok(output) => Ok(self.op.call(output).await), +// Err(err) => Err(err), +// } +// } +// } + +impl op::Op for MapOk where Op1: TryOp, Op2: super::Op, { type Input = Op1::Input; - type Output = Op2::Output; - type Error = Op1::Error; + type Output = Result; #[inline] - async fn try_call(&self, input: Self::Input) -> Result { + async fn call(&self, input: Self::Input) -> Self::Output { match self.prev.try_call(input).await { Ok(output) => Ok(self.op.call(output).await), Err(err) => Err(err), @@ -200,6 +225,35 @@ where } } +pub struct TrySequential { + prev: Op1, + op: Op2, +} + +impl TrySequential { + pub fn new(prev: Op1, op: Op2) -> Self { + Self { prev, op } + } +} + +impl TryOp for TrySequential +where + Op1: TryOp, + Op2: op::Op, +{ + type Input = Op1::Input; + type Output = Op2::Output; + type Error = Op1::Error; + + #[inline] + async fn try_call(&self, input: Self::Input) -> Result { + match self.prev.try_call(input).await { + Ok(output) => Ok(self.op.call(output).await), + Err(err) => Err(err), + } + } +} + // TODO: Implement TryParallel // pub struct TryParallel { // op1: Op1, @@ -227,19 +281,6 @@ where // Ok((output1?, output2?)) // } // } -#[macro_export] -macro_rules! try_parallel { - ($($ops:ident),+) => { - then(|input: i32| { - let ($($ops),+) = ($(&$ops),+); - async move { - try_join!($($ops.try_call(input.clone())),+) - } - }) - }; -} - -pub use try_parallel; #[cfg(test)] mod tests { @@ -340,52 +381,4 @@ mod tests { let result = pipeline.try_call(1).await.unwrap(); assert_eq!(result, 15); } - - #[tokio::test] - async fn test_try_parallel_ok() { - let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); - let op2 = map(|x: i32| { - if x % 3 == 0 { - Ok(x) - } else { - Err("x is not divisible by 3") - } - }); - let op3 = map(|x: i32| { - if x % 5 == 0 { - Ok(x) - } else { - Err("x is not divisible by 5") - } - }); - - let pipeline = try_parallel!(op1, op2, op3); - - let result = pipeline.try_call(30).await.unwrap(); - assert_eq!(result, (30, 30, 30)); - } - - #[tokio::test] - async fn test_try_parallel_err() { - let op1 = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); - let op2 = map(|x: i32| { - if x % 3 == 0 { - Ok(x) - } else { - Err("x is not divisible by 3") - } - }); - let op3 = map(|x: i32| { - if x % 5 == 0 { - Ok(x) - } else { - Err("x is not divisible by 5") - } - }); - - let pipeline = try_parallel!(op1, op2, op3); - - let result = pipeline.try_call(31).await; - assert_eq!(result, Err("x is odd")); - } } From f32ec08dfb44ca4116e2b650f1e1c93a08056684 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 15:19:07 -0500 Subject: [PATCH 08/26] misc: Remove unused module --- rig-core/src/pipeline/builder.rs | 261 ------------------------------- 1 file changed, 261 deletions(-) delete mode 100644 rig-core/src/pipeline/builder.rs diff --git a/rig-core/src/pipeline/builder.rs b/rig-core/src/pipeline/builder.rs deleted file mode 100644 index 61ce3419..00000000 --- a/rig-core/src/pipeline/builder.rs +++ /dev/null @@ -1,261 +0,0 @@ -use std::future::Future; - -use crate::{completion, vector_store}; - -use super::{agent_ops, op}; - -// pub struct PipelineBuilder { -// _error: std::marker::PhantomData, -// } -pub struct PipelineBuilder; - -impl PipelineBuilder { - /// Chain a function to the current pipeline - /// - /// # Example - /// ```rust - /// use rig::pipeline::{self, Op}; - /// - /// let chain = pipeline::new() - /// .map(|(x, y)| x + y) - /// .map(|z| format!("Result: {z}!")); - /// - /// let result = chain.call((1, 2)).await; - /// assert_eq!(result, "Result: 3!"); - /// ``` - pub fn map(self, f: F) -> op::Map - where - F: Fn(In) -> T + Send + Sync, - In: Send + Sync, - T: Send + Sync, - Self: Sized, - { - op::Map::new(f) - } - - /// Same as `map` but for asynchronous functions - /// - /// # Example - /// ```rust - /// use rig::pipeline::{self, Op}; - /// - /// let chain = pipeline::new() - /// .then(|email: String| async move { - /// email.split('@').next().unwrap().to_string() - /// }) - /// .then(|username: String| async move { - /// format!("Hello, {}!", username) - /// }); - /// - /// let result = chain.call("bob@gmail.com".to_string()).await; - /// assert_eq!(result, "Hello, bob!"); - /// ``` - pub fn then(self, f: F) -> op::Then - where - F: Fn(In) -> Fut + Send + Sync, - In: Send + Sync, - Fut: Future + Send + Sync, - Fut::Output: Send + Sync, - Self: Sized, - { - op::Then::new(f) - } - - /// Chain an arbitrary operation to the current pipeline. - /// - /// # Example - /// ```rust - /// use rig::pipeline::{self, Op}; - /// - /// struct MyOp; - /// - /// impl Op for MyOp { - /// type Input = i32; - /// type Output = i32; - /// - /// async fn call(&self, input: Self::Input) -> Self::Output { - /// input + 1 - /// } - /// } - /// - /// let chain = pipeline::new() - /// .chain(MyOp); - /// - /// let result = chain.call(1).await; - /// assert_eq!(result, 2); - /// ``` - pub fn chain(self, op: T) -> T - where - T: op::Op, - Self: Sized, - { - op - } - - /// Chain a lookup operation to the current chain. The lookup operation expects the - /// current chain to output a query string. The lookup operation will use the query to - /// retrieve the top `n` documents from the index and return them with the query string. - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let chain = chain::new() - /// .lookup(index, 2) - /// .chain(|(query, docs): (_, Vec)| async move { - /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) - /// }); - /// - /// let result = chain.call("What is a flurbo?".to_string()).await; - /// ``` - pub fn lookup( - self, - index: I, - n: usize, - ) -> agent_ops::Lookup - where - I: vector_store::VectorStoreIndex, - T: Send + Sync + for<'a> serde::Deserialize<'a>, - In: Into + Send + Sync, - Self: Sized, - { - agent_ops::Lookup::new(index, n) - } - - /// Chain a prompt operation to the current chain. The prompt operation expects the - /// current chain to output a string. The prompt operation will use the string to prompt - /// the given agent (or any other type that implements the `Prompt` trait) and return - /// the response. - /// - /// # Example - /// ```rust - /// use rig::chain::{self, Chain}; - /// - /// let agent = &openai_client.agent("gpt-4").build(); - /// - /// let chain = chain::new() - /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) - /// .prompt(agent); - /// - /// let result = chain.call("Alice".to_string()).await; - /// ``` - pub fn prompt( - self, - prompt: P, - ) -> impl op::Op> - where - P: completion::Prompt, - In: Into + Send + Sync, - Self: Sized, - { - agent_ops::prompt(prompt) - } -} - -#[derive(Debug, thiserror::Error)] -pub enum ChainError { - #[error("Failed to prompt agent: {0}")] - PromptError(#[from] completion::PromptError), - - #[error("Failed to lookup documents: {0}")] - LookupError(#[from] vector_store::VectorStoreError), -} - -// pub fn new() -> PipelineBuilder { -// PipelineBuilder { -// _error: std::marker::PhantomData, -// } -// } - -// pub fn with_error() -> PipelineBuilder { -// PipelineBuilder { -// _error: std::marker::PhantomData, -// } -// } - -pub fn new() -> PipelineBuilder { - PipelineBuilder -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::pipeline::{op::Op, parallel::{parallel, Parallel}}; - use agent_ops::tests::{Foo, MockIndex, MockModel}; - - #[tokio::test] - async fn test_prompt_pipeline() { - let model = MockModel; - - let chain = super::new() - .map(|input| format!("User query: {}", input)) - .prompt(model); - - let result = chain - .call("What is a flurbo?") - .await - .expect("Failed to run chain"); - - assert_eq!(result, "Mock response: User query: What is a flurbo?"); - } - - // #[tokio::test] - // async fn test_lookup_pipeline() { - // let index = MockIndex; - - // let chain = super::new() - // .lookup::<_, _, Foo>(index, 1) - // .map_ok(|docs| format!("Top documents:\n{}", docs[0].foo)); - - // let result = chain - // .try_call("What is a flurbo?") - // .await - // .expect("Failed to run chain"); - - // assert_eq!( - // result, - // "User query: What is a flurbo?\n\nTop documents:\nbar" - // ); - // } - - #[tokio::test] - async fn test_rag_pipeline() { - let index = MockIndex; - - let chain = super::new() - .chain(parallel!(op::passthrough(), agent_ops::lookup::<_, _, Foo>(index, 1),)) - .map(|(query, maybe_docs)| match maybe_docs { - Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo), - Err(err) => format!("Error: {}", err), - }) - .prompt(MockModel); - - let result = chain - .call("What is a flurbo?") - .await - .expect("Failed to run chain"); - - assert_eq!( - result, - "Mock response: User query: What is a flurbo?\n\nTop documents:\nbar" - ); - } - - #[tokio::test] - async fn test_parallel_chain_compile_check() { - let _ = super::new().chain( - Parallel::new( - op::map(|x: i32| x + 1), - Parallel::new( - op::map(|x: i32| x * 3), - Parallel::new( - op::map(|x: i32| format!("{} is the number!", x)), - op::map(|x: i32| x == 1), - ), - ), - ) - .map(|(r1, (r2, (r3, r4)))| (r1, r2, r3, r4)), - ); - } -} From 98af66578cb65bfebba31a9792649d11bc8ba5b5 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 15:29:23 -0500 Subject: [PATCH 09/26] style: cargo fmt --- rig-core/examples/multi_extract.rs | 15 ++++++++++----- rig-core/src/lib.rs | 2 +- rig-core/src/pipeline/agent_ops.rs | 16 +++++++++++----- rig-core/src/pipeline/parallel.rs | 1 - rig-core/src/pipeline/try_op.rs | 5 ++++- 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/rig-core/examples/multi_extract.rs b/rig-core/examples/multi_extract.rs index ccd7bc89..9bc212d4 100644 --- a/rig-core/examples/multi_extract.rs +++ b/rig-core/examples/multi_extract.rs @@ -1,4 +1,8 @@ -use rig::{pipeline::{self, agent_ops, TryOp}, providers::openai, try_parallel}; +use rig::{ + pipeline::{self, agent_ops, TryOp}, + providers::openai, + try_parallel, +}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -25,7 +29,6 @@ pub struct Sentiment { pub confidence: f64, } - #[tokio::main] async fn main() -> anyhow::Result<()> { let openai = openai::Client::from_env(); @@ -42,7 +45,9 @@ async fn main() -> anyhow::Result<()> { let sentiment_extractor = openai .extractor::("gpt-4") - .preamble("Extract sentiment (and how confident you are of the sentiment) from the given text.") + .preamble( + "Extract sentiment (and how confident you are of the sentiment) from the given text.", + ) .build(); let chain = pipeline::new() @@ -59,10 +64,10 @@ async fn main() -> anyhow::Result<()> { sentiment = sentiment.sentiment, ) }); - + let response = chain.try_call("Screw you Putin!").await?; println!("Text analysis:\n{response}"); Ok(()) -} \ No newline at end of file +} diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index cade9835..e90e3156 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -85,8 +85,8 @@ pub mod embeddings; pub mod extractor; pub(crate) mod json_utils; pub mod loaders; -pub mod pipeline; pub mod one_or_many; +pub mod pipeline; pub mod providers; pub mod tool; pub mod vector_store; diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index 1494119f..f51ce713 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -1,4 +1,8 @@ -use crate::{completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, vector_store}; +use crate::{ + completion::{self, CompletionModel}, + extractor::{ExtractionError, Extractor}, + vector_store, +}; use super::Op; @@ -92,7 +96,7 @@ where Prompt::new(prompt) } -pub struct Extract +pub struct Extract where M: CompletionModel, T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, @@ -101,13 +105,16 @@ where _in: std::marker::PhantomData, } -impl Extract +impl Extract where M: CompletionModel, T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, { pub fn new(extractor: Extractor) -> Self { - Self { extractor, _in: std::marker::PhantomData } + Self { + extractor, + _in: std::marker::PhantomData, + } } } @@ -134,7 +141,6 @@ where Extract::new(extractor) } - #[cfg(test)] pub mod tests { use super::*; diff --git a/rig-core/src/pipeline/parallel.rs b/rig-core/src/pipeline/parallel.rs index f790d72e..05c09fa0 100644 --- a/rig-core/src/pipeline/parallel.rs +++ b/rig-core/src/pipeline/parallel.rs @@ -255,7 +255,6 @@ macro_rules! try_parallel_internal { }) } - #[macro_export] macro_rules! try_parallel { ($($es:expr),+ $(,)?) => { diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index 61876ba4..b4f93483 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -65,7 +65,10 @@ pub trait TryOp: Send + Sync { OrElse::new(self, then(f)) } - fn chain_ok(self, op: T) -> impl TryOp + fn chain_ok( + self, + op: T, + ) -> impl TryOp where T: op::Op, Self: Sized, From 408b46c1a7a038cea0612c34d3e51dff8cdbb90b Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 15:31:49 -0500 Subject: [PATCH 10/26] test: fix typo in test --- rig-core/src/pipeline/parallel.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/pipeline/parallel.rs b/rig-core/src/pipeline/parallel.rs index 05c09fa0..b2943a1d 100644 --- a/rig-core/src/pipeline/parallel.rs +++ b/rig-core/src/pipeline/parallel.rs @@ -373,7 +373,7 @@ mod tests { ); let result = pipeline.call(1).await; - assert_eq!(result, (2, 3, "1 is the number!".to_string(), false)); + assert_eq!(result, (2, 3, "1 is the number!".to_string(), true)); } #[tokio::test] From a0a8534dd1f020a45274fc674c973b9de5c2b9d0 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 15:34:11 -0500 Subject: [PATCH 11/26] test: fix typo in test #2 --- rig-core/src/pipeline/parallel.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/pipeline/parallel.rs b/rig-core/src/pipeline/parallel.rs index b2943a1d..95907c31 100644 --- a/rig-core/src/pipeline/parallel.rs +++ b/rig-core/src/pipeline/parallel.rs @@ -373,7 +373,7 @@ mod tests { ); let result = pipeline.call(1).await; - assert_eq!(result, (2, 3, "1 is the number!".to_string(), true)); + assert_eq!(result, (1, 2, "1 is the number!".to_string(), true)); } #[tokio::test] From 90393f11d931e2bbaf20796fd13d48f612e23569 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 15:39:41 -0500 Subject: [PATCH 12/26] test: Fix broken lookup op test --- rig-core/src/pipeline/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index 2cefeb9d..09077bc9 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -235,7 +235,7 @@ mod tests { assert_eq!( result, - "User query: What is a flurbo?\n\nTop documents:\nbar" + "Top documents:\nbar" ); } From fc695152a2f9f0f9ee612924aa0d071e371f6f72 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 16:30:30 -0500 Subject: [PATCH 13/26] feat: Add `Op::batch_call` and `TryOp::try_batch_call` --- rig-core/examples/multi_extract.rs | 19 +++++++++++++++++-- rig-core/src/pipeline/mod.rs | 5 +---- rig-core/src/pipeline/op.rs | 18 ++++++++++++++++++ rig-core/src/pipeline/try_op.rs | 24 ++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/rig-core/examples/multi_extract.rs b/rig-core/examples/multi_extract.rs index 9bc212d4..99e3c2fa 100644 --- a/rig-core/examples/multi_extract.rs +++ b/rig-core/examples/multi_extract.rs @@ -50,6 +50,9 @@ async fn main() -> anyhow::Result<()> { ) .build(); + // Create a chain that extracts names, topics, and sentiment from a given text + // using three different GPT-4 based extractors. + // The chain will output a formatted string containing the extracted information. let chain = pipeline::new() .chain(try_parallel!( agent_ops::extract(names_extractor), @@ -65,9 +68,21 @@ async fn main() -> anyhow::Result<()> { ) }); - let response = chain.try_call("Screw you Putin!").await?; + // Batch call the chain with up to 4 inputs concurrently + let response = chain + .try_batch_call( + 4, + vec![ + "Screw you Putin!", + "I love my dog, but I hate my cat.", + "I'm going to the store to buy some milk.", + ], + ) + .await?; - println!("Text analysis:\n{response}"); + for response in response { + println!("Text analysis:\n{response}"); + } Ok(()) } diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index 09077bc9..98a0d449 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -233,10 +233,7 @@ mod tests { .await .expect("Failed to run chain"); - assert_eq!( - result, - "Top documents:\nbar" - ); + assert_eq!(result, "Top documents:\nbar"); } #[tokio::test] diff --git a/rig-core/src/pipeline/op.rs b/rig-core/src/pipeline/op.rs index 7208bc27..a4a6bc92 100644 --- a/rig-core/src/pipeline/op.rs +++ b/rig-core/src/pipeline/op.rs @@ -2,6 +2,7 @@ use std::future::Future; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::join; +use futures::{stream, StreamExt}; // ================================================================ // Core Op trait @@ -12,6 +13,23 @@ pub trait Op: Send + Sync { fn call(&self, input: Self::Input) -> impl Future + Send; + /// Execute the current pipeline with the given inputs. `n` is the number of concurrent + /// inputs that will be processed concurrently. + fn batch_call(&self, n: usize, input: I) -> impl Future> + Send + where + I: IntoIterator + Send, + I::IntoIter: Send, + Self: Sized, + { + async move { + stream::iter(input) + .map(|input| self.call(input)) + .buffered(n) + .collect() + .await + } + } + /// Chain a function to the current pipeline /// /// # Example diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index b4f93483..fa77229a 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -2,6 +2,7 @@ use std::future::Future; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::try_join; +use futures::{stream, StreamExt, TryStreamExt}; use super::op::{self, map, then}; @@ -18,6 +19,29 @@ pub trait TryOp: Send + Sync { input: Self::Input, ) -> impl Future> + Send; + /// Execute the current pipeline with the given inputs. `n` is the number of concurrent + /// inputs that will be processed concurrently. + /// If one of the inputs fails, the entire operation will fail and the error will + /// be returned. + fn try_batch_call( + &self, + n: usize, + input: I, + ) -> impl Future, Self::Error>> + Send + where + I: IntoIterator + Send, + I::IntoIter: Send, + Self: Sized, + { + async move { + stream::iter(input) + .map(|input| self.try_call(input)) + .buffered(n) + .try_collect() + .await + } + } + fn map_ok(self, f: F) -> impl op::Op> where F: Fn(Self::Output) -> T + Send + Sync, From c2c1996c20cd77747609007738842ef8d24eb18e Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 16:34:23 -0500 Subject: [PATCH 14/26] test: Fix tests --- rig-core/src/pipeline/parallel.rs | 2 +- rig-core/src/pipeline/try_op.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-core/src/pipeline/parallel.rs b/rig-core/src/pipeline/parallel.rs index 95907c31..6ab2b092 100644 --- a/rig-core/src/pipeline/parallel.rs +++ b/rig-core/src/pipeline/parallel.rs @@ -408,7 +408,7 @@ mod tests { ); let result = pipeline.try_call(1).await; - assert_eq!(result, Ok((2, 3, "1 is the number!".to_string(), false))); + assert_eq!(result, Ok((1, 2, "1 is the number!".to_string(), true))); } #[tokio::test] diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index fa77229a..b0885ae8 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -317,7 +317,7 @@ mod tests { #[tokio::test] async fn test_try_op() { let op = map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }); - let result = op.try_call(1).await.unwrap(); + let result = op.try_call(2).await.unwrap(); assert_eq!(result, 2); } From 56ca41c52a75454107987f0fe2f35bfc4974028f Mon Sep 17 00:00:00 2001 From: Christophe Date: Thu, 5 Dec 2024 14:59:30 -0500 Subject: [PATCH 15/26] docs: Add more docstrings to agent pipeline ops --- rig-core/src/pipeline/op.rs | 10 +-- rig-core/src/pipeline/try_op.rs | 129 +++++++++++++++++++++++++------- 2 files changed, 108 insertions(+), 31 deletions(-) diff --git a/rig-core/src/pipeline/op.rs b/rig-core/src/pipeline/op.rs index a4a6bc92..0ddf32bf 100644 --- a/rig-core/src/pipeline/op.rs +++ b/rig-core/src/pipeline/op.rs @@ -30,7 +30,7 @@ pub trait Op: Send + Sync { } } - /// Chain a function to the current pipeline + /// Chain a function `f` to the current op. /// /// # Example /// ```rust @@ -79,15 +79,15 @@ pub trait Op: Send + Sync { Sequential::new(self, then(f)) } - /// Chain an arbitrary operation to the current pipeline. + /// Chain an arbitrary operation to the current op. /// /// # Example /// ```rust /// use rig::pipeline::{self, Op}; /// - /// struct MyOp; + /// struct AddOne; /// - /// impl Op for MyOp { + /// impl Op for AddOne { /// type Input = i32; /// type Output = i32; /// @@ -97,7 +97,7 @@ pub trait Op: Send + Sync { /// } /// /// let chain = pipeline::new() - /// .chain(MyOp); + /// .chain(AddOne); /// /// let result = chain.call(1).await; /// assert_eq!(result, 2); diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index b0885ae8..e4204408 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -14,15 +14,28 @@ pub trait TryOp: Send + Sync { type Output: Send + Sync; type Error: Send + Sync; + /// Execute the current op with the given input. fn try_call( &self, input: Self::Input, ) -> impl Future> + Send; - /// Execute the current pipeline with the given inputs. `n` is the number of concurrent + /// Execute the current op with the given inputs. `n` is the number of concurrent /// inputs that will be processed concurrently. - /// If one of the inputs fails, the entire operation will fail and the error will + /// If the op fails for one of the inputs, the entire operation will fail and the error will /// be returned. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, TryOp}; + /// + /// let op = pipeline::new() + /// .map(|x: i32| if x % 2 == 0 { Ok(x + 1) } else { Err("x is odd") }); + /// + /// // Execute the pipeline concurrently with 2 inputs + /// let result = op.try_batch_call(2, vec![2, 4]).await; + /// assert_eq!(result, Ok(vec![3, 5])); + /// ``` fn try_batch_call( &self, n: usize, @@ -42,6 +55,20 @@ pub trait TryOp: Send + Sync { } } + /// Map the success return value (i.e., `Ok`) of the current op to a different value + /// using the provided closure. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, TryOp}; + /// + /// let op = pipeline::new() + /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) + /// .map_ok(|x| x * 2); + /// + /// let result = op.try_call(2).await; + /// assert_eq!(result, Ok(4)); + /// ``` fn map_ok(self, f: F) -> impl op::Op> where F: Fn(Self::Output) -> T + Send + Sync, @@ -51,10 +78,24 @@ pub trait TryOp: Send + Sync { MapOk::new(self, map(f)) } + /// Map the error return value (i.e., `Err`) of the current op to a different value + /// using the provided closure. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, TryOp}; + /// + /// let op = pipeline::new() + /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) + /// .map_err(|err| format!("Error: {}", err)); + /// + /// let result = op.try_call(1).await; + /// assert_eq!(result, Err("Error: x is odd".to_string())); + /// ``` fn map_err( self, f: F, - ) -> impl TryOp + ) -> impl op::Op> where F: Fn(Self::Error) -> E + Send + Sync, E: Send + Sync, @@ -63,6 +104,21 @@ pub trait TryOp: Send + Sync { MapErr::new(self, map(f)) } + /// Chain a function to the current op. The function will only be called + /// if the current op returns `Ok`. The function must return a `Future` with value + /// `Result` where `E` is the same type as the error type of the current. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, TryOp}; + /// + /// let op = pipeline::new() + /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) + /// .and_then(|x| async move { Ok(x * 2) }); + /// + /// let result = op.try_call(2).await; + /// assert_eq!(result, Ok(4)); + /// ``` fn and_then( self, f: F, @@ -76,6 +132,21 @@ pub trait TryOp: Send + Sync { AndThen::new(self, then(f)) } + /// Chain a function `f` to the current op. The function `f` will only be called + /// if the current op returns `Err`. `f` must return a `Future` with value + /// `Result` where `T` is the same type as the output type of the current op. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, TryOp}; + /// + /// let op = pipeline::new() + /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) + /// .or_else(|err| async move { Err(format!("Error: {}", err)) }); + /// + /// let result = op.try_call(1).await; + /// assert_eq!(result, Err("Error: x is odd".to_string())); + /// ``` fn or_else( self, f: F, @@ -89,6 +160,32 @@ pub trait TryOp: Send + Sync { OrElse::new(self, then(f)) } + /// Chain a new op `op` to the current op. The new op will be called with the success + /// return value of the current op (i.e.: `Ok` value). The chained op can be any type that + /// implements the `Op` trait. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, TryOp}; + /// + /// struct AddOne; + /// + /// impl Op for AddOne { + /// type Input = i32; + /// type Output = i32; + /// + /// async fn call(&self, input: Self::Input) -> Self::Output { + /// input + 1 + /// } + /// } + /// + /// let op = pipeline::new() + /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) + /// .chain_ok(MyOp); + /// + /// let result = op.try_call(2).await; + /// assert_eq!(result, Ok(3)); + /// ``` fn chain_ok( self, op: T, @@ -130,25 +227,6 @@ impl MapOk { } } -// Result -> Result -// impl TryOp for MapOk -// where -// Op1: TryOp, -// Op2: super::Op, -// { -// type Input = Op1::Input; -// type Output = Op2::Output; -// type Error = Op1::Error; - -// #[inline] -// async fn try_call(&self, input: Self::Input) -> Result { -// match self.prev.try_call(input).await { -// Ok(output) => Ok(self.op.call(output).await), -// Err(err) => Err(err), -// } -// } -// } - impl op::Op for MapOk where Op1: TryOp, @@ -178,17 +256,16 @@ impl MapErr { } // Result -> Result -impl TryOp for MapErr +impl op::Op for MapErr where Op1: TryOp, Op2: super::Op, { type Input = Op1::Input; - type Output = Op1::Output; - type Error = Op2::Output; + type Output = Result; #[inline] - async fn try_call(&self, input: Self::Input) -> Result { + async fn call(&self, input: Self::Input) -> Self::Output { match self.prev.try_call(input).await { Ok(output) => Ok(output), Err(err) => Err(self.op.call(err).await), From eb34c4bb584daf8e016ee6476bf2f8c867a2b54f Mon Sep 17 00:00:00 2001 From: Christophe Date: Thu, 5 Dec 2024 17:14:55 -0500 Subject: [PATCH 16/26] docs: Add pipeline module level docs --- rig-core/src/pipeline/mod.rs | 97 ++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index 98a0d449..e99be87a 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -1,3 +1,100 @@ +//! This module defines a flexible pipeline API for defining a sequence of operations that +//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc). +//! +//! The pipeline API was inspired by general orchestration pipelines such as Airflow, Dagster and Prefect, +//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along +//! general combinators. +//! +//! Pipelines are made up of one or more "ops", each of which must implement the [Op] trait. +//! The [Op] trait requires the implementation of only one method: `call`, which takes an input +//! and returns an output. The trait provides a wide range of combinators for chaining operations together. +//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and +//! the edges represent the data flow between operations. When invoking the pipeline on some input, +//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline) and +//! the result of the leaf node is returned as the result of the full pipeline. +//! +//! ## Basic Example +//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats +//! the result as a string using the [map](Op::map) combinator method, which applies a function to the +//! output of the previous op: +//! ```rust +//! use rig::pipeline::{self, Op}; +//! +//! let pipeline = pipeline::new() +//! // op1: add two numbers +//! .map(|(x, y)| x + y) +//! // op2: format result +//! .map(|z| format!("Result: {z}!")); +//! +//! let result = pipeline.call((1, 2)).await; +//! assert_eq!(result, "Result: 3!"); +//! ``` +//! +//! This pipeline can be visualized as the following DAG: +//! ```text +//! Input +//! │ +//! ▼ +//! ┌─────────┐ +//! │ op1 │ +//! └────┬────┘ +//! │ +//! ▼ +//! ┌─────────┐ +//! │ op2 │ +//! └────┬────┘ +//! │ +//! ▼ +//! Output +//! ``` +//! +//! ## Parallel Operations +//! The pipeline API also provides a [parallel!](crate::parallel!) and macro for running operations in parallel. +//! The macro takes a list of ops and turns them into a single op that will duplicate the input +//! and run each op in parallel. The results of each op are then collected and returned as a tuple. +//! +//! For example, the pipeline below runs two operations in parallel: +//! ```rust +//! use rig::{pipeline::{self, Op, map}, parallel}; +//! +//! let pipeline = pipeline::new() +//! .chain(parallel!( +//! // op1: add 1 to input +//! map(|x| x + 1), +//! // op2: subtract 1 from input +//! map(|x| x - 1), +//! )) +//! // op3: format results +//! .map(|(a, b)| format!("Results: {a}, {b}")); +//! +//! let result = pipeline.call(1).await; +//! assert_eq!(result, "Result: 2, 0"); +//! ``` +//! +//! Notes: +//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows +//! for chaining arbitrary operations, as long as they implement the [Op] trait. +//! - [map] is a function that initializes a standalone [Map](self::op::Map) op without an existing pipeline/op. +//! +//! The pipeline above can be visualized as the following DAG: +//! ```text +//! Input +//! │ +//! ┌──────┴──────┐ +//! ▼ ▼ +//! ┌─────────┐ ┌─────────┐ +//! │ op1 │ │ op2 │ +//! └────┬────┘ └────┬────┘ +//! └──────┬──────┘ +//! ▼ +//! ┌─────────┐ +//! │ op3 │ +//! └────┬────┘ +//! │ +//! ▼ +//! Output +//! ``` + pub mod agent_ops; pub mod op; pub mod try_op; From 64a75521750169cac1ff6dd3d943b6feeffe4240 Mon Sep 17 00:00:00 2001 From: Christophe Date: Thu, 5 Dec 2024 17:30:51 -0500 Subject: [PATCH 17/26] docs: improve pipeline docs --- rig-core/src/pipeline/mod.rs | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index e99be87a..155fd651 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -8,10 +8,13 @@ //! Pipelines are made up of one or more "ops", each of which must implement the [Op] trait. //! The [Op] trait requires the implementation of only one method: `call`, which takes an input //! and returns an output. The trait provides a wide range of combinators for chaining operations together. +//! //! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and //! the edges represent the data flow between operations. When invoking the pipeline on some input, -//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline) and -//! the result of the leaf node is returned as the result of the full pipeline. +//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline). +//! The output of each op is then passed to the next op in the pipeline until the output reaches the +//! leaf node (i.e.: the last op defined in the pipeline). The output of the leaf node is then returned +//! as the result of the pipeline. //! //! ## Basic Example //! For example, the pipeline below takes a tuple of two integers, adds them together and then formats @@ -32,28 +35,17 @@ //! //! This pipeline can be visualized as the following DAG: //! ```text -//! Input -//! │ -//! ▼ -//! ┌─────────┐ -//! │ op1 │ -//! └────┬────┘ -//! │ -//! ▼ -//! ┌─────────┐ -//! │ op2 │ -//! └────┬────┘ -//! │ -//! ▼ -//! Output +//! ┌─────────┐ ┌─────────┐ +//! Input───►│ op1 ├──►│ op2 ├──►Output +//! └─────────┘ └─────────┘ //! ``` //! //! ## Parallel Operations //! The pipeline API also provides a [parallel!](crate::parallel!) and macro for running operations in parallel. //! The macro takes a list of ops and turns them into a single op that will duplicate the input -//! and run each op in parallel. The results of each op are then collected and returned as a tuple. +//! and run each op in concurently. The results of each op are then collected and returned as a tuple. //! -//! For example, the pipeline below runs two operations in parallel: +//! For example, the pipeline below runs two operations concurently: //! ```rust //! use rig::{pipeline::{self, Op, map}, parallel}; //! From cbbd1cc87a20c014edb3e52c30d899278d9de0eb Mon Sep 17 00:00:00 2001 From: Christophe Date: Thu, 5 Dec 2024 17:54:17 -0500 Subject: [PATCH 18/26] style: clippy+fmt --- rig-core/src/pipeline/mod.rs | 137 ++++++++++++++++++++------------ rig-core/src/pipeline/try_op.rs | 48 +++++------ 2 files changed, 108 insertions(+), 77 deletions(-) diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index 155fd651..417ae762 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -1,54 +1,54 @@ -//! This module defines a flexible pipeline API for defining a sequence of operations that -//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc). -//! +//! This module defines a flexible pipeline API for defining a sequence of operations that +//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc). +//! //! The pipeline API was inspired by general orchestration pipelines such as Airflow, Dagster and Prefect, -//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along +//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along //! general combinators. -//! -//! Pipelines are made up of one or more "ops", each of which must implement the [Op] trait. +//! +//! Pipelines are made up of one or more operations, or "ops", each of which must implement the [Op] trait. //! The [Op] trait requires the implementation of only one method: `call`, which takes an input //! and returns an output. The trait provides a wide range of combinators for chaining operations together. -//! -//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and -//! the edges represent the data flow between operations. When invoking the pipeline on some input, +//! +//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and +//! the edges represent the data flow between operations. When invoking the pipeline on some input, //! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline). //! The output of each op is then passed to the next op in the pipeline until the output reaches the //! leaf node (i.e.: the last op defined in the pipeline). The output of the leaf node is then returned //! as the result of the pipeline. -//! +//! //! ## Basic Example -//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats -//! the result as a string using the [map](Op::map) combinator method, which applies a function to the -//! output of the previous op: +//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats +//! the result as a string using the [map](Op::map) combinator method, which applies a simple function +//! op to the output of the previous op: //! ```rust //! use rig::pipeline::{self, Op}; -//! +//! //! let pipeline = pipeline::new() //! // op1: add two numbers //! .map(|(x, y)| x + y) //! // op2: format result //! .map(|z| format!("Result: {z}!")); -//! +//! //! let result = pipeline.call((1, 2)).await; //! assert_eq!(result, "Result: 3!"); //! ``` -//! +//! //! This pipeline can be visualized as the following DAG: //! ```text //! ┌─────────┐ ┌─────────┐ //! Input───►│ op1 ├──►│ op2 ├──►Output //! └─────────┘ └─────────┘ //! ``` -//! +//! //! ## Parallel Operations //! The pipeline API also provides a [parallel!](crate::parallel!) and macro for running operations in parallel. -//! The macro takes a list of ops and turns them into a single op that will duplicate the input +//! The macro takes a list of ops and turns them into a single op that will duplicate the input //! and run each op in concurently. The results of each op are then collected and returned as a tuple. -//! +//! //! For example, the pipeline below runs two operations concurently: //! ```rust //! use rig::{pipeline::{self, Op, map}, parallel}; -//! +//! //! let pipeline = pipeline::new() //! .chain(parallel!( //! // op1: add 1 to input @@ -58,16 +58,16 @@ //! )) //! // op3: format results //! .map(|(a, b)| format!("Results: {a}, {b}")); -//! +//! //! let result = pipeline.call(1).await; //! assert_eq!(result, "Result: 2, 0"); //! ``` -//! -//! Notes: -//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows -//! for chaining arbitrary operations, as long as they implement the [Op] trait. +//! +//! Notes: +//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows +//! for chaining arbitrary operations, as long as they implement the [Op] trait. //! - [map] is a function that initializes a standalone [Map](self::op::Map) op without an existing pipeline/op. -//! +//! //! The pipeline above can be visualized as the following DAG: //! ```text //! Input @@ -95,28 +95,27 @@ pub mod parallel; use std::future::Future; -use agent_ops::{lookup, prompt as prompt_op}; pub use op::{map, passthrough, then, Op}; pub use try_op::TryOp; -use crate::{completion, vector_store}; +use crate::{completion, extractor::Extractor, vector_store}; pub struct PipelineBuilder { _error: std::marker::PhantomData, } impl PipelineBuilder { - /// Chain a function to the current pipeline + /// Add a function to the current pipeline /// /// # Example /// ```rust /// use rig::pipeline::{self, Op}; /// - /// let chain = pipeline::new() + /// let pipeline = pipeline::new() /// .map(|(x, y)| x + y) /// .map(|z| format!("Result: {z}!")); /// - /// let result = chain.call((1, 2)).await; + /// let result = pipeline.call((1, 2)).await; /// assert_eq!(result, "Result: 3!"); /// ``` pub fn map(self, f: F) -> impl Op @@ -135,7 +134,7 @@ impl PipelineBuilder { /// ```rust /// use rig::pipeline::{self, Op}; /// - /// let chain = pipeline::new() + /// let pipeline = pipeline::new() /// .then(|email: String| async move { /// email.split('@').next().unwrap().to_string() /// }) @@ -143,7 +142,7 @@ impl PipelineBuilder { /// format!("Hello, {}!", username) /// }); /// - /// let result = chain.call("bob@gmail.com".to_string()).await; + /// let result = pipeline.call("bob@gmail.com".to_string()).await; /// assert_eq!(result, "Hello, bob!"); /// ``` pub fn then(self, f: F) -> impl Op @@ -157,7 +156,7 @@ impl PipelineBuilder { then(f) } - /// Chain an arbitrary operation to the current pipeline. + /// Add an arbitrary operation to the current pipeline. /// /// # Example /// ```rust @@ -174,10 +173,10 @@ impl PipelineBuilder { /// } /// } /// - /// let chain = pipeline::new() + /// let pipeline = pipeline::new() /// .chain(MyOp); /// - /// let result = chain.call(1).await; + /// let result = pipeline.call(1).await; /// assert_eq!(result, 2); /// ``` pub fn chain(self, op: T) -> impl Op @@ -194,15 +193,15 @@ impl PipelineBuilder { /// /// # Example /// ```rust - /// use rig::chain::{self, Chain}; + /// use rig::pipeline::{self, Op}; /// - /// let chain = chain::new() + /// let pipeline = pipeline::new() /// .lookup(index, 2) - /// .chain(|(query, docs): (_, Vec)| async move { + /// .pipeline(|(query, docs): (_, Vec)| async move { /// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n")) /// }); /// - /// let result = chain.call("What is a flurbo?".to_string()).await; + /// let result = pipeline.call("What is a flurbo?".to_string()).await; /// ``` pub fn lookup( self, @@ -216,37 +215,66 @@ impl PipelineBuilder { // E: From + Send + Sync, Self: Sized, { - lookup(index, n) + agent_ops::lookup(index, n) } - /// Chain a prompt operation to the current chain. The prompt operation expects the - /// current chain to output a string. The prompt operation will use the string to prompt - /// the given agent (or any other type that implements the `Prompt` trait) and return + /// Add a prompt operation to the current pipeline/op. The prompt operation expects the + /// current pipeline to output a string. The prompt operation will use the string to prompt + /// the given `agent`, which must implements the [Prompt](completion::Prompt) trait and return /// the response. /// /// # Example /// ```rust - /// use rig::chain::{self, Chain}; + /// use rig::pipeline::{self, Op}; /// /// let agent = &openai_client.agent("gpt-4").build(); /// - /// let chain = chain::new() + /// let pipeline = pipeline::new() /// .map(|name| format!("Find funny nicknames for the following name: {name}!")) /// .prompt(agent); /// - /// let result = chain.call("Alice".to_string()).await; + /// let result = pipeline.call("Alice".to_string()).await; /// ``` - pub fn prompt( - self, - prompt: P, - ) -> impl Op> + pub fn prompt(self, agent: P) -> agent_ops::Prompt where P: completion::Prompt, In: Into + Send + Sync, // E: From + Send + Sync, Self: Sized, { - prompt_op(prompt) + agent_ops::prompt(agent) + } + + /// Add an extract operation to the current pipeline/op. The extract operation expects the + /// current pipeline to output a string. The extract operation will use the given `extractor` + /// to extract information from the string in the form of the type `T` and return it. + /// + /// # Example + /// ```rust + /// use rig::pipeline::{self, Op}; + /// + /// #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] + /// struct Sentiment { + /// /// The sentiment score of the text (0.0 = negative, 1.0 = positive) + /// score: f64, + /// } + /// + /// let extractor = &openai_client.extractor::("gpt-4").build(); + /// + /// let pipeline = pipeline::new() + /// .map(|text| format!("Analyze the sentiment of the following text: {text}!")) + /// .extract(extractor); + /// + /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?; + /// assert!(result.score > 0.5); + /// ``` + pub fn extract(self, extractor: Extractor) -> agent_ops::Extract + where + M: completion::CompletionModel, + T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + In: Into + Send + Sync, + { + agent_ops::extract(extractor) } } @@ -330,7 +358,10 @@ mod tests { let index = MockIndex; let chain = super::new() - .chain(parallel!(passthrough(), lookup::<_, _, Foo>(index, 1),)) + .chain(parallel!( + passthrough(), + agent_ops::lookup::<_, _, Foo>(index, 1), + )) .map(|(query, maybe_docs)| match maybe_docs { Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo), Err(err) => format!("Error: {}", err), diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index e4204408..e8d276e5 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -24,14 +24,14 @@ pub trait TryOp: Send + Sync { /// inputs that will be processed concurrently. /// If the op fails for one of the inputs, the entire operation will fail and the error will /// be returned. - /// + /// /// # Example /// ```rust /// use rig::pipeline::{self, TryOp}; - /// + /// /// let op = pipeline::new() /// .map(|x: i32| if x % 2 == 0 { Ok(x + 1) } else { Err("x is odd") }); - /// + /// /// // Execute the pipeline concurrently with 2 inputs /// let result = op.try_batch_call(2, vec![2, 4]).await; /// assert_eq!(result, Ok(vec![3, 5])); @@ -57,15 +57,15 @@ pub trait TryOp: Send + Sync { /// Map the success return value (i.e., `Ok`) of the current op to a different value /// using the provided closure. - /// + /// /// # Example /// ```rust /// use rig::pipeline::{self, TryOp}; - /// + /// /// let op = pipeline::new() /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) /// .map_ok(|x| x * 2); - /// + /// /// let result = op.try_call(2).await; /// assert_eq!(result, Ok(4)); /// ``` @@ -80,15 +80,15 @@ pub trait TryOp: Send + Sync { /// Map the error return value (i.e., `Err`) of the current op to a different value /// using the provided closure. - /// + /// /// # Example /// ```rust /// use rig::pipeline::{self, TryOp}; - /// + /// /// let op = pipeline::new() /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) /// .map_err(|err| format!("Error: {}", err)); - /// + /// /// let result = op.try_call(1).await; /// assert_eq!(result, Err("Error: x is odd".to_string())); /// ``` @@ -104,18 +104,18 @@ pub trait TryOp: Send + Sync { MapErr::new(self, map(f)) } - /// Chain a function to the current op. The function will only be called - /// if the current op returns `Ok`. The function must return a `Future` with value - /// `Result` where `E` is the same type as the error type of the current. - /// + /// Chain a function to the current op. The function will only be called + /// if the current op returns `Ok`. The function must return a `Future` with value + /// `Result` where `E` is the same type as the error type of the current. + /// /// # Example /// ```rust /// use rig::pipeline::{self, TryOp}; - /// + /// /// let op = pipeline::new() /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) /// .and_then(|x| async move { Ok(x * 2) }); - /// + /// /// let result = op.try_call(2).await; /// assert_eq!(result, Ok(4)); /// ``` @@ -135,15 +135,15 @@ pub trait TryOp: Send + Sync { /// Chain a function `f` to the current op. The function `f` will only be called /// if the current op returns `Err`. `f` must return a `Future` with value /// `Result` where `T` is the same type as the output type of the current op. - /// + /// /// # Example /// ```rust /// use rig::pipeline::{self, TryOp}; - /// + /// /// let op = pipeline::new() /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) /// .or_else(|err| async move { Err(format!("Error: {}", err)) }); - /// + /// /// let result = op.try_call(1).await; /// assert_eq!(result, Err("Error: x is odd".to_string())); /// ``` @@ -163,26 +163,26 @@ pub trait TryOp: Send + Sync { /// Chain a new op `op` to the current op. The new op will be called with the success /// return value of the current op (i.e.: `Ok` value). The chained op can be any type that /// implements the `Op` trait. - /// + /// /// # Example /// ```rust /// use rig::pipeline::{self, TryOp}; - /// + /// /// struct AddOne; - /// + /// /// impl Op for AddOne { /// type Input = i32; /// type Output = i32; - /// + /// /// async fn call(&self, input: Self::Input) -> Self::Output { /// input + 1 /// } /// } - /// + /// /// let op = pipeline::new() /// .map(|x: i32| if x % 2 == 0 { Ok(x) } else { Err("x is odd") }) /// .chain_ok(MyOp); - /// + /// /// let result = op.try_call(2).await; /// assert_eq!(result, Ok(3)); /// ``` From 16970bc06293b3dc422923ad77b254e18b3ea87f Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 6 Dec 2024 13:31:10 -0500 Subject: [PATCH 19/26] fix(pipelines): Type errors --- rig-core/src/pipeline/agent_ops.rs | 34 ++++----- rig-core/src/pipeline/mod.rs | 44 ++++++------ rig-core/src/pipeline/op.rs | 106 ++++++++++++++++++----------- rig-core/src/pipeline/try_op.rs | 65 +++++++++--------- 4 files changed, 137 insertions(+), 112 deletions(-) diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index f51ce713..c7562737 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -17,7 +17,7 @@ impl Lookup where I: vector_store::VectorStoreIndex, { - pub fn new(index: I, n: usize) -> Self { + pub(crate) fn new(index: I, n: usize) -> Self { Self { index, n, @@ -66,7 +66,7 @@ pub struct Prompt { } impl Prompt { - pub fn new(prompt: P) -> Self { + pub(crate) fn new(prompt: P) -> Self { Self { prompt, _in: std::marker::PhantomData, @@ -96,21 +96,21 @@ where Prompt::new(prompt) } -pub struct Extract +pub struct Extract where M: CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, { - extractor: Extractor, - _in: std::marker::PhantomData, + extractor: Extractor, + _in: std::marker::PhantomData, } -impl Extract +impl Extract where M: CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, { - pub fn new(extractor: Extractor) -> Self { + pub(crate) fn new(extractor: Extractor) -> Self { Self { extractor, _in: std::marker::PhantomData, @@ -118,25 +118,25 @@ where } } -impl Op for Extract +impl Op for Extract where M: CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, - In: Into + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Input: Into + Send + Sync, { - type Input = In; - type Output = Result; + type Input = Input; + type Output = Result; async fn call(&self, input: Self::Input) -> Self::Output { self.extractor.extract(&input.into()).await } } -pub fn extract(extractor: Extractor) -> Extract +pub fn extract(extractor: Extractor) -> Extract where M: CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, - In: Into + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Input: Into + Send + Sync, { Extract::new(extractor) } diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index 417ae762..27ea5b1f 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -118,14 +118,14 @@ impl PipelineBuilder { /// let result = pipeline.call((1, 2)).await; /// assert_eq!(result, "Result: 3!"); /// ``` - pub fn map(self, f: F) -> impl Op + pub fn map(self, f: F) -> op::Map where - F: Fn(In) -> T + Send + Sync, - In: Send + Sync, - T: Send + Sync, + F: Fn(Input) -> Output + Send + Sync, + Input: Send + Sync, + Output: Send + Sync, Self: Sized, { - map(f) + op::Map::new(f) } /// Same as `map` but for asynchronous functions @@ -145,15 +145,15 @@ impl PipelineBuilder { /// let result = pipeline.call("bob@gmail.com".to_string()).await; /// assert_eq!(result, "Hello, bob!"); /// ``` - pub fn then(self, f: F) -> impl Op + pub fn then(self, f: F) -> op::Then where - F: Fn(In) -> Fut + Send + Sync, - In: Send + Sync, + F: Fn(Input) -> Fut + Send + Sync, + Input: Send + Sync, Fut: Future + Send + Sync, Fut::Output: Send + Sync, Self: Sized, { - then(f) + op::Then::new(f) } /// Add an arbitrary operation to the current pipeline. @@ -179,7 +179,7 @@ impl PipelineBuilder { /// let result = pipeline.call(1).await; /// assert_eq!(result, 2); /// ``` - pub fn chain(self, op: T) -> impl Op + pub fn chain(self, op: T) -> T where T: Op, Self: Sized, @@ -203,19 +203,19 @@ impl PipelineBuilder { /// /// let result = pipeline.call("What is a flurbo?".to_string()).await; /// ``` - pub fn lookup( + pub fn lookup( self, index: I, n: usize, - ) -> impl Op, vector_store::VectorStoreError>> + ) -> agent_ops::Lookup where I: vector_store::VectorStoreIndex, - T: Send + Sync + for<'a> serde::Deserialize<'a>, - In: Into + Send + Sync, + Output: Send + Sync + for<'a> serde::Deserialize<'a>, + Input: Into + Send + Sync, // E: From + Send + Sync, Self: Sized, { - agent_ops::lookup(index, n) + agent_ops::Lookup::new(index, n) } /// Add a prompt operation to the current pipeline/op. The prompt operation expects the @@ -235,14 +235,14 @@ impl PipelineBuilder { /// /// let result = pipeline.call("Alice".to_string()).await; /// ``` - pub fn prompt(self, agent: P) -> agent_ops::Prompt + pub fn prompt(self, agent: P) -> agent_ops::Prompt where P: completion::Prompt, - In: Into + Send + Sync, + Input: Into + Send + Sync, // E: From + Send + Sync, Self: Sized, { - agent_ops::prompt(agent) + agent_ops::Prompt::new(agent) } /// Add an extract operation to the current pipeline/op. The extract operation expects the @@ -268,13 +268,13 @@ impl PipelineBuilder { /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?; /// assert!(result.score > 0.5); /// ``` - pub fn extract(self, extractor: Extractor) -> agent_ops::Extract + pub fn extract(self, extractor: Extractor) -> agent_ops::Extract where M: completion::CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, - In: Into + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Input: Into + Send + Sync, { - agent_ops::extract(extractor) + agent_ops::Extract::new(extractor) } } diff --git a/rig-core/src/pipeline/op.rs b/rig-core/src/pipeline/op.rs index 0ddf32bf..96ee67e8 100644 --- a/rig-core/src/pipeline/op.rs +++ b/rig-core/src/pipeline/op.rs @@ -2,7 +2,7 @@ use std::future::Future; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::join; -use futures::{stream, StreamExt}; +use futures::stream; // ================================================================ // Core Op trait @@ -21,6 +21,8 @@ pub trait Op: Send + Sync { I::IntoIter: Send, Self: Sized, { + use futures::stream::StreamExt; + async move { stream::iter(input) .map(|input| self.call(input)) @@ -43,13 +45,13 @@ pub trait Op: Send + Sync { /// let result = chain.call((1, 2)).await; /// assert_eq!(result, "Result: 3!"); /// ``` - fn map(self, f: F) -> impl Op + fn map(self, f: F) -> Sequential> where - F: Fn(Self::Output) -> T + Send + Sync, - T: Send + Sync, + F: Fn(Self::Output) -> Input + Send + Sync, + Input: Send + Sync, Self: Sized, { - Sequential::new(self, map(f)) + Sequential::new(self, Map::new(f)) } /// Same as `map` but for asynchronous functions @@ -69,14 +71,14 @@ pub trait Op: Send + Sync { /// let result = chain.call("bob@gmail.com".to_string()).await; /// assert_eq!(result, "Hello, bob!"); /// ``` - fn then(self, f: F) -> impl Op + fn then(self, f: F) -> Sequential> where F: Fn(Self::Output) -> Fut + Send + Sync, Fut: Future + Send + Sync, Fut::Output: Send + Sync, Self: Sized, { - Sequential::new(self, then(f)) + Sequential::new(self, Then::new(f)) } /// Chain an arbitrary operation to the current op. @@ -102,7 +104,7 @@ pub trait Op: Send + Sync { /// let result = chain.call(1).await; /// assert_eq!(result, 2); /// ``` - fn chain(self, op: T) -> impl Op + fn chain(self, op: T) -> Sequential where T: Op, Self: Sized, @@ -126,14 +128,14 @@ pub trait Op: Send + Sync { /// /// let result = chain.call("What is a flurbo?".to_string()).await; /// ``` - fn lookup( + fn lookup( self, index: I, n: usize, - ) -> impl Op, vector_store::VectorStoreError>> + ) -> Sequential> where I: vector_store::VectorStoreIndex, - T: Send + Sync + for<'a> serde::Deserialize<'a>, + Input: Send + Sync + for<'a> serde::Deserialize<'a>, Self::Output: Into, Self: Sized, { @@ -160,7 +162,7 @@ pub trait Op: Send + Sync { fn prompt

( self, prompt: P, - ) -> impl Op> + ) -> Sequential> where P: completion::Prompt, Self::Output: Into, @@ -189,7 +191,7 @@ pub struct Sequential { } impl Sequential { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } @@ -216,13 +218,13 @@ use super::agent_ops::{Lookup, Prompt}; // ================================================================ // Core Op implementations // ================================================================ -pub struct Map { +pub struct Map { f: F, - _t: std::marker::PhantomData, + _t: std::marker::PhantomData, } -impl Map { - pub fn new(f: F) -> Self { +impl Map { + pub(crate) fn new(f: F) -> Self { Self { f, _t: std::marker::PhantomData, @@ -230,14 +232,14 @@ impl Map { } } -impl Op for Map +impl Op for Map where - F: Fn(T) -> Out + Send + Sync, - T: Send + Sync, - Out: Send + Sync, + F: Fn(Input) -> Output + Send + Sync, + Input: Send + Sync, + Output: Send + Sync, { - type Input = T; - type Output = Out; + type Input = Input; + type Output = Output; #[inline] async fn call(&self, input: Self::Input) -> Self::Output { @@ -245,29 +247,53 @@ where } } -pub fn map(f: F) -> impl Op +pub fn map(f: F) -> Map where - F: Fn(T) -> Out + Send + Sync, - T: Send + Sync, - Out: Send + Sync, + F: Fn(Input) -> Output + Send + Sync, + Input: Send + Sync, + Output: Send + Sync, { Map::new(f) } -pub fn passthrough() -> impl Op +pub struct Passthrough { + _t: std::marker::PhantomData, +} + +impl Passthrough { + pub(crate) fn new() -> Self { + Self { + _t: std::marker::PhantomData, + } + } +} + +impl Op for Passthrough +where + T: Send + Sync, +{ + type Input = T; + type Output = T; + + async fn call(&self, input: Self::Input) -> Self::Output { + input + } +} + +pub fn passthrough() -> Passthrough where T: Send + Sync, { - Map::new(|x| x) + Passthrough::new() } -pub struct Then { +pub struct Then { f: F, - _t: std::marker::PhantomData, + _t: std::marker::PhantomData, } -impl Then { - fn new(f: F) -> Self { +impl Then { + pub(crate) fn new(f: F) -> Self { Self { f, _t: std::marker::PhantomData, @@ -275,14 +301,14 @@ impl Then { } } -impl Op for Then +impl Op for Then where - F: Fn(T) -> Fut + Send + Sync, - T: Send + Sync, + F: Fn(Input) -> Fut + Send + Sync, + Input: Send + Sync, Fut: Future + Send, Fut::Output: Send + Sync, { - type Input = T; + type Input = Input; type Output = Fut::Output; #[inline] @@ -291,10 +317,10 @@ where } } -pub fn then(f: F) -> impl Op +pub fn then(f: F) -> Then where - F: Fn(T) -> Fut + Send + Sync, - T: Send + Sync, + F: Fn(Input) -> Fut + Send + Sync, + Input: Send + Sync, Fut: Future + Send, Fut::Output: Send + Sync, { diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index e8d276e5..76b41e91 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -2,9 +2,9 @@ use std::future::Future; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::try_join; -use futures::{stream, StreamExt, TryStreamExt}; +use futures::stream; -use super::op::{self, map, then}; +use super::op::{self}; // ================================================================ // Core TryOp trait @@ -46,6 +46,8 @@ pub trait TryOp: Send + Sync { I::IntoIter: Send, Self: Sized, { + use stream::{StreamExt, TryStreamExt}; + async move { stream::iter(input) .map(|input| self.try_call(input)) @@ -69,13 +71,13 @@ pub trait TryOp: Send + Sync { /// let result = op.try_call(2).await; /// assert_eq!(result, Ok(4)); /// ``` - fn map_ok(self, f: F) -> impl op::Op> + fn map_ok(self, f: F) -> MapOk> where - F: Fn(Self::Output) -> T + Send + Sync, - T: Send + Sync, + F: Fn(Self::Output) -> Output + Send + Sync, + Output: Send + Sync, Self: Sized, { - MapOk::new(self, map(f)) + MapOk::new(self, op::Map::new(f)) } /// Map the error return value (i.e., `Err`) of the current op to a different value @@ -95,13 +97,13 @@ pub trait TryOp: Send + Sync { fn map_err( self, f: F, - ) -> impl op::Op> + ) -> MapErr> where F: Fn(Self::Error) -> E + Send + Sync, E: Send + Sync, Self: Sized, { - MapErr::new(self, map(f)) + MapErr::new(self, op::Map::new(f)) } /// Chain a function to the current op. The function will only be called @@ -119,17 +121,17 @@ pub trait TryOp: Send + Sync { /// let result = op.try_call(2).await; /// assert_eq!(result, Ok(4)); /// ``` - fn and_then( + fn and_then( self, f: F, - ) -> impl TryOp + ) -> AndThen> where F: Fn(Self::Output) -> Fut + Send + Sync, - Fut: Future> + Send + Sync, - T: Send + Sync, + Fut: Future> + Send + Sync, + Output: Send + Sync, Self: Sized, { - AndThen::new(self, then(f)) + AndThen::new(self, op::Then::new(f)) } /// Chain a function `f` to the current op. The function `f` will only be called @@ -150,14 +152,14 @@ pub trait TryOp: Send + Sync { fn or_else( self, f: F, - ) -> impl TryOp + ) -> OrElse> where F: Fn(Self::Error) -> Fut + Send + Sync, Fut: Future> + Send + Sync, E: Send + Sync, Self: Sized, { - OrElse::new(self, then(f)) + OrElse::new(self, op::Then::new(f)) } /// Chain a new op `op` to the current op. The new op will be called with the success @@ -189,7 +191,7 @@ pub trait TryOp: Send + Sync { fn chain_ok( self, op: T, - ) -> impl TryOp + ) -> TrySequential where T: op::Op, Self: Sized, @@ -222,7 +224,7 @@ pub struct MapOk { } impl MapOk { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } @@ -250,7 +252,7 @@ pub struct MapErr { } impl MapErr { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } @@ -279,22 +281,21 @@ pub struct AndThen { } impl AndThen { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } -impl TryOp for AndThen +impl op::Op for AndThen where Op1: TryOp, Op2: TryOp, { type Input = Op1::Input; - type Output = Op2::Output; - type Error = Op1::Error; + type Output = Result; #[inline] - async fn try_call(&self, input: Self::Input) -> Result { + async fn call(&self, input: Self::Input) -> Self::Output { let output = self.prev.try_call(input).await?; self.op.try_call(output).await } @@ -306,22 +307,21 @@ pub struct OrElse { } impl OrElse { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } -impl TryOp for OrElse +impl op::Op for OrElse where Op1: TryOp, Op2: TryOp, { type Input = Op1::Input; - type Output = Op1::Output; - type Error = Op2::Error; + type Output = Result; #[inline] - async fn try_call(&self, input: Self::Input) -> Result { + async fn call(&self, input: Self::Input) -> Self::Output { match self.prev.try_call(input).await { Ok(output) => Ok(output), Err(err) => self.op.try_call(err).await, @@ -335,22 +335,21 @@ pub struct TrySequential { } impl TrySequential { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } -impl TryOp for TrySequential +impl op::Op for TrySequential where Op1: TryOp, Op2: op::Op, { type Input = Op1::Input; - type Output = Op2::Output; - type Error = Op1::Error; + type Output = Result; #[inline] - async fn try_call(&self, input: Self::Input) -> Result { + async fn call(&self, input: Self::Input) -> Self::Output { match self.prev.try_call(input).await { Ok(output) => Ok(self.op.call(output).await), Err(err) => Err(err), From 758f6213c1249d321e928d417e2f44cceac839a5 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 6 Dec 2024 13:45:35 -0500 Subject: [PATCH 20/26] fix: Missing trait import in macros --- rig-core/src/pipeline/parallel.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rig-core/src/pipeline/parallel.rs b/rig-core/src/pipeline/parallel.rs index 6ab2b092..76da1520 100644 --- a/rig-core/src/pipeline/parallel.rs +++ b/rig-core/src/pipeline/parallel.rs @@ -118,6 +118,8 @@ macro_rules! parallel_internal { ] munching: [] ) => ({ + use $crate::pipeline::op::Op; + $crate::parallel_op!($($val),*) .map(|output| { ($( @@ -243,6 +245,7 @@ macro_rules! try_parallel_internal { ] munching: [] ) => ({ + use $crate::pipeline::try_op::TryOp; $crate::parallel_op!($($val),*) .map_ok(|output| { ($( From 8015b69fe0e185b105ea8641e95fe3e4378dcdef Mon Sep 17 00:00:00 2001 From: Christophe Date: Mon, 16 Dec 2024 10:30:03 -0500 Subject: [PATCH 21/26] feat(pipeline): Add id and score to `lookup` op result --- rig-core/examples/chain.rs | 2 +- rig-core/src/pipeline/agent_ops.rs | 11 ++++------- rig-core/src/pipeline/mod.rs | 4 ++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/rig-core/examples/chain.rs b/rig-core/examples/chain.rs index 1fa89e4c..d147ac9a 100644 --- a/rig-core/examples/chain.rs +++ b/rig-core/examples/chain.rs @@ -52,7 +52,7 @@ async fn main() -> Result<(), anyhow::Error> { .map(|(prompt, maybe_docs)| match maybe_docs { Ok(docs) => format!( "Non standard word definitions:\n{}\n\n{}", - docs.join("\n"), + docs.into_iter().map(|(_, _, doc)| doc).collect::>().join("\n"), prompt, ), Err(err) => { diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index c7562737..8f484c7e 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -1,7 +1,5 @@ use crate::{ - completion::{self, CompletionModel}, - extractor::{ExtractionError, Extractor}, - vector_store, + completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, vector_store }; use super::Op; @@ -34,7 +32,7 @@ where T: Send + Sync + for<'a> serde::Deserialize<'a>, { type Input = In; - type Output = Result, vector_store::VectorStoreError>; + type Output = Result, vector_store::VectorStoreError>; async fn call(&self, input: Self::Input) -> Self::Output { let query: String = input.into(); @@ -44,7 +42,6 @@ where .top_n::(&query, self.n) .await? .into_iter() - .map(|(_, _, doc)| doc) .collect(); Ok(docs) @@ -193,9 +190,9 @@ pub mod tests { let result = lookup.call("query".to_string()).await.unwrap(); assert_eq!( result, - vec![Foo { + vec![(1.0, "doc1".to_string(), Foo { foo: "bar".to_string() - }] + })] ); } diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index 27ea5b1f..c3f25049 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -343,7 +343,7 @@ mod tests { let chain = super::new() .lookup::<_, _, Foo>(index, 1) - .map_ok(|docs| format!("Top documents:\n{}", docs[0].foo)); + .map_ok(|docs| format!("Top documents:\n{}", docs[0].2.foo)); let result = chain .try_call("What is a flurbo?") @@ -363,7 +363,7 @@ mod tests { agent_ops::lookup::<_, _, Foo>(index, 1), )) .map(|(query, maybe_docs)| match maybe_docs { - Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo), + Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].2.foo), Err(err) => format!("Error: {}", err), }) .prompt(MockModel); From bebbbca7bcf5a39e156beaf7cd9f77e75f739bdc Mon Sep 17 00:00:00 2001 From: Christophe Date: Mon, 16 Dec 2024 16:55:25 -0500 Subject: [PATCH 22/26] docs(pipelines): Add more docstrings --- rig-core/src/pipeline/agent_ops.rs | 28 ++++++++++++++++++++++------ rig-core/src/pipeline/mod.rs | 11 +++++------ rig-core/src/pipeline/op.rs | 11 ++++------- rig-core/src/pipeline/try_op.rs | 24 ++++++------------------ 4 files changed, 37 insertions(+), 37 deletions(-) diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index 8f484c7e..8274158d 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -1,5 +1,7 @@ use crate::{ - completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, vector_store + completion::{self, CompletionModel}, + extractor::{ExtractionError, Extractor}, + vector_store, }; use super::Op; @@ -48,6 +50,10 @@ where } } +/// Create a new lookup operation. +/// +/// The op will perform semantic search on the provided index and return the top `n` +/// results closest results to the input. pub fn lookup(index: I, n: usize) -> Lookup where I: vector_store::VectorStoreIndex, @@ -85,12 +91,15 @@ where } } -pub fn prompt(prompt: P) -> Prompt +/// Create a new prompt operation. +/// +/// The op will prompt the `model` with the input and return the response. +pub fn prompt(model: P) -> Prompt where P: completion::Prompt, In: Into + Send + Sync, { - Prompt::new(prompt) + Prompt::new(model) } pub struct Extract @@ -129,6 +138,9 @@ where } } +/// Create a new extract operation. +/// +/// The op will extract the structured data from the input using the provided `extractor`. pub fn extract(extractor: Extractor) -> Extract where M: CompletionModel, @@ -190,9 +202,13 @@ pub mod tests { let result = lookup.call("query".to_string()).await.unwrap(); assert_eq!( result, - vec![(1.0, "doc1".to_string(), Foo { - foo: "bar".to_string() - })] + vec![( + 1.0, + "doc1".to_string(), + Foo { + foo: "bar".to_string() + } + )] ); } diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index c3f25049..2c5c195c 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -203,11 +203,7 @@ impl PipelineBuilder { /// /// let result = pipeline.call("What is a flurbo?".to_string()).await; /// ``` - pub fn lookup( - self, - index: I, - n: usize, - ) -> agent_ops::Lookup + pub fn lookup(self, index: I, n: usize) -> agent_ops::Lookup where I: vector_store::VectorStoreIndex, Output: Send + Sync + for<'a> serde::Deserialize<'a>, @@ -268,7 +264,10 @@ impl PipelineBuilder { /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?; /// assert!(result.score > 0.5); /// ``` - pub fn extract(self, extractor: Extractor) -> agent_ops::Extract + pub fn extract( + self, + extractor: Extractor, + ) -> agent_ops::Extract where M: completion::CompletionModel, Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, diff --git a/rig-core/src/pipeline/op.rs b/rig-core/src/pipeline/op.rs index 96ee67e8..fda072dc 100644 --- a/rig-core/src/pipeline/op.rs +++ b/rig-core/src/pipeline/op.rs @@ -22,7 +22,7 @@ pub trait Op: Send + Sync { Self: Sized, { use futures::stream::StreamExt; - + async move { stream::iter(input) .map(|input| self.call(input)) @@ -159,10 +159,7 @@ pub trait Op: Send + Sync { /// /// let result = chain.call("Alice".to_string()).await; /// ``` - fn prompt

( - self, - prompt: P, - ) -> Sequential> + fn prompt

(self, prompt: P) -> Sequential> where P: completion::Prompt, Self::Output: Into, @@ -268,8 +265,8 @@ impl Passthrough { } } -impl Op for Passthrough -where +impl Op for Passthrough +where T: Send + Sync, { type Input = T; diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index 76b41e91..6e37af9a 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -1,8 +1,8 @@ use std::future::Future; +use futures::stream; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::try_join; -use futures::stream; use super::op::{self}; @@ -47,7 +47,7 @@ pub trait TryOp: Send + Sync { Self: Sized, { use stream::{StreamExt, TryStreamExt}; - + async move { stream::iter(input) .map(|input| self.try_call(input)) @@ -94,10 +94,7 @@ pub trait TryOp: Send + Sync { /// let result = op.try_call(1).await; /// assert_eq!(result, Err("Error: x is odd".to_string())); /// ``` - fn map_err( - self, - f: F, - ) -> MapErr> + fn map_err(self, f: F) -> MapErr> where F: Fn(Self::Error) -> E + Send + Sync, E: Send + Sync, @@ -121,10 +118,7 @@ pub trait TryOp: Send + Sync { /// let result = op.try_call(2).await; /// assert_eq!(result, Ok(4)); /// ``` - fn and_then( - self, - f: F, - ) -> AndThen> + fn and_then(self, f: F) -> AndThen> where F: Fn(Self::Output) -> Fut + Send + Sync, Fut: Future> + Send + Sync, @@ -149,10 +143,7 @@ pub trait TryOp: Send + Sync { /// let result = op.try_call(1).await; /// assert_eq!(result, Err("Error: x is odd".to_string())); /// ``` - fn or_else( - self, - f: F, - ) -> OrElse> + fn or_else(self, f: F) -> OrElse> where F: Fn(Self::Error) -> Fut + Send + Sync, Fut: Future> + Send + Sync, @@ -188,10 +179,7 @@ pub trait TryOp: Send + Sync { /// let result = op.try_call(2).await; /// assert_eq!(result, Ok(3)); /// ``` - fn chain_ok( - self, - op: T, - ) -> TrySequential + fn chain_ok(self, op: T) -> TrySequential where T: op::Op, Self: Sized, From e4bd652f1bce7324cef03eb4af8506dc04d57b96 Mon Sep 17 00:00:00 2001 From: Christophe Date: Mon, 16 Dec 2024 17:07:03 -0500 Subject: [PATCH 23/26] docs(pipelines): Update example --- rig-core/examples/chain.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/rig-core/examples/chain.rs b/rig-core/examples/chain.rs index d147ac9a..0c2c5c65 100644 --- a/rig-core/examples/chain.rs +++ b/rig-core/examples/chain.rs @@ -52,7 +52,10 @@ async fn main() -> Result<(), anyhow::Error> { .map(|(prompt, maybe_docs)| match maybe_docs { Ok(docs) => format!( "Non standard word definitions:\n{}\n\n{}", - docs.into_iter().map(|(_, _, doc)| doc).collect::>().join("\n"), + docs.into_iter() + .map(|(_, _, doc)| doc) + .collect::>() + .join("\n"), prompt, ), Err(err) => { From 9b2fa787ad95a0b71692af341539e2e2735464dd Mon Sep 17 00:00:00 2001 From: Christophe Date: Mon, 16 Dec 2024 17:13:27 -0500 Subject: [PATCH 24/26] test(mongodb): Fix flaky test again --- rig-mongodb/tests/integration_tests.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-mongodb/tests/integration_tests.rs b/rig-mongodb/tests/integration_tests.rs index 9b340ae2..3dc10e6c 100644 --- a/rig-mongodb/tests/integration_tests.rs +++ b/rig-mongodb/tests/integration_tests.rs @@ -71,7 +71,7 @@ async fn vector_search_test() { .await .unwrap(); - sleep(Duration::from_secs(15)).await; + sleep(Duration::from_secs(30)).await; // Query the index let results = index From bf52fd901f7b40e21762235f76ee6cb8576d025d Mon Sep 17 00:00:00 2001 From: Christophe Date: Mon, 16 Dec 2024 17:14:08 -0500 Subject: [PATCH 25/26] style: fmt --- rig-core/src/pipeline/agent_ops.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index 8274158d..320eed74 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -51,8 +51,8 @@ where } /// Create a new lookup operation. -/// -/// The op will perform semantic search on the provided index and return the top `n` +/// +/// The op will perform semantic search on the provided index and return the top `n` /// results closest results to the input. pub fn lookup(index: I, n: usize) -> Lookup where @@ -92,7 +92,7 @@ where } /// Create a new prompt operation. -/// +/// /// The op will prompt the `model` with the input and return the response. pub fn prompt(model: P) -> Prompt where @@ -139,7 +139,7 @@ where } /// Create a new extract operation. -/// +/// /// The op will extract the structured data from the input using the provided `extractor`. pub fn extract(extractor: Extractor) -> Extract where From f780473a07042b1d7af5d0362e64a285b8ba1160 Mon Sep 17 00:00:00 2001 From: Christophe Date: Mon, 16 Dec 2024 17:16:15 -0500 Subject: [PATCH 26/26] test(mongodb): fix --- rig-mongodb/tests/integration_tests.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rig-mongodb/tests/integration_tests.rs b/rig-mongodb/tests/integration_tests.rs index 3dc10e6c..78e8ea71 100644 --- a/rig-mongodb/tests/integration_tests.rs +++ b/rig-mongodb/tests/integration_tests.rs @@ -59,6 +59,9 @@ async fn vector_search_test() { collection.insert_many(embeddings).await.unwrap(); + // Wait for the new documents to be indexed + sleep(Duration::from_secs(30)).await; + // Create a vector index on our vector store. // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. // IMPORTANT: Reuse the same model that was used to generate the embeddings @@ -71,8 +74,6 @@ async fn vector_search_test() { .await .unwrap(); - sleep(Duration::from_secs(30)).await; - // Query the index let results = index .top_n::("What is a linglingdong?", 1)