Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(xai): initial xai (grok) implementation #106

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions rig-core/examples/agent_with_grok.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
use std::env;

use rig::{
agent::AgentBuilder,
completion::{Prompt, ToolDefinition},
loaders::FileLoader,
providers,
tool::Tool,
};
use serde::{Deserialize, Serialize};
use serde_json::json;

fn client() -> providers::xai::Client {
providers::xai::Client::new(&env::var("XAI_API_KEY").expect("XAI_API_KEY not set"))
}

/// Create an xAI agent (grok)
fn partial_agent() -> AgentBuilder<providers::xai::completion::CompletionModel> {
let client = client();
client.agent(providers::xai::GROK_BETA)
}

async fn basic() -> Result<(), anyhow::Error> {
let comedian_agent = partial_agent()
.preamble("You are a comedian here to entertain the user using humour and jokes.")
.build();

// Prompt the agent and print the response
let response = comedian_agent.prompt("Entertain me!").await?;
println!("{}", response);

Ok(())
}

async fn tools() -> Result<(), anyhow::Error> {
// Create agent with a single context prompt and two tools
let calculator_agent = partial_agent()
.preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
.max_tokens(1024)
.tool(Adder)
.tool(Subtract)
.build();

// Prompt the agent and print the response
println!("Calculate 2 - 5");
println!(
"Calculator Agent: {}",
calculator_agent.prompt("Calculate 2 - 5").await?
);

Ok(())
}

async fn loaders() -> Result<(), anyhow::Error> {
let model = client().completion_model(providers::xai::GROK_BETA);

// Load in all the rust examples
let examples = FileLoader::with_glob("rig-core/examples/*.rs")?
.read_with_path()
.ignore_errors()
.into_iter();

// Create an agent with multiple context documents
let agent = examples
.fold(AgentBuilder::new(model), |builder, (path, content)| {
builder.context(format!("Rust Example {:?}:\n{}", path, content).as_str())
})
.build();

// Prompt the agent and print the response
let response = agent
.prompt("Which rust example is best suited for the operation 1 + 2")
.await?;

println!("{}", response);

Ok(())
}

async fn context() -> Result<(), anyhow::Error> {
let model = client().completion_model(providers::xai::GROK_BETA);

// Create an agent with multiple context documents
let agent = AgentBuilder::new(model)
.context("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
.context("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
.context("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
.build();

// Prompt the agent and print the response
let response = agent.prompt("What does \"glarb-glarb\" mean?").await?;

println!("{}", response);

Ok(())
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
basic().await?;
tools().await?;
loaders().await?;
context().await?;

Ok(())
}

#[derive(Deserialize)]
struct OperationArgs {
x: i32,
y: i32,
}

#[derive(Debug, thiserror::Error)]
#[error("Math error")]
struct MathError;

#[derive(Deserialize, Serialize)]
struct Adder;
impl Tool for Adder {
const NAME: &'static str = "add";

type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "add".to_string(),
description: "Add x and y together".to_string(),
parameters: json!({
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The first number to add"
},
"y": {
"type": "number",
"description": "The second number to add"
}
}
}),
}
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x + args.y;
Ok(result)
}
}

#[derive(Deserialize, Serialize)]
struct Subtract;
impl Tool for Subtract {
const NAME: &'static str = "subtract";

type Error = MathError;
type Args = OperationArgs;
type Output = i32;

async fn definition(&self, _prompt: String) -> ToolDefinition {
serde_json::from_value(json!({
"name": "subtract",
"description": "Subtract y from x (i.e.: x - y)",
"parameters": {
"type": "object",
"properties": {
"x": {
"type": "number",
"description": "The number to substract from"
},
"y": {
"type": "number",
"description": "The number to substract"
}
}
}
}))
.expect("Tool Definition")
}

async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let result = args.x - args.y;
Ok(result)
}
}
19 changes: 19 additions & 0 deletions rig-core/examples/xai_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use rig::providers::xai;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize the xAI client
let client = xai::Client::from_env();

let embeddings = client
.embeddings(xai::embedding::EMBEDDING_V1)
.simple_document("doc0", "Hello, world!")
.simple_document("doc1", "Goodbye, world!")
.build()
.await
.expect("Failed to embed documents");

println!("{:?}", embeddings);

Ok(())
}
1 change: 1 addition & 0 deletions rig-core/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ pub mod cohere;
pub mod gemini;
pub mod openai;
pub mod perplexity;
pub mod xai;
172 changes: 172 additions & 0 deletions rig-core/src/providers/xai/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use crate::{
agent::AgentBuilder,
embeddings::{self},
extractor::ExtractorBuilder,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};

use super::{completion::CompletionModel, embedding::EmbeddingModel, EMBEDDING_V1};

// ================================================================
// Google Gemini Client
// ================================================================
const XAI_BASE_URL: &str = "https://api.x.ai";

#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}

impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, XAI_BASE_URL)
}
fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("xAI reqwest client should build"),
}
}

/// Create a new xAI client from the `XAI_API_KEY` environment variable.
/// Panics if the environment variable is not set.
pub fn from_env() -> Self {
let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
Self::new(&api_key)
}

pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");

tracing::debug!("POST {}", url);
self.http_client.post(url)
}

/// Create an embedding model with the given name.
/// Note: default embedding dimension of 0 will be used if model is not known.
/// If this is the case, it's better to use function `embedding_model_with_ndims`
///
/// # Example
/// ```
/// use rig::providers::xai::{Client, self};
///
/// // Initialize the xAI client
/// let xai = Client::new("your-xai-api-key");
///
/// let embedding_model = xai.embedding_model(xai::embedding::EMBEDDING_V1);
/// ```
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
let ndims = match model {
EMBEDDING_V1 => 3072,
_ => 0,
};
EmbeddingModel::new(self.clone(), model, ndims)
}

/// Create an embedding model with the given name and the number of dimensions in the embedding
/// generated by the model.
///
/// # Example
/// ```
/// use rig::providers::xai::{Client, self};
///
/// // Initialize the xAI client
/// let xai = Client::new("your-xai-api-key");
///
/// let embedding_model = xai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
/// ```
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, ndims)
}

/// Create an embedding builder with the given embedding model.
///
/// # Example
/// ```
/// use rig::providers::xai::{Client, self};
///
/// // Initialize the xAI client
/// let xai = Client::new("your-xai-api-key");
///
/// let embeddings = xai.embeddings(xai::embedding::EMBEDDING_V1)
/// .simple_document("doc0", "Hello, world!")
/// .simple_document("doc1", "Goodbye, world!")
/// .build()
/// .await
/// .expect("Failed to embed documents");
/// ```
pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder<EmbeddingModel> {
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
}

/// Create a completion model with the given name.
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}

/// Create an agent builder with the given completion model.
/// # Example
/// ```
/// use rig::providers::xai::{Client, self};
///
/// // Initialize the xAI client
/// let xai = Client::new("your-xai-api-key");
///
/// let agent = xai.agent(xai::completion::GROK_BETA)
/// .preamble("You are comedian AI with a mission to make people laugh.")
/// .temperature(0.0)
/// .build();
/// ```
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}

/// Create an extractor builder with the given completion model.
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}

pub mod xai_api_types {
use serde::Deserialize;

impl ApiErrorResponse {
pub fn message(&self) -> String {
format!("Code `{}`: {}", self.code, self.error)
}
}

#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub error: String,
pub code: String,
}

#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Error(ApiErrorResponse),
}
}
Loading
Loading