diff --git a/README.md b/README.md
index 0c2a7ab0..bb5e1bc7 100644
--- a/README.md
+++ b/README.md
@@ -39,8 +39,8 @@ Help us improve Rig by contributing to our [Feedback form](https://bit.ly/Rig-Fe
- [What is Rig?](#what-is-rig)
- [Table of contents](#table-of-contents)
- [High-level features](#high-level-features)
-- [Installation](#installation)
-- [Simple example:](#simple-example)
+- [Get Started](#get-started)
+ - [Simple example:](#simple-example)
- [Integrations](#integrations)
## High-level features
@@ -85,7 +85,7 @@ You can find more examples each crate's `examples` (ie. [`src/examples`](./src/e
| Model Providers | Vector Stores |
|:--------------:|:-------------:|
-|
|
|
+|
|
|
Vector stores are available as separate companion-crates:
diff --git a/rig-core/README.md b/rig-core/README.md
index 567bb9e4..f8fa2c71 100644
--- a/rig-core/README.md
+++ b/rig-core/README.md
@@ -4,10 +4,12 @@ Rig is a Rust library for building LLM-powered applications that focuses on ergo
More information about this crate can be found in the [crate documentation](https://docs.rs/rig-core/latest/rig/).
## Table of contents
-- [High-level features](#high-level-features)
-- [Installation](#)
-- [Simple Example](#simple-example)
-- [Integrations](#integrations)
+- [Rig](#rig)
+ - [Table of contents](#table-of-contents)
+ - [High-level features](#high-level-features)
+ - [Installation](#installation)
+ - [Simple example:](#simple-example)
+ - [Integrations](#integrations)
## High-level features
- Full support for LLM completion and embedding workflows
@@ -48,5 +50,7 @@ Rig supports the following LLM providers natively:
- OpenAI
- Cohere
- Google Gemini
+- xAI
+
Additionally, Rig currently has the following integration sub-libraries:
- MongoDB vector store: `rig-mongodb`
diff --git a/rig-core/examples/agent_with_grok.rs b/rig-core/examples/agent_with_grok.rs
new file mode 100644
index 00000000..12c8aeb9
--- /dev/null
+++ b/rig-core/examples/agent_with_grok.rs
@@ -0,0 +1,209 @@
+use std::env;
+
+use rig::{
+ agent::AgentBuilder,
+ completion::{Prompt, ToolDefinition},
+ loaders::FileLoader,
+ providers,
+ tool::Tool,
+};
+use serde::{Deserialize, Serialize};
+use serde_json::json;
+
+/// Runs 4 agents based on grok (dervived from the other examples)
+#[tokio::main]
+async fn main() -> Result<(), anyhow::Error> {
+ println!("Running basic agent with grok");
+ basic().await?;
+
+ println!("\nRunning grok agent with tools");
+ tools().await?;
+
+ println!("\nRunning grok agent with loaders");
+ loaders().await?;
+
+ println!("\nRunning grok agent with context");
+ context().await?;
+
+ println!("\n\nAll agents ran successfully");
+ Ok(())
+}
+
+fn client() -> providers::xai::Client {
+ providers::xai::Client::new(&env::var("XAI_API_KEY").expect("XAI_API_KEY not set"))
+}
+
+/// Create a partial xAI agent (grok)
+fn partial_agent() -> AgentBuilder {
+ let client = client();
+ client.agent(providers::xai::GROK_BETA)
+}
+
+/// Create an xAI agent (grok) with a preamble
+/// Based upon the `agent` example
+///
+/// This example creates a comedian agent with a preamble
+async fn basic() -> Result<(), anyhow::Error> {
+ let comedian_agent = partial_agent()
+ .preamble("You are a comedian here to entertain the user using humour and jokes.")
+ .build();
+
+ // Prompt the agent and print the response
+ let response = comedian_agent.prompt("Entertain me!").await?;
+ println!("{}", response);
+
+ Ok(())
+}
+
+/// Create an xAI agent (grok) with tools
+/// Based upon the `tools` example
+///
+/// This example creates a calculator agent with two tools: add and subtract
+async fn tools() -> Result<(), anyhow::Error> {
+ // Create agent with a single context prompt and two tools
+ let calculator_agent = partial_agent()
+ .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
+ .max_tokens(1024)
+ .tool(Adder)
+ .tool(Subtract)
+ .build();
+
+ // Prompt the agent and print the response
+ println!("Calculate 2 - 5");
+ println!(
+ "Calculator Agent: {}",
+ calculator_agent.prompt("Calculate 2 - 5").await?
+ );
+
+ Ok(())
+}
+
+/// Create an xAI agent (grok) with loaders
+/// Based upon the `loaders` example
+///
+/// This example loads in all the rust examples from the rig-core crate and uses them as\\
+/// context for the agent
+async fn loaders() -> Result<(), anyhow::Error> {
+ let model = client().completion_model(providers::xai::GROK_BETA);
+
+ // Load in all the rust examples
+ let examples = FileLoader::with_glob("rig-core/examples/*.rs")?
+ .read_with_path()
+ .ignore_errors()
+ .into_iter();
+
+ // Create an agent with multiple context documents
+ let agent = examples
+ .fold(AgentBuilder::new(model), |builder, (path, content)| {
+ builder.context(format!("Rust Example {:?}:\n{}", path, content).as_str())
+ })
+ .build();
+
+ // Prompt the agent and print the response
+ let response = agent
+ .prompt("Which rust example is best suited for the operation 1 + 2")
+ .await?;
+
+ println!("{}", response);
+
+ Ok(())
+}
+
+async fn context() -> Result<(), anyhow::Error> {
+ let model = client().completion_model(providers::xai::GROK_BETA);
+
+ // Create an agent with multiple context documents
+ let agent = AgentBuilder::new(model)
+ .context("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets")
+ .context("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.")
+ .context("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.")
+ .build();
+
+ // Prompt the agent and print the response
+ let response = agent.prompt("What does \"glarb-glarb\" mean?").await?;
+
+ println!("{}", response);
+
+ Ok(())
+}
+
+#[derive(Deserialize)]
+struct OperationArgs {
+ x: i32,
+ y: i32,
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error("Math error")]
+struct MathError;
+
+#[derive(Deserialize, Serialize)]
+struct Adder;
+impl Tool for Adder {
+ const NAME: &'static str = "add";
+
+ type Error = MathError;
+ type Args = OperationArgs;
+ type Output = i32;
+
+ async fn definition(&self, _prompt: String) -> ToolDefinition {
+ ToolDefinition {
+ name: "add".to_string(),
+ description: "Add x and y together".to_string(),
+ parameters: json!({
+ "type": "object",
+ "properties": {
+ "x": {
+ "type": "number",
+ "description": "The first number to add"
+ },
+ "y": {
+ "type": "number",
+ "description": "The second number to add"
+ }
+ }
+ }),
+ }
+ }
+
+ async fn call(&self, args: Self::Args) -> Result {
+ let result = args.x + args.y;
+ Ok(result)
+ }
+}
+
+#[derive(Deserialize, Serialize)]
+struct Subtract;
+impl Tool for Subtract {
+ const NAME: &'static str = "subtract";
+
+ type Error = MathError;
+ type Args = OperationArgs;
+ type Output = i32;
+
+ async fn definition(&self, _prompt: String) -> ToolDefinition {
+ serde_json::from_value(json!({
+ "name": "subtract",
+ "description": "Subtract y from x (i.e.: x - y)",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "x": {
+ "type": "number",
+ "description": "The number to substract from"
+ },
+ "y": {
+ "type": "number",
+ "description": "The number to substract"
+ }
+ }
+ }
+ }))
+ .expect("Tool Definition")
+ }
+
+ async fn call(&self, args: Self::Args) -> Result {
+ let result = args.x - args.y;
+ Ok(result)
+ }
+}
diff --git a/rig-core/examples/xai_embeddings.rs b/rig-core/examples/xai_embeddings.rs
new file mode 100644
index 00000000..ba24a9b0
--- /dev/null
+++ b/rig-core/examples/xai_embeddings.rs
@@ -0,0 +1,19 @@
+use rig::providers::xai;
+
+#[tokio::main]
+async fn main() -> Result<(), anyhow::Error> {
+ // Initialize the xAI client
+ let client = xai::Client::from_env();
+
+ let embeddings = client
+ .embeddings(xai::embedding::EMBEDDING_V1)
+ .simple_document("doc0", "Hello, world!")
+ .simple_document("doc1", "Goodbye, world!")
+ .build()
+ .await
+ .expect("Failed to embed documents");
+
+ println!("{:?}", embeddings);
+
+ Ok(())
+}
diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs
index 9d774459..23d4d181 100644
--- a/rig-core/src/providers/mod.rs
+++ b/rig-core/src/providers/mod.rs
@@ -45,3 +45,4 @@ pub mod cohere;
pub mod gemini;
pub mod openai;
pub mod perplexity;
+pub mod xai;
diff --git a/rig-core/src/providers/xai/client.rs b/rig-core/src/providers/xai/client.rs
new file mode 100644
index 00000000..e03c6978
--- /dev/null
+++ b/rig-core/src/providers/xai/client.rs
@@ -0,0 +1,172 @@
+use crate::{
+ agent::AgentBuilder,
+ embeddings::{self},
+ extractor::ExtractorBuilder,
+};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+
+use super::{completion::CompletionModel, embedding::EmbeddingModel, EMBEDDING_V1};
+
+// ================================================================
+// xAI Client
+// ================================================================
+const XAI_BASE_URL: &str = "https://api.x.ai";
+
+#[derive(Clone)]
+pub struct Client {
+ base_url: String,
+ http_client: reqwest::Client,
+}
+
+impl Client {
+ pub fn new(api_key: &str) -> Self {
+ Self::from_url(api_key, XAI_BASE_URL)
+ }
+ fn from_url(api_key: &str, base_url: &str) -> Self {
+ Self {
+ base_url: base_url.to_string(),
+ http_client: reqwest::Client::builder()
+ .default_headers({
+ let mut headers = reqwest::header::HeaderMap::new();
+ headers.insert(
+ reqwest::header::CONTENT_TYPE,
+ "application/json".parse().unwrap(),
+ );
+ headers.insert(
+ "Authorization",
+ format!("Bearer {}", api_key)
+ .parse()
+ .expect("Bearer token should parse"),
+ );
+ headers
+ })
+ .build()
+ .expect("xAI reqwest client should build"),
+ }
+ }
+
+ /// Create a new xAI client from the `XAI_API_KEY` environment variable.
+ /// Panics if the environment variable is not set.
+ pub fn from_env() -> Self {
+ let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set");
+ Self::new(&api_key)
+ }
+
+ pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
+ let url = format!("{}/{}", self.base_url, path).replace("//", "/");
+
+ tracing::debug!("POST {}", url);
+ self.http_client.post(url)
+ }
+
+ /// Create an embedding model with the given name.
+ /// Note: default embedding dimension of 0 will be used if model is not known.
+ /// If this is the case, it's better to use function `embedding_model_with_ndims`
+ ///
+ /// # Example
+ /// ```
+ /// use rig::providers::xai::{Client, self};
+ ///
+ /// // Initialize the xAI client
+ /// let xai = Client::new("your-xai-api-key");
+ ///
+ /// let embedding_model = xai.embedding_model(xai::embedding::EMBEDDING_V1);
+ /// ```
+ pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
+ let ndims = match model {
+ EMBEDDING_V1 => 3072,
+ _ => 0,
+ };
+ EmbeddingModel::new(self.clone(), model, ndims)
+ }
+
+ /// Create an embedding model with the given name and the number of dimensions in the embedding
+ /// generated by the model.
+ ///
+ /// # Example
+ /// ```
+ /// use rig::providers::xai::{Client, self};
+ ///
+ /// // Initialize the xAI client
+ /// let xai = Client::new("your-xai-api-key");
+ ///
+ /// let embedding_model = xai.embedding_model_with_ndims("model-unknown-to-rig", 1024);
+ /// ```
+ pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
+ EmbeddingModel::new(self.clone(), model, ndims)
+ }
+
+ /// Create an embedding builder with the given embedding model.
+ ///
+ /// # Example
+ /// ```
+ /// use rig::providers::xai::{Client, self};
+ ///
+ /// // Initialize the xAI client
+ /// let xai = Client::new("your-xai-api-key");
+ ///
+ /// let embeddings = xai.embeddings(xai::embedding::EMBEDDING_V1)
+ /// .simple_document("doc0", "Hello, world!")
+ /// .simple_document("doc1", "Goodbye, world!")
+ /// .build()
+ /// .await
+ /// .expect("Failed to embed documents");
+ /// ```
+ pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder {
+ embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
+ }
+
+ /// Create a completion model with the given name.
+ pub fn completion_model(&self, model: &str) -> CompletionModel {
+ CompletionModel::new(self.clone(), model)
+ }
+
+ /// Create an agent builder with the given completion model.
+ /// # Example
+ /// ```
+ /// use rig::providers::xai::{Client, self};
+ ///
+ /// // Initialize the xAI client
+ /// let xai = Client::new("your-xai-api-key");
+ ///
+ /// let agent = xai.agent(xai::completion::GROK_BETA)
+ /// .preamble("You are comedian AI with a mission to make people laugh.")
+ /// .temperature(0.0)
+ /// .build();
+ /// ```
+ pub fn agent(&self, model: &str) -> AgentBuilder {
+ AgentBuilder::new(self.completion_model(model))
+ }
+
+ /// Create an extractor builder with the given completion model.
+ pub fn extractor Deserialize<'a> + Serialize + Send + Sync>(
+ &self,
+ model: &str,
+ ) -> ExtractorBuilder {
+ ExtractorBuilder::new(self.completion_model(model))
+ }
+}
+
+pub mod xai_api_types {
+ use serde::Deserialize;
+
+ impl ApiErrorResponse {
+ pub fn message(&self) -> String {
+ format!("Code `{}`: {}", self.code, self.error)
+ }
+ }
+
+ #[derive(Debug, Deserialize)]
+ pub struct ApiErrorResponse {
+ pub error: String,
+ pub code: String,
+ }
+
+ #[derive(Debug, Deserialize)]
+ #[serde(untagged)]
+ pub enum ApiResponse {
+ Ok(T),
+ Error(ApiErrorResponse),
+ }
+}
diff --git a/rig-core/src/providers/xai/completion.rs b/rig-core/src/providers/xai/completion.rs
new file mode 100644
index 00000000..560e35d9
--- /dev/null
+++ b/rig-core/src/providers/xai/completion.rs
@@ -0,0 +1,209 @@
+// ================================================================
+//! xAI Completion Integration
+//! From [xAI Reference](https://docs.x.ai/api/endpoints#chat-completions)
+// ================================================================
+
+use crate::{
+ completion::{self, CompletionError},
+ json_utils,
+};
+
+use serde_json::json;
+use xai_api_types::{CompletionResponse, ToolDefinition};
+
+use super::client::{xai_api_types::ApiResponse, Client};
+
+/// `grok-beta` completion model
+pub const GROK_BETA: &str = "grok-beta";
+
+// =================================================================
+// Rig Implementation Types
+// =================================================================
+
+#[derive(Clone)]
+pub struct CompletionModel {
+ client: Client,
+ pub model: String,
+}
+
+impl CompletionModel {
+ pub fn new(client: Client, model: &str) -> Self {
+ Self {
+ client,
+ model: model.to_string(),
+ }
+ }
+}
+
+impl completion::CompletionModel for CompletionModel {
+ type Response = CompletionResponse;
+
+ async fn completion(
+ &self,
+ mut completion_request: completion::CompletionRequest,
+ ) -> Result, CompletionError> {
+ let mut messages = if let Some(preamble) = &completion_request.preamble {
+ vec![completion::Message {
+ role: "system".into(),
+ content: preamble.clone(),
+ }]
+ } else {
+ vec![]
+ };
+ messages.append(&mut completion_request.chat_history);
+
+ let prompt_with_context = completion_request.prompt_with_context();
+
+ messages.push(completion::Message {
+ role: "user".into(),
+ content: prompt_with_context,
+ });
+
+ let mut request = if completion_request.tools.is_empty() {
+ json!({
+ "model": self.model,
+ "messages": messages,
+ "temperature": completion_request.temperature,
+ })
+ } else {
+ json!({
+ "model": self.model,
+ "messages": messages,
+ "temperature": completion_request.temperature,
+ "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(),
+ "tool_choice": "auto",
+ })
+ };
+
+ request = if let Some(params) = completion_request.additional_params {
+ json_utils::merge(request, params)
+ } else {
+ request
+ };
+
+ let response = self
+ .client
+ .post("/v1/chat/completions")
+ .json(&request)
+ .send()
+ .await?;
+
+ if response.status().is_success() {
+ match response.json::>().await? {
+ ApiResponse::Ok(completion) => completion.try_into(),
+ ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())),
+ }
+ } else {
+ Err(CompletionError::ProviderError(response.text().await?))
+ }
+ }
+}
+
+pub mod xai_api_types {
+ use serde::{Deserialize, Serialize};
+
+ use crate::completion::{self, CompletionError};
+
+ impl TryFrom for completion::CompletionResponse {
+ type Error = CompletionError;
+
+ fn try_from(value: CompletionResponse) -> std::prelude::v1::Result {
+ match value.choices.as_slice() {
+ [Choice {
+ message:
+ Message {
+ content: Some(content),
+ ..
+ },
+ ..
+ }, ..] => Ok(completion::CompletionResponse {
+ choice: completion::ModelChoice::Message(content.to_string()),
+ raw_response: value,
+ }),
+ [Choice {
+ message:
+ Message {
+ tool_calls: Some(calls),
+ ..
+ },
+ ..
+ }, ..] => {
+ let call = calls.first().ok_or(CompletionError::ResponseError(
+ "Tool selection is empty".into(),
+ ))?;
+
+ Ok(completion::CompletionResponse {
+ choice: completion::ModelChoice::ToolCall(
+ call.function.name.clone(),
+ serde_json::from_str(&call.function.arguments)?,
+ ),
+ raw_response: value,
+ })
+ }
+ _ => Err(CompletionError::ResponseError(
+ "Response did not contain a message or tool call".into(),
+ )),
+ }
+ }
+ }
+
+ impl From for ToolDefinition {
+ fn from(tool: completion::ToolDefinition) -> Self {
+ Self {
+ r#type: "function".into(),
+ function: tool,
+ }
+ }
+ }
+
+ #[derive(Debug, Deserialize)]
+ pub struct ToolCall {
+ pub id: String,
+ pub r#type: String,
+ pub function: Function,
+ }
+
+ #[derive(Clone, Debug, Deserialize, Serialize)]
+ pub struct ToolDefinition {
+ pub r#type: String,
+ pub function: completion::ToolDefinition,
+ }
+
+ #[derive(Debug, Deserialize)]
+ pub struct Function {
+ pub name: String,
+ pub arguments: String,
+ }
+
+ #[derive(Debug, Deserialize)]
+ pub struct CompletionResponse {
+ pub id: String,
+ pub model: String,
+ pub choices: Vec,
+ pub created: i64,
+ pub object: String,
+ pub system_fingerprint: String,
+ pub usage: Usage,
+ }
+
+ #[derive(Debug, Deserialize)]
+ pub struct Choice {
+ pub finish_reason: String,
+ pub index: i32,
+ pub message: Message,
+ }
+
+ #[derive(Debug, Deserialize)]
+ pub struct Message {
+ pub role: String,
+ pub content: Option,
+ pub tool_calls: Option>,
+ }
+
+ #[derive(Debug, Deserialize)]
+ pub struct Usage {
+ pub completion_tokens: i32,
+ pub prompt_tokens: i32,
+ pub total_tokens: i32,
+ }
+}
diff --git a/rig-core/src/providers/xai/embedding.rs b/rig-core/src/providers/xai/embedding.rs
new file mode 100644
index 00000000..1c588071
--- /dev/null
+++ b/rig-core/src/providers/xai/embedding.rs
@@ -0,0 +1,123 @@
+// ================================================================
+//! xAI Embeddings Integration
+//! From [xAI Reference](https://docs.x.ai/api/endpoints#create-embeddings)
+// ================================================================
+
+use serde::Deserialize;
+use serde_json::json;
+
+use crate::embeddings::{self, EmbeddingError};
+
+use super::{
+ client::xai_api_types::{ApiErrorResponse, ApiResponse},
+ Client,
+};
+
+// ================================================================
+// xAI Embedding API
+// ================================================================
+/// `v1` embedding model
+pub const EMBEDDING_V1: &str = "v1";
+
+#[derive(Debug, Deserialize)]
+pub struct EmbeddingResponse {
+ pub object: String,
+ pub data: Vec,
+ pub model: String,
+ pub usage: Usage,
+}
+
+impl From for EmbeddingError {
+ fn from(err: ApiErrorResponse) -> Self {
+ EmbeddingError::ProviderError(err.message())
+ }
+}
+
+impl From> for Result {
+ fn from(value: ApiResponse) -> Self {
+ match value {
+ ApiResponse::Ok(response) => Ok(response),
+ ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
+ }
+ }
+}
+
+#[derive(Debug, Deserialize)]
+pub struct EmbeddingData {
+ pub object: String,
+ pub embedding: Vec,
+ pub index: usize,
+}
+
+#[derive(Debug, Deserialize)]
+pub struct Usage {
+ pub prompt_tokens: usize,
+ pub total_tokens: usize,
+}
+
+#[derive(Clone)]
+pub struct EmbeddingModel {
+ client: Client,
+ pub model: String,
+ ndims: usize,
+}
+
+impl embeddings::EmbeddingModel for EmbeddingModel {
+ const MAX_DOCUMENTS: usize = 1024;
+
+ fn ndims(&self) -> usize {
+ self.ndims
+ }
+
+ async fn embed_documents(
+ &self,
+ documents: impl IntoIterator- ,
+ ) -> Result, EmbeddingError> {
+ let documents = documents.into_iter().collect::>();
+
+ let response = self
+ .client
+ .post("/v1/embeddings")
+ .json(&json!({
+ "model": self.model,
+ "input": documents,
+ }))
+ .send()
+ .await?;
+
+ if response.status().is_success() {
+ match response.json::>().await? {
+ ApiResponse::Ok(response) => {
+ if response.data.len() != documents.len() {
+ return Err(EmbeddingError::ResponseError(
+ "Response data length does not match input length".into(),
+ ));
+ }
+
+ Ok(response
+ .data
+ .into_iter()
+ .zip(documents.into_iter())
+ .map(|(embedding, document)| embeddings::Embedding {
+ document,
+ vec: embedding.embedding,
+ })
+ .collect())
+ }
+ ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())),
+ }
+ } else {
+ Err(EmbeddingError::ProviderError(response.text().await?))
+ }
+ }
+}
+
+impl EmbeddingModel {
+ pub fn new(client: Client, model: &str, ndims: usize) -> Self {
+ Self {
+ client,
+ model: model.to_string(),
+ ndims,
+ }
+ }
+}
diff --git a/rig-core/src/providers/xai/mod.rs b/rig-core/src/providers/xai/mod.rs
new file mode 100644
index 00000000..4150ff5a
--- /dev/null
+++ b/rig-core/src/providers/xai/mod.rs
@@ -0,0 +1,18 @@
+//! xAi API client and Rig integration
+//!
+//! # Example
+//! ```
+//! use rig::providers::xai;
+//!
+//! let client = xai::Client::new("YOUR_API_KEY");
+//!
+//! let groq_embedding_model = client.embedding_model(xai::v1);
+//! ```
+
+pub mod client;
+pub mod completion;
+pub mod embedding;
+
+pub use client::Client;
+pub use completion::GROK_BETA;
+pub use embedding::EMBEDDING_V1;
diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs
index 85a15e4a..5943ac3f 100644
--- a/rig-mongodb/examples/vector_search_mongodb.rs
+++ b/rig-mongodb/examples/vector_search_mongodb.rs
@@ -1,13 +1,22 @@
+use mongodb::bson;
use mongodb::{options::ClientOptions, Client as MongoClient, Collection};
use rig::vector_store::VectorStore;
use rig::{
- embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
+ embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::VectorStoreIndex,
};
use rig_mongodb::{MongoDbVectorStore, SearchParams};
+use serde::{Deserialize, Serialize};
use std::env;
+#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)]
+pub struct DocumentResponse {
+ #[serde(rename = "_id")]
+ pub id: String,
+ pub document: serde_json::Value,
+}
+
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize OpenAI client
@@ -25,7 +34,7 @@ async fn main() -> Result<(), anyhow::Error> {
MongoClient::with_options(options).expect("MongoDB client options should be valid");
// Initialize MongoDB vector store
- let collection: Collection = mongodb_client
+ let collection: Collection = mongodb_client
.database("knowledgebase")
.collection("context");
@@ -49,11 +58,13 @@ async fn main() -> Result<(), anyhow::Error> {
// Create a vector index on our vector store
// IMPORTANT: Reuse the same model that was used to generate the embeddings
- let index = vector_store.index(model, "vector_index", SearchParams::default());
+ let index = vector_store
+ .index(model, "vector_index", SearchParams::default())
+ .await?;
// Query the index
let results = index
- .top_n::("What is a linglingdong?", 1)
+ .top_n::("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc.document))
diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs
index 43869989..56c0009d 100644
--- a/rig-mongodb/src/lib.rs
+++ b/rig-mongodb/src/lib.rs
@@ -5,11 +5,56 @@ use rig::{
embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel},
vector_store::{VectorStore, VectorStoreError, VectorStoreIndex},
};
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
/// A MongoDB vector store.
pub struct MongoDbVectorStore {
- collection: mongodb::Collection,
+ collection: mongodb::Collection,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct SearchIndex {
+ id: String,
+ name: String,
+ #[serde(rename = "type")]
+ index_type: String,
+ status: String,
+ queryable: bool,
+ latest_definition: LatestDefinition,
+}
+
+impl SearchIndex {
+ async fn get_search_index(
+ collection: mongodb::Collection,
+ index_name: &str,
+ ) -> Result {
+ collection
+ .list_search_indexes(index_name, None, None)
+ .await
+ .map_err(mongodb_to_rig_error)?
+ .with_type::()
+ .next()
+ .await
+ .transpose()
+ .map_err(mongodb_to_rig_error)?
+ .ok_or(VectorStoreError::DatastoreError("Index not found".into()))
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct LatestDefinition {
+ fields: Vec,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct Field {
+ #[serde(rename = "type")]
+ field_type: String,
+ path: String,
+ num_dimensions: i32,
+ similarity: String,
}
fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
@@ -24,6 +69,7 @@ impl VectorStore for MongoDbVectorStore {
documents: Vec,
) -> Result<(), VectorStoreError> {
self.collection
+ .clone_with_type::()
.insert_many(documents, None)
.await
.map_err(mongodb_to_rig_error)?;
@@ -35,6 +81,7 @@ impl VectorStore for MongoDbVectorStore {
id: &str,
) -> Result