Skip to content

Commit

Permalink
fix: error caused by container coming out of scope
Browse files Browse the repository at this point in the history
  • Loading branch information
marieaurore123 committed Dec 2, 2024
1 parent d9c86ee commit f71fd5d
Showing 1 changed file with 27 additions and 31 deletions.
58 changes: 27 additions & 31 deletions rig-mongodb/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use rig::{
embeddings::EmbeddingsBuilder, providers::openai, vector_store::VectorStoreIndex, Embed,
};
use rig_mongodb::{MongoDbVectorIndex, SearchParams};
use serde_json::json;
use testcontainers::{
core::{IntoContainerPort, WaitFor},
runners::AsyncRunner,
Expand All @@ -26,27 +27,33 @@ const MONGODB_PORT: u16 = 27017;
const COLLECTION_NAME: &str = "fake_definitions";
const DATABASE_NAME: &str = "rig";

/// Setup a local MongoDB Atlas container for testing. NOTE: docker service must be running.
/// This includes running the container with `testcontainers`, and creating a database and collection
/// that will be used by integration tests.
async fn setup_mongo_server() -> Collection<bson::Document> {
// Setup local MongoDB Atlas
#[tokio::test]
async fn vector_search_test() {
// Setup a local MongoDB Atlas container for testing. NOTE: docker service must be running.
let container = GenericImage::new("mongodb/mongodb-atlas-local", "latest")
.with_exposed_port(MONGODB_PORT.tcp())
.with_wait_for(WaitFor::Duration {
length: std::time::Duration::from_secs(10),
length: std::time::Duration::from_secs(5),
})
.with_env_var("MONGODB_INITDB_ROOT_USERNAME", "riguser")
.with_env_var("MONGODB_INITDB_ROOT_PASSWORD", "rigpassword")
.start()
.await
.expect("Failed to start MongoDB Atlas container");

// Initialize OpenAI client
let openai_client = openai::Client::from_env();

// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002);

let port = container.get_host_port_ipv4(MONGODB_PORT).await.unwrap();

let host = container.get_host().await.unwrap().to_string();

// Initialize MongoDB client
let options = ClientOptions::parse(format!(
"mongodb://riguser:rigpassword@localhost:{port}/?directConnection=true"
"mongodb://riguser:rigpassword@{host}:{port}/?directConnection=true"
))
.await
.expect("MongoDB connection string should be valid");
Expand Down Expand Up @@ -75,7 +82,7 @@ async fn setup_mongo_server() -> Collection<bson::Document> {
.definition(doc! {
"fields": [{
"numDimensions": 1536,
"path": "embeddings.vec",
"path": "embedding",
"similarity": "cosine",
"type": "vector"
}]
Expand All @@ -85,24 +92,6 @@ async fn setup_mongo_server() -> Collection<bson::Document> {
.await
.expect("Failed to create search index");

collection
}

#[tokio::test]
async fn vector_search_test() {
// Initialize OpenAI client
let openai_client = openai::Client::from_env();

let collection = setup_mongo_server().await;

// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002);

let linglingdong = FakeDefinition {
id: "doc2".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
};

let fake_definitions = vec![
FakeDefinition {
id: "doc0".to_string(),
Expand All @@ -112,7 +101,10 @@ async fn vector_search_test() {
id: "doc1".to_string(),
definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(),
},
linglingdong.clone()
FakeDefinition {
id: "doc2".to_string(),
definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
}
];

let embeddings = EmbeddingsBuilder::new(model.clone())
Expand All @@ -126,7 +118,7 @@ async fn vector_search_test() {
.iter()
.map(|(FakeDefinition { id, definition, .. }, embedding)| {
doc! {
"id": id.clone(),
"_id": id.clone(),
"definition": definition.clone(),
"embedding": embedding.first().vec.clone(),
}
Expand Down Expand Up @@ -160,10 +152,14 @@ async fn vector_search_test() {
.expect("Failed to query vector index");
}

let result_string = &results.first().unwrap();
let (score, _, value) = &results.first().unwrap();

assert_eq!(
result_string.2,
serde_json::to_value(&linglingdong).unwrap()
*value,
json!({
"_id": "doc2".to_string(),
"definition": "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(),
"score": score
})
);
}

0 comments on commit f71fd5d

Please sign in to comment.