From 6d7d8931fdb70968d6dcc20867a0b9f50b039f81 Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Sun, 17 Nov 2024 20:42:24 -0800 Subject: [PATCH 1/6] fix(mongodb): remove embeddings from `top_n` lookup --- rig-mongodb/examples/vector_search_mongodb.rs | 4 ++-- rig-mongodb/src/lib.rs | 20 ++++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 85a15e4a..504bdc51 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -5,7 +5,7 @@ use rig::{ providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; -use rig_mongodb::{MongoDbVectorStore, SearchParams}; +use rig_mongodb::{DocumentResponse, MongoDbVectorStore, SearchParams}; use std::env; #[tokio::main] @@ -53,7 +53,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, id, doc)| (score, id, doc.document)) diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 43869989..1ee2fada 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -5,7 +5,10 @@ use rig::{ embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; + +const EMBEDDINGS_VECTOR_FIELD: &str = "embeddings.vec"; +const EMBEDDINGS_FIELD: &str = "embeddings"; /// A MongoDB vector store. pub struct MongoDbVectorStore { @@ -118,7 +121,7 @@ impl MongoDbVectorIndex { doc! { "$vectorSearch": { "index": &self.index_name, - "path": "embeddings.vec", + "path": EMBEDDINGS_VECTOR_FIELD, "queryVector": &prompt_embedding.vec, "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, @@ -228,9 +231,13 @@ impl VectorStoreIndex for MongoDbV let mut results = Vec::new(); while let Some(doc) = cursor.next().await { - let doc = doc.map_err(mongodb_to_rig_error)?; + let mut doc = doc.map_err(mongodb_to_rig_error)?; let score = doc.get("score").expect("score").as_f64().expect("f64"); let id = doc.get("_id").expect("_id").to_string(); + // Remove the embeddings field from the document + if let Some(val) = doc.get_mut(EMBEDDINGS_FIELD) { + val.take(); + } let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?; results.push((score, id, doc_t)); } @@ -291,3 +298,10 @@ impl VectorStoreIndex for MongoDbV Ok(results) } } + +#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] +pub struct DocumentResponse { + #[serde(rename = "_id")] + pub id: String, + pub document: serde_json::Value, +} From 205b3072bacbe606dfedae0936e9828f5b4c02a4 Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Sun, 17 Nov 2024 20:48:18 -0800 Subject: [PATCH 2/6] fix(mongodb): filter embeddings within agg pipeline --- rig-mongodb/src/lib.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 1ee2fada..e482935e 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -8,7 +8,6 @@ use rig::{ use serde::{Deserialize, Serialize}; const EMBEDDINGS_VECTOR_FIELD: &str = "embeddings.vec"; -const EMBEDDINGS_FIELD: &str = "embeddings"; /// A MongoDB vector store. pub struct MongoDbVectorStore { @@ -158,7 +157,7 @@ impl MongoDbVectorIndex { } } -/// See [MongoDB Vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information +/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information /// on each of the fields pub struct SearchParams { filter: mongodb::bson::Document, @@ -222,6 +221,11 @@ impl VectorStoreIndex for MongoDbV [ self.pipeline_search_stage(&prompt_embedding, n), self.pipeline_score_stage(), + doc! { + "$project": { + EMBEDDINGS_VECTOR_FIELD: 0, + }, + }, ], None, ) @@ -234,10 +238,10 @@ impl VectorStoreIndex for MongoDbV let mut doc = doc.map_err(mongodb_to_rig_error)?; let score = doc.get("score").expect("score").as_f64().expect("f64"); let id = doc.get("_id").expect("_id").to_string(); - // Remove the embeddings field from the document - if let Some(val) = doc.get_mut(EMBEDDINGS_FIELD) { - val.take(); - } + // // Remove the embeddings field from the document + // if let Some(val) = doc.get_mut(EMBEDDINGS_FIELD) { + // val.take(); + // } let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?; results.push((score, id, doc_t)); } From cba9f06e71bd85b19af9cf2c092052c99d21eee2 Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Sun, 17 Nov 2024 20:52:02 -0800 Subject: [PATCH 3/6] style(mongodb): clippy moment --- rig-mongodb/src/lib.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index e482935e..3dbfcb15 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -235,13 +235,9 @@ impl VectorStoreIndex for MongoDbV let mut results = Vec::new(); while let Some(doc) = cursor.next().await { - let mut doc = doc.map_err(mongodb_to_rig_error)?; + let doc = doc.map_err(mongodb_to_rig_error)?; let score = doc.get("score").expect("score").as_f64().expect("f64"); let id = doc.get("_id").expect("_id").to_string(); - // // Remove the embeddings field from the document - // if let Some(val) = doc.get_mut(EMBEDDINGS_FIELD) { - // val.take(); - // } let doc_t: T = serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?; results.push((score, id, doc_t)); } From 5c1b44506e32f29bbe105fd61d86269b25088b79 Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Mon, 18 Nov 2024 17:24:06 -0800 Subject: [PATCH 4/6] fix(mongodb): dynamically get embedded fields from mongodb --- rig-core/src/vector_store/mod.rs | 3 + rig-mongodb/examples/vector_search_mongodb.rs | 4 +- rig-mongodb/src/lib.rs | 94 ++++++++++++++++--- 3 files changed, 86 insertions(+), 15 deletions(-) diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 405ef83b..dbf4bb4e 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -17,6 +17,9 @@ pub enum VectorStoreError { #[error("Datastore error: {0}")] DatastoreError(#[from] Box), + + #[error("Vector store error: {0}")] + Error(String), } /// Trait for vector stores diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 504bdc51..6903a510 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -49,7 +49,9 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "vector_index", SearchParams::default()); + let index = vector_store + .index(model, "vector_index", SearchParams::default()) + .await?; // Query the index let results = index diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 3dbfcb15..1b2343bf 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -7,13 +7,56 @@ use rig::{ }; use serde::{Deserialize, Serialize}; -const EMBEDDINGS_VECTOR_FIELD: &str = "embeddings.vec"; - /// A MongoDB vector store. pub struct MongoDbVectorStore { collection: mongodb::Collection, } +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct SearchIndex { + id: String, + name: String, + #[serde(rename = "type")] + index_type: String, + status: String, + queryable: bool, + latest_definition: LatestDefinition, +} + +impl SearchIndex { + async fn get_search_index( + collection: mongodb::Collection, + index_name: &str, + ) -> Result { + collection + .list_search_indexes(index_name, None, None) + .await + .map_err(mongodb_to_rig_error)? + .with_type::() + .next() + .await + .transpose() + .map_err(mongodb_to_rig_error)? + .ok_or(VectorStoreError::Error("Index not found".to_string())) + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct LatestDefinition { + fields: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct Field { + #[serde(rename = "type")] + field_type: String, + path: String, + num_dimensions: i32, + similarity: String, +} + fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } @@ -89,13 +132,13 @@ impl MongoDbVectorStore { /// /// The index (of type "vector") must already exist for the MongoDB collection. /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. - pub fn index( + pub async fn index( &self, model: M, index_name: &str, search_params: SearchParams, - ) -> MongoDbVectorIndex { - MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) + ) -> Result, VectorStoreError> { + MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params).await } } @@ -104,6 +147,7 @@ pub struct MongoDbVectorIndex { collection: mongodb::Collection, model: M, index_name: String, + embedded_field: String, search_params: SearchParams, } @@ -120,7 +164,7 @@ impl MongoDbVectorIndex { doc! { "$vectorSearch": { "index": &self.index_name, - "path": EMBEDDINGS_VECTOR_FIELD, + "path": self.embedded_field.clone(), "queryVector": &prompt_embedding.vec, "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, @@ -142,18 +186,38 @@ impl MongoDbVectorIndex { } impl MongoDbVectorIndex { - pub fn new( + pub async fn new( collection: mongodb::Collection, model: M, index_name: &str, search_params: SearchParams, - ) -> Self { - Self { + ) -> Result { + let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?; + + if !search_index.queryable { + return Err(VectorStoreError::Error( + "Index is not queryable".to_string(), + )); + } + + let embedded_field = search_index + .latest_definition + .fields + .into_iter() + .map(|field| field.path) + .next() + // This error shouldn't occur if the index is queryable + .ok_or(VectorStoreError::Error( + "No embedded fields found".to_string(), + ))?; + + Ok(Self { collection, model, index_name: index_name.to_string(), + embedded_field, search_params, - } + }) } } @@ -221,10 +285,12 @@ impl VectorStoreIndex for MongoDbV [ self.pipeline_search_stage(&prompt_embedding, n), self.pipeline_score_stage(), - doc! { - "$project": { - EMBEDDINGS_VECTOR_FIELD: 0, - }, + { + doc! { + "$project": { + self.embedded_field.clone(): 0, + }, + } }, ], None, From 8eff66086437471a50746272d2bb6e1f027cc98a Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Tue, 19 Nov 2024 19:36:09 -0800 Subject: [PATCH 5/6] fix(mongodb): apply fixes from comments --- rig-core/src/vector_store/mod.rs | 3 -- rig-mongodb/examples/vector_search_mongodb.rs | 15 ++++++++-- rig-mongodb/src/lib.rs | 30 ++++++++----------- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index dbf4bb4e..405ef83b 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -17,9 +17,6 @@ pub enum VectorStoreError { #[error("Datastore error: {0}")] DatastoreError(#[from] Box), - - #[error("Vector store error: {0}")] - Error(String), } /// Trait for vector stores diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 6903a510..60904719 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,13 +1,22 @@ +use mongodb::bson; use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; use rig::vector_store::VectorStore; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::{EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; -use rig_mongodb::{DocumentResponse, MongoDbVectorStore, SearchParams}; +use rig_mongodb::{MongoDbVectorStore, SearchParams}; +use serde::{Deserialize, Serialize}; use std::env; +#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] +pub struct DocumentResponse { + #[serde(rename = "_id")] + pub id: String, + pub document: serde_json::Value, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -25,7 +34,7 @@ async fn main() -> Result<(), anyhow::Error> { MongoClient::with_options(options).expect("MongoDB client options should be valid"); // Initialize MongoDB vector store - let collection: Collection = mongodb_client + let collection: Collection = mongodb_client .database("knowledgebase") .collection("context"); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 1b2343bf..56c0009d 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; /// A MongoDB vector store. pub struct MongoDbVectorStore { - collection: mongodb::Collection, + collection: mongodb::Collection, } #[derive(Debug, Serialize, Deserialize)] @@ -26,7 +26,7 @@ struct SearchIndex { impl SearchIndex { async fn get_search_index( - collection: mongodb::Collection, + collection: mongodb::Collection, index_name: &str, ) -> Result { collection @@ -38,7 +38,7 @@ impl SearchIndex { .await .transpose() .map_err(mongodb_to_rig_error)? - .ok_or(VectorStoreError::Error("Index not found".to_string())) + .ok_or(VectorStoreError::DatastoreError("Index not found".into())) } } @@ -69,6 +69,7 @@ impl VectorStore for MongoDbVectorStore { documents: Vec, ) -> Result<(), VectorStoreError> { self.collection + .clone_with_type::() .insert_many(documents, None) .await .map_err(mongodb_to_rig_error)?; @@ -80,6 +81,7 @@ impl VectorStore for MongoDbVectorStore { id: &str, ) -> Result, VectorStoreError> { self.collection + .clone_with_type::() .find_one(doc! { "_id": id }, None) .await .map_err(mongodb_to_rig_error) @@ -116,6 +118,7 @@ impl VectorStore for MongoDbVectorStore { query: Self::Q, ) -> Result, VectorStoreError> { self.collection + .clone_with_type::() .find_one(query, None) .await .map_err(mongodb_to_rig_error) @@ -124,7 +127,7 @@ impl VectorStore for MongoDbVectorStore { impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection) -> Self { + pub fn new(collection: mongodb::Collection) -> Self { Self { collection } } @@ -144,7 +147,7 @@ impl MongoDbVectorStore { /// A vector index for a MongoDB collection. pub struct MongoDbVectorIndex { - collection: mongodb::Collection, + collection: mongodb::Collection, model: M, index_name: String, embedded_field: String, @@ -187,7 +190,7 @@ impl MongoDbVectorIndex { impl MongoDbVectorIndex { pub async fn new( - collection: mongodb::Collection, + collection: mongodb::Collection, model: M, index_name: &str, search_params: SearchParams, @@ -195,8 +198,8 @@ impl MongoDbVectorIndex { let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?; if !search_index.queryable { - return Err(VectorStoreError::Error( - "Index is not queryable".to_string(), + return Err(VectorStoreError::DatastoreError( + "Index is not queryable".into(), )); } @@ -207,8 +210,8 @@ impl MongoDbVectorIndex { .map(|field| field.path) .next() // This error shouldn't occur if the index is queryable - .ok_or(VectorStoreError::Error( - "No embedded fields found".to_string(), + .ok_or(VectorStoreError::DatastoreError( + "No embedded fields found".into(), ))?; Ok(Self { @@ -364,10 +367,3 @@ impl VectorStoreIndex for MongoDbV Ok(results) } } - -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] -pub struct DocumentResponse { - #[serde(rename = "_id")] - pub id: String, - pub document: serde_json::Value, -} From 427e39d3485ef54b4a8dc7cc81dde0f45d85173c Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Tue, 19 Nov 2024 19:49:54 -0800 Subject: [PATCH 6/6] style(mongodb): fmt --- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 60904719..5943ac3f 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -2,7 +2,7 @@ use mongodb::bson; use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; use rig::vector_store::VectorStore; use rig::{ - embeddings::{EmbeddingsBuilder}, + embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, };