Skip to content

Commit

Permalink
fix(rig-mongodb): remove embeddings from top_n lookup (#115)
Browse files Browse the repository at this point in the history
* fix(mongodb): remove embeddings from `top_n` lookup

* fix(mongodb): filter embeddings within agg pipeline

* style(mongodb): clippy moment

* fix(mongodb): dynamically get embedded fields from mongodb

* fix(mongodb): apply fixes from comments

* style(mongodb): fmt
  • Loading branch information
0xMochan authored Nov 22, 2024
1 parent f8bc80c commit b55075e
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 18 deletions.
19 changes: 15 additions & 4 deletions rig-mongodb/examples/vector_search_mongodb.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
use mongodb::bson;
use mongodb::{options::ClientOptions, Client as MongoClient, Collection};
use rig::vector_store::VectorStore;
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
embeddings::EmbeddingsBuilder,
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::VectorStoreIndex,
};
use rig_mongodb::{MongoDbVectorStore, SearchParams};
use serde::{Deserialize, Serialize};
use std::env;

#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)]
pub struct DocumentResponse {
#[serde(rename = "_id")]
pub id: String,
pub document: serde_json::Value,
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize OpenAI client
Expand All @@ -25,7 +34,7 @@ async fn main() -> Result<(), anyhow::Error> {
MongoClient::with_options(options).expect("MongoDB client options should be valid");

// Initialize MongoDB vector store
let collection: Collection<DocumentEmbeddings> = mongodb_client
let collection: Collection<bson::Document> = mongodb_client
.database("knowledgebase")
.collection("context");

Expand All @@ -49,11 +58,13 @@ async fn main() -> Result<(), anyhow::Error> {

// Create a vector index on our vector store
// IMPORTANT: Reuse the same model that was used to generate the embeddings
let index = vector_store.index(model, "vector_index", SearchParams::default());
let index = vector_store
.index(model, "vector_index", SearchParams::default())
.await?;

// Query the index
let results = index
.top_n::<DocumentEmbeddings>("What is a linglingdong?", 1)
.top_n::<DocumentResponse>("What is a linglingdong?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc.document))
Expand Down
104 changes: 90 additions & 14 deletions rig-mongodb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,56 @@ use rig::{
embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel},
vector_store::{VectorStore, VectorStoreError, VectorStoreIndex},
};
use serde::Deserialize;
use serde::{Deserialize, Serialize};

/// A MongoDB vector store.
pub struct MongoDbVectorStore {
collection: mongodb::Collection<DocumentEmbeddings>,
collection: mongodb::Collection<bson::Document>,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct SearchIndex {
id: String,
name: String,
#[serde(rename = "type")]
index_type: String,
status: String,
queryable: bool,
latest_definition: LatestDefinition,
}

impl SearchIndex {
async fn get_search_index(
collection: mongodb::Collection<bson::Document>,
index_name: &str,
) -> Result<SearchIndex, VectorStoreError> {
collection
.list_search_indexes(index_name, None, None)
.await
.map_err(mongodb_to_rig_error)?
.with_type::<SearchIndex>()
.next()
.await
.transpose()
.map_err(mongodb_to_rig_error)?
.ok_or(VectorStoreError::DatastoreError("Index not found".into()))
}
}

#[derive(Debug, Serialize, Deserialize)]
struct LatestDefinition {
fields: Vec<Field>,
}

#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Field {
#[serde(rename = "type")]
field_type: String,
path: String,
num_dimensions: i32,
similarity: String,
}

fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {
Expand All @@ -24,6 +69,7 @@ impl VectorStore for MongoDbVectorStore {
documents: Vec<DocumentEmbeddings>,
) -> Result<(), VectorStoreError> {
self.collection
.clone_with_type::<DocumentEmbeddings>()
.insert_many(documents, None)
.await
.map_err(mongodb_to_rig_error)?;
Expand All @@ -35,6 +81,7 @@ impl VectorStore for MongoDbVectorStore {
id: &str,
) -> Result<Option<DocumentEmbeddings>, VectorStoreError> {
self.collection
.clone_with_type::<DocumentEmbeddings>()
.find_one(doc! { "_id": id }, None)
.await
.map_err(mongodb_to_rig_error)
Expand Down Expand Up @@ -71,6 +118,7 @@ impl VectorStore for MongoDbVectorStore {
query: Self::Q,
) -> Result<Option<DocumentEmbeddings>, VectorStoreError> {
self.collection
.clone_with_type::<DocumentEmbeddings>()
.find_one(query, None)
.await
.map_err(mongodb_to_rig_error)
Expand All @@ -79,29 +127,30 @@ impl VectorStore for MongoDbVectorStore {

impl MongoDbVectorStore {
/// Create a new `MongoDbVectorStore` from a MongoDB collection.
pub fn new(collection: mongodb::Collection<DocumentEmbeddings>) -> Self {
pub fn new(collection: mongodb::Collection<bson::document::Document>) -> Self {
Self { collection }
}

/// Create a new `MongoDbVectorIndex` from an existing `MongoDbVectorStore`.
///
/// The index (of type "vector") must already exist for the MongoDB collection.
/// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes.
pub fn index<M: EmbeddingModel>(
pub async fn index<M: EmbeddingModel>(
&self,
model: M,
index_name: &str,
search_params: SearchParams,
) -> MongoDbVectorIndex<M> {
MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params)
) -> Result<MongoDbVectorIndex<M>, VectorStoreError> {
MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params).await
}
}

/// A vector index for a MongoDB collection.
pub struct MongoDbVectorIndex<M: EmbeddingModel> {
collection: mongodb::Collection<DocumentEmbeddings>,
collection: mongodb::Collection<bson::Document>,
model: M,
index_name: String,
embedded_field: String,
search_params: SearchParams,
}

Expand All @@ -118,7 +167,7 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> {
doc! {
"$vectorSearch": {
"index": &self.index_name,
"path": "embeddings.vec",
"path": self.embedded_field.clone(),
"queryVector": &prompt_embedding.vec,
"numCandidates": num_candidates.unwrap_or((n * 10) as u32),
"limit": n as u32,
Expand All @@ -140,22 +189,42 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> {
}

impl<M: EmbeddingModel> MongoDbVectorIndex<M> {
pub fn new(
collection: mongodb::Collection<DocumentEmbeddings>,
pub async fn new(
collection: mongodb::Collection<bson::Document>,
model: M,
index_name: &str,
search_params: SearchParams,
) -> Self {
Self {
) -> Result<Self, VectorStoreError> {
let search_index = SearchIndex::get_search_index(collection.clone(), index_name).await?;

if !search_index.queryable {
return Err(VectorStoreError::DatastoreError(
"Index is not queryable".into(),
));
}

let embedded_field = search_index
.latest_definition
.fields
.into_iter()
.map(|field| field.path)
.next()
// This error shouldn't occur if the index is queryable
.ok_or(VectorStoreError::DatastoreError(
"No embedded fields found".into(),
))?;

Ok(Self {
collection,
model,
index_name: index_name.to_string(),
embedded_field,
search_params,
}
})
}
}

/// See [MongoDB Vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information
/// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information
/// on each of the fields
pub struct SearchParams {
filter: mongodb::bson::Document,
Expand Down Expand Up @@ -219,6 +288,13 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbV
[
self.pipeline_search_stage(&prompt_embedding, n),
self.pipeline_score_stage(),
{
doc! {
"$project": {
self.embedded_field.clone(): 0,
},
}
},
],
None,
)
Expand Down

0 comments on commit b55075e

Please sign in to comment.