From 16970bc06293b3dc422923ad77b254e18b3ea87f Mon Sep 17 00:00:00 2001 From: Christophe Date: Fri, 6 Dec 2024 13:31:10 -0500 Subject: [PATCH] fix(pipelines): Type errors --- rig-core/src/pipeline/agent_ops.rs | 34 ++++----- rig-core/src/pipeline/mod.rs | 44 ++++++------ rig-core/src/pipeline/op.rs | 106 ++++++++++++++++++----------- rig-core/src/pipeline/try_op.rs | 65 +++++++++--------- 4 files changed, 137 insertions(+), 112 deletions(-) diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index f51ce713..c7562737 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -17,7 +17,7 @@ impl Lookup where I: vector_store::VectorStoreIndex, { - pub fn new(index: I, n: usize) -> Self { + pub(crate) fn new(index: I, n: usize) -> Self { Self { index, n, @@ -66,7 +66,7 @@ pub struct Prompt { } impl Prompt { - pub fn new(prompt: P) -> Self { + pub(crate) fn new(prompt: P) -> Self { Self { prompt, _in: std::marker::PhantomData, @@ -96,21 +96,21 @@ where Prompt::new(prompt) } -pub struct Extract +pub struct Extract where M: CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, { - extractor: Extractor, - _in: std::marker::PhantomData, + extractor: Extractor, + _in: std::marker::PhantomData, } -impl Extract +impl Extract where M: CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, { - pub fn new(extractor: Extractor) -> Self { + pub(crate) fn new(extractor: Extractor) -> Self { Self { extractor, _in: std::marker::PhantomData, @@ -118,25 +118,25 @@ where } } -impl Op for Extract +impl Op for Extract where M: CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, - In: Into + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Input: Into + Send + Sync, { - type Input = In; - type Output = Result; + type Input = Input; + type Output = Result; async fn call(&self, input: Self::Input) -> Self::Output { self.extractor.extract(&input.into()).await } } -pub fn extract(extractor: Extractor) -> Extract +pub fn extract(extractor: Extractor) -> Extract where M: CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, - In: Into + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Input: Into + Send + Sync, { Extract::new(extractor) } diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index 417ae762..27ea5b1f 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -118,14 +118,14 @@ impl PipelineBuilder { /// let result = pipeline.call((1, 2)).await; /// assert_eq!(result, "Result: 3!"); /// ``` - pub fn map(self, f: F) -> impl Op + pub fn map(self, f: F) -> op::Map where - F: Fn(In) -> T + Send + Sync, - In: Send + Sync, - T: Send + Sync, + F: Fn(Input) -> Output + Send + Sync, + Input: Send + Sync, + Output: Send + Sync, Self: Sized, { - map(f) + op::Map::new(f) } /// Same as `map` but for asynchronous functions @@ -145,15 +145,15 @@ impl PipelineBuilder { /// let result = pipeline.call("bob@gmail.com".to_string()).await; /// assert_eq!(result, "Hello, bob!"); /// ``` - pub fn then(self, f: F) -> impl Op + pub fn then(self, f: F) -> op::Then where - F: Fn(In) -> Fut + Send + Sync, - In: Send + Sync, + F: Fn(Input) -> Fut + Send + Sync, + Input: Send + Sync, Fut: Future + Send + Sync, Fut::Output: Send + Sync, Self: Sized, { - then(f) + op::Then::new(f) } /// Add an arbitrary operation to the current pipeline. @@ -179,7 +179,7 @@ impl PipelineBuilder { /// let result = pipeline.call(1).await; /// assert_eq!(result, 2); /// ``` - pub fn chain(self, op: T) -> impl Op + pub fn chain(self, op: T) -> T where T: Op, Self: Sized, @@ -203,19 +203,19 @@ impl PipelineBuilder { /// /// let result = pipeline.call("What is a flurbo?".to_string()).await; /// ``` - pub fn lookup( + pub fn lookup( self, index: I, n: usize, - ) -> impl Op, vector_store::VectorStoreError>> + ) -> agent_ops::Lookup where I: vector_store::VectorStoreIndex, - T: Send + Sync + for<'a> serde::Deserialize<'a>, - In: Into + Send + Sync, + Output: Send + Sync + for<'a> serde::Deserialize<'a>, + Input: Into + Send + Sync, // E: From + Send + Sync, Self: Sized, { - agent_ops::lookup(index, n) + agent_ops::Lookup::new(index, n) } /// Add a prompt operation to the current pipeline/op. The prompt operation expects the @@ -235,14 +235,14 @@ impl PipelineBuilder { /// /// let result = pipeline.call("Alice".to_string()).await; /// ``` - pub fn prompt(self, agent: P) -> agent_ops::Prompt + pub fn prompt(self, agent: P) -> agent_ops::Prompt where P: completion::Prompt, - In: Into + Send + Sync, + Input: Into + Send + Sync, // E: From + Send + Sync, Self: Sized, { - agent_ops::prompt(agent) + agent_ops::Prompt::new(agent) } /// Add an extract operation to the current pipeline/op. The extract operation expects the @@ -268,13 +268,13 @@ impl PipelineBuilder { /// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?; /// assert!(result.score > 0.5); /// ``` - pub fn extract(self, extractor: Extractor) -> agent_ops::Extract + pub fn extract(self, extractor: Extractor) -> agent_ops::Extract where M: completion::CompletionModel, - T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, - In: Into + Send + Sync, + Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync, + Input: Into + Send + Sync, { - agent_ops::extract(extractor) + agent_ops::Extract::new(extractor) } } diff --git a/rig-core/src/pipeline/op.rs b/rig-core/src/pipeline/op.rs index 0ddf32bf..96ee67e8 100644 --- a/rig-core/src/pipeline/op.rs +++ b/rig-core/src/pipeline/op.rs @@ -2,7 +2,7 @@ use std::future::Future; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::join; -use futures::{stream, StreamExt}; +use futures::stream; // ================================================================ // Core Op trait @@ -21,6 +21,8 @@ pub trait Op: Send + Sync { I::IntoIter: Send, Self: Sized, { + use futures::stream::StreamExt; + async move { stream::iter(input) .map(|input| self.call(input)) @@ -43,13 +45,13 @@ pub trait Op: Send + Sync { /// let result = chain.call((1, 2)).await; /// assert_eq!(result, "Result: 3!"); /// ``` - fn map(self, f: F) -> impl Op + fn map(self, f: F) -> Sequential> where - F: Fn(Self::Output) -> T + Send + Sync, - T: Send + Sync, + F: Fn(Self::Output) -> Input + Send + Sync, + Input: Send + Sync, Self: Sized, { - Sequential::new(self, map(f)) + Sequential::new(self, Map::new(f)) } /// Same as `map` but for asynchronous functions @@ -69,14 +71,14 @@ pub trait Op: Send + Sync { /// let result = chain.call("bob@gmail.com".to_string()).await; /// assert_eq!(result, "Hello, bob!"); /// ``` - fn then(self, f: F) -> impl Op + fn then(self, f: F) -> Sequential> where F: Fn(Self::Output) -> Fut + Send + Sync, Fut: Future + Send + Sync, Fut::Output: Send + Sync, Self: Sized, { - Sequential::new(self, then(f)) + Sequential::new(self, Then::new(f)) } /// Chain an arbitrary operation to the current op. @@ -102,7 +104,7 @@ pub trait Op: Send + Sync { /// let result = chain.call(1).await; /// assert_eq!(result, 2); /// ``` - fn chain(self, op: T) -> impl Op + fn chain(self, op: T) -> Sequential where T: Op, Self: Sized, @@ -126,14 +128,14 @@ pub trait Op: Send + Sync { /// /// let result = chain.call("What is a flurbo?".to_string()).await; /// ``` - fn lookup( + fn lookup( self, index: I, n: usize, - ) -> impl Op, vector_store::VectorStoreError>> + ) -> Sequential> where I: vector_store::VectorStoreIndex, - T: Send + Sync + for<'a> serde::Deserialize<'a>, + Input: Send + Sync + for<'a> serde::Deserialize<'a>, Self::Output: Into, Self: Sized, { @@ -160,7 +162,7 @@ pub trait Op: Send + Sync { fn prompt

( self, prompt: P, - ) -> impl Op> + ) -> Sequential> where P: completion::Prompt, Self::Output: Into, @@ -189,7 +191,7 @@ pub struct Sequential { } impl Sequential { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } @@ -216,13 +218,13 @@ use super::agent_ops::{Lookup, Prompt}; // ================================================================ // Core Op implementations // ================================================================ -pub struct Map { +pub struct Map { f: F, - _t: std::marker::PhantomData, + _t: std::marker::PhantomData, } -impl Map { - pub fn new(f: F) -> Self { +impl Map { + pub(crate) fn new(f: F) -> Self { Self { f, _t: std::marker::PhantomData, @@ -230,14 +232,14 @@ impl Map { } } -impl Op for Map +impl Op for Map where - F: Fn(T) -> Out + Send + Sync, - T: Send + Sync, - Out: Send + Sync, + F: Fn(Input) -> Output + Send + Sync, + Input: Send + Sync, + Output: Send + Sync, { - type Input = T; - type Output = Out; + type Input = Input; + type Output = Output; #[inline] async fn call(&self, input: Self::Input) -> Self::Output { @@ -245,29 +247,53 @@ where } } -pub fn map(f: F) -> impl Op +pub fn map(f: F) -> Map where - F: Fn(T) -> Out + Send + Sync, - T: Send + Sync, - Out: Send + Sync, + F: Fn(Input) -> Output + Send + Sync, + Input: Send + Sync, + Output: Send + Sync, { Map::new(f) } -pub fn passthrough() -> impl Op +pub struct Passthrough { + _t: std::marker::PhantomData, +} + +impl Passthrough { + pub(crate) fn new() -> Self { + Self { + _t: std::marker::PhantomData, + } + } +} + +impl Op for Passthrough +where + T: Send + Sync, +{ + type Input = T; + type Output = T; + + async fn call(&self, input: Self::Input) -> Self::Output { + input + } +} + +pub fn passthrough() -> Passthrough where T: Send + Sync, { - Map::new(|x| x) + Passthrough::new() } -pub struct Then { +pub struct Then { f: F, - _t: std::marker::PhantomData, + _t: std::marker::PhantomData, } -impl Then { - fn new(f: F) -> Self { +impl Then { + pub(crate) fn new(f: F) -> Self { Self { f, _t: std::marker::PhantomData, @@ -275,14 +301,14 @@ impl Then { } } -impl Op for Then +impl Op for Then where - F: Fn(T) -> Fut + Send + Sync, - T: Send + Sync, + F: Fn(Input) -> Fut + Send + Sync, + Input: Send + Sync, Fut: Future + Send, Fut::Output: Send + Sync, { - type Input = T; + type Input = Input; type Output = Fut::Output; #[inline] @@ -291,10 +317,10 @@ where } } -pub fn then(f: F) -> impl Op +pub fn then(f: F) -> Then where - F: Fn(T) -> Fut + Send + Sync, - T: Send + Sync, + F: Fn(Input) -> Fut + Send + Sync, + Input: Send + Sync, Fut: Future + Send, Fut::Output: Send + Sync, { diff --git a/rig-core/src/pipeline/try_op.rs b/rig-core/src/pipeline/try_op.rs index e8d276e5..76b41e91 100644 --- a/rig-core/src/pipeline/try_op.rs +++ b/rig-core/src/pipeline/try_op.rs @@ -2,9 +2,9 @@ use std::future::Future; #[allow(unused_imports)] // Needed since this is used in a macro rule use futures::try_join; -use futures::{stream, StreamExt, TryStreamExt}; +use futures::stream; -use super::op::{self, map, then}; +use super::op::{self}; // ================================================================ // Core TryOp trait @@ -46,6 +46,8 @@ pub trait TryOp: Send + Sync { I::IntoIter: Send, Self: Sized, { + use stream::{StreamExt, TryStreamExt}; + async move { stream::iter(input) .map(|input| self.try_call(input)) @@ -69,13 +71,13 @@ pub trait TryOp: Send + Sync { /// let result = op.try_call(2).await; /// assert_eq!(result, Ok(4)); /// ``` - fn map_ok(self, f: F) -> impl op::Op> + fn map_ok(self, f: F) -> MapOk> where - F: Fn(Self::Output) -> T + Send + Sync, - T: Send + Sync, + F: Fn(Self::Output) -> Output + Send + Sync, + Output: Send + Sync, Self: Sized, { - MapOk::new(self, map(f)) + MapOk::new(self, op::Map::new(f)) } /// Map the error return value (i.e., `Err`) of the current op to a different value @@ -95,13 +97,13 @@ pub trait TryOp: Send + Sync { fn map_err( self, f: F, - ) -> impl op::Op> + ) -> MapErr> where F: Fn(Self::Error) -> E + Send + Sync, E: Send + Sync, Self: Sized, { - MapErr::new(self, map(f)) + MapErr::new(self, op::Map::new(f)) } /// Chain a function to the current op. The function will only be called @@ -119,17 +121,17 @@ pub trait TryOp: Send + Sync { /// let result = op.try_call(2).await; /// assert_eq!(result, Ok(4)); /// ``` - fn and_then( + fn and_then( self, f: F, - ) -> impl TryOp + ) -> AndThen> where F: Fn(Self::Output) -> Fut + Send + Sync, - Fut: Future> + Send + Sync, - T: Send + Sync, + Fut: Future> + Send + Sync, + Output: Send + Sync, Self: Sized, { - AndThen::new(self, then(f)) + AndThen::new(self, op::Then::new(f)) } /// Chain a function `f` to the current op. The function `f` will only be called @@ -150,14 +152,14 @@ pub trait TryOp: Send + Sync { fn or_else( self, f: F, - ) -> impl TryOp + ) -> OrElse> where F: Fn(Self::Error) -> Fut + Send + Sync, Fut: Future> + Send + Sync, E: Send + Sync, Self: Sized, { - OrElse::new(self, then(f)) + OrElse::new(self, op::Then::new(f)) } /// Chain a new op `op` to the current op. The new op will be called with the success @@ -189,7 +191,7 @@ pub trait TryOp: Send + Sync { fn chain_ok( self, op: T, - ) -> impl TryOp + ) -> TrySequential where T: op::Op, Self: Sized, @@ -222,7 +224,7 @@ pub struct MapOk { } impl MapOk { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } @@ -250,7 +252,7 @@ pub struct MapErr { } impl MapErr { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } @@ -279,22 +281,21 @@ pub struct AndThen { } impl AndThen { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } -impl TryOp for AndThen +impl op::Op for AndThen where Op1: TryOp, Op2: TryOp, { type Input = Op1::Input; - type Output = Op2::Output; - type Error = Op1::Error; + type Output = Result; #[inline] - async fn try_call(&self, input: Self::Input) -> Result { + async fn call(&self, input: Self::Input) -> Self::Output { let output = self.prev.try_call(input).await?; self.op.try_call(output).await } @@ -306,22 +307,21 @@ pub struct OrElse { } impl OrElse { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } -impl TryOp for OrElse +impl op::Op for OrElse where Op1: TryOp, Op2: TryOp, { type Input = Op1::Input; - type Output = Op1::Output; - type Error = Op2::Error; + type Output = Result; #[inline] - async fn try_call(&self, input: Self::Input) -> Result { + async fn call(&self, input: Self::Input) -> Self::Output { match self.prev.try_call(input).await { Ok(output) => Ok(output), Err(err) => self.op.try_call(err).await, @@ -335,22 +335,21 @@ pub struct TrySequential { } impl TrySequential { - pub fn new(prev: Op1, op: Op2) -> Self { + pub(crate) fn new(prev: Op1, op: Op2) -> Self { Self { prev, op } } } -impl TryOp for TrySequential +impl op::Op for TrySequential where Op1: TryOp, Op2: op::Op, { type Input = Op1::Input; - type Output = Op2::Output; - type Error = Op1::Error; + type Output = Result; #[inline] - async fn try_call(&self, input: Self::Input) -> Result { + async fn call(&self, input: Self::Input) -> Self::Output { match self.prev.try_call(input).await { Ok(output) => Ok(self.op.call(output).await), Err(err) => Err(err),