Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: tool result callback #290

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
9 changes: 6 additions & 3 deletions rig-core/examples/agent_with_deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ use serde_json::json;
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();

let client = providers::deepseek::Client::from_env();

let agent = client
.agent("deepseek-chat")
.preamble("You are a helpful assistant.")
Expand All @@ -23,6 +24,10 @@ async fn main() -> Result<(), anyhow::Error> {
println!("Answer: {}", answer);

// Create agent with a single context prompt and two tools
/* WARNING from DeepSeek documentation (https://api-docs.deepseek.com/guides/function_calling)
Notice
The current version of the deepseek-chat model's Function Calling capabilitity is unstable, which may result in looped calls or empty responses.
*/
let calculator_agent = client
.agent(providers::deepseek::DEEPSEEK_CHAT)
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
Expand Down Expand Up @@ -81,7 +86,6 @@ impl Tool for Adder {
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Adding {} and {}", args.x, args.y);
let result = args.x + args.y;
Ok(result)
}
Expand Down Expand Up @@ -118,7 +122,6 @@ impl Tool for Subtract {
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Subtracting {} from {}", args.y, args.x);
let result = args.x - args.y;
Ok(result)
}
Expand Down
75 changes: 75 additions & 0 deletions rig-core/examples/agent_with_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,80 @@ async fn main() -> Result<(), anyhow::Error> {
calculator_agent.prompt("Calculate 2 - 5").await?
);

// Create agent with a single context prompt and a search tool
let search_agent = openai_client
.agent(providers::openai::GPT_4O)
.preamble(
"You are an assistant helping to find useful information on the internet. \
If you can't find the information, you can use the search tool to find it. \
If search tool return an error just notify the user saying you could not find any result.",
)
.max_tokens(1024)
.tool(SearchTool)
.build();

// Prompt the agent and print the response
println!("Can you please let me know title and url of rig platform?");
println!(
"OpenAI Search Agent: {}",
search_agent
.prompt("Can you please let me know title and url of rig platform?")
.await?
);

Ok(())
}

#[derive(Deserialize, Serialize)]
struct SearchArgs {
pub query_string: String,
}

#[derive(Deserialize, Serialize)]
struct SearchResult {
pub title: String,
pub url: String,
}

#[derive(Debug, thiserror::Error)]
#[error("Search error")]
struct SearchError;

#[derive(Deserialize, Serialize)]
struct SearchTool;

impl Tool for SearchTool {
const NAME: &'static str = "search";

type Error = SearchError;
type Args = SearchArgs;
type Output = SearchResult;

async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "search",
"description": "Search for a website, it will return the title and URL",
"parameters": {
"type": "object",
"properties": {
"query_string": {
"type": "string",
"description": "The query string to search for"
},
}
}
}))
.expect("Tool Definition")
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
if args.query_string.to_lowercase().contains("rig") {
Ok(SearchResult {
title: "Rig Documentation".to_string(),
url: "https://docs.rig.ai".to_string(),
})
} else {
Err(SearchError)
}
}
}
60 changes: 45 additions & 15 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ use crate::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
message::AssistantContent,
message::{AssistantContent, UserContent},
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
Expand Down Expand Up @@ -299,21 +299,51 @@ impl<M: CompletionModel> Prompt for &Agent<M> {
impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
prompt: impl Into<Message>,
mut chat_history: Vec<Message>,
) -> Result<String, PromptError> {
let resp = self.completion(prompt, chat_history).await?.send().await?;

// TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
match resp.choice.first() {
AssistantContent::Text(text) => Ok(text.text.clone()),
AssistantContent::ToolCall(tool_call) => Ok(self
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?),
let mut prompt: Message = prompt.into();

tracing::debug!("Chat prompt: {:?}", prompt);

loop {
// call model
let resp = self
.completion(prompt.clone(), chat_history.clone())
.await?
.send()
.await?;

if tracing::enabled!(tracing::Level::DEBUG) {
for (i, message) in resp.choice.iter().enumerate() {
tracing::debug!("Chat response message #{}: {:?}", i, message);
}
}

// keep calling tools until we get human readable answer from the model
match resp.choice.first() {
AssistantContent::Text(text) => break Ok(text.text.clone()),
AssistantContent::ToolCall(tool_call) => {
// Call the tool
let tool_response = self
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?;

let tool_response_message = UserContent::tool_result_from_text_response(
tool_call.id.clone(),
tool_response,
);

// add tool call and response into chat history and continue the loop
chat_history.push(prompt);
chat_history.push(tool_call.into());
prompt = tool_response_message.into();
}
}
}
}
}
Expand Down
29 changes: 29 additions & 0 deletions rig-core/src/completion/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ impl UserContent {
content,
})
}

/// Helper constructor to make creating user tool result from text easier.
pub fn tool_result_from_text_response(id: impl Into<String>, content: String) -> Self {
let content = OneOrMany::one(ToolResultContent::Text(content.into()));
Self::tool_result(id, content)
}
}

impl AssistantContent {
Expand All @@ -306,6 +312,13 @@ impl AssistantContent {
},
})
}

pub fn is_empty(&self) -> bool {
match self {
AssistantContent::Text(Text { text }) => text.is_empty(),
_ => false,
}
}
}

impl ToolResultContent {
Expand Down Expand Up @@ -544,6 +557,22 @@ impl From<String> for UserContent {
}
}

impl From<ToolCall> for Message {
fn from(tool_call: ToolCall) -> Self {
Message::Assistant {
content: OneOrMany::one(AssistantContent::ToolCall(tool_call)),
}
}
}

impl From<UserContent> for Message {
fn from(content: UserContent) -> Self {
Message::User {
content: OneOrMany::one(content),
}
}
}

// ================================================================
// Error types
// ================================================================
Expand Down
6 changes: 4 additions & 2 deletions rig-core/src/providers/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ pub enum Message {
#[serde(default, deserialize_with = "json_utils::null_or_vec")]
tool_calls: Vec<ToolCall>,
},
#[serde(rename = "Tool")]
#[serde(rename = "tool")]
ToolResult {
tool_call_id: String,
content: String,
Expand Down Expand Up @@ -332,7 +332,7 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
.iter()
.map(|call| {
completion::AssistantContent::tool_call(
&call.function.name,
&call.id,
&call.function.name,
call.function.arguments.clone(),
)
Expand Down Expand Up @@ -416,6 +416,8 @@ impl CompletionModel for DeepSeekCompletionModel {
})
};

tracing::debug!( target: "rig", "DeepSeek completion request: {}", serde_json::to_string_pretty(&request).unwrap());

let response = self
.client
.post("/chat/completions")
Expand Down
Loading