From fc695152a2f9f0f9ee612924aa0d071e371f6f72 Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 29 Nov 2024 16:30:30 -0500 Subject: [PATCH] feat: Add `Op::batch_call` and `TryOp::try_batch_call` --- rig-core/examples/multi_extract.rs | 19 +++++++++++++++++-- rig-core/src/pipeline/mod.rs | 5 +---- rig-core/src/pipeline/op.rs | 18 ++++++++++++++++++ rig-core/src/pipeline/try_op.rs | 24 ++++++++++++++++++++++++ 4 files changed, 60 insertions(+), 6 deletions(-) diff --git a/rig-core/examples/multi_extract.rs b/rig-core/examples/multi_extract.rs index 9bc212d4..99e3c2fa 100644 --- a/rig-core/examples/multi_extract.rs +++ b/rig-core/examples/multi_extract.rs @@ -50,6 +50,9 @@ async fn main() -> anyhow::Result<()> { ) .build(); + // Create a chain that extracts names, topics, and sentiment from a given text + // using three different GPT-4 based extractors. + // The chain will output a formatted string containing the extracted information. let chain = pipeline::new() .chain(try_parallel!( agent_ops::extract(names_extractor), @@ -65,9 +68,21 @@ async fn main() -> anyhow::Result<()> { ) }); - let response = chain.try_call("Screw you Putin!").await?; + // Batch call the chain with up to 4 inputs concurrently + let response = chain + .try_batch_call( + 4, + vec![ + "Screw you Putin!", + "I love my dog, but I hate my cat.", + "I'm going to the store to buy some milk.", + ], + ) + .await?; - println!("Text analysis:\n{response}"); + for response in response { + println!("Text analysis:\n{response}"); + } Ok(()) } diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index 09077bc9..98a0d449 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -233,10 +233,7 @@ mod tests { .await .expect("Failed to run chain"); - assert_eq!( - result, - "Top documents:\nbar" - ); + assert_eq!(result, "Top documents:\nbar"); } #[tokio::test] diff --git a/rig-core/src/pipeline/op.rs b/rig-core/src/pipeline/op.rs index 7208bc27..a4a6bc92 100644 --- a/rig-core/src/pipeline/op.rs +++ b/rig-core/src/pipeline/op.rs @@ -2,6 +2,7 @@ use std::future::Future; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::join; +use futures::{stream, StreamExt}; // ================================================================ // Core Op trait @@ -12,6 +13,23 @@ pub trait Op: Send + Sync { fn call(&self, input: Self::Input) -> impl Future + Send; + /// Execute the current pipeline with the given inputs. `n` is the number of concurrent + /// inputs that will be processed concurrently. + fn batch_call(&self, n: usize, input: I) -> impl Future> + Send + where + I: IntoIterator + Send, + I::IntoIter: Send, + Self: Sized, + { + async move { + stream::iter(input) + .map(|input| self.call(input)) + .buffered(n) + .collect() + .await + } + } + /// Chain a function to the current pipeline /// /// # Example diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index b4f93483..fa77229a 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -2,6 +2,7 @@ use std::future::Future; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::try_join; +use futures::{stream, StreamExt, TryStreamExt}; use super::op::{self, map, then}; @@ -18,6 +19,29 @@ pub trait TryOp: Send + Sync { input: Self::Input, ) -> impl Future> + Send; + /// Execute the current pipeline with the given inputs. `n` is the number of concurrent + /// inputs that will be processed concurrently. + /// If one of the inputs fails, the entire operation will fail and the error will + /// be returned. + fn try_batch_call( + &self, + n: usize, + input: I, + ) -> impl Future, Self::Error>> + Send + where + I: IntoIterator + Send, + I::IntoIter: Send, + Self: Sized, + { + async move { + stream::iter(input) + .map(|input| self.try_call(input)) + .buffered(n) + .try_collect() + .await + } + } + fn map_ok(self, f: F) -> impl op::Op> where F: Fn(Self::Output) -> T + Send + Sync,