diff --git a/rig-core/examples/agent_with_deepseek.rs b/rig-core/examples/agent_with_deepseek.rs index df82442b..dbfd35ef 100644 --- a/rig-core/examples/agent_with_deepseek.rs +++ b/rig-core/examples/agent_with_deepseek.rs @@ -9,11 +9,12 @@ use serde_json::json; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { tracing_subscriber::fmt() - .with_max_level(tracing::Level::INFO) + .with_max_level(tracing::Level::DEBUG) .with_target(false) .init(); let client = providers::deepseek::Client::from_env(); + let agent = client .agent("deepseek-chat") .preamble("You are a helpful assistant.") @@ -23,6 +24,10 @@ async fn main() -> Result<(), anyhow::Error> { println!("Answer: {}", answer); // Create agent with a single context prompt and two tools + /* WARNING from DeepSeek documentation (https://api-docs.deepseek.com/guides/function_calling) + Notice + The current version of the deepseek-chat model's Function Calling capabilitity is unstable, which may result in looped calls or empty responses. + */ let calculator_agent = client .agent(providers::deepseek::DEEPSEEK_CHAT) .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") @@ -81,7 +86,6 @@ impl Tool for Adder { } async fn call(&self, args: Self::Args) -> Result { - println!("[tool-call] Adding {} and {}", args.x, args.y); let result = args.x + args.y; Ok(result) } @@ -118,7 +122,6 @@ impl Tool for Subtract { } async fn call(&self, args: Self::Args) -> Result { - println!("[tool-call] Subtracting {} from {}", args.y, args.x); let result = args.x - args.y; Ok(result) } diff --git a/rig-core/examples/agent_with_tools.rs b/rig-core/examples/agent_with_tools.rs index ab0ae862..c4c56857 100644 --- a/rig-core/examples/agent_with_tools.rs +++ b/rig-core/examples/agent_with_tools.rs @@ -116,5 +116,80 @@ async fn main() -> Result<(), anyhow::Error> { calculator_agent.prompt("Calculate 2 - 5").await? ); + // Create agent with a single context prompt and a search tool + let search_agent = openai_client + .agent(providers::openai::GPT_4O) + .preamble( + "You are an assistant helping to find useful information on the internet. \ + If you can't find the information, you can use the search tool to find it. \ + If search tool return an error just notify the user saying you could not find any result.", + ) + .max_tokens(1024) + .tool(SearchTool) + .build(); + + // Prompt the agent and print the response + println!("Can you please let me know title and url of rig platform?"); + println!( + "OpenAI Search Agent: {}", + search_agent + .prompt("Can you please let me know title and url of rig platform?") + .await? + ); + Ok(()) } + +#[derive(Deserialize, Serialize)] +struct SearchArgs { + pub query_string: String, +} + +#[derive(Deserialize, Serialize)] +struct SearchResult { + pub title: String, + pub url: String, +} + +#[derive(Debug, thiserror::Error)] +#[error("Search error")] +struct SearchError; + +#[derive(Deserialize, Serialize)] +struct SearchTool; + +impl Tool for SearchTool { + const NAME: &'static str = "search"; + + type Error = SearchError; + type Args = SearchArgs; + type Output = SearchResult; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "search", + "description": "Search for a website, it will return the title and URL", + "parameters": { + "type": "object", + "properties": { + "query_string": { + "type": "string", + "description": "The query string to search for" + }, + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + if args.query_string.to_lowercase().contains("rig") { + Ok(SearchResult { + title: "Rig Documentation".to_string(), + url: "https://docs.rig.ai".to_string(), + }) + } else { + Err(SearchError) + } + } +} diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index b2a299e6..3590882c 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -115,7 +115,7 @@ use crate::{ Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document, Message, Prompt, PromptError, }, - message::AssistantContent, + message::{AssistantContent, UserContent}, streaming::{ StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt, StreamingResult, @@ -299,21 +299,51 @@ impl Prompt for &Agent { impl Chat for Agent { async fn chat( &self, - prompt: impl Into + Send, - chat_history: Vec, + prompt: impl Into, + mut chat_history: Vec, ) -> Result { - let resp = self.completion(prompt, chat_history).await?.send().await?; - - // TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls - match resp.choice.first() { - AssistantContent::Text(text) => Ok(text.text.clone()), - AssistantContent::ToolCall(tool_call) => Ok(self - .tools - .call( - &tool_call.function.name, - tool_call.function.arguments.to_string(), - ) - .await?), + let mut prompt: Message = prompt.into(); + + tracing::debug!("Chat prompt: {:?}", prompt); + + loop { + // call model + let resp = self + .completion(prompt.clone(), chat_history.clone()) + .await? + .send() + .await?; + + if tracing::enabled!(tracing::Level::DEBUG) { + for (i, message) in resp.choice.iter().enumerate() { + tracing::debug!("Chat response message #{}: {:?}", i, message); + } + } + + // keep calling tools until we get human readable answer from the model + match resp.choice.first() { + AssistantContent::Text(text) => break Ok(text.text.clone()), + AssistantContent::ToolCall(tool_call) => { + // Call the tool + let tool_response = self + .tools + .call( + &tool_call.function.name, + tool_call.function.arguments.to_string(), + ) + .await?; + + let tool_response_message = UserContent::tool_result_from_text_response( + tool_call.id.clone(), + tool_response, + ); + + // add tool call and response into chat history and continue the loop + chat_history.push(prompt); + chat_history.push(tool_call.into()); + prompt = tool_response_message.into(); + } + } } } } diff --git a/rig-core/src/completion/message.rs b/rig-core/src/completion/message.rs index dc9e6eb4..63b9c8d6 100644 --- a/rig-core/src/completion/message.rs +++ b/rig-core/src/completion/message.rs @@ -284,6 +284,12 @@ impl UserContent { content, }) } + + /// Helper constructor to make creating user tool result from text easier. + pub fn tool_result_from_text_response(id: impl Into, content: String) -> Self { + let content = OneOrMany::one(ToolResultContent::Text(content.into())); + Self::tool_result(id, content) + } } impl AssistantContent { @@ -306,6 +312,13 @@ impl AssistantContent { }, }) } + + pub fn is_empty(&self) -> bool { + match self { + AssistantContent::Text(Text { text }) => text.is_empty(), + _ => false, + } + } } impl ToolResultContent { @@ -544,6 +557,22 @@ impl From for UserContent { } } +impl From for Message { + fn from(tool_call: ToolCall) -> Self { + Message::Assistant { + content: OneOrMany::one(AssistantContent::ToolCall(tool_call)), + } + } +} + +impl From for Message { + fn from(content: UserContent) -> Self { + Message::User { + content: OneOrMany::one(content), + } + } +} + // ================================================================ // Error types // ================================================================ diff --git a/rig-core/src/providers/deepseek.rs b/rig-core/src/providers/deepseek.rs index 30b1971f..dffb484f 100644 --- a/rig-core/src/providers/deepseek.rs +++ b/rig-core/src/providers/deepseek.rs @@ -143,7 +143,7 @@ pub enum Message { #[serde(default, deserialize_with = "json_utils::null_or_vec")] tool_calls: Vec, }, - #[serde(rename = "Tool")] + #[serde(rename = "tool")] ToolResult { tool_call_id: String, content: String, @@ -332,7 +332,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse { let mut content = content .iter() - .map(|c| match c { - AssistantContent::Text { text } => completion::AssistantContent::text(text), + .filter_map(|c| match c { + AssistantContent::Text { text } => { + if text.trim().is_empty() { + None + } else { + Some(completion::AssistantContent::text(text)) + } + } AssistantContent::Refusal { refusal } => { - completion::AssistantContent::text(refusal) + Some(completion::AssistantContent::text(refusal)) } }) .collect::>(); @@ -566,6 +572,15 @@ pub struct ToolResultContent { #[serde(default)] r#type: ToolResultContentType, text: String, + #[serde(default)] + r#type: ToolResultContentType, +} + +#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)] +#[serde(rename_all = "lowercase")] +pub enum ToolResultContentType { + #[default] + Text, } #[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)] @@ -586,8 +601,8 @@ impl FromStr for ToolResultContent { impl From for ToolResultContent { fn from(s: String) -> Self { ToolResultContent { - r#type: ToolResultContentType::default(), text: s, + r#type: ToolResultContentType::default(), } } } @@ -938,7 +953,7 @@ impl completion::CompletionModel for CompletionModel { full_history.extend(chat_history); full_history.extend(prompt); - let request = if completion_request.tools.is_empty() { + let mut request = if completion_request.tools.is_empty() { json!({ "model": self.model, "messages": full_history, @@ -953,37 +968,37 @@ impl completion::CompletionModel for CompletionModel { }) }; + // merge additional params into request + if let Some(params) = completion_request.additional_params { + request = json_utils::merge(request, params); + } + // only include temperature if it exists // because some models don't support temperature - let request = if let Some(temperature) = completion_request.temperature { - json_utils::merge( + if let Some(temperature) = completion_request.temperature { + request = json_utils::merge( request, json!({ "temperature": temperature, }), - ) - } else { - request - }; + ); + } let response = self .client .post("/chat/completions") - .json( - &if let Some(params) = completion_request.additional_params { - json_utils::merge(request, params) - } else { - request - }, - ) + .json(&request) .send() .await?; if response.status().is_success() { - let t = response.text().await?; - tracing::debug!(target: "rig", "OpenAI completion error: {}", t); + let t: Value = response.json().await?; + + tracing::debug!(target: "rig", "OpenAI completion success: \nRequest: \n{} \n\nResponse: \n {}", + serde_json::to_string_pretty(&request).unwrap(), + serde_json::to_string_pretty(&t).unwrap()); - match serde_json::from_str::>(&t)? { + match serde_json::from_value::>(t)? { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "OpenAI completion token usage: {:?}", @@ -994,7 +1009,11 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)), } } else { - Err(CompletionError::ProviderError(response.text().await?)) + let t = response.text().await?; + tracing::debug!(target: "rig", "OpenAI completion error: \nRequest: \n{} \n\nResponse: \n {}", + serde_json::to_string_pretty(&request).unwrap(), + t); + Err(CompletionError::ProviderError(t)) } } }