diff --git a/README.md b/README.md index 0c2a7ab0..bb5e1bc7 100644 --- a/README.md +++ b/README.md @@ -39,8 +39,8 @@ Help us improve Rig by contributing to our [Feedback form](https://bit.ly/Rig-Fe - [What is Rig?](#what-is-rig) - [Table of contents](#table-of-contents) - [High-level features](#high-level-features) -- [Installation](#installation) -- [Simple example:](#simple-example) +- [Get Started](#get-started) + - [Simple example:](#simple-example) - [Integrations](#integrations) ## High-level features @@ -85,7 +85,7 @@ You can find more examples each crate's `examples` (ie. [`src/examples`](./src/e | Model Providers | Vector Stores | |:--------------:|:-------------:| -|
ChatGPT logo Claude Anthropic logo
Cohere logo Gemini logo perplexity logo|
Mongo DB logo Neo4j logo

Lance DB logo | +|
ChatGPT logo Claude Anthropic logo
Cohere logo Gemini logo
xAI logo perplexity logo|
Mongo DB logo Neo4j logo

Lance DB logo | Vector stores are available as separate companion-crates: diff --git a/rig-core/README.md b/rig-core/README.md index 567bb9e4..f8fa2c71 100644 --- a/rig-core/README.md +++ b/rig-core/README.md @@ -4,10 +4,12 @@ Rig is a Rust library for building LLM-powered applications that focuses on ergo More information about this crate can be found in the [crate documentation](https://docs.rs/rig-core/latest/rig/). ## Table of contents -- [High-level features](#high-level-features) -- [Installation](#) -- [Simple Example](#simple-example) -- [Integrations](#integrations) +- [Rig](#rig) + - [Table of contents](#table-of-contents) + - [High-level features](#high-level-features) + - [Installation](#installation) + - [Simple example:](#simple-example) + - [Integrations](#integrations) ## High-level features - Full support for LLM completion and embedding workflows @@ -48,5 +50,7 @@ Rig supports the following LLM providers natively: - OpenAI - Cohere - Google Gemini +- xAI + Additionally, Rig currently has the following integration sub-libraries: - MongoDB vector store: `rig-mongodb` diff --git a/rig-core/examples/agent_with_grok.rs b/rig-core/examples/agent_with_grok.rs new file mode 100644 index 00000000..12c8aeb9 --- /dev/null +++ b/rig-core/examples/agent_with_grok.rs @@ -0,0 +1,209 @@ +use std::env; + +use rig::{ + agent::AgentBuilder, + completion::{Prompt, ToolDefinition}, + loaders::FileLoader, + providers, + tool::Tool, +}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +/// Runs 4 agents based on grok (dervived from the other examples) +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + println!("Running basic agent with grok"); + basic().await?; + + println!("\nRunning grok agent with tools"); + tools().await?; + + println!("\nRunning grok agent with loaders"); + loaders().await?; + + println!("\nRunning grok agent with context"); + context().await?; + + println!("\n\nAll agents ran successfully"); + Ok(()) +} + +fn client() -> providers::xai::Client { + providers::xai::Client::new(&env::var("XAI_API_KEY").expect("XAI_API_KEY not set")) +} + +/// Create a partial xAI agent (grok) +fn partial_agent() -> AgentBuilder { + let client = client(); + client.agent(providers::xai::GROK_BETA) +} + +/// Create an xAI agent (grok) with a preamble +/// Based upon the `agent` example +/// +/// This example creates a comedian agent with a preamble +async fn basic() -> Result<(), anyhow::Error> { + let comedian_agent = partial_agent() + .preamble("You are a comedian here to entertain the user using humour and jokes.") + .build(); + + // Prompt the agent and print the response + let response = comedian_agent.prompt("Entertain me!").await?; + println!("{}", response); + + Ok(()) +} + +/// Create an xAI agent (grok) with tools +/// Based upon the `tools` example +/// +/// This example creates a calculator agent with two tools: add and subtract +async fn tools() -> Result<(), anyhow::Error> { + // Create agent with a single context prompt and two tools + let calculator_agent = partial_agent() + .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.") + .max_tokens(1024) + .tool(Adder) + .tool(Subtract) + .build(); + + // Prompt the agent and print the response + println!("Calculate 2 - 5"); + println!( + "Calculator Agent: {}", + calculator_agent.prompt("Calculate 2 - 5").await? + ); + + Ok(()) +} + +/// Create an xAI agent (grok) with loaders +/// Based upon the `loaders` example +/// +/// This example loads in all the rust examples from the rig-core crate and uses them as\\ +/// context for the agent +async fn loaders() -> Result<(), anyhow::Error> { + let model = client().completion_model(providers::xai::GROK_BETA); + + // Load in all the rust examples + let examples = FileLoader::with_glob("rig-core/examples/*.rs")? + .read_with_path() + .ignore_errors() + .into_iter(); + + // Create an agent with multiple context documents + let agent = examples + .fold(AgentBuilder::new(model), |builder, (path, content)| { + builder.context(format!("Rust Example {:?}:\n{}", path, content).as_str()) + }) + .build(); + + // Prompt the agent and print the response + let response = agent + .prompt("Which rust example is best suited for the operation 1 + 2") + .await?; + + println!("{}", response); + + Ok(()) +} + +async fn context() -> Result<(), anyhow::Error> { + let model = client().completion_model(providers::xai::GROK_BETA); + + // Create an agent with multiple context documents + let agent = AgentBuilder::new(model) + .context("Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") + .context("Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") + .context("Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .build(); + + // Prompt the agent and print the response + let response = agent.prompt("What does \"glarb-glarb\" mean?").await?; + + println!("{}", response); + + Ok(()) +} + +#[derive(Deserialize)] +struct OperationArgs { + x: i32, + y: i32, +} + +#[derive(Debug, thiserror::Error)] +#[error("Math error")] +struct MathError; + +#[derive(Deserialize, Serialize)] +struct Adder; +impl Tool for Adder { + const NAME: &'static str = "add"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + ToolDefinition { + name: "add".to_string(), + description: "Add x and y together".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The first number to add" + }, + "y": { + "type": "number", + "description": "The second number to add" + } + } + }), + } + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x + args.y; + Ok(result) + } +} + +#[derive(Deserialize, Serialize)] +struct Subtract; +impl Tool for Subtract { + const NAME: &'static str = "subtract"; + + type Error = MathError; + type Args = OperationArgs; + type Output = i32; + + async fn definition(&self, _prompt: String) -> ToolDefinition { + serde_json::from_value(json!({ + "name": "subtract", + "description": "Subtract y from x (i.e.: x - y)", + "parameters": { + "type": "object", + "properties": { + "x": { + "type": "number", + "description": "The number to substract from" + }, + "y": { + "type": "number", + "description": "The number to substract" + } + } + } + })) + .expect("Tool Definition") + } + + async fn call(&self, args: Self::Args) -> Result { + let result = args.x - args.y; + Ok(result) + } +} diff --git a/rig-core/examples/xai_embeddings.rs b/rig-core/examples/xai_embeddings.rs new file mode 100644 index 00000000..ba24a9b0 --- /dev/null +++ b/rig-core/examples/xai_embeddings.rs @@ -0,0 +1,19 @@ +use rig::providers::xai; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize the xAI client + let client = xai::Client::from_env(); + + let embeddings = client + .embeddings(xai::embedding::EMBEDDING_V1) + .simple_document("doc0", "Hello, world!") + .simple_document("doc1", "Goodbye, world!") + .build() + .await + .expect("Failed to embed documents"); + + println!("{:?}", embeddings); + + Ok(()) +} diff --git a/rig-core/src/providers/mod.rs b/rig-core/src/providers/mod.rs index 9d774459..23d4d181 100644 --- a/rig-core/src/providers/mod.rs +++ b/rig-core/src/providers/mod.rs @@ -45,3 +45,4 @@ pub mod cohere; pub mod gemini; pub mod openai; pub mod perplexity; +pub mod xai; diff --git a/rig-core/src/providers/xai/client.rs b/rig-core/src/providers/xai/client.rs new file mode 100644 index 00000000..e03c6978 --- /dev/null +++ b/rig-core/src/providers/xai/client.rs @@ -0,0 +1,172 @@ +use crate::{ + agent::AgentBuilder, + embeddings::{self}, + extractor::ExtractorBuilder, +}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use super::{completion::CompletionModel, embedding::EmbeddingModel, EMBEDDING_V1}; + +// ================================================================ +// xAI Client +// ================================================================ +const XAI_BASE_URL: &str = "https://api.x.ai"; + +#[derive(Clone)] +pub struct Client { + base_url: String, + http_client: reqwest::Client, +} + +impl Client { + pub fn new(api_key: &str) -> Self { + Self::from_url(api_key, XAI_BASE_URL) + } + fn from_url(api_key: &str, base_url: &str) -> Self { + Self { + base_url: base_url.to_string(), + http_client: reqwest::Client::builder() + .default_headers({ + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::CONTENT_TYPE, + "application/json".parse().unwrap(), + ); + headers.insert( + "Authorization", + format!("Bearer {}", api_key) + .parse() + .expect("Bearer token should parse"), + ); + headers + }) + .build() + .expect("xAI reqwest client should build"), + } + } + + /// Create a new xAI client from the `XAI_API_KEY` environment variable. + /// Panics if the environment variable is not set. + pub fn from_env() -> Self { + let api_key = std::env::var("XAI_API_KEY").expect("XAI_API_KEY not set"); + Self::new(&api_key) + } + + pub fn post(&self, path: &str) -> reqwest::RequestBuilder { + let url = format!("{}/{}", self.base_url, path).replace("//", "/"); + + tracing::debug!("POST {}", url); + self.http_client.post(url) + } + + /// Create an embedding model with the given name. + /// Note: default embedding dimension of 0 will be used if model is not known. + /// If this is the case, it's better to use function `embedding_model_with_ndims` + /// + /// # Example + /// ``` + /// use rig::providers::xai::{Client, self}; + /// + /// // Initialize the xAI client + /// let xai = Client::new("your-xai-api-key"); + /// + /// let embedding_model = xai.embedding_model(xai::embedding::EMBEDDING_V1); + /// ``` + pub fn embedding_model(&self, model: &str) -> EmbeddingModel { + let ndims = match model { + EMBEDDING_V1 => 3072, + _ => 0, + }; + EmbeddingModel::new(self.clone(), model, ndims) + } + + /// Create an embedding model with the given name and the number of dimensions in the embedding + /// generated by the model. + /// + /// # Example + /// ``` + /// use rig::providers::xai::{Client, self}; + /// + /// // Initialize the xAI client + /// let xai = Client::new("your-xai-api-key"); + /// + /// let embedding_model = xai.embedding_model_with_ndims("model-unknown-to-rig", 1024); + /// ``` + pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, ndims) + } + + /// Create an embedding builder with the given embedding model. + /// + /// # Example + /// ``` + /// use rig::providers::xai::{Client, self}; + /// + /// // Initialize the xAI client + /// let xai = Client::new("your-xai-api-key"); + /// + /// let embeddings = xai.embeddings(xai::embedding::EMBEDDING_V1) + /// .simple_document("doc0", "Hello, world!") + /// .simple_document("doc1", "Goodbye, world!") + /// .build() + /// .await + /// .expect("Failed to embed documents"); + /// ``` + pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder { + embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) + } + + /// Create a completion model with the given name. + pub fn completion_model(&self, model: &str) -> CompletionModel { + CompletionModel::new(self.clone(), model) + } + + /// Create an agent builder with the given completion model. + /// # Example + /// ``` + /// use rig::providers::xai::{Client, self}; + /// + /// // Initialize the xAI client + /// let xai = Client::new("your-xai-api-key"); + /// + /// let agent = xai.agent(xai::completion::GROK_BETA) + /// .preamble("You are comedian AI with a mission to make people laugh.") + /// .temperature(0.0) + /// .build(); + /// ``` + pub fn agent(&self, model: &str) -> AgentBuilder { + AgentBuilder::new(self.completion_model(model)) + } + + /// Create an extractor builder with the given completion model. + pub fn extractor Deserialize<'a> + Serialize + Send + Sync>( + &self, + model: &str, + ) -> ExtractorBuilder { + ExtractorBuilder::new(self.completion_model(model)) + } +} + +pub mod xai_api_types { + use serde::Deserialize; + + impl ApiErrorResponse { + pub fn message(&self) -> String { + format!("Code `{}`: {}", self.code, self.error) + } + } + + #[derive(Debug, Deserialize)] + pub struct ApiErrorResponse { + pub error: String, + pub code: String, + } + + #[derive(Debug, Deserialize)] + #[serde(untagged)] + pub enum ApiResponse { + Ok(T), + Error(ApiErrorResponse), + } +} diff --git a/rig-core/src/providers/xai/completion.rs b/rig-core/src/providers/xai/completion.rs new file mode 100644 index 00000000..560e35d9 --- /dev/null +++ b/rig-core/src/providers/xai/completion.rs @@ -0,0 +1,209 @@ +// ================================================================ +//! xAI Completion Integration +//! From [xAI Reference](https://docs.x.ai/api/endpoints#chat-completions) +// ================================================================ + +use crate::{ + completion::{self, CompletionError}, + json_utils, +}; + +use serde_json::json; +use xai_api_types::{CompletionResponse, ToolDefinition}; + +use super::client::{xai_api_types::ApiResponse, Client}; + +/// `grok-beta` completion model +pub const GROK_BETA: &str = "grok-beta"; + +// ================================================================= +// Rig Implementation Types +// ================================================================= + +#[derive(Clone)] +pub struct CompletionModel { + client: Client, + pub model: String, +} + +impl CompletionModel { + pub fn new(client: Client, model: &str) -> Self { + Self { + client, + model: model.to_string(), + } + } +} + +impl completion::CompletionModel for CompletionModel { + type Response = CompletionResponse; + + async fn completion( + &self, + mut completion_request: completion::CompletionRequest, + ) -> Result, CompletionError> { + let mut messages = if let Some(preamble) = &completion_request.preamble { + vec![completion::Message { + role: "system".into(), + content: preamble.clone(), + }] + } else { + vec![] + }; + messages.append(&mut completion_request.chat_history); + + let prompt_with_context = completion_request.prompt_with_context(); + + messages.push(completion::Message { + role: "user".into(), + content: prompt_with_context, + }); + + let mut request = if completion_request.tools.is_empty() { + json!({ + "model": self.model, + "messages": messages, + "temperature": completion_request.temperature, + }) + } else { + json!({ + "model": self.model, + "messages": messages, + "temperature": completion_request.temperature, + "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::>(), + "tool_choice": "auto", + }) + }; + + request = if let Some(params) = completion_request.additional_params { + json_utils::merge(request, params) + } else { + request + }; + + let response = self + .client + .post("/v1/chat/completions") + .json(&request) + .send() + .await?; + + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(completion) => completion.try_into(), + ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message())), + } + } else { + Err(CompletionError::ProviderError(response.text().await?)) + } + } +} + +pub mod xai_api_types { + use serde::{Deserialize, Serialize}; + + use crate::completion::{self, CompletionError}; + + impl TryFrom for completion::CompletionResponse { + type Error = CompletionError; + + fn try_from(value: CompletionResponse) -> std::prelude::v1::Result { + match value.choices.as_slice() { + [Choice { + message: + Message { + content: Some(content), + .. + }, + .. + }, ..] => Ok(completion::CompletionResponse { + choice: completion::ModelChoice::Message(content.to_string()), + raw_response: value, + }), + [Choice { + message: + Message { + tool_calls: Some(calls), + .. + }, + .. + }, ..] => { + let call = calls.first().ok_or(CompletionError::ResponseError( + "Tool selection is empty".into(), + ))?; + + Ok(completion::CompletionResponse { + choice: completion::ModelChoice::ToolCall( + call.function.name.clone(), + serde_json::from_str(&call.function.arguments)?, + ), + raw_response: value, + }) + } + _ => Err(CompletionError::ResponseError( + "Response did not contain a message or tool call".into(), + )), + } + } + } + + impl From for ToolDefinition { + fn from(tool: completion::ToolDefinition) -> Self { + Self { + r#type: "function".into(), + function: tool, + } + } + } + + #[derive(Debug, Deserialize)] + pub struct ToolCall { + pub id: String, + pub r#type: String, + pub function: Function, + } + + #[derive(Clone, Debug, Deserialize, Serialize)] + pub struct ToolDefinition { + pub r#type: String, + pub function: completion::ToolDefinition, + } + + #[derive(Debug, Deserialize)] + pub struct Function { + pub name: String, + pub arguments: String, + } + + #[derive(Debug, Deserialize)] + pub struct CompletionResponse { + pub id: String, + pub model: String, + pub choices: Vec, + pub created: i64, + pub object: String, + pub system_fingerprint: String, + pub usage: Usage, + } + + #[derive(Debug, Deserialize)] + pub struct Choice { + pub finish_reason: String, + pub index: i32, + pub message: Message, + } + + #[derive(Debug, Deserialize)] + pub struct Message { + pub role: String, + pub content: Option, + pub tool_calls: Option>, + } + + #[derive(Debug, Deserialize)] + pub struct Usage { + pub completion_tokens: i32, + pub prompt_tokens: i32, + pub total_tokens: i32, + } +} diff --git a/rig-core/src/providers/xai/embedding.rs b/rig-core/src/providers/xai/embedding.rs new file mode 100644 index 00000000..1c588071 --- /dev/null +++ b/rig-core/src/providers/xai/embedding.rs @@ -0,0 +1,123 @@ +// ================================================================ +//! xAI Embeddings Integration +//! From [xAI Reference](https://docs.x.ai/api/endpoints#create-embeddings) +// ================================================================ + +use serde::Deserialize; +use serde_json::json; + +use crate::embeddings::{self, EmbeddingError}; + +use super::{ + client::xai_api_types::{ApiErrorResponse, ApiResponse}, + Client, +}; + +// ================================================================ +// xAI Embedding API +// ================================================================ +/// `v1` embedding model +pub const EMBEDDING_V1: &str = "v1"; + +#[derive(Debug, Deserialize)] +pub struct EmbeddingResponse { + pub object: String, + pub data: Vec, + pub model: String, + pub usage: Usage, +} + +impl From for EmbeddingError { + fn from(err: ApiErrorResponse) -> Self { + EmbeddingError::ProviderError(err.message()) + } +} + +impl From> for Result { + fn from(value: ApiResponse) -> Self { + match value { + ApiResponse::Ok(response) => Ok(response), + ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingData { + pub object: String, + pub embedding: Vec, + pub index: usize, +} + +#[derive(Debug, Deserialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub total_tokens: usize, +} + +#[derive(Clone)] +pub struct EmbeddingModel { + client: Client, + pub model: String, + ndims: usize, +} + +impl embeddings::EmbeddingModel for EmbeddingModel { + const MAX_DOCUMENTS: usize = 1024; + + fn ndims(&self) -> usize { + self.ndims + } + + async fn embed_documents( + &self, + documents: impl IntoIterator, + ) -> Result, EmbeddingError> { + let documents = documents.into_iter().collect::>(); + + let response = self + .client + .post("/v1/embeddings") + .json(&json!({ + "model": self.model, + "input": documents, + })) + .send() + .await?; + + if response.status().is_success() { + match response.json::>().await? { + ApiResponse::Ok(response) => { + if response.data.len() != documents.len() { + return Err(EmbeddingError::ResponseError( + "Response data length does not match input length".into(), + )); + } + + Ok(response + .data + .into_iter() + .zip(documents.into_iter()) + .map(|(embedding, document)| embeddings::Embedding { + document, + vec: embedding.embedding, + }) + .collect()) + } + ApiResponse::Error(err) => Err(EmbeddingError::ProviderError(err.message())), + } + } else { + Err(EmbeddingError::ProviderError(response.text().await?)) + } + } +} + +impl EmbeddingModel { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { + Self { + client, + model: model.to_string(), + ndims, + } + } +} diff --git a/rig-core/src/providers/xai/mod.rs b/rig-core/src/providers/xai/mod.rs new file mode 100644 index 00000000..4150ff5a --- /dev/null +++ b/rig-core/src/providers/xai/mod.rs @@ -0,0 +1,18 @@ +//! xAi API client and Rig integration +//! +//! # Example +//! ``` +//! use rig::providers::xai; +//! +//! let client = xai::Client::new("YOUR_API_KEY"); +//! +//! let groq_embedding_model = client.embedding_model(xai::v1); +//! ``` + +pub mod client; +pub mod completion; +pub mod embedding; + +pub use client::Client; +pub use completion::GROK_BETA; +pub use embedding::EMBEDDING_V1; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 85a15e4a..5943ac3f 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,13 +1,22 @@ +use mongodb::bson; use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; use rig::vector_store::VectorStore; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; +use serde::{Deserialize, Serialize}; use std::env; +#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] +pub struct DocumentResponse { + #[serde(rename = "_id")] + pub id: String, + pub document: serde_json::Value, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -25,7 +34,7 @@ async fn main() -> Result<(), anyhow::Error> { MongoClient::with_options(options).expect("MongoDB client options should be valid"); // Initialize MongoDB vector store - let collection: Collection = mongodb_client + let collection: Collection = mongodb_client .database("knowledgebase") .collection("context"); @@ -49,11 +58,13 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "vector_index", SearchParams::default()); + let index = vector_store + .index(model, "vector_index", SearchParams::default()) + .await?; // Query the index let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, id, doc)| (score, id, doc.document)) diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 43869989..56c0009d 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -5,11 +5,56 @@ use rig::{ embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; /// A MongoDB vector store. pub struct MongoDbVectorStore { - collection: mongodb::Collection, + collection: mongodb::Collection, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct SearchIndex { + id: String, + name: String, + #[serde(rename = "type")] + index_type: String, + status: String, + queryable: bool, + latest_definition: LatestDefinition, +} + +impl SearchIndex { + async fn get_search_index( + collection: mongodb::Collection, + index_name: &str, + ) -> Result { + collection + .list_search_indexes(index_name, None, None) + .await + .map_err(mongodb_to_rig_error)? + .with_type::() + .next() + .await + .transpose() + .map_err(mongodb_to_rig_error)? + .ok_or(VectorStoreError::DatastoreError("Index not found".into())) + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct LatestDefinition { + fields: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct Field { + #[serde(rename = "type")] + field_type: String, + path: String, + num_dimensions: i32, + similarity: String, } fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { @@ -24,6 +69,7 @@ impl VectorStore for MongoDbVectorStore { documents: Vec, ) -> Result<(), VectorStoreError> { self.collection + .clone_with_type::() .insert_many(documents, None) .await .map_err(mongodb_to_rig_error)?; @@ -35,6 +81,7 @@ impl VectorStore for MongoDbVectorStore { id: &str, ) -> Result, VectorStoreError> { self.collection + .clone_with_type::() .find_one(doc! { "_id": id }, None) .await .map_err(mongodb_to_rig_error) @@ -71,6 +118,7 @@ impl VectorStore for MongoDbVectorStore { query: Self::Q, ) -> Result, VectorStoreError> { self.collection + .clone_with_type::() .find_one(query, None) .await .map_err(mongodb_to_rig_error) @@ -79,7 +127,7 @@ impl VectorStore for MongoDbVectorStore { impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection) -> Self { + pub fn new(collection: mongodb::Collection) -> Self { Self { collection } } @@ -87,21 +135,22 @@ impl MongoDbVectorStore { /// /// The index (of type "vector") must already exist for the MongoDB collection. /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. - pub fn index( + pub async fn index( &self, model: M, index_name: &str, search_params: SearchParams, - ) -> MongoDbVectorIndex { - MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) + ) -> Result, VectorStoreError> { + MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params).await } } /// A vector index for a MongoDB collection. pub struct MongoDbVectorIndex { - collection: mongodb::Collection, + collection: mongodb::Collection, model: M, index_name: String, + embedded_field: String, search_params: SearchParams, } @@ -118,7 +167,7 @@ impl MongoDbVectorIndex { doc! { "$vectorSearch": { "index": &self.index_name, - "path": "embeddings.vec", + "path": self.embedded_field.clone(), "queryVector": &prompt_embedding.vec, "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, @@ -140,22 +189,42 @@ impl MongoDbVectorIndex { } impl MongoDbVectorIndex { - pub fn new( - collection: mongodb::Collection, + pub async fn new( + collection: mongodb::Collection, model: M, index_name: &str, search_params: SearchParams, - ) -> Self { - Self { + ) -> Result { + let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?; + + if !search_index.queryable { + return Err(VectorStoreError::DatastoreError( + "Index is not queryable".into(), + )); + } + + let embedded_field = search_index + .latest_definition + .fields + .into_iter() + .map(|field| field.path) + .next() + // This error shouldn't occur if the index is queryable + .ok_or(VectorStoreError::DatastoreError( + "No embedded fields found".into(), + ))?; + + Ok(Self { collection, model, index_name: index_name.to_string(), + embedded_field, search_params, - } + }) } } -/// See [MongoDB Vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information +/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information /// on each of the fields pub struct SearchParams { filter: mongodb::bson::Document, @@ -219,6 +288,13 @@ impl VectorStoreIndex for MongoDbV [ self.pipeline_search_stage(&prompt_embedding, n), self.pipeline_score_stage(), + { + doc! { + "$project": { + self.embedded_field.clone(): 0, + }, + } + }, ], None, )