Skip to content

Commit

Permalink
feat: add local model support and examples
Browse files Browse the repository at this point in the history
  • Loading branch information
vacekj committed Dec 10, 2024
1 parent 412ea16 commit 9e9fcfc
Show file tree
Hide file tree
Showing 7 changed files with 764 additions and 0 deletions.
4 changes: 4 additions & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ required-features = ["derive"]
[[example]]
name = "xai_embeddings"
required-features = ["derive"]

[[example]]
name = "local_embeddings"
required-features = ["derive"]
99 changes: 99 additions & 0 deletions rig-core/examples/debate_local.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use anyhow::Result;
use rig::{
agent::Agent,
completion::{Chat, Message},
providers::local,
};

struct Debater {
local_debater_1: Agent<local::CompletionModel>,
local_debater_2: Agent<local::CompletionModel>,
}

impl Debater {
fn new(position_a: &str, position_b: &str) -> Self {
let local1 = local::Client::new();
let local2 = local::Client::new();

Self {
local_debater_1: local1
.agent("llama3.1:8b-instruct-q8_0")
.preamble(position_a)
.build(),
local_debater_2: local2
.agent("llama3.1:8b-instruct-q8_0")
.preamble(position_b)
.build(),
}
}

async fn rounds(&self, n: usize) -> Result<()> {
let mut history_a: Vec<Message> = vec![];
let mut history_b: Vec<Message> = vec![];

let mut last_resp_b: Option<String> = None;

for _ in 0..n {
let prompt_a = if let Some(msg_b) = &last_resp_b {
msg_b.clone()
} else {
"Plead your case!".into()
};

let resp_a = self.local_debater_1.chat(&prompt_a, history_a.clone()).await?;
println!("GPT-4:\n{}", resp_a);
history_a.push(Message {
role: "user".into(),
content: prompt_a.clone(),
});
history_a.push(Message {
role: "assistant".into(),
content: resp_a.clone(),
});
println!("================================================================");

let resp_b = self.local_debater_2.chat(&resp_a, history_b.clone()).await?;
println!("Coral:\n{}", resp_b);
println!("================================================================");

history_b.push(Message {
role: "user".into(),
content: resp_a.clone(),
});
history_b.push(Message {
role: "assistant".into(),
content: resp_b.clone(),
});

last_resp_b = Some(resp_b)
}

Ok(())
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create model
let debator = Debater::new(
"\
You believe that religion is a useful concept. \
This could be for security, financial, ethical, philosophical, metaphysical, religious or any kind of other reason. \
You choose what your arguments are. \
I will argue against you and you must rebuke me and try to convince me that I am wrong. \
Make your statements short and concise. \
",
"\
You believe that religion is a harmful concept. \
This could be for security, financial, ethical, philosophical, metaphysical, religious or any kind of other reason. \
You choose what your arguments are. \
I will argue against you and you must rebuke me and try to convince me that I am wrong. \
Make your statements short and concise. \
",
);

// Run the debate for 4 rounds
debator.rounds(4).await?;

Ok(())
}
16 changes: 16 additions & 0 deletions rig-core/examples/local.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use rig::{completion::Prompt, providers::local};

#[tokio::main]
async fn main() {
let ollama_client = local::Client::new();

let llama3 = ollama_client.agent("llama3.1:8b-instruct-q8_0").build();

// Prompt the model and print its response
let response = llama3
.prompt("Who are you?")
.await
.expect("Failed to prompt ollama");

println!("Ollama: {response}");
}
131 changes: 131 additions & 0 deletions rig-core/examples/local_agent_with_tools.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use anyhow::Result;
use rig::{
completion::{Chat, Prompt, ToolDefinition},
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::{debug, info_span, Instrument};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}

#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;

#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";

type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
tracing::info!("Adding {} and {}", args.x, args.y);
let result = args.x + args.y;
Ok(result)
}
}

#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";

type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to substract from"
},
"y": {
"type": "number",
"description": "The number to substract"
}
}
}
}))
.expect("Tool Definition")
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
tracing::info!("Subtracting {} from {}", args.y, args.x);
let result = args.x - args.y;
Ok(result)
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize tracing
tracing_subscriber::registry()
.with(fmt::layer())
.with(
EnvFilter::from_default_env()
.add_directive("rig=debug".parse()?)
.add_directive("local_agent_with_tools=debug".parse()?),
)
.init();

// Create local client
let local = providers::local::Client::new();

let span = info_span!("calculator_agent");

// Create agent with a single context prompt and two tools
let calculator_agent = local
.agent("llama3.1:8b-instruct-q8_0")
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.tool(Adder)
.tool(Subtract)
.max_tokens(1024)
.build();

// Prompt the agent and print the response
let prompt = "Calculate 2 - 5";
debug!(?prompt, "Raw prompt");

let response = calculator_agent.prompt(prompt).instrument(span).await?;

debug!(?response, "Raw response");
println!("Calculator Agent: {}", response);

Ok(())
}
30 changes: 30 additions & 0 deletions rig-core/examples/local_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use rig::providers::local;
use rig::Embed;

#[derive(Embed, Debug)]
struct Greetings {
#[embed]
message: String,
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize the local client
let client = local::Client::new();

let embeddings = client
.embeddings("mxbai-embed-large")
.document(Greetings {
message: "Hello, world!".to_string(),
})?
.document(Greetings {
message: "Goodbye, world!".to_string(),
})?
.build()
.await
.expect("Failed to embed documents");

println!("{:?}", embeddings);

Ok(())
}
Loading

0 comments on commit 9e9fcfc

Please sign in to comment.