From b57c0026b510d0a1b91ba52d0010c9c8d371c882 Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Fri, 13 Dec 2024 17:47:34 -0800 Subject: [PATCH 1/2] feat(anthropic): update model params + better max_token handling --- .../src/providers/anthropic/completion.rs | 43 +++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/rig-core/src/providers/anthropic/completion.rs b/rig-core/src/providers/anthropic/completion.rs index cc9c84dc..56c00bed 100644 --- a/rig-core/src/providers/anthropic/completion.rs +++ b/rig-core/src/providers/anthropic/completion.rs @@ -15,11 +15,14 @@ use super::client::Client; // ================================================================ // Anthropic Completion API // ================================================================ -/// `claude-3-5-sonnet-20240620` completion model -pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-20240620"; +/// `claude-3-5-sonnet-latest` completion model +pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest"; -/// `claude-3-5-haiku-20240620` completion model -pub const CLAUDE_3_OPUS: &str = "claude-3-opus-20240229"; +/// `claude-3-5-haiku-latest` completion model +pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest"; + +/// `claude-3-5-haiku-latest` completion model +pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest"; /// `claude-3-sonnet-20240229` completion model pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229"; @@ -139,6 +142,7 @@ impl From for Message { pub struct CompletionModel { client: Client, pub model: String, + default_max_tokens: Option, } impl CompletionModel { @@ -146,10 +150,29 @@ impl CompletionModel { Self { client, model: model.to_string(), + default_max_tokens: calculate_max_tokens(model), } } } +/// Anthropic requires a `max_tokens` parameter to be set, which is dependant on the model. If not +/// set or if set too high, the request will fail. The following values are based on the models +/// available at the time of writing. +/// +/// Dev Note: This is really bad design, I'm not sure why they did it like this.. +fn calculate_max_tokens(model: &str) -> Option { + if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") { + Some(8192) + } else if model.starts_with("claude-3-opus") + || model.starts_with("claude-3-sonnet") + || model.starts_with("claude-3-haiku") + { + Some(4096) + } else { + None + } +} + #[derive(Debug, Deserialize, Serialize)] struct Metadata { user_id: Option, @@ -177,11 +200,15 @@ impl completion::CompletionModel for CompletionModel { let prompt_with_context = completion_request.prompt_with_context(); // Check if max_tokens is set, required for Anthropic - if completion_request.max_tokens.is_none() { + let max_tokens = if let Some(tokens) = completion_request.max_tokens { + tokens + } else if let Some(tokens) = self.default_max_tokens { + tokens + } else { return Err(CompletionError::RequestError( - "max_tokens must be set for Anthropic".into(), + "`max_tokens` must be set for Anthropic".into(), )); - } + }; let mut request = json!({ "model": self.model, @@ -194,7 +221,7 @@ impl completion::CompletionModel for CompletionModel { content: prompt_with_context, })) .collect::>(), - "max_tokens": completion_request.max_tokens, + "max_tokens": max_tokens, "system": completion_request.preamble.unwrap_or("".to_string()), }); From 81ca1e0585051fa9855373f7f76431465b9a6efb Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Fri, 13 Dec 2024 17:49:45 -0800 Subject: [PATCH 2/2] test(anthropic): remove `max_tokens` argument --- rig-core/examples/anthropic_agent.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rig-core/examples/anthropic_agent.rs b/rig-core/examples/anthropic_agent.rs index 9e3ea255..2ef99ff7 100644 --- a/rig-core/examples/anthropic_agent.rs +++ b/rig-core/examples/anthropic_agent.rs @@ -18,7 +18,6 @@ async fn main() -> Result<(), anyhow::Error> { .agent(CLAUDE_3_5_SONNET) .preamble("Be precise and concise.") .temperature(0.5) - .max_tokens(8192) .build(); // Prompt the agent and print the response