From 8015b69fe0e185b105ea8641e95fe3e4378dcdef Mon Sep 17 00:00:00 2001 From: Christophe Date: Mon, 16 Dec 2024 10:30:03 -0500 Subject: [PATCH] feat(pipeline): Add id and score to `lookup` op result --- rig-core/examples/chain.rs | 2 +- rig-core/src/pipeline/agent_ops.rs | 11 ++++------- rig-core/src/pipeline/mod.rs | 4 ++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/rig-core/examples/chain.rs b/rig-core/examples/chain.rs index 1fa89e4c..d147ac9a 100644 --- a/rig-core/examples/chain.rs +++ b/rig-core/examples/chain.rs @@ -52,7 +52,7 @@ async fn main() -> Result<(), anyhow::Error> { .map(|(prompt, maybe_docs)| match maybe_docs { Ok(docs) => format!( "Non standard word definitions:\n{}\n\n{}", - docs.join("\n"), + docs.into_iter().map(|(_, _, doc)| doc).collect::>().join("\n"), prompt, ), Err(err) => { diff --git a/rig-core/src/pipeline/agent_ops.rs b/rig-core/src/pipeline/agent_ops.rs index c7562737..8f484c7e 100644 --- a/rig-core/src/pipeline/agent_ops.rs +++ b/rig-core/src/pipeline/agent_ops.rs @@ -1,7 +1,5 @@ use crate::{ - completion::{self, CompletionModel}, - extractor::{ExtractionError, Extractor}, - vector_store, + completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, vector_store }; use super::Op; @@ -34,7 +32,7 @@ where T: Send + Sync + for<'a> serde::Deserialize<'a>, { type Input = In; - type Output = Result, vector_store::VectorStoreError>; + type Output = Result, vector_store::VectorStoreError>; async fn call(&self, input: Self::Input) -> Self::Output { let query: String = input.into(); @@ -44,7 +42,6 @@ where .top_n::(&query, self.n) .await? .into_iter() - .map(|(_, _, doc)| doc) .collect(); Ok(docs) @@ -193,9 +190,9 @@ pub mod tests { let result = lookup.call("query".to_string()).await.unwrap(); assert_eq!( result, - vec![Foo { + vec![(1.0, "doc1".to_string(), Foo { foo: "bar".to_string() - }] + })] ); } diff --git a/rig-core/src/pipeline/mod.rs b/rig-core/src/pipeline/mod.rs index 27ea5b1f..c3f25049 100644 --- a/rig-core/src/pipeline/mod.rs +++ b/rig-core/src/pipeline/mod.rs @@ -343,7 +343,7 @@ mod tests { let chain = super::new() .lookup::<_, _, Foo>(index, 1) - .map_ok(|docs| format!("Top documents:\n{}", docs[0].foo)); + .map_ok(|docs| format!("Top documents:\n{}", docs[0].2.foo)); let result = chain .try_call("What is a flurbo?") @@ -363,7 +363,7 @@ mod tests { agent_ops::lookup::<_, _, Foo>(index, 1), )) .map(|(query, maybe_docs)| match maybe_docs { - Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo), + Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].2.foo), Err(err) => format!("Error: {}", err), }) .prompt(MockModel);