Skip to content

Commit

Permalink
fix(mongodb): apply fixes from comments
Browse files Browse the repository at this point in the history
  • Loading branch information
0xMochan committed Nov 20, 2024
1 parent 5c1b445 commit 8eff660
Showing 3 changed files with 25 additions and 23 deletions.
3 changes: 0 additions & 3 deletions rig-core/src/vector_store/mod.rs
Original file line number Diff line number Diff line change
@@ -17,9 +17,6 @@ pub enum VectorStoreError {

#[error("Datastore error: {0}")]
DatastoreError(#[from] Box<dyn std::error::Error + Send + Sync>),

#[error("Vector store error: {0}")]
Error(String),
}

/// Trait for vector stores
15 changes: 12 additions & 3 deletions rig-mongodb/examples/vector_search_mongodb.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
use mongodb::bson;
use mongodb::{options::ClientOptions, Client as MongoClient, Collection};

Check warning on line 2 in rig-mongodb/examples/vector_search_mongodb.rs

GitHub Actions / stable / fmt

Diff in /home/runner/work/rig/rig/rig-mongodb/examples/vector_search_mongodb.rs
use rig::vector_store::VectorStore;
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
embeddings::{EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::VectorStoreIndex,
};
use rig_mongodb::{DocumentResponse, MongoDbVectorStore, SearchParams};
use rig_mongodb::{MongoDbVectorStore, SearchParams};
use serde::{Deserialize, Serialize};
use std::env;

#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)]
pub struct DocumentResponse {
#[serde(rename = "_id")]
pub id: String,
pub document: serde_json::Value,
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize OpenAI client
@@ -25,7 +34,7 @@ async fn main() -> Result<(), anyhow::Error> {
MongoClient::with_options(options).expect("MongoDB client options should be valid");

// Initialize MongoDB vector store
let collection: Collection<DocumentEmbeddings> = mongodb_client
let collection: Collection<bson::Document> = mongodb_client
.database("knowledgebase")
.collection("context");

30 changes: 13 additions & 17 deletions rig-mongodb/src/lib.rs
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize};

/// A MongoDB vector store.
pub struct MongoDbVectorStore {
collection: mongodb::Collection<DocumentEmbeddings>,
collection: mongodb::Collection<bson::Document>,
}

#[derive(Debug, Serialize, Deserialize)]
@@ -26,7 +26,7 @@ struct SearchIndex {

impl SearchIndex {
async fn get_search_index(
collection: mongodb::Collection<DocumentEmbeddings>,
collection: mongodb::Collection<bson::Document>,
index_name: &str,
) -> Result<SearchIndex, VectorStoreError> {
collection
@@ -38,7 +38,7 @@ impl SearchIndex {
.await
.transpose()
.map_err(mongodb_to_rig_error)?
.ok_or(VectorStoreError::Error("Index not found".to_string()))
.ok_or(VectorStoreError::DatastoreError("Index not found".into()))
}
}

@@ -69,6 +69,7 @@ impl VectorStore for MongoDbVectorStore {
documents: Vec<DocumentEmbeddings>,
) -> Result<(), VectorStoreError> {
self.collection
.clone_with_type::<DocumentEmbeddings>()
.insert_many(documents, None)
.await
.map_err(mongodb_to_rig_error)?;
@@ -80,6 +81,7 @@ impl VectorStore for MongoDbVectorStore {
id: &str,
) -> Result<Option<DocumentEmbeddings>, VectorStoreError> {
self.collection
.clone_with_type::<DocumentEmbeddings>()
.find_one(doc! { "_id": id }, None)
.await
.map_err(mongodb_to_rig_error)
@@ -116,6 +118,7 @@ impl VectorStore for MongoDbVectorStore {
query: Self::Q,
) -> Result<Option<DocumentEmbeddings>, VectorStoreError> {
self.collection
.clone_with_type::<DocumentEmbeddings>()
.find_one(query, None)
.await
.map_err(mongodb_to_rig_error)
@@ -124,7 +127,7 @@ impl VectorStore for MongoDbVectorStore {

impl MongoDbVectorStore {
/// Create a new `MongoDbVectorStore` from a MongoDB collection.
pub fn new(collection: mongodb::Collection<DocumentEmbeddings>) -> Self {
pub fn new(collection: mongodb::Collection<bson::document::Document>) -> Self {
Self { collection }
}

@@ -144,7 +147,7 @@ impl MongoDbVectorStore {

/// A vector index for a MongoDB collection.
pub struct MongoDbVectorIndex<M: EmbeddingModel> {
collection: mongodb::Collection<DocumentEmbeddings>,
collection: mongodb::Collection<bson::Document>,
model: M,
index_name: String,
embedded_field: String,
@@ -187,16 +190,16 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> {

impl<M: EmbeddingModel> MongoDbVectorIndex<M> {
pub async fn new(
collection: mongodb::Collection<DocumentEmbeddings>,
collection: mongodb::Collection<bson::Document>,
model: M,
index_name: &str,
search_params: SearchParams,
) -> Result<Self, VectorStoreError> {
let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;

if !search_index.queryable {
return Err(VectorStoreError::Error(
"Index is not queryable".to_string(),
return Err(VectorStoreError::DatastoreError(
"Index is not queryable".into(),
));
}

@@ -207,8 +210,8 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> {
.map(|field| field.path)
.next()
// This error shouldn't occur if the index is queryable
.ok_or(VectorStoreError::Error(
"No embedded fields found".to_string(),
.ok_or(VectorStoreError::DatastoreError(
"No embedded fields found".into(),
))?;

Ok(Self {
@@ -364,10 +367,3 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbV
Ok(results)
}
}

#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)]
pub struct DocumentResponse {
#[serde(rename = "_id")]
pub id: String,
pub document: serde_json::Value,
}

0 comments on commit 8eff660

Please sign in to comment.