Skip to content

Commit

Permalink
style: clippy+fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
cvauclair committed Dec 5, 2024
1 parent 64a7552 commit cbbd1cc
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 77 deletions.
137 changes: 84 additions & 53 deletions rig-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,54 @@
//! This module defines a flexible pipeline API for defining a sequence of operations that
//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc).
//!
//! This module defines a flexible pipeline API for defining a sequence of operations that
//! may or may not use AI components (e.g.: semantic search, LLMs prompting, etc).
//!
//! The pipeline API was inspired by general orchestration pipelines such as Airflow, Dagster and Prefect,
//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along
//! but implemented with idiomatic Rust patterns and providing some AI-specific ops out-of-the-box along
//! general combinators.
//!
//! Pipelines are made up of one or more "ops", each of which must implement the [Op] trait.
//!
//! Pipelines are made up of one or more operations, or "ops", each of which must implement the [Op] trait.
//! The [Op] trait requires the implementation of only one method: `call`, which takes an input
//! and returns an output. The trait provides a wide range of combinators for chaining operations together.
//!
//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and
//! the edges represent the data flow between operations. When invoking the pipeline on some input,
//!
//! One can think of a pipeline as a DAG (Directed Acyclic Graph) where each node is an operation and
//! the edges represent the data flow between operations. When invoking the pipeline on some input,
//! the input is passed to the root node of the DAG (i.e.: the first op defined in the pipeline).
//! The output of each op is then passed to the next op in the pipeline until the output reaches the
//! leaf node (i.e.: the last op defined in the pipeline). The output of the leaf node is then returned
//! as the result of the pipeline.
//!
//!
//! ## Basic Example
//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats
//! the result as a string using the [map](Op::map) combinator method, which applies a function to the
//! output of the previous op:
//! For example, the pipeline below takes a tuple of two integers, adds them together and then formats
//! the result as a string using the [map](Op::map) combinator method, which applies a simple function
//! op to the output of the previous op:
//! ```rust
//! use rig::pipeline::{self, Op};
//!
//!
//! let pipeline = pipeline::new()
//! // op1: add two numbers
//! .map(|(x, y)| x + y)
//! // op2: format result
//! .map(|z| format!("Result: {z}!"));
//!
//!
//! let result = pipeline.call((1, 2)).await;
//! assert_eq!(result, "Result: 3!");
//! ```
//!
//!
//! This pipeline can be visualized as the following DAG:
//! ```text
//! ┌─────────┐ ┌─────────┐
//! Input───►│ op1 ├──►│ op2 ├──►Output
//! └─────────┘ └─────────┘
//! ```
//!
//!
//! ## Parallel Operations
//! The pipeline API also provides a [parallel!](crate::parallel!) and macro for running operations in parallel.
//! The macro takes a list of ops and turns them into a single op that will duplicate the input
//! The macro takes a list of ops and turns them into a single op that will duplicate the input
//! and run each op in concurently. The results of each op are then collected and returned as a tuple.
//!
//!
//! For example, the pipeline below runs two operations concurently:
//! ```rust
//! use rig::{pipeline::{self, Op, map}, parallel};
//!
//!
//! let pipeline = pipeline::new()
//! .chain(parallel!(
//! // op1: add 1 to input
Expand All @@ -58,16 +58,16 @@
//! ))
//! // op3: format results
//! .map(|(a, b)| format!("Results: {a}, {b}"));
//!
//!
//! let result = pipeline.call(1).await;
//! assert_eq!(result, "Result: 2, 0");
//! ```
//!
//! Notes:
//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows
//! for chaining arbitrary operations, as long as they implement the [Op] trait.
//!
//! Notes:
//! - The [chain](Op::chain) method is similar to the [map](Op::map) method but it allows
//! for chaining arbitrary operations, as long as they implement the [Op] trait.
//! - [map] is a function that initializes a standalone [Map](self::op::Map) op without an existing pipeline/op.
//!
//!
//! The pipeline above can be visualized as the following DAG:
//! ```text
//! Input
Expand Down Expand Up @@ -95,28 +95,27 @@ pub mod parallel;

use std::future::Future;

use agent_ops::{lookup, prompt as prompt_op};
pub use op::{map, passthrough, then, Op};
pub use try_op::TryOp;

use crate::{completion, vector_store};
use crate::{completion, extractor::Extractor, vector_store};

pub struct PipelineBuilder<E> {
_error: std::marker::PhantomData<E>,
}

impl<E> PipelineBuilder<E> {
/// Chain a function to the current pipeline
/// Add a function to the current pipeline
///
/// # Example
/// ```rust
/// use rig::pipeline::{self, Op};
///
/// let chain = pipeline::new()
/// let pipeline = pipeline::new()
/// .map(|(x, y)| x + y)
/// .map(|z| format!("Result: {z}!"));
///
/// let result = chain.call((1, 2)).await;
/// let result = pipeline.call((1, 2)).await;
/// assert_eq!(result, "Result: 3!");
/// ```
pub fn map<F, In, T>(self, f: F) -> impl Op<Input = In, Output = T>
Expand All @@ -135,15 +134,15 @@ impl<E> PipelineBuilder<E> {
/// ```rust
/// use rig::pipeline::{self, Op};
///
/// let chain = pipeline::new()
/// let pipeline = pipeline::new()
/// .then(|email: String| async move {
/// email.split('@').next().unwrap().to_string()
/// })
/// .then(|username: String| async move {
/// format!("Hello, {}!", username)
/// });
///
/// let result = chain.call("[email protected]".to_string()).await;
/// let result = pipeline.call("[email protected]".to_string()).await;
/// assert_eq!(result, "Hello, bob!");
/// ```
pub fn then<F, In, Fut>(self, f: F) -> impl Op<Input = In, Output = Fut::Output>
Expand All @@ -157,7 +156,7 @@ impl<E> PipelineBuilder<E> {
then(f)
}

/// Chain an arbitrary operation to the current pipeline.
/// Add an arbitrary operation to the current pipeline.
///
/// # Example
/// ```rust
Expand All @@ -174,10 +173,10 @@ impl<E> PipelineBuilder<E> {
/// }
/// }
///
/// let chain = pipeline::new()
/// let pipeline = pipeline::new()
/// .chain(MyOp);
///
/// let result = chain.call(1).await;
/// let result = pipeline.call(1).await;
/// assert_eq!(result, 2);
/// ```
pub fn chain<T>(self, op: T) -> impl Op<Input = T::Input, Output = T::Output>
Expand All @@ -194,15 +193,15 @@ impl<E> PipelineBuilder<E> {
///
/// # Example
/// ```rust
/// use rig::chain::{self, Chain};
/// use rig::pipeline::{self, Op};
///
/// let chain = chain::new()
/// let pipeline = pipeline::new()
/// .lookup(index, 2)
/// .chain(|(query, docs): (_, Vec<String>)| async move {
/// .pipeline(|(query, docs): (_, Vec<String>)| async move {
/// format!("User query: {}\n\nTop documents:\n{}", query, docs.join("\n"))
/// });
///
/// let result = chain.call("What is a flurbo?".to_string()).await;
/// let result = pipeline.call("What is a flurbo?".to_string()).await;
/// ```
pub fn lookup<I, In, T>(
self,
Expand All @@ -216,37 +215,66 @@ impl<E> PipelineBuilder<E> {
// E: From<vector_store::VectorStoreError> + Send + Sync,
Self: Sized,
{
lookup(index, n)
agent_ops::lookup(index, n)
}

/// Chain a prompt operation to the current chain. The prompt operation expects the
/// current chain to output a string. The prompt operation will use the string to prompt
/// the given agent (or any other type that implements the `Prompt` trait) and return
/// Add a prompt operation to the current pipeline/op. The prompt operation expects the
/// current pipeline to output a string. The prompt operation will use the string to prompt
/// the given `agent`, which must implements the [Prompt](completion::Prompt) trait and return
/// the response.
///
/// # Example
/// ```rust
/// use rig::chain::{self, Chain};
/// use rig::pipeline::{self, Op};
///
/// let agent = &openai_client.agent("gpt-4").build();
///
/// let chain = chain::new()
/// let pipeline = pipeline::new()
/// .map(|name| format!("Find funny nicknames for the following name: {name}!"))
/// .prompt(agent);
///
/// let result = chain.call("Alice".to_string()).await;
/// let result = pipeline.call("Alice".to_string()).await;
/// ```
pub fn prompt<P, In>(
self,
prompt: P,
) -> impl Op<Input = In, Output = Result<String, completion::PromptError>>
pub fn prompt<P, In>(self, agent: P) -> agent_ops::Prompt<P, In>
where
P: completion::Prompt,
In: Into<String> + Send + Sync,
// E: From<completion::PromptError> + Send + Sync,
Self: Sized,
{
prompt_op(prompt)
agent_ops::prompt(agent)
}

/// Add an extract operation to the current pipeline/op. The extract operation expects the
/// current pipeline to output a string. The extract operation will use the given `extractor`
/// to extract information from the string in the form of the type `T` and return it.
///
/// # Example
/// ```rust
/// use rig::pipeline::{self, Op};
///
/// #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
/// struct Sentiment {
/// /// The sentiment score of the text (0.0 = negative, 1.0 = positive)
/// score: f64,
/// }
///
/// let extractor = &openai_client.extractor::<Sentiment>("gpt-4").build();
///
/// let pipeline = pipeline::new()
/// .map(|text| format!("Analyze the sentiment of the following text: {text}!"))
/// .extract(extractor);
///
/// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?;
/// assert!(result.score > 0.5);
/// ```
pub fn extract<M, T, In>(self, extractor: Extractor<M, T>) -> agent_ops::Extract<M, T, In>
where
M: completion::CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
In: Into<String> + Send + Sync,
{
agent_ops::extract(extractor)
}
}

Expand Down Expand Up @@ -330,7 +358,10 @@ mod tests {
let index = MockIndex;

let chain = super::new()
.chain(parallel!(passthrough(), lookup::<_, _, Foo>(index, 1),))
.chain(parallel!(
passthrough(),
agent_ops::lookup::<_, _, Foo>(index, 1),
))
.map(|(query, maybe_docs)| match maybe_docs {
Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo),
Err(err) => format!("Error: {}", err),
Expand Down
Loading

0 comments on commit cbbd1cc

Please sign in to comment.