diff --git a/rig-core/examples/anthropic_agent.rs b/rig-core/examples/anthropic_agent.rs new file mode 100644 index 00000000..9e3ea255 --- /dev/null +++ b/rig-core/examples/anthropic_agent.rs @@ -0,0 +1,31 @@ +use std::env; + +use rig::{ + completion::Prompt, + providers::anthropic::{self, CLAUDE_3_5_SONNET}, +}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Create OpenAI client + let client = anthropic::ClientBuilder::new( + &env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"), + ) + .build(); + + // Create agent with a single context prompt + let agent = client + .agent(CLAUDE_3_5_SONNET) + .preamble("Be precise and concise.") + .temperature(0.5) + .max_tokens(8192) + .build(); + + // Prompt the agent and print the response + let response = agent + .prompt("When and where and what type is the next solar eclipse?") + .await?; + println!("{}", response); + + Ok(()) +} diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 8648a9eb..1607d339 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -150,6 +150,8 @@ pub struct Agent { static_tools: Vec, /// Temperature of the model temperature: Option, + /// Maximum number of tokens for the completion + max_tokens: Option, /// Additional parameters to be passed to the model additional_params: Option, /// List of vector store, with the sample number @@ -238,6 +240,7 @@ impl Completion for Agent { .documents([self.static_context.clone(), dynamic_context].concat()) .tools([static_tools.clone(), dynamic_tools].concat()) .temperature_opt(self.temperature) + .max_tokens_opt(self.max_tokens) .additional_params_opt(self.additional_params.clone())) } } @@ -295,6 +298,8 @@ pub struct AgentBuilder { static_tools: Vec, /// Additional parameters to be passed to the model additional_params: Option, + /// Maximum number of tokens for the completion + max_tokens: Option, /// List of vector store, with the sample number dynamic_context: Vec<(usize, Box)>, /// Dynamic tools @@ -313,6 +318,7 @@ impl AgentBuilder { static_context: vec![], static_tools: vec![], temperature: None, + max_tokens: None, additional_params: None, dynamic_context: vec![], dynamic_tools: vec![], @@ -385,6 +391,12 @@ impl AgentBuilder { self } + /// Set the maximum number of tokens for the completion + pub fn max_tokens(mut self, max_tokens: u64) -> Self { + self.max_tokens = Some(max_tokens); + self + } + /// Set additional parameters to be passed to the model pub fn additional_params(mut self, params: serde_json::Value) -> Self { self.additional_params = Some(params); @@ -399,6 +411,7 @@ impl AgentBuilder { static_context: self.static_context, static_tools: self.static_tools, temperature: self.temperature, + max_tokens: self.max_tokens, additional_params: self.additional_params, dynamic_context: self.dynamic_context, dynamic_tools: self.dynamic_tools, diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index 1dd6f6b9..7c4d27e4 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -237,6 +237,8 @@ pub struct CompletionRequest { pub tools: Vec, /// The temperature to be sent to the completion model provider pub temperature: Option, + /// The max tokens to be sent to the completion model provider + pub max_tokens: Option, /// Additional provider-specific parameters to be sent to the completion model provider pub additional_params: Option, } @@ -293,6 +295,7 @@ pub struct CompletionRequestBuilder { documents: Vec, tools: Vec, temperature: Option, + max_tokens: Option, additional_params: Option, } @@ -306,6 +309,7 @@ impl CompletionRequestBuilder { documents: Vec::new(), tools: Vec::new(), temperature: None, + max_tokens: None, additional_params: None, } } @@ -394,6 +398,20 @@ impl CompletionRequestBuilder { self } + /// Sets the max tokens for the completion request. + /// Only required for: [ Anthropic ] + pub fn max_tokens(mut self, max_tokens: u64) -> Self { + self.max_tokens = Some(max_tokens); + self + } + + /// Sets the max tokens for the completion request. + /// Only required for: [ Anthropic ] + pub fn max_tokens_opt(mut self, max_tokens: Option) -> Self { + self.max_tokens = max_tokens; + self + } + /// Builds the completion request. pub fn build(self) -> CompletionRequest { CompletionRequest { @@ -403,6 +421,7 @@ impl CompletionRequestBuilder { documents: self.documents, tools: self.tools, temperature: self.temperature, + max_tokens: self.max_tokens, additional_params: self.additional_params, } } diff --git a/rig-core/src/providers/anthropic/client.rs b/rig-core/src/providers/anthropic/client.rs new file mode 100644 index 00000000..e2458d9d --- /dev/null +++ b/rig-core/src/providers/anthropic/client.rs @@ -0,0 +1,149 @@ +//! Anthropic client api implementation + +use crate::{agent::AgentBuilder, extractor::ExtractorBuilder}; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use super::completion::{CompletionModel, ANTHROPIC_VERSION_LATEST}; + +// ================================================================ +// Main Anthropic Client +// ================================================================ +const ANTHROPIC_API_BASE_URL: &str = "https://api.anthropic.com"; + +#[derive(Clone)] +pub struct ClientBuilder<'a> { + api_key: &'a str, + base_url: &'a str, + anthropic_version: &'a str, + anthropic_betas: Option>, +} + +/// Create a new anthropic client using the builder +/// +/// # Example +/// ``` +/// use rig::providers::anthropic::{ClientBuilder, self}; +/// +/// // Initialize the Anthropic client +/// let anthropic_client = ClientBuilder::new("your-claude-api-key") +/// .anthropic_version(ANTHROPIC_VERSION_LATEST) +/// .anthropic_beta("prompt-caching-2024-07-31") +/// .build() +/// ``` +impl<'a> ClientBuilder<'a> { + pub fn new(api_key: &'a str) -> Self { + Self { + api_key, + base_url: ANTHROPIC_API_BASE_URL, + anthropic_version: ANTHROPIC_VERSION_LATEST, + anthropic_betas: None, + } + } + + pub fn base_url(mut self, base_url: &'a str) -> Self { + self.base_url = base_url; + self + } + + pub fn anthropic_version(mut self, anthropic_version: &'a str) -> Self { + self.anthropic_version = anthropic_version; + self + } + + pub fn anthropic_beta(mut self, anthropic_beta: &'a str) -> Self { + if let Some(mut betas) = self.anthropic_betas { + betas.push(anthropic_beta); + self.anthropic_betas = Some(betas); + } else { + self.anthropic_betas = Some(vec![anthropic_beta]); + } + self + } + + pub fn build(self) -> Client { + Client::new( + self.api_key, + self.base_url, + self.anthropic_betas, + self.anthropic_version, + ) + } +} + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + /// Create a new Anthropic client with the given API key, base URL, betas, and version. + /// Note, you proably want to use the `ClientBuilder` instead. + /// + /// Panics: + /// - If the API key or version cannot be parsed as a Json value from a String. + /// - This should really never happen. + /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized). + pub fn new(api_key: &str, base_url: &str, betas: Option>, version: &str) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert("x-api-key", api_key.parse().expect("API key should parse")); + headers.insert( + "anthropic-version", + version.parse().expect("Anthropic version should parse"), + ); + if let Some(betas) = betas { + headers.insert( + "anthropic-beta", + betas + .join(",") + .parse() + .expect("Anthropic betas should parse"), + ); + } + headers + }) + .build() + .expect("Anthropic reqwest client should build"), + } + } + + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + self.http_client.post(url) + } + + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + /// Create an agent builder with the given completion model. + /// + /// # Example + /// ``` + /// use rig::providers::anthropic::{ClientBuilder, self}; + /// + /// // Initialize the Anthropic client + /// let anthropic = ClientBuilder::new("your-claude-api-key").build(); + /// + /// let agent = anthropic.agent(anthropic::CLAUDE_3_5_SONNET) + /// .preamble("You are comedian AI with a mission to make people laugh.") + /// .temperature(0.0) + /// .build(); + /// ``` + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs new file mode 100644 index 00000000..ce4ce967 --- /dev/null +++ b/rig-core/src/providers/anthropic/completion.rs @@ -0,0 +1,223 @@ +//! Anthropic completion api implementation + +use std::iter; + +use crate::{ + completion::{self, CompletionError}, + json_utils, +}; + +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use super::client::Client; + +// ================================================================ +// Anthropic Completion API +// ================================================================ +/// `claude-3-5-sonnet-20240620` completion model +pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-20240620"; + +/// `claude-3-5-haiku-20240620` completion model +pub const CLAUDE_3_OPUS: &str = "claude-3-opus-20240229"; + +/// `claude-3-sonnet-20240229` completion model +pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229"; + +/// `claude-3-haiku-20240307` completion model +pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307"; + +pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01"; +pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01"; +pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01; + +#[derive(Debug, Deserialize)] +pub struct CompletionResponse { + pub content: Vec, + pub id: String, + pub model: String, + pub role: String, + pub stop_reason: Option, + pub stop_sequence: Option, + pub usage: Usage, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum Content { + String(String), + Text { + text: String, + #[serde(rename = "type")] + content_type: String, + }, + ToolUse { + id: String, + name: String, + input: String, + #[serde(rename = "type")] + content_type: String, + }, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Usage { + pub input_tokens: u64, + pub cache_read_input_tokens: Option, + pub cache_creation_input_tokens: Option, + pub output_tokens: u64, +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ToolDefinition { + pub name: String, + pub description: Option, + pub input_schema: serde_json::Value, + pub cache_control: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum CacheControl { + Ephemeral, +} + +impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(response: CompletionResponse) -> std::prelude::v1::Result { + match response.content.as_slice() { + [Content::String(text) | Content::Text { text, .. }, ..] => { + Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message(text.to_string()), + raw_response: response, + }) + } + [Content::ToolUse { name, input, .. }, ..] => Ok(completion::CompletionResponse { + choice: completion::ModelChoice::ToolCall( + name.clone(), + serde_json::from_str(input)?, + ), + raw_response: response, + }), + _ => Err(CompletionError::ResponseError( + "Response did not contain a message or tool call".into(), + )), + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct Message { + pub role: String, + pub content: String, +} + +impl From for Message { + fn from(message: completion::Message) -> Self { + Self { + role: message.role, + content: message.content, + } + } +} + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +struct Metadata { + user_id: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ToolChoice { + Auto, + Any, + Tool { name: String }, +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + async fn completion( + &self, + completion_request: completion::CompletionRequest, + ) -> Result, CompletionError> { + let request = json!({ + "model": self.model, + "messages": completion_request + .chat_history + .into_iter() + .map(Message::from) + .chain(completion_request.documents.into_iter().map(|doc| Message { + role: "system".to_owned(), + content: serde_json::to_string(&doc).expect("Document should serialize"), + })) + .chain(iter::once(Message { + role: "user".to_owned(), + content: completion_request.prompt, + })) + .collect::>(), + "max_tokens": completion_request.max_tokens, + "system": completion_request.preamble.unwrap_or("".to_string()), + "temperature": completion_request.temperature, + "tools": completion_request + .tools + .into_iter() + .map(|tool| ToolDefinition { + name: tool.name, + description: Some(tool.description), + input_schema: tool.parameters, + cache_control: None, + }) + .collect::>(), + }); + + let request = if let Some(ref params) = completion_request.additional_params { + json_utils::merge(request, params.clone()) + } else { + request + }; + + let response = self + .client + .post("/v1/messages") + .json(&request) + .send() + .await? + .error_for_status()? + .json::>() + .await?; + + match response { + ApiResponse::Message(completion) => completion.try_into(), + ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)), + } + } +} + +#[derive(Debug, Deserialize)] +struct ApiErrorResponse { + message: String, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ApiResponse { + Message(T), + Error(ApiErrorResponse), +} diff --git a/rig-core/src/providers/anthropic/mod.rs b/rig-core/src/providers/anthropic/mod.rs new file mode 100644 index 00000000..420e3a92 --- /dev/null +++ b/rig-core/src/providers/anthropic/mod.rs @@ -0,0 +1,19 @@ +//! Anthropic API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::anthropic; +//! +//! let client = anthropic::Anthropic::new("YOUR_API_KEY"); +//! +//! let sonnet = client.completion_model(anthropic::CLAUDE_3_5_SONNET); +//! ``` + +pub mod client; +pub mod completion; + +pub use client::{Client, ClientBuilder}; +pub use completion::{ + ANTHROPIC_VERSION_2023_01_01, ANTHROPIC_VERSION_2023_06_01, ANTHROPIC_VERSION_LATEST, + CLAUDE_3_5_SONNET, CLAUDE_3_HAIKU, CLAUDE_3_OPUS, CLAUDE_3_SONNET, +}; diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 521ee1e4..84bee958 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -4,6 +4,7 @@ //! - Cohere //! - OpenAI //! - Perplexity +//! - Anthropic //! //! Each provider has its own module, which contains a `Client` implementation that can //! be used to initialize completion and embedding models and execute requests to those models. @@ -38,6 +39,7 @@ //! ``` //! Note: The example above uses the OpenAI provider client, but the same pattern can //! be used with the Cohere provider client. +pub mod anthropic; pub mod cohere; pub mod openai; pub mod perplexity;