Skip to content

Commit

Permalink
fix(pipelines): Type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
cvauclair committed Dec 6, 2024
1 parent cbbd1cc commit 16970bc
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 112 deletions.
34 changes: 17 additions & 17 deletions rig-core/src/pipeline/agent_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl<I, In, T> Lookup<I, In, T>
where
I: vector_store::VectorStoreIndex,
{
pub fn new(index: I, n: usize) -> Self {
pub(crate) fn new(index: I, n: usize) -> Self {
Self {
index,
n,
Expand Down Expand Up @@ -66,7 +66,7 @@ pub struct Prompt<P, In> {
}

impl<P, In> Prompt<P, In> {
pub fn new(prompt: P) -> Self {
pub(crate) fn new(prompt: P) -> Self {
Self {
prompt,
_in: std::marker::PhantomData,
Expand Down Expand Up @@ -96,47 +96,47 @@ where
Prompt::new(prompt)
}

pub struct Extract<M, T, In>
pub struct Extract<M, Input, Output>
where
M: CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
{
extractor: Extractor<M, T>,
_in: std::marker::PhantomData<In>,
extractor: Extractor<M, Output>,
_in: std::marker::PhantomData<Input>,
}

impl<M, T, In> Extract<M, T, In>
impl<M, Input, Output> Extract<M, Input, Output>
where
M: CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
{
pub fn new(extractor: Extractor<M, T>) -> Self {
pub(crate) fn new(extractor: Extractor<M, Output>) -> Self {
Self {
extractor,
_in: std::marker::PhantomData,
}
}
}

impl<M, T, In> Op for Extract<M, T, In>
impl<M, Input, Output> Op for Extract<M, Input, Output>
where
M: CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
In: Into<String> + Send + Sync,
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
Input: Into<String> + Send + Sync,
{
type Input = In;
type Output = Result<T, ExtractionError>;
type Input = Input;
type Output = Result<Output, ExtractionError>;

async fn call(&self, input: Self::Input) -> Self::Output {
self.extractor.extract(&input.into()).await
}
}

pub fn extract<M, T, In>(extractor: Extractor<M, T>) -> Extract<M, T, In>
pub fn extract<M, Input, Output>(extractor: Extractor<M, Output>) -> Extract<M, Input, Output>
where
M: CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
In: Into<String> + Send + Sync,
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
Input: Into<String> + Send + Sync,
{
Extract::new(extractor)
}
Expand Down
44 changes: 22 additions & 22 deletions rig-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,14 @@ impl<E> PipelineBuilder<E> {
/// let result = pipeline.call((1, 2)).await;
/// assert_eq!(result, "Result: 3!");
/// ```
pub fn map<F, In, T>(self, f: F) -> impl Op<Input = In, Output = T>
pub fn map<F, Input, Output>(self, f: F) -> op::Map<F, Input>
where
F: Fn(In) -> T + Send + Sync,
In: Send + Sync,
T: Send + Sync,
F: Fn(Input) -> Output + Send + Sync,
Input: Send + Sync,
Output: Send + Sync,
Self: Sized,
{
map(f)
op::Map::new(f)
}

/// Same as `map` but for asynchronous functions
Expand All @@ -145,15 +145,15 @@ impl<E> PipelineBuilder<E> {
/// let result = pipeline.call("[email protected]".to_string()).await;
/// assert_eq!(result, "Hello, bob!");
/// ```
pub fn then<F, In, Fut>(self, f: F) -> impl Op<Input = In, Output = Fut::Output>
pub fn then<F, Input, Fut>(self, f: F) -> op::Then<F, Input>
where
F: Fn(In) -> Fut + Send + Sync,
In: Send + Sync,
F: Fn(Input) -> Fut + Send + Sync,
Input: Send + Sync,
Fut: Future + Send + Sync,
Fut::Output: Send + Sync,
Self: Sized,
{
then(f)
op::Then::new(f)
}

/// Add an arbitrary operation to the current pipeline.
Expand All @@ -179,7 +179,7 @@ impl<E> PipelineBuilder<E> {
/// let result = pipeline.call(1).await;
/// assert_eq!(result, 2);
/// ```
pub fn chain<T>(self, op: T) -> impl Op<Input = T::Input, Output = T::Output>
pub fn chain<T>(self, op: T) -> T
where
T: Op,
Self: Sized,
Expand All @@ -203,19 +203,19 @@ impl<E> PipelineBuilder<E> {
///

Check warning on line 203 in rig-core/src/pipeline/mod.rs

View workflow job for this annotation

GitHub Actions / stable / fmt

Diff in /home/runner/work/rig/rig/rig-core/src/pipeline/mod.rs
/// let result = pipeline.call("What is a flurbo?".to_string()).await;
/// ```
pub fn lookup<I, In, T>(
pub fn lookup<I, Input, Output>(
self,
index: I,
n: usize,
) -> impl Op<Input = In, Output = Result<Vec<T>, vector_store::VectorStoreError>>
) -> agent_ops::Lookup<I, Input, Output>
where
I: vector_store::VectorStoreIndex,
T: Send + Sync + for<'a> serde::Deserialize<'a>,
In: Into<String> + Send + Sync,
Output: Send + Sync + for<'a> serde::Deserialize<'a>,
Input: Into<String> + Send + Sync,
// E: From<vector_store::VectorStoreError> + Send + Sync,
Self: Sized,
{
agent_ops::lookup(index, n)
agent_ops::Lookup::new(index, n)
}

/// Add a prompt operation to the current pipeline/op. The prompt operation expects the
Expand All @@ -235,14 +235,14 @@ impl<E> PipelineBuilder<E> {
///
/// let result = pipeline.call("Alice".to_string()).await;
/// ```
pub fn prompt<P, In>(self, agent: P) -> agent_ops::Prompt<P, In>
pub fn prompt<P, Input>(self, agent: P) -> agent_ops::Prompt<P, Input>
where
P: completion::Prompt,
In: Into<String> + Send + Sync,
Input: Into<String> + Send + Sync,
// E: From<completion::PromptError> + Send + Sync,
Self: Sized,
{
agent_ops::prompt(agent)
agent_ops::Prompt::new(agent)
}

/// Add an extract operation to the current pipeline/op. The extract operation expects the
Expand All @@ -268,13 +268,13 @@ impl<E> PipelineBuilder<E> {
/// let result: Sentiment = pipeline.call("I love ice cream!".to_string()).await?;

Check warning on line 268 in rig-core/src/pipeline/mod.rs

View workflow job for this annotation

GitHub Actions / stable / fmt

Diff in /home/runner/work/rig/rig/rig-core/src/pipeline/mod.rs
/// assert!(result.score > 0.5);
/// ```
pub fn extract<M, T, In>(self, extractor: Extractor<M, T>) -> agent_ops::Extract<M, T, In>
pub fn extract<M, Input, Output>(self, extractor: Extractor<M, Output>) -> agent_ops::Extract<M, Input, Output>
where
M: completion::CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
In: Into<String> + Send + Sync,
Output: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
Input: Into<String> + Send + Sync,
{
agent_ops::extract(extractor)
agent_ops::Extract::new(extractor)
}
}

Expand Down
Loading

0 comments on commit 16970bc

Please sign in to comment.