Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: groq integration #263

Merged
merged 6 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions rig-core/examples/agent_with_groq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::env;

use rig::{
completion::Prompt,
providers::{self, groq::DEEPSEEK_R1_DISTILL_LLAMA_70B},
};

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create OpenAI client
let client =
providers::groq::Client::new(&env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set"));

// Create agent with a single context prompt
let comedian_agent = client
.agent(DEEPSEEK_R1_DISTILL_LLAMA_70B)
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();

// Prompt the agent and print the response
let response = comedian_agent.prompt("Entertain me!").await?;
println!("{}", response);

Ok(())
}
336 changes: 336 additions & 0 deletions rig-core/src/providers/groq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,336 @@
//! Groq API client and Rig integration
//!
//! # Example
//! ```
//! use rig::providers::groq;
//!
//! let client = groq::Client::new("YOUR_API_KEY");
//!
//! let gpt4o = client.completion_model(groq::GPT_4O);
//! ```
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError, CompletionRequest},
extractor::ExtractorBuilder,
json_utils,
message::{self, MessageError},
providers::openai::ToolDefinition,
OneOrMany,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;

use super::openai::CompletionResponse;

// ================================================================
// Main Groq Client
// ================================================================
const GROQ_API_BASE_URL: &str = "https://api.groq.com/openai/v1";

#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}

impl Client {
/// Create a new Groq client with the given API key.
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, GROQ_API_BASE_URL)
}

/// Create a new Groq client with the given API key and base API URL.
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("Groq reqwest client should build"),
}
}

/// Create a new Groq client from the `GROQ_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY not set");
Self::new(&api_key)
}

fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}

/// Create a completion model with the given name.
///
/// # Example
/// ```
/// use rig::providers::groq::{Client, self};
///
/// // Initialize the Groq client
/// let groq = Client::new("your-groq-api-key");
///
/// let gpt4 = groq.completion_model(groq::GPT_4);
/// ```
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
///
/// # Example
/// ```
/// use rig::providers::groq::{Client, self};
///
/// // Initialize the Groq client
/// let groq = Client::new("your-groq-api-key");
///
/// let agent = groq.agent(groq::GPT_4)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: Option<String>,
}

impl TryFrom<Message> for message::Message {
type Error = message::MessageError;

fn try_from(message: Message) -> Result<Self, Self::Error> {
match message.role.as_str() {
"user" => Ok(Self::User {
content: OneOrMany::one(
message
.content
.map(|content| message::UserContent::text(&content))
.ok_or_else(|| {
message::MessageError::ConversionError("Empty user message".to_string())
})?,
),
}),
"assistant" => Ok(Self::Assistant {
content: OneOrMany::one(
message
.content
.map(|content| message::AssistantContent::text(&content))
.ok_or_else(|| {
message::MessageError::ConversionError(
"Empty assistant message".to_string(),
)
})?,
),
}),
_ => Err(message::MessageError::ConversionError(format!(
"Unknown role: {}",
message.role
))),
}
}
}

impl TryFrom<message::Message> for Message {
type Error = message::MessageError;

fn try_from(message: message::Message) -> Result<Self, Self::Error> {
match message {
message::Message::User { content } => Ok(Self {
role: "user".to_string(),
content: content.iter().find_map(|c| match c {
message::UserContent::Text(text) => Some(text.text.clone()),
_ => None,
}),
}),
message::Message::Assistant { content } => {
let mut text_content: Option<String> = None;

for c in content.iter() {
match c {
message::AssistantContent::Text(text) => {
text_content = Some(
text_content
.map(|mut existing| {
existing.push('\n');
existing.push_str(&text.text);
existing
})
.unwrap_or_else(|| text.text.clone()),
);
}
message::AssistantContent::ToolCall(_tool_call) => {
return Err(MessageError::ConversionError(
"Tool calls do not exist on this message".into(),
))
}
}
}

Ok(Self {
role: "assistant".to_string(),
content: text_content,
})
}
}
}
}

// ================================================================
// Groq Completion API
// ================================================================
/// The `deepseek-r1-distill-llama-70b` model. Used for chat completion.
pub const DEEPSEEK_R1_DISTILL_LLAMA_70B: &str = "deepseek-r1-distill-llama-70b";
/// The `gemma2-9b-it` model. Used for chat completion.
pub const GEMMA2_9B_IT: &str = "gemma2-9b-it";
/// The `llama-3.1-8b-instant` model. Used for chat completion.
pub const LLAMA_3_1_8B_INSTANT: &str = "llama-3.1-8b-instant";
/// The `llama-3.2-11b-vision-preview` model. Used for chat completion.
pub const LLAMA_3_2_11B_VISION_PREVIEW: &str = "llama-3.2-11b-vision-preview";
/// The `llama-3.2-1b-preview` model. Used for chat completion.
pub const LLAMA_3_2_1B_PREVIEW: &str = "llama-3.2-1b-preview";
/// The `llama-3.2-3b-preview` model. Used for chat completion.
pub const LLAMA_3_2_3B_PREVIEW: &str = "llama-3.2-3b-preview";
/// The `llama-3.2-90b-vision-preview` model. Used for chat completion.
pub const LLAMA_3_2_90B_VISION_PREVIEW: &str = "llama-3.2-90b-vision-preview";
/// The `llama-3.2-70b-specdec` model. Used for chat completion.
pub const LLAMA_3_2_70B_SPECDEC: &str = "llama-3.2-70b-specdec";
/// The `llama-3.2-70b-versatile` model. Used for chat completion.
pub const LLAMA_3_2_70B_VERSATILE: &str = "llama-3.2-70b-versatile";
/// The `llama-guard-3-8b` model. Used for chat completion.
pub const LLAMA_GUARD_3_8B: &str = "llama-guard-3-8b";
/// The `llama3-70b-8192` model. Used for chat completion.
pub const LLAMA_3_70B_8192: &str = "llama3-70b-8192";
/// The `llama3-8b-8192` model. Used for chat completion.
pub const LLAMA_3_8B_8192: &str = "llama3-8b-8192";
/// The `mixtral-8x7b-32768` model. Used for chat completion.
pub const MIXTRAL_8X7B_32768: &str = "mixtral-8x7b-32768";

#[derive(Clone)]
pub struct CompletionModel {
client: Client,
/// Name of the model (e.g.: deepseek-r1-distill-llama-70b)
pub model: String,
}

impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}

impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;

#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
// Add preamble to chat history (if available)
let mut full_history: Vec<Message> = match &completion_request.preamble {
Some(preamble) => vec![Message {
role: "system".to_string(),
content: Some(preamble.to_string()),
}],
None => vec![],
};

// Convert prompt to user message
let prompt: Message = completion_request.prompt_with_context().try_into()?;

// Convert existing chat history
let chat_history: Vec<Message> = completion_request
.chat_history
.into_iter()
.map(|message| message.try_into())
.collect::<Result<Vec<Message>, _>>()?;

// Combine all messages into a single history
full_history.extend(chat_history);
full_history.push(prompt);

let request = if completion_request.tools.is_empty() {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
})
} else {
json!({
"model": self.model,
"messages": full_history,
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
"tool_choice": "auto",
})
};

let response = self
.client
.post("/chat/completions")
.json(
&if let Some(params) = completion_request.additional_params {
json_utils::merge(request, params)
} else {
request
},
)
.send()
.await?;

if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(response) => {
tracing::info!(target: "rig",
"groq completion token usage: {:?}",
response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
);
response.try_into()
}
ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}
1 change: 1 addition & 0 deletions rig-core/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub mod cohere;
pub mod deepseek;
pub mod galadriel;
pub mod gemini;
pub mod groq;
pub mod hyperbolic;
pub mod moonshot;
pub mod openai;
Expand Down