Skip to content

Commit

Permalink
feat: Add Op::batch_call and TryOp::try_batch_call
Browse files Browse the repository at this point in the history
  • Loading branch information
cvauclair committed Nov 29, 2024
1 parent 90393f1 commit fc69515
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 6 deletions.
19 changes: 17 additions & 2 deletions rig-core/examples/multi_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ async fn main() -> anyhow::Result<()> {
)
.build();

// Create a chain that extracts names, topics, and sentiment from a given text
// using three different GPT-4 based extractors.
// The chain will output a formatted string containing the extracted information.
let chain = pipeline::new()
.chain(try_parallel!(
agent_ops::extract(names_extractor),
Expand All @@ -65,9 +68,21 @@ async fn main() -> anyhow::Result<()> {
)
});

let response = chain.try_call("Screw you Putin!").await?;
// Batch call the chain with up to 4 inputs concurrently
let response = chain
.try_batch_call(
4,
vec![
"Screw you Putin!",
"I love my dog, but I hate my cat.",
"I'm going to the store to buy some milk.",
],
)
.await?;

println!("Text analysis:\n{response}");
for response in response {
println!("Text analysis:\n{response}");
}

Ok(())
}
5 changes: 1 addition & 4 deletions rig-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,7 @@ mod tests {
.await
.expect("Failed to run chain");

assert_eq!(
result,
"Top documents:\nbar"
);
assert_eq!(result, "Top documents:\nbar");
}

#[tokio::test]
Expand Down
18 changes: 18 additions & 0 deletions rig-core/src/pipeline/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::future::Future;

#[allow(unused_imports)] // Needed since this is used in a macro rule
use futures::join;
use futures::{stream, StreamExt};

// ================================================================
// Core Op trait
Expand All @@ -12,6 +13,23 @@ pub trait Op: Send + Sync {

fn call(&self, input: Self::Input) -> impl Future<Output = Self::Output> + Send;

/// Execute the current pipeline with the given inputs. `n` is the number of concurrent
/// inputs that will be processed concurrently.
fn batch_call<I>(&self, n: usize, input: I) -> impl Future<Output = Vec<Self::Output>> + Send
where
I: IntoIterator<Item = Self::Input> + Send,
I::IntoIter: Send,
Self: Sized,
{
async move {
stream::iter(input)
.map(|input| self.call(input))
.buffered(n)
.collect()
.await
}
}

/// Chain a function to the current pipeline
///
/// # Example
Expand Down
24 changes: 24 additions & 0 deletions rig-core/src/pipeline/try_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::future::Future;

#[allow(unused_imports)] // Needed since this is used in a macro rule
use futures::try_join;
use futures::{stream, StreamExt, TryStreamExt};

use super::op::{self, map, then};

Expand All @@ -18,6 +19,29 @@ pub trait TryOp: Send + Sync {
input: Self::Input,
) -> impl Future<Output = Result<Self::Output, Self::Error>> + Send;

/// Execute the current pipeline with the given inputs. `n` is the number of concurrent
/// inputs that will be processed concurrently.
/// If one of the inputs fails, the entire operation will fail and the error will
/// be returned.
fn try_batch_call<I>(
&self,
n: usize,
input: I,
) -> impl Future<Output = Result<Vec<Self::Output>, Self::Error>> + Send
where
I: IntoIterator<Item = Self::Input> + Send,
I::IntoIter: Send,
Self: Sized,
{
async move {
stream::iter(input)
.map(|input| self.try_call(input))
.buffered(n)
.try_collect()
.await
}
}

fn map_ok<F, T>(self, f: F) -> impl op::Op<Input = Self::Input, Output = Result<T, Self::Error>>
where
F: Fn(Self::Output) -> T + Send + Sync,
Expand Down

0 comments on commit fc69515

Please sign in to comment.