forked from 0xPlaygrounds/rig
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlocal_agent_with_tools.rs
146 lines (125 loc) · 4.03 KB
/
local_agent_with_tools.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
use anyhow::Result;
use rig::{
completion::{Chat, Message, Prompt, ToolDefinition},
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}
#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;
#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
tracing::info!("Adding {} and {}", args.x, args.y);
let result = args.x + args.y;
Ok(result)
}
}
#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";
type Error = MathError;
type Args = OperationArgs;
type Output = i32;
async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to substract from"
},
"y": {
"type": "number",
"description": "The number to substract"
}
}
}
}))
.expect("Tool Definition")
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
tracing::info!("Subtracting {} from {}", args.y, args.x);
let result = args.x - args.y;
Ok(result)
}
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create local client
let local = providers::openai::Client::from_url("", "http://192.168.0.10:11434/v1");
let span = info_span!("calculator_agent");
// Create agent with a single context prompt and two tools
let calculator_agent = local
.agent("c4ai-command-r7b-12-2024-abliterated")
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.tool(Adder)
.tool(Subtract)
.max_tokens(1024)
.build();
// Initialize chat history
let mut chat_history = Vec::new();
println!("Calculator Agent: Ready to help with calculations! (Type 'quit' to exit)");
loop {
print!("\nYou: ");
let mut input = String::new();
std::io::stdin().read_line(&mut input)?;
let input = input.trim();
if input.to_lowercase() == "quit" {
break;
}
// Add user message to history
chat_history.push(Message {
role: "user".into(),
content: input.into(),
});
// Get response from agent
let response = calculator_agent
.chat(input, chat_history.clone())
.instrument(span.clone())
.await?;
// Add assistant's response to history
chat_history.push(Message {
role: "assistant".into(),
content: response.clone(),
});
println!("Calculator Agent: {}", response);
}
println!("\nGoodbye!");
Ok(())
}