Skip to content

Commit

Permalink
feat: call model with tool result
Browse files Browse the repository at this point in the history
  • Loading branch information
carlos-verdes committed Feb 10, 2025
1 parent b82ed37 commit f192a2b
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 25 deletions.
2 changes: 1 addition & 1 deletion rig-core/examples/agent_with_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl Tool for Subtract {
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt()
.with_max_level(tracing::Level::DEBUG)
.with_max_level(tracing::Level::INFO)
.with_target(false)
.init();

Expand Down
52 changes: 43 additions & 9 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ use crate::{
Chat, Completion, CompletionError, CompletionModel, CompletionRequestBuilder, Document,
Message, Prompt, PromptError,
},
message::AssistantContent,
message::{AssistantContent, UserContent},
streaming::{
StreamingChat, StreamingCompletion, StreamingCompletionModel, StreamingPrompt,
StreamingResult,
Expand Down Expand Up @@ -299,21 +299,55 @@ impl<M: CompletionModel> Prompt for &Agent<M> {
impl<M: CompletionModel> Chat for Agent<M> {
async fn chat(
&self,
prompt: impl Into<Message> + Send,
chat_history: Vec<Message>,
prompt: impl Into<Message>,
mut chat_history: Vec<Message>,
) -> Result<String, PromptError> {
let resp = self.completion(prompt, chat_history).await?.send().await?;
let prompt: Message = prompt.into();

// TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
match resp.choice.first() {
AssistantContent::Text(text) => Ok(text.text.clone()),
AssistantContent::ToolCall(tool_call) => Ok(self
let resp = self
.completion(prompt.clone(), chat_history.clone())
.await?
.send()
.await?;

let mut first_choice = resp.choice.first();

// loop to handle tool calls
while let AssistantContent::ToolCall(tool_call) = first_choice {
tracing::debug!("Inside chat Tool call loop: {:?}", tool_call);

// Call the tool
let tool_response = self
.tools
.call(
&tool_call.function.name,
tool_call.function.arguments.to_string(),
)
.await?),
.await?;

let tool_response_message =
UserContent::tool_result_from_text_response(tool_call.id.clone(), tool_response);

// add tool call and response into chat history
chat_history.push(tool_call.into());
chat_history.push(tool_response_message.into());

let resp = self
.completion(prompt.clone(), chat_history.clone())
.await?
.send()
.await?;

first_choice = resp.choice.first();
}

// TODO: consider returning a `Message` instead of `String` for parallel responses / tool calls
match first_choice {
AssistantContent::Text(text) => Ok(text.text.clone()),
AssistantContent::ToolCall(tool_call) => unreachable!(
"Tool call should have been handled in the loop: {:?}",
tool_call
),
}
}
}
Expand Down
29 changes: 29 additions & 0 deletions rig-core/src/completion/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ impl UserContent {
content,
})
}

/// Helper constructor to make creating user tool result from text easier.
pub fn tool_result_from_text_response(id: impl Into<String>, content: String) -> Self {
let content = OneOrMany::one(ToolResultContent::Text(content.into()));
Self::tool_result(id, content)
}
}

impl AssistantContent {
Expand All @@ -306,6 +312,13 @@ impl AssistantContent {
},
})
}

pub fn is_empty(&self) -> bool {
match self {
AssistantContent::Text(Text { text }) => text.is_empty(),
_ => false,
}
}
}

impl ToolResultContent {
Expand Down Expand Up @@ -544,6 +557,22 @@ impl From<String> for UserContent {
}
}

impl From<ToolCall> for Message {
fn from(tool_call: ToolCall) -> Self {
Message::Assistant {
content: OneOrMany::one(AssistantContent::ToolCall(tool_call)),
}
}
}

impl From<UserContent> for Message {
fn from(content: UserContent) -> Self {
Message::User {
content: OneOrMany::one(content),
}
}
}

// ================================================================
// Error types
// ================================================================
Expand Down
45 changes: 30 additions & 15 deletions rig-core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::{json, Value};

// ================================================================
// Main OpenAI Client
Expand Down Expand Up @@ -469,7 +469,7 @@ pub enum Message {
#[serde(default, deserialize_with = "json_utils::null_or_vec")]
tool_calls: Vec<ToolCall>,
},
#[serde(rename = "Tool")]
#[serde(rename = "tool")]
ToolResult {
tool_call_id: String,
content: OneOrMany<ToolResultContent>,
Expand Down Expand Up @@ -535,6 +535,12 @@ pub struct InputAudio {
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct ToolResultContent {
text: String,
#[serde(default = "default_result_content")]
r#type: String,
}

fn default_result_content() -> String {
"text".to_string()
}

impl FromStr for ToolResultContent {
Expand All @@ -547,7 +553,10 @@ impl FromStr for ToolResultContent {

impl From<String> for ToolResultContent {
fn from(s: String) -> Self {
ToolResultContent { text: s }
ToolResultContent {
text: s,
r#type: default_result_content(),
}
}
}

Expand Down Expand Up @@ -897,7 +906,7 @@ impl completion::CompletionModel for CompletionModel {
full_history.extend(chat_history);
full_history.extend(prompt);

let request = if completion_request.tools.is_empty() {
let mut request = if completion_request.tools.is_empty() {
json!({
"model": self.model,
"messages": full_history,
Expand All @@ -913,24 +922,26 @@ impl completion::CompletionModel for CompletionModel {
})
};

// merge additional params into request
if let Some(params) = completion_request.additional_params {
request = json_utils::merge(request, params);
}

let response = self
.client
.post("/chat/completions")
.json(
&if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
.json(&request)
.send()
.await?;

if response.status().is_success() {
let t = response.text().await?;
tracing::debug!(target: "rig", "OpenAI completion error: {}", t);
let t: Value = response.json().await?;

tracing::debug!(target: "rig", "OpenAI completion success: \nRequest: \n{} \n\nResponse: \n {}",
serde_json::to_string_pretty(&request).unwrap(),
serde_json::to_string_pretty(&t).unwrap());

match serde_json::from_str::<ApiResponse<CompletionResponse>>(&t)? {
match serde_json::from_value::<ApiResponse<CompletionResponse>>(t)? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"OpenAI completion token usage: {:?}",
Expand All @@ -941,7 +952,11 @@ impl completion::CompletionModel for CompletionModel {
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
let t = response.text().await?;
tracing::debug!(target: "rig", "OpenAI completion error: \nRequest: \n{} \n\nResponse: \n {}",
serde_json::to_string_pretty(&request).unwrap(),
t);
Err(CompletionError::ProviderError(t))
}
}
}
Expand Down

0 comments on commit f192a2b

Please sign in to comment.