diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 0af6b71d..fb2793e3 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -64,3 +64,7 @@ required-features = ["derive"] [[example]] name = "xai_embeddings" required-features = ["derive"] + +[[example]] +name = "local_embeddings" +required-features = ["derive"] diff --git a/rig-core/examples/collab_local.rs b/rig-core/examples/collab_local.rs new file mode 100644 index 00000000..c7edcb79 --- /dev/null +++ b/rig-core/examples/collab_local.rs @@ -0,0 +1,104 @@ +use anyhow::Result; +use rig::{ + agent::Agent, + completion::{Chat, Message}, + providers::local, +}; + +struct Collaborator { + local_agent_1: Agent, + local_agent_2: Agent, +} + +impl Collaborator { + fn new(position_a: &str, position_b: &str) -> Self { + let local1 = local::Client::new(); + let local2 = local::Client::new(); + + Self { + local_agent_1: local1 + .agent("llama3.1:8b-instruct-q8_0") + .preamble(position_a) + .build(), + local_agent_2: local2 + .agent("llama3.1:8b-instruct-q8_0") + .preamble(position_b) + .build(), + } + } + + async fn rounds(&self, n: usize) -> Result<()> { + let mut history_a: Vec = vec![]; + let mut history_b: Vec = vec![]; + + let mut last_resp_b: Option = None; + + for _ in 0..n { + let prompt_a = if let Some(msg_b) = &last_resp_b { + msg_b.clone() + } else { + "Let's start improving prompts!".into() + }; + + let resp_a = self + .local_agent_1 + .chat(&prompt_a, history_a.clone()) + .await?; + println!("Agent 1:\n{}", resp_a); + history_a.push(Message { + role: "user".into(), + content: prompt_a.clone(), + }); + history_a.push(Message { + role: "assistant".into(), + content: resp_a.clone(), + }); + println!("================================================================"); + + let resp_b = self.local_agent_2.chat(&resp_a, history_b.clone()).await?; + println!("Agent 2:\n{}", resp_b); + println!("================================================================"); + + history_b.push(Message { + role: "user".into(), + content: resp_a.clone(), + }); + history_b.push(Message { + role: "assistant".into(), + content: resp_b.clone(), + }); + + last_resp_b = Some(resp_b) + } + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create model + let collaborator = Collaborator::new( + "\ + You are a prompt engineering expert focused on improving AI model performance. \ + Your goal is to collaborate with another AI to iteratively refine and improve prompts. \ + Analyze the previous response and suggest specific improvements to make prompts more effective. \ + Consider aspects like clarity, specificity, context-setting, and task framing. \ + Keep your suggestions focused and actionable. \ + Format: Start with 'Suggested improvements:' followed by your specific recommendations. \ + ", + "\ + You are a prompt engineering expert focused on improving AI model performance. \ + Your goal is to collaborate with another AI to iteratively refine and improve prompts. \ + Review the suggested improvements and either build upon them or propose alternative approaches. \ + Consider practical implementation and potential edge cases. \ + Keep your response constructive and specific. \ + Format: Start with 'Building on that:' followed by your refined suggestions. \ + ", + ); + + // Run the collaboration for 4 rounds + collaborator.rounds(4).await?; + + Ok(()) +} diff --git a/rig-core/examples/local.rs b/rig-core/examples/local.rs new file mode 100644 index 00000000..e76d452e --- /dev/null +++ b/rig-core/examples/local.rs @@ -0,0 +1,16 @@ +use rig::{completion::Prompt, providers::local}; + +#[tokio::main] +async fn main() { + let ollama_client = local::Client::new(); + + let llama3 = ollama_client.agent("llama3.1:8b-instruct-q8_0").build(); + + // Prompt the model and print its response + let response = llama3 + .prompt("Who are you?") + .await + .expect("Failed to prompt ollama"); + + println!("Ollama: {response}"); +} diff --git a/rig-core/examples/local_agent_with_tools.rs b/rig-core/examples/local_agent_with_tools.rs new file mode 100644 index 00000000..eb4973a8 --- /dev/null +++ b/rig-core/examples/local_agent_with_tools.rs @@ -0,0 +1,131 @@ +use anyhow::Result; +use rig::{ + completion::{Prompt, ToolDefinition}, + providers, + tool::Tool, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use tracing::{debug, info_span, Instrument}; +use tracing_subscriber::{fmt, prelude::*, EnvFilter}; + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + } + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + tracing::info!("Adding {} and {}", args.x, args.y); + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to substract from" + }, + "y": { + "type": "number", + "description": "The number to substract" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + tracing::info!("Subtracting {} from {}", args.y, args.x); + let result = args.x - args.y; + Ok(result) + } +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize tracing + tracing_subscriber::registry() + .with(fmt::layer()) + .with( + EnvFilter::from_default_env() + .add_directive("rig=debug".parse()?) + .add_directive("local_agent_with_tools=debug".parse()?), + ) + .init(); + + // Create local client + let local = providers::local::Client::new(); + + let span = info_span!("calculator_agent"); + + // Create agent with a single context prompt and two tools + let calculator_agent = local + .agent("llama3.1:8b-instruct-q8_0") + .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") + .tool(Adder) + .tool(Subtract) + .max_tokens(1024) + .build(); + + // Prompt the agent and print the response + let prompt = "Calculate 2 - 5"; + debug!(?prompt, "Raw prompt"); + + let response = calculator_agent.prompt(prompt).instrument(span).await?; + + debug!(?response, "Raw response"); + println!("Calculator Agent: {}", response); + + Ok(()) +} diff --git a/rig-core/examples/local_embeddings.rs b/rig-core/examples/local_embeddings.rs new file mode 100644 index 00000000..3f71aa83 --- /dev/null +++ b/rig-core/examples/local_embeddings.rs @@ -0,0 +1,30 @@ +use rig::providers::local; +use rig::Embed; + +#[derive(Embed, Debug)] +struct Greetings { + #[embed] + message: String, +} + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize the local client + let client = local::Client::new(); + + let embeddings = client + .embeddings("mxbai-embed-large") + .document(Greetings { + message: "Hello, world!".to_string(), + })? + .document(Greetings { + message: "Goodbye, world!".to_string(), + })? + .build() + .await + .expect("Failed to embed documents"); + + println!("{:?}", embeddings); + + Ok(()) +} diff --git a/rig-core/src/providers/local.rs b/rig-core/src/providers/local.rs new file mode 100644 index 00000000..73fca5a6 --- /dev/null +++ b/rig-core/src/providers/local.rs @@ -0,0 +1,491 @@ +//! Local API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::local; +//! +//! let client = local::Client::new(); +//! +//! let model = client.completion_model("llama3.1:8b-instruct-q8_0"); +//! ``` + +use crate::{ + agent::AgentBuilder, + completion::{self, CompletionError, CompletionRequest}, + embeddings::{self, EmbeddingError, EmbeddingsBuilder}, + extractor::ExtractorBuilder, + json_utils, Embed, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +// ================================================================ +// Main Local Client +// ================================================================ +const DEFAULT_API_BASE_URL: &str = "http://localhost:11434"; // Ollama and LM Studio endpoint + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + /// Create a new Local client with an optional API key, using the default endpoint (ollama). + pub fn new() -> Self { + Self::from_url("", DEFAULT_API_BASE_URL) + } + + /// Create a new Local client with the given API key and base API URL. + pub fn from_url(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Bearer token should parse"), + ); + headers + }) + .build() + .expect("Local reqwest client should build"), + } + } + + /// Create a new Local client from the `LOCAL_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + Self::new() + } + + fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + /// Create an embedding model with the given name. + pub fn embedding_model(&self, model: &str) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, 1536) // Default to 1536 dimensions + } + + /// Create an embedding builder with the given embedding model. + pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder { + EmbeddingsBuilder::new(self.embedding_model(model)) + } + + /// Create a completion model with the given name. + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + /// Create an agent builder with the given completion model. + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +impl Default for Client { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Deserialize)] +struct ApiErrorResponse { + message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum ApiResponse { + Ok(T), + Err(ApiErrorResponse), +} + +// ================================================================ +// Local Embedding API +// ================================================================ + +#[derive(Debug, Deserialize)] +pub struct EmbeddingResponse { + pub object: String, + pub data: Vec, + pub model: String, + pub usage: Usage, +} + +impl From for EmbeddingError { + fn from(err: ApiErrorResponse) -> Self { + EmbeddingError::ProviderError(err.message) + } +} + +impl From> for Result { + fn from(value: ApiResponse) -> Self { + match value { + ApiResponse::Ok(response) => Ok(response), + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingData { + pub object: String, + pub embedding: Vec, + pub index: usize, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub total_tokens: usize, +} + +impl std::fmt::Display for Usage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Prompt tokens: {} Total tokens: {}", + self.prompt_tokens, self.total_tokens + ) + } +} + +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + pub model: String, + ndims: usize, +} + +impl embeddings::EmbeddingModel for EmbeddingModel { + const MAX_DOCUMENTS: usize = 1024; + + fn ndims(&self) -> usize { + self.ndims + } + + async fn embed_texts( + &self, + documents: impl IntoIterator, + ) -> Result, EmbeddingError> { + let documents = documents.into_iter().collect::>(); + + let response = self + .client + .post("/v1/embeddings") + .json(&json!({ + "model": self.model, + "input": documents, + })) + .send() + .await?; + + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + tracing::info!(target: "rig", + "Local embedding token usage: {}", + response.usage + ); + + if response.data.len() != documents.len() { + return Err(EmbeddingError::ResponseError( + "Response data length does not match input length".into(), + )); + } + + Ok(response + .data + .into_iter() + .zip(documents.into_iter()) + .map(|(embedding, document)| embeddings::Embedding { + document, + vec: embedding.embedding, + }) + .collect()) + } + ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)), + } + } else { + Err(EmbeddingError::ProviderError(response.text().await?)) + } + } +} + +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { + Self { + client, + model: model.to_string(), + ndims, + } + } +} + +// ================================================================ +// Local Completion API +// ================================================================ + +#[derive(Debug, Deserialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: String, + pub system_fingerprint: Option, + pub choices: Vec, + pub usage: Option, +} + +impl From for CompletionError { + fn from(err: ApiErrorResponse) -> Self { + CompletionError::ProviderError(err.message) + } +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(value: CompletionResponse) -> std::prelude::v1::Result { + tracing::debug!(target: "rig", ?value, "Processing completion response"); + match value.choices.as_slice() { + // First check for tool calls + [Choice { + message: + Message { + tool_calls: Some(calls), + .. + }, + .. + }, ..] => { + tracing::debug!(target: "rig", ?calls, "Received tool calls"); + let call = calls.first().ok_or_else(|| { + CompletionError::ResponseError("Tool selection is empty".into()) + })?; + + tracing::info!(target: "rig", + tool_name = ?call.function.name, + args = ?call.function.arguments, + "Processing tool call" + ); + + // Validate the arguments are valid JSON + let parsed_args: serde_json::Value = serde_json::from_str(&call.function.arguments) + .map_err(|e| { + CompletionError::ResponseError(format!( + "Invalid tool arguments JSON: {}", + e + )) + })?; + + // Validate it's a JSON object + if !parsed_args.is_object() { + return Err(CompletionError::ResponseError( + "Tool arguments must be a JSON object".into(), + )); + } + + tracing::debug!(target: "rig", ?parsed_args, "Parsed and validated tool arguments"); + + Ok(completion::CompletionResponse { + choice: completion::ModelChoice::ToolCall( + call.function.name.clone(), + parsed_args, + ), + raw_response: value, + }) + } + // Then check for content + [Choice { + message: + Message { + content: Some(content), + tool_calls: None, + .. + }, + .. + }, ..] => { + tracing::debug!(target: "rig", content_length = content.len(), "Received text response"); + Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message(content.to_string()), + raw_response: value, + }) + } + _ => { + tracing::error!( + target: "rig", + choices = ?value.choices, + "Response contained neither valid message nor tool calls" + ); + Err(CompletionError::ResponseError( + "Response did not contain a valid message or tool call".into(), + )) + } + } + } +} + +#[derive(Debug, Deserialize)] +pub struct Choice { + pub index: usize, + pub message: Message, + pub logprobs: Option, + pub finish_reason: String, +} + +#[derive(Debug, Deserialize)] +pub struct Message { + pub role: String, + pub content: Option, + pub tool_calls: Option>, +} + +#[derive(Debug, Deserialize)] +pub struct ToolCall { + pub id: String, + pub r#type: String, + pub function: Function, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ToolDefinition { + pub r#type: String, + pub function: completion::ToolDefinition, +} + +impl From for ToolDefinition { + fn from(tool: completion::ToolDefinition) -> Self { + Self { + r#type: "function".into(), + function: tool, + } + } +} + +#[derive(Debug, Deserialize)] +pub struct Function { + pub name: String, + pub arguments: String, +} + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + async fn completion( + &self, + mut completion_request: CompletionRequest, + ) -> Result, CompletionError> { + // Add preamble to chat history (if available) + let mut full_history = if let Some(preamble) = &completion_request.preamble { + tracing::debug!(target: "rig", system_prompt = ?preamble, "System prompt"); + vec![completion::Message { + role: "system".into(), + content: preamble.clone(), + }] + } else { + vec![] + }; + + // Extend existing chat history + full_history.append(&mut completion_request.chat_history); + + // Add context documents to chat history + let prompt_with_context = completion_request.prompt_with_context(); + tracing::debug!(target: "rig", prompt = ?prompt_with_context, "User prompt with context"); + + // Add context documents to chat history + full_history.push(completion::Message { + role: "user".into(), + content: prompt_with_context, + }); + + let request = if completion_request.tools.is_empty() { + tracing::debug!(target: "rig", "No tools provided in request"); + json!({ + "model": self.model, + "messages": full_history, + "temperature": completion_request.temperature, + }) + } else { + let tools: Vec = completion_request + .tools + .into_iter() + .map(ToolDefinition::from) + .collect(); + tracing::info!(target: "rig", tool_count = tools.len(), ?tools, "Sending tools to model"); + json!({ + "model": self.model, + "messages": full_history, + "temperature": completion_request.temperature, + "tools": tools, + "tool_choice": "auto", + }) + }; + + tracing::debug!(target: "rig", ?request, "Request to local model"); + + let response = self + .client + .post("/v1/chat/completions") + .json( + &if let Some(params) = completion_request.additional_params { + json_utils::merge(request, params) + } else { + request + }, + ) + .send() + .await?; + + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + tracing::info!(target: "rig", + "Local completion token usage: {:?}", + response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) + ); + tracing::debug!(target: "rig", ?response, "Raw response from local model"); + response.try_into() + } + ApiResponse::Err(err) => { + tracing::error!(target: "rig", error = ?err.message, "Local model error"); + Err(CompletionError::ProviderError(err.message)) + } + } + } else { + let error_text = response.text().await?; + tracing::error!(target: "rig", error = ?error_text, "Local model HTTP error"); + Err(CompletionError::ProviderError(error_text)) + } + } +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 23d4d181..fecbe98d 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -6,6 +6,7 @@ //! - Perplexity //! - Anthropic //! - Google Gemini +//! - Local (with configurable endpoint) //! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. @@ -43,6 +44,7 @@ pub mod anthropic; pub mod cohere; pub mod gemini; +pub mod local; pub mod openai; pub mod perplexity; pub mod xai;