Skip to content

Commit

Permalink
feat(pipeline): Add id and score to lookup op result
Browse files Browse the repository at this point in the history
  • Loading branch information
cvauclair committed Dec 16, 2024
1 parent 758f621 commit 8015b69
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
2 changes: 1 addition & 1 deletion rig-core/examples/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async fn main() -> Result<(), anyhow::Error> {
.map(|(prompt, maybe_docs)| match maybe_docs {

Check warning on line 52 in rig-core/examples/chain.rs

View workflow job for this annotation

GitHub Actions / stable / fmt

Diff in /home/runner/work/rig/rig/rig-core/examples/chain.rs
Ok(docs) => format!(
"Non standard word definitions:\n{}\n\n{}",
docs.join("\n"),
docs.into_iter().map(|(_, _, doc)| doc).collect::<Vec<_>>().join("\n"),
prompt,
),
Err(err) => {
Expand Down
11 changes: 4 additions & 7 deletions rig-core/src/pipeline/agent_ops.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::{

Check warning on line 1 in rig-core/src/pipeline/agent_ops.rs

View workflow job for this annotation

GitHub Actions / stable / fmt

Diff in /home/runner/work/rig/rig/rig-core/src/pipeline/agent_ops.rs
completion::{self, CompletionModel},
extractor::{ExtractionError, Extractor},
vector_store,
completion::{self, CompletionModel}, extractor::{ExtractionError, Extractor}, vector_store
};

use super::Op;
Expand Down Expand Up @@ -34,7 +32,7 @@ where
T: Send + Sync + for<'a> serde::Deserialize<'a>,
{
type Input = In;
type Output = Result<Vec<T>, vector_store::VectorStoreError>;
type Output = Result<Vec<(f64, String, T)>, vector_store::VectorStoreError>;

async fn call(&self, input: Self::Input) -> Self::Output {
let query: String = input.into();
Expand All @@ -44,7 +42,6 @@ where
.top_n::<T>(&query, self.n)
.await?
.into_iter()
.map(|(_, _, doc)| doc)
.collect();

Ok(docs)
Expand Down Expand Up @@ -193,9 +190,9 @@ pub mod tests {
let result = lookup.call("query".to_string()).await.unwrap();

Check warning on line 190 in rig-core/src/pipeline/agent_ops.rs

View workflow job for this annotation

GitHub Actions / stable / fmt

Diff in /home/runner/work/rig/rig/rig-core/src/pipeline/agent_ops.rs
assert_eq!(
result,
vec![Foo {
vec![(1.0, "doc1".to_string(), Foo {
foo: "bar".to_string()
}]
})]
);
}

Expand Down
4 changes: 2 additions & 2 deletions rig-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ mod tests {

let chain = super::new()
.lookup::<_, _, Foo>(index, 1)
.map_ok(|docs| format!("Top documents:\n{}", docs[0].foo));
.map_ok(|docs| format!("Top documents:\n{}", docs[0].2.foo));

let result = chain
.try_call("What is a flurbo?")
Expand All @@ -363,7 +363,7 @@ mod tests {
agent_ops::lookup::<_, _, Foo>(index, 1),
))
.map(|(query, maybe_docs)| match maybe_docs {
Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].foo),
Ok(docs) => format!("User query: {}\n\nTop documents:\n{}", query, docs[0].2.foo),
Err(err) => format!("Error: {}", err),
})
.prompt(MockModel);
Expand Down

0 comments on commit 8015b69

Please sign in to comment.