-
Notifications
You must be signed in to change notification settings - Fork 173
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add local model support and examples
- Loading branch information
Showing
7 changed files
with
764 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
use anyhow::Result; | ||
use rig::{ | ||
agent::Agent, | ||
completion::{Chat, Message}, | ||
providers::local, | ||
}; | ||
|
||
struct Debater { | ||
local_debater_1: Agent<local::CompletionModel>, | ||
local_debater_2: Agent<local::CompletionModel>, | ||
} | ||
|
||
impl Debater { | ||
fn new(position_a: &str, position_b: &str) -> Self { | ||
let local1 = local::Client::new(); | ||
let local2 = local::Client::new(); | ||
|
||
Self { | ||
local_debater_1: local1 | ||
.agent("llama3.1:8b-instruct-q8_0") | ||
.preamble(position_a) | ||
.build(), | ||
local_debater_2: local2 | ||
.agent("llama3.1:8b-instruct-q8_0") | ||
.preamble(position_b) | ||
.build(), | ||
} | ||
} | ||
|
||
async fn rounds(&self, n: usize) -> Result<()> { | ||
let mut history_a: Vec<Message> = vec![]; | ||
let mut history_b: Vec<Message> = vec![]; | ||
|
||
let mut last_resp_b: Option<String> = None; | ||
|
||
for _ in 0..n { | ||
let prompt_a = if let Some(msg_b) = &last_resp_b { | ||
msg_b.clone() | ||
} else { | ||
"Plead your case!".into() | ||
}; | ||
|
||
let resp_a = self.local_debater_1.chat(&prompt_a, history_a.clone()).await?; | ||
println!("GPT-4:\n{}", resp_a); | ||
history_a.push(Message { | ||
role: "user".into(), | ||
content: prompt_a.clone(), | ||
}); | ||
history_a.push(Message { | ||
role: "assistant".into(), | ||
content: resp_a.clone(), | ||
}); | ||
println!("================================================================"); | ||
|
||
let resp_b = self.local_debater_2.chat(&resp_a, history_b.clone()).await?; | ||
println!("Coral:\n{}", resp_b); | ||
println!("================================================================"); | ||
|
||
history_b.push(Message { | ||
role: "user".into(), | ||
content: resp_a.clone(), | ||
}); | ||
history_b.push(Message { | ||
role: "assistant".into(), | ||
content: resp_b.clone(), | ||
}); | ||
|
||
last_resp_b = Some(resp_b) | ||
} | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), anyhow::Error> { | ||
// Create model | ||
let debator = Debater::new( | ||
"\ | ||
You believe that religion is a useful concept. \ | ||
This could be for security, financial, ethical, philosophical, metaphysical, religious or any kind of other reason. \ | ||
You choose what your arguments are. \ | ||
I will argue against you and you must rebuke me and try to convince me that I am wrong. \ | ||
Make your statements short and concise. \ | ||
", | ||
"\ | ||
You believe that religion is a harmful concept. \ | ||
This could be for security, financial, ethical, philosophical, metaphysical, religious or any kind of other reason. \ | ||
You choose what your arguments are. \ | ||
I will argue against you and you must rebuke me and try to convince me that I am wrong. \ | ||
Make your statements short and concise. \ | ||
", | ||
); | ||
|
||
// Run the debate for 4 rounds | ||
debator.rounds(4).await?; | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
use rig::{completion::Prompt, providers::local}; | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
let ollama_client = local::Client::new(); | ||
|
||
let llama3 = ollama_client.agent("llama3.1:8b-instruct-q8_0").build(); | ||
|
||
// Prompt the model and print its response | ||
let response = llama3 | ||
.prompt("Who are you?") | ||
.await | ||
.expect("Failed to prompt ollama"); | ||
|
||
println!("Ollama: {response}"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
use anyhow::Result; | ||
use rig::{ | ||
completion::{Chat, Prompt, ToolDefinition}, | ||
providers, | ||
tool::Tool, | ||
}; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::json; | ||
use tracing::{debug, info_span, Instrument}; | ||
use tracing_subscriber::{fmt, prelude::*, EnvFilter}; | ||
|
||
#[derive(Deserialize)] | ||
struct OperationArgs { | ||
x: i32, | ||
y: i32, | ||
} | ||
|
||
#[derive(Debug, thiserror::Error)] | ||
#[error("Math error")] | ||
struct MathError; | ||
|
||
#[derive(Deserialize, Serialize)] | ||
struct Adder; | ||
impl Tool for Adder { | ||
const NAME: &'static str = "add"; | ||
|
||
type Error = MathError; | ||
type Args = OperationArgs; | ||
type Output = i32; | ||
|
||
async fn definition(&self, _prompt: String) -> ToolDefinition { | ||
ToolDefinition { | ||
name: "add".to_string(), | ||
description: "Add x and y together".to_string(), | ||
parameters: json!({ | ||
"type": "object", | ||
"properties": { | ||
"x": { | ||
"type": "number", | ||
"description": "The first number to add" | ||
}, | ||
"y": { | ||
"type": "number", | ||
"description": "The second number to add" | ||
} | ||
} | ||
}), | ||
} | ||
} | ||
|
||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { | ||
tracing::info!("Adding {} and {}", args.x, args.y); | ||
let result = args.x + args.y; | ||
Ok(result) | ||
} | ||
} | ||
|
||
#[derive(Deserialize, Serialize)] | ||
struct Subtract; | ||
impl Tool for Subtract { | ||
const NAME: &'static str = "subtract"; | ||
|
||
type Error = MathError; | ||
type Args = OperationArgs; | ||
type Output = i32; | ||
|
||
async fn definition(&self, _prompt: String) -> ToolDefinition { | ||
serde_json::from_value(json!({ | ||
"name": "subtract", | ||
"description": "Subtract y from x (i.e.: x - y)", | ||
"parameters": { | ||
"type": "object", | ||
"properties": { | ||
"x": { | ||
"type": "number", | ||
"description": "The number to substract from" | ||
}, | ||
"y": { | ||
"type": "number", | ||
"description": "The number to substract" | ||
} | ||
} | ||
} | ||
})) | ||
.expect("Tool Definition") | ||
} | ||
|
||
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> { | ||
tracing::info!("Subtracting {} from {}", args.y, args.x); | ||
let result = args.x - args.y; | ||
Ok(result) | ||
} | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), anyhow::Error> { | ||
// Initialize tracing | ||
tracing_subscriber::registry() | ||
.with(fmt::layer()) | ||
.with( | ||
EnvFilter::from_default_env() | ||
.add_directive("rig=debug".parse()?) | ||
.add_directive("local_agent_with_tools=debug".parse()?), | ||
) | ||
.init(); | ||
|
||
// Create local client | ||
let local = providers::local::Client::new(); | ||
|
||
let span = info_span!("calculator_agent"); | ||
|
||
// Create agent with a single context prompt and two tools | ||
let calculator_agent = local | ||
.agent("llama3.1:8b-instruct-q8_0") | ||
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") | ||
.tool(Adder) | ||
.tool(Subtract) | ||
.max_tokens(1024) | ||
.build(); | ||
|
||
// Prompt the agent and print the response | ||
let prompt = "Calculate 2 - 5"; | ||
debug!(?prompt, "Raw prompt"); | ||
|
||
let response = calculator_agent.prompt(prompt).instrument(span).await?; | ||
|
||
debug!(?response, "Raw response"); | ||
println!("Calculator Agent: {}", response); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
use rig::providers::local; | ||
use rig::Embed; | ||
|
||
#[derive(Embed, Debug)] | ||
struct Greetings { | ||
#[embed] | ||
message: String, | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), anyhow::Error> { | ||
// Initialize the local client | ||
let client = local::Client::new(); | ||
|
||
let embeddings = client | ||
.embeddings("mxbai-embed-large") | ||
.document(Greetings { | ||
message: "Hello, world!".to_string(), | ||
})? | ||
.document(Greetings { | ||
message: "Goodbye, world!".to_string(), | ||
})? | ||
.build() | ||
.await | ||
.expect("Failed to embed documents"); | ||
|
||
println!("{:?}", embeddings); | ||
|
||
Ok(()) | ||
} |
Oops, something went wrong.