Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add local model support and examples #147

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rig-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ required-features = ["derive"]
[[example]]
name = "xai_embeddings"
required-features = ["derive"]

[[example]]
name = "local_embeddings"
required-features = ["derive"]
104 changes: 104 additions & 0 deletions rig-core/examples/collab_local.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
use anyhow::Result;
use rig::{
agent::Agent,
completion::{Chat, Message},
providers::local,
};

struct Collaborator {
local_agent_1: Agent<local::CompletionModel>,
local_agent_2: Agent<local::CompletionModel>,
}

impl Collaborator {
fn new(position_a: &str, position_b: &str) -> Self {
let local1 = local::Client::new();
let local2 = local::Client::new();

Self {
local_agent_1: local1
.agent("llama3.1:8b-instruct-q8_0")
.preamble(position_a)
.build(),
local_agent_2: local2
.agent("llama3.1:8b-instruct-q8_0")
.preamble(position_b)
.build(),
}
}

async fn rounds(&self, n: usize) -> Result<()> {
let mut history_a: Vec<Message> = vec![];
let mut history_b: Vec<Message> = vec![];

let mut last_resp_b: Option<String> = None;

for _ in 0..n {
let prompt_a = if let Some(msg_b) = &last_resp_b {
msg_b.clone()
} else {
"Let's start improving prompts!".into()
};

let resp_a = self
.local_agent_1
.chat(&prompt_a, history_a.clone())
.await?;
println!("Agent 1:\n{}", resp_a);
history_a.push(Message {
role: "user".into(),
content: prompt_a.clone(),
});
history_a.push(Message {
role: "assistant".into(),
content: resp_a.clone(),
});
println!("================================================================");

let resp_b = self.local_agent_2.chat(&resp_a, history_b.clone()).await?;
println!("Agent 2:\n{}", resp_b);
println!("================================================================");

history_b.push(Message {
role: "user".into(),
content: resp_a.clone(),
});
history_b.push(Message {
role: "assistant".into(),
content: resp_b.clone(),
});

last_resp_b = Some(resp_b)
}

Ok(())
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create model
let collaborator = Collaborator::new(
"\
You are a prompt engineering expert focused on improving AI model performance. \
Your goal is to collaborate with another AI to iteratively refine and improve prompts. \
Analyze the previous response and suggest specific improvements to make prompts more effective. \
Consider aspects like clarity, specificity, context-setting, and task framing. \
Keep your suggestions focused and actionable. \
Format: Start with 'Suggested improvements:' followed by your specific recommendations. \
",
"\
You are a prompt engineering expert focused on improving AI model performance. \
Your goal is to collaborate with another AI to iteratively refine and improve prompts. \
Review the suggested improvements and either build upon them or propose alternative approaches. \
Consider practical implementation and potential edge cases. \
Keep your response constructive and specific. \
Format: Start with 'Building on that:' followed by your refined suggestions. \
",
);

// Run the collaboration for 4 rounds
collaborator.rounds(4).await?;

Ok(())
}
16 changes: 16 additions & 0 deletions rig-core/examples/local.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use rig::{completion::Prompt, providers::local};

#[tokio::main]
async fn main() {
let ollama_client = local::Client::new();

let llama3 = ollama_client.agent("llama3.1:8b-instruct-q8_0").build();

// Prompt the model and print its response
let response = llama3
.prompt("Who are you?")
.await
.expect("Failed to prompt ollama");

println!("Ollama: {response}");
}
131 changes: 131 additions & 0 deletions rig-core/examples/local_agent_with_tools.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
use anyhow::Result;
use rig::{
completion::{Prompt, ToolDefinition},
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::{debug, info_span, Instrument};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}

#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;

#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";

type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
tracing::info!("Adding {} and {}", args.x, args.y);
let result = args.x + args.y;
Ok(result)
}
}

#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";

type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to substract from"
},
"y": {
"type": "number",
"description": "The number to substract"
}
}
}
}))
.expect("Tool Definition")
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
tracing::info!("Subtracting {} from {}", args.y, args.x);
let result = args.x - args.y;
Ok(result)
}
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize tracing
tracing_subscriber::registry()
.with(fmt::layer())
.with(
EnvFilter::from_default_env()
.add_directive("rig=debug".parse()?)
.add_directive("local_agent_with_tools=debug".parse()?),
)
.init();

// Create local client
let local = providers::local::Client::new();

let span = info_span!("calculator_agent");

// Create agent with a single context prompt and two tools
let calculator_agent = local
.agent("llama3.1:8b-instruct-q8_0")
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.tool(Adder)
.tool(Subtract)
.max_tokens(1024)
.build();

// Prompt the agent and print the response
let prompt = "Calculate 2 - 5";
debug!(?prompt, "Raw prompt");

let response = calculator_agent.prompt(prompt).instrument(span).await?;

debug!(?response, "Raw response");
println!("Calculator Agent: {}", response);

Ok(())
}
30 changes: 30 additions & 0 deletions rig-core/examples/local_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use rig::providers::local;
use rig::Embed;

#[derive(Embed, Debug)]
struct Greetings {
#[embed]
message: String,
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize the local client
let client = local::Client::new();

let embeddings = client
.embeddings("mxbai-embed-large")
.document(Greetings {
message: "Hello, world!".to_string(),
})?
.document(Greetings {
message: "Goodbye, world!".to_string(),
})?
.build()
.await
.expect("Failed to embed documents");

println!("{:?}", embeddings);

Ok(())
}
Loading
Loading