Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(providers): Integrate anthropic models #27

Merged
merged 5 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions rig-core/examples/anthropic_agent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use std::env;

use rig::{
completion::Prompt,
providers::anthropic::{self, CLAUDE_3_5_SONNET},
};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let client = anthropic::ClientBuilder::new(
&env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
)
.build();

// Create agent with a single context prompt
let agent = client
.agent(CLAUDE_3_5_SONNET)
.preamble("Be precise and concise.")
.temperature(0.5)
.max_tokens(8192)
.build();

// Prompt the agent and print the response
let response = agent
.prompt("When and where and what type is the next solar eclipse?")
.await?;
println!("{}", response);

Ok(())
}
13 changes: 13 additions & 0 deletions rig-core/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ pub struct Agent<M: CompletionModel> {
static_tools: Vec<String>,
/// Temperature of the model
temperature: Option<f64>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// List of vector store, with the sample number
Expand Down Expand Up @@ -238,6 +240,7 @@ impl<M: CompletionModel> Completion<M> for Agent<M> {
.documents([self.static_context.clone(), dynamic_context].concat())
.tools([static_tools.clone(), dynamic_tools].concat())
.temperature_opt(self.temperature)
.max_tokens_opt(self.max_tokens)
.additional_params_opt(self.additional_params.clone()))
}
}
Expand Down Expand Up @@ -295,6 +298,8 @@ pub struct AgentBuilder<M: CompletionModel> {
static_tools: Vec<String>,
/// Additional parameters to be passed to the model
additional_params: Option<serde_json::Value>,
/// Maximum number of tokens for the completion
max_tokens: Option<u64>,
/// List of vector store, with the sample number
dynamic_context: Vec<(usize, Box<dyn VectorStoreIndexDyn>)>,
/// Dynamic tools
Expand All @@ -313,6 +318,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
static_context: vec![],
static_tools: vec![],
temperature: None,
max_tokens: None,
additional_params: None,
dynamic_context: vec![],
dynamic_tools: vec![],
Expand Down Expand Up @@ -385,6 +391,12 @@ impl<M: CompletionModel> AgentBuilder<M> {
self
}

/// Set the maximum number of tokens for the completion
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}

/// Set additional parameters to be passed to the model
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
Expand All @@ -399,6 +411,7 @@ impl<M: CompletionModel> AgentBuilder<M> {
static_context: self.static_context,
static_tools: self.static_tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
dynamic_context: self.dynamic_context,
dynamic_tools: self.dynamic_tools,
Expand Down
19 changes: 19 additions & 0 deletions rig-core/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ pub struct CompletionRequest {
pub tools: Vec<ToolDefinition>,
/// The temperature to be sent to the completion model provider
pub temperature: Option<f64>,
/// The max tokens to be sent to the completion model provider
pub max_tokens: Option<u64>,
/// Additional provider-specific parameters to be sent to the completion model provider
pub additional_params: Option<serde_json::Value>,
}
Expand Down Expand Up @@ -293,6 +295,7 @@ pub struct CompletionRequestBuilder<M: CompletionModel> {
documents: Vec<Document>,
tools: Vec<ToolDefinition>,
temperature: Option<f64>,
max_tokens: Option<u64>,
additional_params: Option<serde_json::Value>,
}

Expand All @@ -306,6 +309,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
documents: Vec::new(),
tools: Vec::new(),
temperature: None,
max_tokens: None,
additional_params: None,
}
}
Expand Down Expand Up @@ -394,6 +398,20 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
self
}

/// Sets the max tokens for the completion request.
/// Only required for: [ Anthropic ]
pub fn max_tokens(mut self, max_tokens: u64) -> Self {
self.max_tokens = Some(max_tokens);
self
}

/// Sets the max tokens for the completion request.
/// Only required for: [ Anthropic ]
pub fn max_tokens_opt(mut self, max_tokens: Option<u64>) -> Self {
self.max_tokens = max_tokens;
self
}

/// Builds the completion request.
pub fn build(self) -> CompletionRequest {
CompletionRequest {
Expand All @@ -403,6 +421,7 @@ impl<M: CompletionModel> CompletionRequestBuilder<M> {
documents: self.documents,
tools: self.tools,
temperature: self.temperature,
max_tokens: self.max_tokens,
additional_params: self.additional_params,
}
}
Expand Down
143 changes: 143 additions & 0 deletions rig-core/src/providers/anthropic/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
//! Anthropic client api implementation

use crate::{agent::AgentBuilder, extractor::ExtractorBuilder};

use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::completion::{CompletionModel, ANTHROPIC_VERSION_LATEST};

// ================================================================
// Main Cohere Client
// ================================================================
const COHERE_API_BASE_URL: &str = "https://api.anthropic.com";

#[derive(Clone)]
pub struct ClientBuilder<'a> {
api_key: &'a str,
base_url: &'a str,
anthropic_version: &'a str,
anthropic_betas: Option<Vec<&'a str>>,
}

/// Create a new anthropic client using the builder
///
/// # Example
/// ```
/// use rig::providers::anthropic::{ClientBuilder, self};
///
/// // Initialize the Anthropic client
/// let anthropic_client = ClientBuilder::new("your-claude-api-key")
/// .base_url("https://api.anthropic.com")
/// .anthropic_version(ANTHROPIC_VERSION_LATEST)
/// .anthropic_beta("prompt-caching-2024-07-31")
/// .build()
/// ```
impl<'a> ClientBuilder<'a> {
pub fn new(api_key: &'a str) -> Self {
Self {
api_key,
base_url: COHERE_API_BASE_URL,
anthropic_version: ANTHROPIC_VERSION_LATEST,
anthropic_betas: None,
}
}

pub fn base_url(mut self, base_url: &'a str) -> Self {
self.base_url = base_url;
self
}

pub fn anthropic_version(mut self, anthropic_version: &'a str) -> Self {
self.anthropic_version = anthropic_version;
self
}

pub fn anthropic_beta(mut self, anthropic_beta: &'a str) -> Self {
if let Some(mut betas) = self.anthropic_betas {
betas.push(anthropic_beta);
self.anthropic_betas = Some(betas);
} else {
self.anthropic_betas = Some(vec![anthropic_beta]);
}
self
}

pub fn build(self) -> Client {
Client::new(
self.api_key,
self.base_url,
self.anthropic_betas,
self.anthropic_version,
)
}
}

#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}

impl Client {
pub fn new(api_key: &str, base_url: &str, betas: Option<Vec<&str>>, version: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("x-api-key", api_key.parse().expect("API key should parse"));
headers.insert(
"anthropic-version",
version.parse().expect("Anthropic version should parse"),
);
if let Some(betas) = betas {
headers.insert(
"anthropic-beta",
betas
.join(",")
.parse()
.expect("Anthropic betas should parse"),
);
}
headers
})
.build()
.expect("Cohere reqwest client should build"),
}
}
0xMochan marked this conversation as resolved.
Show resolved Hide resolved

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}

pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::anthropic::{ClientBuilder, self};
///
/// // Initialize the Anthropic client
/// let anthropic = ClientBuilder::new("your-claude-api-key").build();
///
/// let agent = anthropic.agent(anthropic::CLAUDE_3_5_SONNET)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
Loading