Skip to content

Commit

Permalink
feat: send tool response back to model
Browse files Browse the repository at this point in the history
  • Loading branch information
carlos-verdes committed Feb 6, 2025
1 parent 0a45ac2 commit 235bddb
Show file tree
Hide file tree
Showing 6 changed files with 703 additions and 70 deletions.
84 changes: 83 additions & 1 deletion rig-core/examples/agent_with_deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@ use serde_json::json;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_target(false)
.init();

let client = providers::deepseek::Client::from_env();
let agent = client
.agent("deepseek-chat")
.agent(providers::deepseek::DEEPSEEK_CHAT)
.preamble("You are a helpful assistant.")
.build();

Expand All @@ -33,6 +38,27 @@ async fn main() -> Result<(), anyhow::Error> {
calculator_agent.prompt("Calculate 2 - 5").await?
);

// Create agent with a single context prompt and a search tool
let search_agent = client
.agent(providers::deepseek::DEEPSEEK_CHAT)
.preamble(
"You are an assistant helping to find useful information on the internet. \
If you can't find the information, you can use the search tool to find it. \
If search tool return an error just notify the user saying you could not find any result.",
)
.max_tokens(1024)
.tool(SearchTool)
.build();

// Prompt the agent and print the response
println!("Can you please let me know title and url of rig platform?");
println!(
"DeepSeek Search Agent: {}",
search_agent
.prompt("Can you please let me know title and url of rig platform?")
.await?
);

Ok(())
}

Expand Down Expand Up @@ -118,3 +144,59 @@ impl Tool for Subtract {
Ok(result)
}
}

#[derive(Deserialize, Serialize)]
struct SearchArgs {
pub query_string: String,
}

#[derive(Deserialize, Serialize)]
struct SearchResult {
pub title: String,
pub url: String,
}

#[derive(Debug, thiserror::Error)]
#[error("Search error")]
struct SearchError;

#[derive(Deserialize, Serialize)]
struct SearchTool;

impl Tool for SearchTool {
const NAME: &'static str = "search";

type Error = SearchError;
type Args = SearchArgs;
type Output = SearchResult;

async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "search",
"description": "Search for a website, it will return the title and URL",
"parameters": {
"type": "object",
"properties": {
"query_string": {
"type": "string",
"description": "The query string to search for"
},
}
}
}))
.expect("Tool Definition")
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Searching for: '{}'", args.query_string);

if args.query_string.to_lowercase().contains("rig") {
Ok(SearchResult {
title: "Rig Documentation".to_string(),
url: "https://docs.rig.ai".to_string(),
})
} else {
Err(SearchError)
}
}
}
5 changes: 1 addition & 4 deletions rig-core/examples/agent_with_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,13 @@ impl Tool for SearchTool {
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();

// Create OpenAI client
let openai_client = providers::openai::Client::from_env();

/*
// Create agent with a single context prompt and two tools
let calculator_agent = openai_client
.agent(providers::openai::GPT_4O)
Expand All @@ -173,8 +172,6 @@ async fn main() -> Result<(), anyhow::Error> {
calculator_agent.prompt("Calculate 2 - 5").await?
);

*/

// Create agent with a single context prompt and a search tool
let search_agent = openai_client
.agent(providers::openai::GPT_4O)
Expand Down
111 changes: 102 additions & 9 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ use crate::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
message::AssistantContent,
message::{AssistantContent, UserContent},
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
Expand Down Expand Up @@ -296,24 +296,117 @@ impl<M: CompletionModel> Prompt for &Agent<M> {
}
}

impl<M: CompletionModel> Chat for Agent<M> {
impl<M: CompletionModel> Agent<M> {
async fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
prompt: impl Into<Message>,
mut chat_history: Vec<Message>,
) -> Result<String, PromptError> {
let resp = self.completion(prompt, chat_history).await?.send().await?;
let prompt: Message = prompt.into();

let resp = self
.completion(prompt.clone(), chat_history.clone())
.await?
.send()
.await?;

let mut first_choice = resp.choice.first();

tracing::debug!("Chat response choices: {:?}", resp.choice);

// loop to handle tool calls
while let AssistantContent::ToolCall(tool_call) = first_choice {
tracing::debug!("Inside chat Tool call loop: {:?}", tool_call);

// Call the tool
let tool_response = self
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?;

let tool_response_message =
UserContent::tool_result_from_text_response(tool_call.id.clone(), tool_response);

// add tool call and response into chat history
chat_history.push(tool_call.into());
chat_history.push(tool_response_message.into());

let resp = self
.completion(prompt.clone(), chat_history.clone())
.await?
.send()
.await?;

first_choice = resp.choice.first();
}

// TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
match resp.choice.first() {
AssistantContent::Text(text) => Ok(text.text.clone()),
AssistantContent::ToolCall(tool_call) => Ok(self
if let AssistantContent::Text(text) = first_choice {
Ok(text.text.clone())
} else {
unreachable!(
"Tool call should have been handled in the loop: {:?}",
first_choice
)
}
}
}

impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(
&self,
prompt: impl Into<Message>,
mut chat_history: Vec<Message>,
) -> Result<String, PromptError> {
let prompt: Message = prompt.into();

let resp = self
.completion(prompt.clone(), chat_history.clone())
.await?
.send()
.await?;

let mut first_choice = resp.choice.first();

// loop to handle tool calls
while let AssistantContent::ToolCall(tool_call) = first_choice {
tracing::debug!("Inside chat Tool call loop: {:?}", tool_call);

// Call the tool
let tool_response = self
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?),
.await?;

let tool_response_message =
UserContent::tool_result_from_text_response(tool_call.id.clone(), tool_response);

// add tool call and response into chat history
chat_history.push(tool_call.into());
chat_history.push(tool_response_message.into());

let resp = self
.completion(prompt.clone(), chat_history.clone())
.await?
.send()
.await?;

first_choice = resp.choice.first();
}

// TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
match resp.choice.first() {
AssistantContent::Text(text) => Ok(text.text.clone()),
AssistantContent::ToolCall(tool_call) => unreachable!(
"Tool call should have been handled in the loop: {:?}",
tool_call
),
}
}
}
Expand Down
34 changes: 34 additions & 0 deletions rig-core/src/completion/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ pub enum AssistantContent {
ToolCall(ToolCall),
}

impl AssistantContent {
pub fn is_empty(&self) -> bool {
match self {
AssistantContent::Text(Text { text }) => text.is_empty(),
_ => false,
}
}
}

/// Tool result content containing information about a tool call and it's resulting content.
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
pub struct ToolResult {
Expand Down Expand Up @@ -284,6 +293,15 @@ impl UserContent {
content,
})
}

/// Helper constructor to make creating user tool result from text easier.
pub fn tool_result_from_text_response(id: impl Into<String>, content: String) -> Self {
let content = OneOrMany::one(ToolResultContent::Text(content.into()));
UserContent::ToolResult(ToolResult {
id: id.into(),
content,
})
}
}

impl AssistantContent {
Expand Down Expand Up @@ -526,6 +544,22 @@ impl From<Document> for Message {
}
}

impl From<ToolCall> for Message {
fn from(tool_call: ToolCall) -> Self {
Message::Assistant {
content: OneOrMany::one(AssistantContent::ToolCall(tool_call)),
}
}
}

impl From<UserContent> for Message {
fn from(content: UserContent) -> Self {
Message::User {
content: OneOrMany::one(content),
}
}
}

impl From<String> for ToolResultContent {
fn from(text: String) -> Self {
ToolResultContent::text(text)
Expand Down
Loading

0 comments on commit 235bddb

Please sign in to comment.