Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(integration test): MongoDB #126

Merged
merged 15 commits into from
Dec 3, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
style: cargo fmt
  • Loading branch information
marieaurore123 committed Nov 26, 2024
commit 5f8333d664a254a6d01a7814d9766f58758ab3f8
56 changes: 25 additions & 31 deletions rig-mongodb/src/lib.rs
Original file line number Diff line number Diff line change
@@ -93,13 +93,11 @@ impl VectorStore for MongoDbVectorStore {
Ok(self
.collection
.clone_with_type::<String>()
.aggregate(
[
doc! {"$match": { "_id": id}},
doc! {"$project": { "document": 1 }},
doc! {"$replaceRoot": { "newRoot": "$document" }},
],
)
.aggregate([
doc! {"$match": { "_id": id}},
doc! {"$project": { "document": 1 }},
doc! {"$replaceRoot": { "newRoot": "$document" }},
])
.await
.map_err(mongodb_to_rig_error)?
.with_type::<String>()
@@ -282,19 +280,17 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbV

let mut cursor = self
.collection
.aggregate(
[
self.pipeline_search_stage(&prompt_embedding, n),
self.pipeline_score_stage(),
{
doc! {
"$project": {
self.embedded_field.clone(): 0,
},
}
},
],
)
.aggregate([
self.pipeline_search_stage(&prompt_embedding, n),
self.pipeline_score_stage(),
{
doc! {
"$project": {
self.embedded_field.clone(): 0,
},
}
},
])
.await
.map_err(mongodb_to_rig_error)?
.with_type::<serde_json::Value>();
@@ -328,18 +324,16 @@ impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for MongoDbV

let mut cursor = self
.collection
.aggregate(
[
self.pipeline_search_stage(&prompt_embedding, n),
self.pipeline_score_stage(),
doc! {
"$project": {
"_id": 1,
"score": 1
},
.aggregate([
self.pipeline_search_stage(&prompt_embedding, n),
self.pipeline_score_stage(),
doc! {
"$project": {
"_id": 1,
"score": 1
},
],
)
},
])
.await
.map_err(mongodb_to_rig_error)?
.with_type::<serde_json::Value>();
77 changes: 47 additions & 30 deletions rig-mongodb/tests/integration_tests.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
use mongodb::{bson::{self, doc}, options::ClientOptions, Collection, SearchIndexModel};
use rig::{embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, providers::openai, vector_store::VectorStoreIndex};
use mongodb::{
bson::{self, doc},
options::ClientOptions,
Collection, SearchIndexModel,
};
use rig::{
embeddings::{DocumentEmbeddings, EmbeddingsBuilder},
providers::openai,
vector_store::VectorStoreIndex,
};
use rig_mongodb::MongoDbVectorIndex;
use testcontainers::{core::{IntoContainerPort, WaitFor}, runners::AsyncRunner, GenericImage, ImageExt};
use testcontainers::{
core::{IntoContainerPort, WaitFor},
runners::AsyncRunner,
GenericImage, ImageExt,
};

const VECTOR_SEARCH_INDEX_NAME: &str = "vector_index";

#[tokio::test]
async fn integration_test() {
// Initialize OpenAI client
let openai_client = openai::Client::from_env();

// Setup local MongoDB Atlas
let container = GenericImage::new("mongodb/mongodb-atlas-local", "latest")
let container = GenericImage::new("mongodb/mongodb-atlas-local", "latest")
.with_exposed_port(27017.tcp())
.with_wait_for(WaitFor::Duration { length: std::time::Duration::from_secs(10) })
.with_wait_for(WaitFor::Duration {
length: std::time::Duration::from_secs(10),
})
.with_env_var("MONGODB_INITDB_ROOT_USERNAME", "riguser")
.with_env_var("MONGODB_INITDB_ROOT_PASSWORD", "rigpassword")
.start()
@@ -23,9 +37,11 @@ async fn integration_test() {
let port = container.get_host_port_ipv4(27017).await.unwrap();

// Initialize MongoDB client
let options = ClientOptions::parse(format!("mongodb://riguser:rigpassword@localhost:{port}/?directConnection=true"))
.await
.expect("MongoDB connection string should be valid");
let options = ClientOptions::parse(format!(
"mongodb://riguser:rigpassword@localhost:{port}/?directConnection=true"
))
.await
.expect("MongoDB connection string should be valid");

let mongodb_client =
mongodb::Client::with_options(options).expect("MongoDB client options should be valid");
@@ -43,21 +59,23 @@ async fn integration_test() {
.collection("fake_definitions");

// Create a vector search index
collection.create_search_index(
SearchIndexModel::builder()
.name(Some(VECTOR_SEARCH_INDEX_NAME.to_string()))
.index_type(Some(mongodb::SearchIndexType::VectorSearch))
.definition(doc! {
"fields": [{
"numDimensions": 1536,
"path": "embeddings.vec",
"similarity": "cosine",
"type": "vector"
}]
})
.build()
).await
.expect("Failed to create search index");
collection
.create_search_index(
SearchIndexModel::builder()
.name(Some(VECTOR_SEARCH_INDEX_NAME.to_string()))
.index_type(Some(mongodb::SearchIndexType::VectorSearch))
.definition(doc! {
"fields": [{
"numDimensions": 1536,
"path": "embeddings.vec",
"similarity": "cosine",
"type": "vector"
}]
})
.build(),
)
.await
.expect("Failed to create search index");

// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002);
@@ -83,7 +101,9 @@ async fn integration_test() {
model,
VECTOR_SEARCH_INDEX_NAME,
rig_mongodb::SearchParams::new(),
).await.expect("Failed to create Rig vector index");
)
.await
.expect("Failed to create Rig vector index");

// Query the index
let results = vector_index
@@ -93,8 +113,5 @@ async fn integration_test() {

let result_string = &results.first().unwrap().1;

assert_eq!(
result_string,
"\"doc2\""
);
}
assert_eq!(result_string, "\"doc2\"");
}