Skip to content

Commit

Permalink
test(rig-mongodb): fix flaky test (#153)
Browse files Browse the repository at this point in the history
* docs(rig-mongodb): Fix docstrings

* docs(rig-mongodb): Improve docstrings more

* test(rig-mongodb): Fix flaky test

* style: cargo fmt
  • Loading branch information
cvauclair authored Dec 16, 2024
1 parent 1010a15 commit 619e8a9
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 72 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rig-mongodb/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ tracing = "0.1.40"
anyhow = "1.0.86"
testcontainers = "0.23.1"
tokio = { version = "1.38.0", features = ["macros"] }
tokio-test = "0.4.4"

[[example]]
name = "vector_search_mongodb"
Expand Down
91 changes: 21 additions & 70 deletions rig-mongodb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,39 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError {

/// A vector index for a MongoDB collection.
/// # Example
/// ```
/// ```rust
/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
/// use rig::embeddings::EmbeddingModel;
/// use rig::{providers::openai, vector_store::VectorStoreIndex};
///
/// #[derive(serde::Serialize, Debug)]
/// struct Document {
/// # tokio_test::block_on(async {
/// #[derive(serde::Deserialize, serde::Serialize, Debug)]
/// struct WordDefinition {
/// #[serde(rename = "_id")]
/// id: String,
/// definition: String,
/// embedding: Vec<f64>,
/// }
///
/// let collection: collection: mongodb::Collection<Document> = mongodb_client.collection(""); // <-- replace with your mongodb collection.
/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
/// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri.
/// let openai_client = openai::Client::from_env();
///
/// let collection = mongodb_client.database("db").collection::<WordDefinition>(""); // <-- replace with your mongodb collection.
///
/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
/// let index = MongoDbVectorIndex::new(
/// collection,
/// model,
/// "vector_index", // <-- replace with the name of the index in your mongodb collection.
/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings.
/// );
/// SearchParams::new(), // <-- field name in `Document` that contains the embeddings.
/// )
/// .await?;
///
/// // Query the index
/// let definitions = index
/// .top_n::<WordDefinition>("My boss says I zindle too much, what does that mean?", 1)
/// .await?;
/// # Ok::<_, anyhow::Error>(())
/// # }).unwrap()
/// ```
pub struct MongoDbVectorIndex<M: EmbeddingModel, C: Send + Sync> {
collection: mongodb::Collection<C>,
Expand Down Expand Up @@ -211,41 +224,6 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex
for MongoDbVectorIndex<M, C>
{
/// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
/// # Example
/// ```
/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
/// use rig::embeddings::EmbeddingModel;
///
/// #[derive(serde::Serialize, Debug)]
/// struct Document {
/// #[serde(rename = "_id")]
/// id: String,
/// definition: String,
/// embedding: Vec<f64>,
/// }
///
/// #[derive(serde::Deserialize, Debug)]
/// struct Definition {
/// #[serde(rename = "_id")]
/// id: String,
/// definition: String,
/// }
///
/// let collection: collection: mongodb::Collection<Document> = mongodb_client.collection(""); // <-- replace with your mongodb collection.
/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
///
/// let vector_store_index = MongoDbVectorIndex::new(
/// collection,
/// model,
/// "vector_index", // <-- replace with the name of the index in your mongodb collection.
/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings.
/// );
///
/// // Query the index
/// vector_store_index
/// .top_n::<Definition>("My boss says I zindle too much, what does that mean?", 1)
/// .await?;
/// ```
async fn top_n<T: for<'a> Deserialize<'a> + Send>(
&self,
query: &str,
Expand Down Expand Up @@ -291,33 +269,6 @@ impl<M: EmbeddingModel + Sync + Send, C: Sync + Send> VectorStoreIndex
}

/// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`.
/// # Example
/// ```
/// use rig_mongodb::{MongoDbVectorIndex, SearchParams};
/// use rig::embeddings::EmbeddingModel;
///
/// #[derive(serde::Serialize, Debug)]
/// struct Document {
/// #[serde(rename = "_id")]
/// id: String,
/// definition: String,
/// embedding: Vec<f64>,
/// }
///
/// let collection: collection: mongodb::Collection<Document> = mongodb_client.collection(""); // <-- replace with your mongodb collection.
/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model.
/// let vector_store_index = MongoDbVectorIndex::new(
/// collection,
/// model,
/// "vector_index", // <-- replace with the name of the index in your mongodb collection.
/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings.
/// );
///
/// // Query the index
/// vector_store_index
/// .top_n_ids("My boss says I zindle too much, what does that mean?", 1)
/// .await?;
/// ```
async fn top_n_ids(
&self,
query: &str,
Expand Down
2 changes: 1 addition & 1 deletion rig-mongodb/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async fn vector_search_test() {
.await
.unwrap();

sleep(Duration::from_secs(5)).await;
sleep(Duration::from_secs(15)).await;

// Query the index
let results = index
Expand Down

0 comments on commit 619e8a9

Please sign in to comment.