Skip to content

Commit

Permalink
Merge pull request #67 from 0xPlaygrounds/fix/anthropic-bad-request
Browse files Browse the repository at this point in the history
fix(anthropic): bad request with tools, max_token & temp
  • Loading branch information
cvauclair authored Oct 23, 2024
2 parents 093c434 + 8d789ee commit 0d0083b
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 123 deletions.
31 changes: 7 additions & 24 deletions rig-core/examples/agent_with_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use rig::{
};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::env;

#[derive(Deserialize)]
struct OperationArgs {
Expand Down Expand Up @@ -92,38 +91,22 @@ impl Tool for Subtract {
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = providers::openai::Client::new(&openai_api_key);
let openai_client = providers::openai::Client::from_env();

// Create agent with a single context prompt and two tools
let gpt4_calculator_agent = openai_client
.agent("gpt-4")
.context("You are a calculator here to help the user perform arithmetic operations.")
.tool(Adder)
.tool(Subtract)
.build();

// Create OpenAI client
let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
let cohere_client = providers::cohere::Client::new(&cohere_api_key);

// Create agent with a single context prompt and two tools
let coral_calculator_agent = cohere_client
.agent("command-r")
.preamble("You are a calculator here to help the user perform arithmetic operations.")
let calculator_agent = openai_client
.agent(providers::openai::GPT_4O)
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();

// Prompt the agent and print the response
println!("Calculate 2 - 5");
println!(
"GPT-4: {}",
gpt4_calculator_agent.prompt("Calculate 2 - 5").await?
);
println!(
"Coral: {}",
coral_calculator_agent.prompt("Calculate 2 - 5").await?
"Calculator Agent: {}",
calculator_agent.prompt("Calculate 2 - 5").await?
);

Ok(())
Expand Down
20 changes: 14 additions & 6 deletions rig-core/src/json_utils.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
pub fn merge(a: serde_json::Value, b: serde_json::Value) -> serde_json::Value {
match (a.clone(), b) {
(serde_json::Value::Object(mut a), serde_json::Value::Object(b)) => {
b.into_iter().for_each(|(key, value)| {
a.insert(key.clone(), value.clone());
match (a, b) {
(serde_json::Value::Object(mut a_map), serde_json::Value::Object(b_map)) => {
b_map.into_iter().for_each(|(key, value)| {
a_map.insert(key, value);
});
serde_json::Value::Object(a)
serde_json::Value::Object(a_map)
}
_ => a,
(a, _) => a,
}
}

pub fn merge_inplace(a: &mut serde_json::Value, b: serde_json::Value) {
if let (serde_json::Value::Object(a_map), serde_json::Value::Object(b_map)) = (a, b) {
b_map.into_iter().for_each(|(key, value)| {
a_map.insert(key, value);
});
}
}
7 changes: 7 additions & 0 deletions rig-core/src/providers/anthropic/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ impl Client {
}
}

/// Create a new Anthropic client from the `ANTHROPIC_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set");
ClientBuilder::new(&api_key).build()
}

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
Expand Down
81 changes: 48 additions & 33 deletions rig-core/src/providers/anthropic/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,14 @@ pub struct CompletionResponse {
pub enum Content {
String(String),
Text {
r#type: String,
text: String,
#[serde(rename = "type")]
content_type: String,
},
ToolUse {
r#type: String,
id: String,
name: String,
input: String,
#[serde(rename = "type")]
content_type: String,
input: serde_json::Value,
},
}

Expand All @@ -73,7 +71,6 @@ pub struct ToolDefinition {
pub name: String,
pub description: Option<String>,
pub input_schema: serde_json::Value,
pub cache_control: Option<CacheControl>,
}

#[derive(Debug, Deserialize, Serialize)]
Expand All @@ -94,10 +91,7 @@ impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionRe
})
}
[Content::ToolUse { name, input, .. }, ..] => Ok(completion::CompletionResponse {
choice: completion::ModelChoice::ToolCall(
name.clone(),
serde_json::from_str(input)?,
),
choice: completion::ModelChoice::ToolCall(name.clone(), input.clone()),
raw_response: response,
}),
_ => Err(CompletionError::ResponseError(
Expand Down Expand Up @@ -157,9 +151,20 @@ impl completion::CompletionModel for CompletionModel {
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Note: Ideally we'd introduce provider-specific Request models to handle the
// specific requirements of each provider. For now, we just manually check while
// building the request as a raw JSON document.

let prompt_with_context = completion_request.prompt_with_context();

let request = json!({
// Check if max_tokens is set, required for Anthropic
if completion_request.max_tokens.is_none() {
return Err(CompletionError::RequestError(
"max_tokens must be set for Anthropic".into(),
));
}

let mut request = json!({
"model": self.model,
"messages": completion_request
.chat_history
Expand All @@ -172,38 +177,48 @@ impl completion::CompletionModel for CompletionModel {
.collect::<Vec<_>>(),
"max_tokens": completion_request.max_tokens,
"system": completion_request.preamble.unwrap_or("".to_string()),
"temperature": completion_request.temperature,
"tools": completion_request
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
cache_control: None,
})
.collect::<Vec<_>>(),
});

let request = if let Some(ref params) = completion_request.additional_params {
json_utils::merge(request, params.clone())
} else {
request
};
if let Some(temperature) = completion_request.temperature {
json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
}

if !completion_request.tools.is_empty() {
json_utils::merge_inplace(
&mut request,
json!({
"tools": completion_request
.tools
.into_iter()
.map(|tool| ToolDefinition {
name: tool.name,
description: Some(tool.description),
input_schema: tool.parameters,
})
.collect::<Vec<_>>(),
"tool_choice": ToolChoice::Auto,
}),
);
}

if let Some(ref params) = completion_request.additional_params {
json_utils::merge_inplace(&mut request, params.clone())
}

let response = self
.client
.post("/v1/messages")
.json(&request)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<CompletionResponse>>()
.await?;

match response {
ApiResponse::Message(completion) => completion.try_into(),
ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Message(completion) => completion.try_into(),
ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
Expand Down
65 changes: 37 additions & 28 deletions rig-core/src/providers/cohere.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ impl Client {
}
}

/// Create a new Cohere client from the `COHERE_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
Self::new(&api_key)
}

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
Expand Down Expand Up @@ -203,32 +210,33 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
"input_type": self.input_type,
}))
.send()
.await?
.error_for_status()?
.json::<ApiResponse<EmbeddingResponse>>()
.await?;

match response {
ApiResponse::Ok(response) => {
if response.embeddings.len() != documents.len() {
return Err(EmbeddingError::DocumentError(format!(
"Expected {} embeddings, got {}",
documents.len(),
response.embeddings.len()
)));
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
if response.embeddings.len() != documents.len() {
return Err(EmbeddingError::DocumentError(format!(
"Expected {} embeddings, got {}",
documents.len(),
response.embeddings.len()
)));
}

Ok(response
.embeddings
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding,
})
.collect())
}

Ok(response
.embeddings
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding,
})
.collect())
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
}
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
Expand Down Expand Up @@ -500,14 +508,15 @@ impl completion::CompletionModel for CompletionModel {
},
)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<CompletionResponse>>()
.await?;

match response {
ApiResponse::Ok(completion) => Ok(completion.into()),
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(completion) => Ok(completion.into()),
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
54 changes: 28 additions & 26 deletions rig-core/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,30 +251,31 @@ impl embeddings::EmbeddingModel for EmbeddingModel {
"input": documents,
}))
.send()
.await?
.error_for_status()?
.json::<ApiResponse<EmbeddingResponse>>()
.await?;

match response {
ApiResponse::Ok(response) => {
if response.data.len() != documents.len() {
return Err(EmbeddingError::ResponseError(
"Response data length does not match input length".into(),
));
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
if response.data.len() != documents.len() {
return Err(EmbeddingError::ResponseError(
"Response data length does not match input length".into(),
));
}

Ok(response
.data
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding.embedding,
})
.collect())
}

Ok(response
.data
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding.embedding,
})
.collect())
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
}
ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
Expand Down Expand Up @@ -510,14 +511,15 @@ impl completion::CompletionModel for CompletionModel {
},
)
.send()
.await?
.error_for_status()?
.json::<ApiResponse<CompletionResponse>>()
.await?;

match response {
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(response) => response.try_into(),
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
Loading

0 comments on commit 0d0083b

Please sign in to comment.