Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: agentic chains #131

Merged
merged 31 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e28bf0a
feat(chain): Initial prototype for agentic chain feature
cvauclair Nov 15, 2024
f3c81b3
feat: Add chain error handling
cvauclair Nov 19, 2024
1a59d77
feat: Add parallel ops and rename `chain` to `pipeline`
cvauclair Nov 21, 2024
6f29e01
Merge branch 'main' into feat/agentic-chains
cvauclair Nov 25, 2024
7f83c1e
Merge branch 'main' into feat/agentic-chains
cvauclair Nov 29, 2024
72b49bb
Merge branch 'main' into feat/agentic-chains
cvauclair Nov 29, 2024
f46828d
docs: Update example
cvauclair Nov 29, 2024
110554b
feat: Add extraction pipeline op
cvauclair Nov 29, 2024
d49e11e
docs: Add extraction pipeline example
cvauclair Nov 29, 2024
fbfc2bc
feat: Add `try_parallel!` pipeline op macro
cvauclair Nov 29, 2024
f32ec08
misc: Remove unused module
cvauclair Nov 29, 2024
98af665
style: cargo fmt
cvauclair Nov 29, 2024
408b46c
test: fix typo in test
cvauclair Nov 29, 2024
a0a8534
test: fix typo in test #2
cvauclair Nov 29, 2024
90393f1
test: Fix broken lookup op test
cvauclair Nov 29, 2024
fc69515
feat: Add `Op::batch_call` and `TryOp::try_batch_call`
cvauclair Nov 29, 2024
c2c1996
test: Fix tests
cvauclair Nov 29, 2024
0fd59d6
Merge branch 'main' into feat/agentic-chains
cvauclair Dec 5, 2024
56ca41c
docs: Add more docstrings to agent pipeline ops
cvauclair Dec 5, 2024
eb34c4b
docs: Add pipeline module level docs
cvauclair Dec 5, 2024
64a7552
docs: improve pipeline docs
cvauclair Dec 5, 2024
cbbd1cc
style: clippy+fmt
cvauclair Dec 5, 2024
16970bc
fix(pipelines): Type errors
cvauclair Dec 6, 2024
758f621
fix: Missing trait import in macros
cvauclair Dec 6, 2024
8015b69
feat(pipeline): Add id and score to `lookup` op result
cvauclair Dec 16, 2024
22375f6
Merge branch 'main' into feat/agentic-chains
cvauclair Dec 16, 2024
bebbbca
docs(pipelines): Add more docstrings
cvauclair Dec 16, 2024
e4bd652
docs(pipelines): Update example
cvauclair Dec 16, 2024
9b2fa78
test(mongodb): Fix flaky test again
cvauclair Dec 16, 2024
bf52fd9
style: fmt
cvauclair Dec 16, 2024
f780473
test(mongodb): fix
cvauclair Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions rig-core/examples/chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use std::env;

use rig::{
embeddings::EmbeddingsBuilder,
parallel,
pipeline::{self, agent_ops::lookup, passthrough, Op},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::in_memory_store::InMemoryVectorStore,
};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);

let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

// Create embeddings for our documents
let embeddings = EmbeddingsBuilder::new(embedding_model.clone())
.document("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")?
.document("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")?
.document("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")?
.build()
.await?;

// Create vector store with the embeddings
let vector_store = InMemoryVectorStore::from_documents(embeddings);

// Create vector store index
let index = vector_store.index(embedding_model);

let agent = openai_client.agent("gpt-4")
.preamble("
You are a dictionary assistant here to assist the user in understanding the meaning of words.
")
.build();

let chain = pipeline::new()
// Chain a parallel operation to the current chain. The parallel operation will
// perform a lookup operation to retrieve additional context from the user prompt
// while simultaneously applying a passthrough operation. The latter will allow
// us to forward the initial prompt to the next operation in the chain.
.chain(parallel!(
passthrough(),
lookup::<_, _, String>(index, 1), // Required to specify document type
))
// Chain a "map" operation to the current chain, which will combine the user
// prompt with the retrieved context documents to create the final prompt.
// If an error occurs during the lookup operation, we will log the error and
// simply return the initial prompt.
.map(|(prompt, maybe_docs)| match maybe_docs {
Ok(docs) => format!(
"Non standard word definitions:\n{}\n\n{}",
docs.join("\n"),
prompt,
),
Err(err) => {
println!("Error: {}! Prompting without additional context", err);
format!("{prompt}")
}
})
// Chain a "prompt" operation which will prompt out agent with the final prompt
.prompt(agent);

// Prompt the agent and print the response
let response = chain.call("What does \"glarb-glarb\" mean?").await?;

println!("{:?}", response);

Ok(())
}
88 changes: 88 additions & 0 deletions rig-core/examples/multi_extract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
use rig::{
pipeline::{self, agent_ops, TryOp},
providers::openai,
try_parallel,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
/// A record containing extracted names
pub struct Names {
/// The names extracted from the text
pub names: Vec<String>,
}

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
/// A record containing extracted topics
pub struct Topics {
/// The topics extracted from the text
pub topics: Vec<String>,
}

#[derive(Debug, Deserialize, JsonSchema, Serialize)]
/// A record containing extracted sentiment
pub struct Sentiment {
/// The sentiment of the text (-1 being negative, 1 being positive)
pub sentiment: f64,
/// The confidence of the sentiment
pub confidence: f64,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let openai = openai::Client::from_env();

let names_extractor = openai
.extractor::<Names>("gpt-4")
.preamble("Extract names (e.g.: of people, places) from the given text.")
.build();

let topics_extractor = openai
.extractor::<Topics>("gpt-4")
.preamble("Extract topics from the given text.")
.build();

let sentiment_extractor = openai
.extractor::<Sentiment>("gpt-4")
.preamble(
"Extract sentiment (and how confident you are of the sentiment) from the given text.",
)
.build();

// Create a chain that extracts names, topics, and sentiment from a given text
// using three different GPT-4 based extractors.
// The chain will output a formatted string containing the extracted information.
let chain = pipeline::new()
.chain(try_parallel!(
agent_ops::extract(names_extractor),
agent_ops::extract(topics_extractor),
agent_ops::extract(sentiment_extractor),
))
.map_ok(|(names, topics, sentiment)| {
format!(
"Extracted names: {names}\nExtracted topics: {topics}\nExtracted sentiment: {sentiment}",
names = names.names.join(", "),
topics = topics.topics.join(", "),
sentiment = sentiment.sentiment,
)
});

// Batch call the chain with up to 4 inputs concurrently
let response = chain
.try_batch_call(
4,
vec![
"Screw you Putin!",
"I love my dog, but I hate my cat.",
"I'm going to the store to buy some milk.",
],
)
.await?;

for response in response {
println!("Text analysis:\n{response}");
}

Ok(())
}
6 changes: 6 additions & 0 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ impl<M: CompletionModel> Prompt for Agent<M> {
}
}

impl<M: CompletionModel> Prompt for &Agent<M> {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
self.chat(prompt, vec![]).await
}
}

impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(&self, prompt: &str, chat_history: Vec<Message>) -> Result<String, PromptError> {
match self.completion(prompt, chat_history).await?.send().await? {
Expand Down
1 change: 1 addition & 0 deletions rig-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ pub mod extractor;
pub(crate) mod json_utils;
pub mod loaders;
pub mod one_or_many;
pub mod pipeline;
pub mod providers;
pub mod tool;
pub mod vector_store;
Expand Down
210 changes: 210 additions & 0 deletions rig-core/src/pipeline/agent_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
use crate::{
completion::{self, CompletionModel},
extractor::{ExtractionError, Extractor},
vector_store,
};

use super::Op;

pub struct Lookup<I, In, T> {
index: I,
n: usize,
_in: std::marker::PhantomData<In>,
_t: std::marker::PhantomData<T>,
}

impl<I, In, T> Lookup<I, In, T>
where
I: vector_store::VectorStoreIndex,
{
pub fn new(index: I, n: usize) -> Self {
Self {
index,
n,
_in: std::marker::PhantomData,
_t: std::marker::PhantomData,
}
}
}

impl<I, In, T> Op for Lookup<I, In, T>
where
I: vector_store::VectorStoreIndex,
In: Into<String> + Send + Sync,
T: Send + Sync + for<'a> serde::Deserialize<'a>,
{
type Input = In;
type Output = Result<Vec<T>, vector_store::VectorStoreError>;

async fn call(&self, input: Self::Input) -> Self::Output {
let query: String = input.into();

let docs = self
.index
.top_n::<T>(&query, self.n)
.await?
.into_iter()
.map(|(_, _, doc)| doc)
cvauclair marked this conversation as resolved.
Show resolved Hide resolved
.collect();

Ok(docs)
}
}

pub fn lookup<I, In, T>(index: I, n: usize) -> Lookup<I, In, T>
where
I: vector_store::VectorStoreIndex,
In: Into<String> + Send + Sync,
T: Send + Sync + for<'a> serde::Deserialize<'a>,
{
Lookup::new(index, n)
}

pub struct Prompt<P, In> {
prompt: P,
_in: std::marker::PhantomData<In>,
}

impl<P, In> Prompt<P, In> {
pub fn new(prompt: P) -> Self {
Self {
prompt,
_in: std::marker::PhantomData,
}
}
}

impl<P, In> Op for Prompt<P, In>
where
P: completion::Prompt,
In: Into<String> + Send + Sync,
{
type Input = In;
type Output = Result<String, completion::PromptError>;

async fn call(&self, input: Self::Input) -> Self::Output {
let prompt: String = input.into();
self.prompt.prompt(&prompt).await
}
}

pub fn prompt<P, In>(prompt: P) -> Prompt<P, In>
where
P: completion::Prompt,
In: Into<String> + Send + Sync,
{
Prompt::new(prompt)
}

pub struct Extract<M, T, In>
where
M: CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
{
extractor: Extractor<M, T>,
_in: std::marker::PhantomData<In>,
}

impl<M, T, In> Extract<M, T, In>
where
M: CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
{
pub fn new(extractor: Extractor<M, T>) -> Self {
Self {
extractor,
_in: std::marker::PhantomData,
}
}
}

impl<M, T, In> Op for Extract<M, T, In>
where
M: CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
In: Into<String> + Send + Sync,
{
type Input = In;
type Output = Result<T, ExtractionError>;

async fn call(&self, input: Self::Input) -> Self::Output {
self.extractor.extract(&input.into()).await
}
}

pub fn extract<M, T, In>(extractor: Extractor<M, T>) -> Extract<M, T, In>
where
M: CompletionModel,
T: schemars::JsonSchema + for<'a> serde::Deserialize<'a> + Send + Sync,
In: Into<String> + Send + Sync,
{
Extract::new(extractor)
}

#[cfg(test)]
pub mod tests {
use super::*;
use completion::{Prompt, PromptError};
use vector_store::{VectorStoreError, VectorStoreIndex};

pub struct MockModel;

impl Prompt for MockModel {
async fn prompt(&self, prompt: &str) -> Result<String, PromptError> {
Ok(format!("Mock response: {}", prompt))
}
}

pub struct MockIndex;

impl VectorStoreIndex for MockIndex {
async fn top_n<T: for<'a> serde::Deserialize<'a> + std::marker::Send>(
&self,
_query: &str,
_n: usize,
) -> Result<Vec<(f64, String, T)>, VectorStoreError> {
let doc = serde_json::from_value(serde_json::json!({
"foo": "bar",
}))
.unwrap();

Ok(vec![(1.0, "doc1".to_string(), doc)])
}

async fn top_n_ids(
&self,
_query: &str,
_n: usize,
) -> Result<Vec<(f64, String)>, VectorStoreError> {
Ok(vec![(1.0, "doc1".to_string())])
}
}

#[derive(Debug, serde::Deserialize, PartialEq)]
pub struct Foo {
pub foo: String,
}

#[tokio::test]
async fn test_lookup() {
let index = MockIndex;
let lookup = lookup::<MockIndex, String, Foo>(index, 1);

let result = lookup.call("query".to_string()).await.unwrap();
assert_eq!(
result,
vec![Foo {
foo: "bar".to_string()
}]
);
}

#[tokio::test]
async fn test_prompt() {
let model = MockModel;
let prompt = prompt::<MockModel, String>(model);

let result = prompt.call("hello".to_string()).await.unwrap();
assert_eq!(result, "Mock response: hello");
}
}
Loading