Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support tools for DeepSeek provider #251

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ assert_fs = "1.1.2"
tokio = { version = "1.34.0", features = ["full"] }
tracing-subscriber = "0.3.18"
tokio-test = "0.4.4"
dotenvy = "0.15.7"

[features]
all = ["derive", "pdf", "rayon"]
Expand Down
3 changes: 3 additions & 0 deletions rig-core/examples/agent_with_deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ use rig::{completion::Prompt, providers};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// load env from .env file
dotenvy::dotenv().ok();

let client = providers::deepseek::Client::from_env();
let agent = client
.agent("deepseek-chat")
Expand Down
32 changes: 31 additions & 1 deletion rig-core/examples/agent_with_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ impl Tool for Adder {
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Adding {} and {}", args.x, args.y);
let result = args.x + args.y;
Ok(result)
}
Expand Down Expand Up @@ -83,13 +84,17 @@ impl Tool for Subtract {
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
println!("[tool-call] Subtracting {} from {}", args.y, args.x);
let result = args.x - args.y;
Ok(result)
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// load env from .env file
dotenvy::dotenv().ok();

// Create OpenAI client
let openai_client = providers::openai::Client::from_env();

Expand All @@ -103,9 +108,34 @@ async fn main() -> Result<(), anyhow::Error> {
.build();

// Prompt the agent and print the response
println!("\n\n####################################################");
println!("OpenAI example");
println!("####################################################");
println!("Calculate 2 - 5");
println!(
"OpenAI Calculator Agent: {}",
calculator_agent.prompt("Calculate 2 - 5").await?
);

// Create a DeepSeek client
let deepseek_client = providers::deepseek::Client::from_env();

// Create agent with a single context prompt and two tools
let calculator_agent = deepseek_client
.agent(providers::deepseek::DEEPSEEK_CHAT)
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();

// Prompt the agent and print the response
println!("\n\n####################################################");
println!("DeepSeek example");
println!("####################################################");
println!("Calculate 2 - 5");
println!(
"Calculator Agent: {}",
"DeepSeek Calculator Agent: {}",
calculator_agent.prompt("Calculate 2 - 5").await?
);

Expand Down
119 changes: 92 additions & 27 deletions rig-core/src/providers/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
json_utils,
};
use reqwest::Client as HttpClient;
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use serde_json::json;

// ================================================================
Expand Down Expand Up @@ -72,10 +72,54 @@ impl Client {
#[derive(Debug, Deserialize)]
pub struct DeepSeekResponse {
// We'll match the JSON:
pub choices: Option<Vec<Choice>>,
pub choices: Vec<Choice>,
// you may want usage or other fields
}

impl TryFrom<DeepSeekResponse> for CompletionResponse<DeepSeekResponse> {
type Error = crate::completion::CompletionError;

fn try_from(value: DeepSeekResponse) -> Result<Self, Self::Error> {
match value.choices.as_slice() {
[Choice {
message:
Some(DeepSeekMessage {
tool_calls: Some(calls),
..
}),
..
}, ..]
if !calls.is_empty() =>
{
let call = calls.first().unwrap();

Ok(crate::completion::CompletionResponse {
choice: crate::completion::ModelChoice::ToolCall(
call.function.name.clone(),
"".to_owned(),
serde_json::from_str(&call.function.arguments)?,
),
raw_response: value,
})
}
[Choice {
message:
Some(DeepSeekMessage {
content: Some(content),
..
}),
..
}, ..] => Ok(crate::completion::CompletionResponse {
choice: crate::completion::ModelChoice::Message(content.to_string()),
raw_response: value,
}),
_ => Err(crate::completion::CompletionError::ResponseError(
"Response did not contain a message or tool call".into(),
)),
}
}
}

#[derive(Debug, Deserialize)]
pub struct Choice {
pub message: Option<DeepSeekMessage>,
Expand All @@ -85,6 +129,35 @@ pub struct Choice {
pub struct DeepSeekMessage {
pub role: Option<String>,
pub content: Option<String>,
pub tool_calls: Option<Vec<DeepSeekToolCall>>,
}

#[derive(Debug, Deserialize)]
pub struct DeepSeekToolCall {
pub id: String,
pub r#type: String,
pub function: DeepSeekFunction,
}

#[derive(Debug, Deserialize)]
pub struct DeepSeekFunction {
pub name: String,
pub arguments: String,
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct DeepSeekToolDefinition {
pub r#type: String,
pub function: crate::completion::ToolDefinition,
}

impl From<crate::completion::ToolDefinition> for DeepSeekToolDefinition {
fn from(tool: crate::completion::ToolDefinition) -> Self {
Self {
r#type: "function".into(),
function: tool,
}
}
}

/// The struct implementing the `CompletionModel` trait
Expand Down Expand Up @@ -145,11 +218,24 @@ impl CompletionModel for DeepSeekCompletionModel {
"presence_penalty": 0,
"temperature": request.temperature.unwrap_or(1.0),
"top_p": 1,
"tool_choice": "none",
"logprobs": false,
"stream": false,
});

// prepare tools
let tools = if request.tools.is_empty() {
json!({
"tool_choice": "none",
})
} else {
json!({
"tools": request.tools.into_iter().map(DeepSeekToolDefinition::from).collect::<Vec<_>>(),
"tool_choice": "auto",
})
};

let body = json_utils::merge(body, tools);

// if user set additional_params, merge them:
let final_body = if let Some(params) = request.additional_params {
json_utils::merge(body, params)
Expand All @@ -176,31 +262,10 @@ impl CompletionModel for DeepSeekCompletionModel {
)));
}

let json_resp: DeepSeekResponse = resp.json().await?;
// 4. Convert DeepSeekResponse -> rig’s `CompletionResponse<DeepSeekResponse>`
let deep_seek_response: DeepSeekResponse = resp.json().await?;

// If no choices or content, return an empty message
let content = if let Some(choices) = &json_resp.choices {
if let Some(choice) = choices.first() {
if let Some(msg) = &choice.message {
msg.content.clone().unwrap_or_default()
} else {
"".to_string()
}
} else {
"".to_string()
}
} else {
"".to_string()
};

// For now, we just treat it as a normal text message
let model_choice = crate::completion::ModelChoice::Message(content);

Ok(CompletionResponse {
choice: model_choice,
raw_response: json_resp,
})
// 4. Convert DeepSeekResponse -> rig’s `CompletionResponse<DeepSeekResponse>`
deep_seek_response.try_into()
}
}

Expand Down