Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(anthropic): update model params + better max_token handling #151

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rig-core/examples/anthropic_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ async fn main() -> Result<(), anyhow::Error> {
.agent(CLAUDE_3_5_SONNET)
.preamble("Be precise and concise.")
.temperature(0.5)
.max_tokens(8192)
.build();

// Prompt the agent and print the response
Expand Down
43 changes: 35 additions & 8 deletions rig-core/src/providers/anthropic/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ use super::client::Client;
// ================================================================
// Anthropic Completion API
// ================================================================
/// `claude-3-5-sonnet-20240620` completion model
pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-20240620";
/// `claude-3-5-sonnet-latest` completion model
pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";

/// `claude-3-5-haiku-20240620` completion model
pub const CLAUDE_3_OPUS: &str = "claude-3-opus-20240229";
/// `claude-3-5-haiku-latest` completion model
pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";

/// `claude-3-5-haiku-latest` completion model
pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";

/// `claude-3-sonnet-20240229` completion model
pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
Expand Down Expand Up @@ -139,17 +142,37 @@ impl From<completion::Message> for Message {
pub struct CompletionModel {
client: Client,
pub model: String,
default_max_tokens: Option<u64>,
}

impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
default_max_tokens: calculate_max_tokens(model),
}
}
}

/// Anthropic requires a `max_tokens` parameter to be set, which is dependant on the model. If not
/// set or if set too high, the request will fail. The following values are based on the models
/// available at the time of writing.
///
/// Dev Note: This is really bad design, I'm not sure why they did it like this..
fn calculate_max_tokens(model: &str) -> Option<u64> {
if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
Some(8192)
} else if model.starts_with("claude-3-opus")
|| model.starts_with("claude-3-sonnet")
|| model.starts_with("claude-3-haiku")
{
Some(4096)
} else {
None
}
}

#[derive(Debug, Deserialize, Serialize)]
struct Metadata {
user_id: Option<String>,
Expand Down Expand Up @@ -177,11 +200,15 @@ impl completion::CompletionModel for CompletionModel {
let prompt_with_context = completion_request.prompt_with_context();

// Check if max_tokens is set, required for Anthropic
if completion_request.max_tokens.is_none() {
let max_tokens = if let Some(tokens) = completion_request.max_tokens {
tokens
} else if let Some(tokens) = self.default_max_tokens {
tokens
} else {
return Err(CompletionError::RequestError(
"max_tokens must be set for Anthropic".into(),
"`max_tokens` must be set for Anthropic".into(),
));
}
};

let mut request = json!({
"model": self.model,
Expand All @@ -194,7 +221,7 @@ impl completion::CompletionModel for CompletionModel {
content: prompt_with_context,
}))
.collect::<Vec<_>>(),
"max_tokens": completion_request.max_tokens,
"max_tokens": max_tokens,
"system": completion_request.preamble.unwrap_or("".to_string()),
});

Expand Down
Loading