Skip to content

Commit

Permalink
Merge pull request #27 from 0xPlaygrounds/feat/anthropic
Browse files Browse the repository at this point in the history
feat(providers): Integrate anthropic models
  • Loading branch information
cvauclair authored Sep 23, 2024
2 parents e7233e6 + 53ec3c8 commit 278ea1e
Show file tree
Hide file tree
Showing 7 changed files with 456 additions and 0 deletions.
31 changes: 31 additions & 0 deletions rig-core/examples/anthropic_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use std::env;

use rig::{
completion::Prompt,
providers::anthropic::{self, CLAUDE_3_5_SONNET},
};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let client = anthropic::ClientBuilder::new(
&env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
)
.build();

// Create agent with a single context prompt
let agent = client
.agent(CLAUDE_3_5_SONNET)
.preamble("Be precise and concise.")
.temperature(0.5)
.max_tokens(8192)
.build();

// Prompt the agent and print the response
let response = agent
.prompt("When and where and what type is the next solar eclipse?")
.await?;
println!("{}", response);

Ok(())
}
13 changes: 13 additions & 0 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ pub struct Agent<M: CompletionModel> {
static_tools: Vec<String>,
/// Temperature of the model
temperature: Option<f64>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
Expand Down Expand Up @@ -238,6 +240,7 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone()))
}
}
Expand Down Expand Up @@ -295,6 +298,8 @@ pub struct AgentBuilder<M: CompletionModel> {
static_tools: Vec<String>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
Expand All @@ -313,6 +318,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
static_context: vec![],
static_tools: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
Expand Down Expand Up @@ -385,6 +391,12 @@ impl<M: CompletionModel> AgentBuilder<M> {
self
}

/// Set the maximum number of tokens for the completion
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}

/// Set additional parameters to be passed to the model
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
Expand All @@ -399,6 +411,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
Expand Down
19 changes: 19 additions & 0 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ pub struct CompletionRequest {
pub tools: Vec<ToolDefinition>,
/// The temperature to be sent to the completion model provider
pub temperature: Option<f64>,
/// The max tokens to be sent to the completion model provider
pub max_tokens: Option<u64>,
/// Additional provider-specific parameters to be sent to the completion model provider
pub additional_params: Option<serde_json::Value>,
}
Expand Down Expand Up @@ -293,6 +295,7 @@ pub struct CompletionRequestBuilder<M: CompletionModel> {
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
}

Expand All @@ -306,6 +309,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
documents: Vec::new(),
tools: Vec::new(),
temperature: None,
max_tokens: None,
additional_params: None,
}
}
Expand Down Expand Up @@ -394,6 +398,20 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
self
}

/// Sets the max tokens for the completion request.
/// Only required for: [ Anthropic ]
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}

/// Sets the max tokens for the completion request.
/// Only required for: [ Anthropic ]
pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
self.max_tokens = max_tokens;
self
}

/// Builds the completion request.
pub fn build(self) -> CompletionRequest {
CompletionRequest {
Expand All @@ -403,6 +421,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
documents: self.documents,
tools: self.tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
}
}
Expand Down
149 changes: 149 additions & 0 deletions rig-core/src/providers/anthropic/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
//! Anthropic client api implementation
use crate::{agent::AgentBuilder, extractor::ExtractorBuilder};

use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::completion::{CompletionModel, ANTHROPIC_VERSION_LATEST};

// ================================================================
// Main Anthropic Client
// ================================================================
const ANTHROPIC_API_BASE_URL: &str = "https://api.anthropic.com";

#[derive(Clone)]
pub struct ClientBuilder<'a> {
api_key: &'a str,
base_url: &'a str,
anthropic_version: &'a str,
anthropic_betas: Option<Vec<&'a str>>,
}

/// Create a new anthropic client using the builder
///
/// # Example
/// ```
/// use rig::providers::anthropic::{ClientBuilder, self};
///
/// // Initialize the Anthropic client
/// let anthropic_client = ClientBuilder::new("your-claude-api-key")
/// .anthropic_version(ANTHROPIC_VERSION_LATEST)
/// .anthropic_beta("prompt-caching-2024-07-31")
/// .build()
/// ```
impl<'a> ClientBuilder<'a> {
pub fn new(api_key: &'a str) -> Self {
Self {
api_key,
base_url: ANTHROPIC_API_BASE_URL,
anthropic_version: ANTHROPIC_VERSION_LATEST,
anthropic_betas: None,
}
}

pub fn base_url(mut self, base_url: &'a str) -> Self {
self.base_url = base_url;
self
}

pub fn anthropic_version(mut self, anthropic_version: &'a str) -> Self {
self.anthropic_version = anthropic_version;
self
}

pub fn anthropic_beta(mut self, anthropic_beta: &'a str) -> Self {
if let Some(mut betas) = self.anthropic_betas {
betas.push(anthropic_beta);
self.anthropic_betas = Some(betas);
} else {
self.anthropic_betas = Some(vec![anthropic_beta]);
}
self
}

pub fn build(self) -> Client {
Client::new(
self.api_key,
self.base_url,
self.anthropic_betas,
self.anthropic_version,
)
}
}

#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}

impl Client {
/// Create a new Anthropic client with the given API key, base URL, betas, and version.
/// Note, you proably want to use the `ClientBuilder` instead.
///
/// Panics:
/// - If the API key or version cannot be parsed as a Json value from a String.
/// - This should really never happen.
/// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
pub fn new(api_key: &str, base_url: &str, betas: Option<Vec<&str>>, version: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("x-api-key", api_key.parse().expect("API key should parse"));
headers.insert(
"anthropic-version",
version.parse().expect("Anthropic version should parse"),
);
if let Some(betas) = betas {
headers.insert(
"anthropic-beta",
betas
.join(",")
.parse()
.expect("Anthropic betas should parse"),
);
}
headers
})
.build()
.expect("Anthropic reqwest client should build"),
}
}

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}

pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::anthropic::{ClientBuilder, self};
///
/// // Initialize the Anthropic client
/// let anthropic = ClientBuilder::new("your-claude-api-key").build();
///
/// let agent = anthropic.agent(anthropic::CLAUDE_3_5_SONNET)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
Loading

0 comments on commit 278ea1e

Please sign in to comment.