From 8eff66086437471a50746272d2bb6e1f027cc98a Mon Sep 17 00:00:00 2001 From: 0xMochan Date: Tue, 19 Nov 2024 19:36:09 -0800 Subject: [PATCH] fix(mongodb): apply fixes from comments --- rig-core/src/vector_store/mod.rs | 3 -- rig-mongodb/examples/vector_search_mongodb.rs | 15 ++++++++-- rig-mongodb/src/lib.rs | 30 ++++++++----------- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index dbf4bb4e..405ef83b 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -17,9 +17,6 @@ pub enum VectorStoreError { #[error("Datastore error: {0}")] DatastoreError(#[from] Box), - - #[error("Vector store error: {0}")] - Error(String), } /// Trait for vector stores diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 6903a510..60904719 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,13 +1,22 @@ +use mongodb::bson; use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; use rig::vector_store::VectorStore; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::{EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; -use rig_mongodb::{DocumentResponse, MongoDbVectorStore, SearchParams}; +use rig_mongodb::{MongoDbVectorStore, SearchParams}; +use serde::{Deserialize, Serialize}; use std::env; +#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] +pub struct DocumentResponse { + #[serde(rename = "_id")] + pub id: String, + pub document: serde_json::Value, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -25,7 +34,7 @@ async fn main() -> Result<(), anyhow::Error> { MongoClient::with_options(options).expect("MongoDB client options should be valid"); // Initialize MongoDB vector store - let collection: Collection = mongodb_client + let collection: Collection = mongodb_client .database("knowledgebase") .collection("context"); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 1b2343bf..56c0009d 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; /// A MongoDB vector store. pub struct MongoDbVectorStore { - collection: mongodb::Collection, + collection: mongodb::Collection, } #[derive(Debug, Serialize, Deserialize)] @@ -26,7 +26,7 @@ struct SearchIndex { impl SearchIndex { async fn get_search_index( - collection: mongodb::Collection, + collection: mongodb::Collection, index_name: &str, ) -> Result { collection @@ -38,7 +38,7 @@ impl SearchIndex { .await .transpose() .map_err(mongodb_to_rig_error)? - .ok_or(VectorStoreError::Error("Index not found".to_string())) + .ok_or(VectorStoreError::DatastoreError("Index not found".into())) } } @@ -69,6 +69,7 @@ impl VectorStore for MongoDbVectorStore { documents: Vec, ) -> Result<(), VectorStoreError> { self.collection + .clone_with_type::() .insert_many(documents, None) .await .map_err(mongodb_to_rig_error)?; @@ -80,6 +81,7 @@ impl VectorStore for MongoDbVectorStore { id: &str, ) -> Result, VectorStoreError> { self.collection + .clone_with_type::() .find_one(doc! { "_id": id }, None) .await .map_err(mongodb_to_rig_error) @@ -116,6 +118,7 @@ impl VectorStore for MongoDbVectorStore { query: Self::Q, ) -> Result, VectorStoreError> { self.collection + .clone_with_type::() .find_one(query, None) .await .map_err(mongodb_to_rig_error) @@ -124,7 +127,7 @@ impl VectorStore for MongoDbVectorStore { impl MongoDbVectorStore { /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection) -> Self { + pub fn new(collection: mongodb::Collection) -> Self { Self { collection } } @@ -144,7 +147,7 @@ impl MongoDbVectorStore { /// A vector index for a MongoDB collection. pub struct MongoDbVectorIndex { - collection: mongodb::Collection, + collection: mongodb::Collection, model: M, index_name: String, embedded_field: String, @@ -187,7 +190,7 @@ impl MongoDbVectorIndex { impl MongoDbVectorIndex { pub async fn new( - collection: mongodb::Collection, + collection: mongodb::Collection, model: M, index_name: &str, search_params: SearchParams, @@ -195,8 +198,8 @@ impl MongoDbVectorIndex { let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?; if !search_index.queryable { - return Err(VectorStoreError::Error( - "Index is not queryable".to_string(), + return Err(VectorStoreError::DatastoreError( + "Index is not queryable".into(), )); } @@ -207,8 +210,8 @@ impl MongoDbVectorIndex { .map(|field| field.path) .next() // This error shouldn't occur if the index is queryable - .ok_or(VectorStoreError::Error( - "No embedded fields found".to_string(), + .ok_or(VectorStoreError::DatastoreError( + "No embedded fields found".into(), ))?; Ok(Self { @@ -364,10 +367,3 @@ impl VectorStoreIndex for MongoDbV Ok(results) } } - -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] -pub struct DocumentResponse { - #[serde(rename = "_id")] - pub id: String, - pub document: serde_json::Value, -}