diff --git a/Cargo.lock b/Cargo.lock index 62e304f9..59e0a63b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -4862,6 +4862,7 @@ dependencies = [ "serde_json", "testcontainers", "tokio", + "tokio-test", "tracing", ] diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index 16fd0861..3de846cf 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -21,6 +21,7 @@ tracing = "0.1.40" anyhow = "1.0.86" testcontainers = "0.23.1" tokio = { version = "1.38.0", features = ["macros"] } +tokio-test = "0.4.4" [[example]] name = "vector_search_mongodb" diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 8953e2b5..1e722b27 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -59,26 +59,39 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { /// A vector index for a MongoDB collection. /// # Example -/// ``` +/// ```rust /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; -/// use rig::embeddings::EmbeddingModel; +/// use rig::{providers::openai, vector_store::VectorStoreIndex}; /// -/// #[derive(serde::Serialize, Debug)] -/// struct Document { +/// # tokio_test::block_on(async { +/// #[derive(serde::Deserialize, serde::Serialize, Debug)] +/// struct WordDefinition { /// #[serde(rename = "_id")] /// id: String, /// definition: String, /// embedding: Vec, /// } /// -/// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. -/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. +/// let mongodb_client = mongodb::Client::with_uri_str("mongodb://localhost:27017").await?; // <-- replace with your mongodb uri. +/// let openai_client = openai::Client::from_env(); +/// +/// let collection = mongodb_client.database("db").collection::(""); // <-- replace with your mongodb collection. +/// +/// let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. /// let index = MongoDbVectorIndex::new( /// collection, /// model, /// "vector_index", // <-- replace with the name of the index in your mongodb collection. -/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. -/// ); +/// SearchParams::new(), // <-- field name in `Document` that contains the embeddings. +/// ) +/// .await?; +/// +/// // Query the index +/// let definitions = index +/// .top_n::("My boss says I zindle too much, what does that mean?", 1) +/// .await?; +/// # Ok::<_, anyhow::Error>(()) +/// # }).unwrap() /// ``` pub struct MongoDbVectorIndex { collection: mongodb::Collection, @@ -211,41 +224,6 @@ impl VectorStoreIndex for MongoDbVectorIndex { /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. - /// # Example - /// ``` - /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; - /// use rig::embeddings::EmbeddingModel; - /// - /// #[derive(serde::Serialize, Debug)] - /// struct Document { - /// #[serde(rename = "_id")] - /// id: String, - /// definition: String, - /// embedding: Vec, - /// } - /// - /// #[derive(serde::Deserialize, Debug)] - /// struct Definition { - /// #[serde(rename = "_id")] - /// id: String, - /// definition: String, - /// } - /// - /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. - /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. - /// - /// let vector_store_index = MongoDbVectorIndex::new( - /// collection, - /// model, - /// "vector_index", // <-- replace with the name of the index in your mongodb collection. - /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. - /// ); - /// - /// // Query the index - /// vector_store_index - /// .top_n::("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// ``` async fn top_n Deserialize<'a> + Send>( &self, query: &str, @@ -291,33 +269,6 @@ impl VectorStoreIndex } /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. - /// # Example - /// ``` - /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; - /// use rig::embeddings::EmbeddingModel; - /// - /// #[derive(serde::Serialize, Debug)] - /// struct Document { - /// #[serde(rename = "_id")] - /// id: String, - /// definition: String, - /// embedding: Vec, - /// } - /// - /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. - /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. - /// let vector_store_index = MongoDbVectorIndex::new( - /// collection, - /// model, - /// "vector_index", // <-- replace with the name of the index in your mongodb collection. - /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. - /// ); - /// - /// // Query the index - /// vector_store_index - /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) - /// .await?; - /// ``` async fn top_n_ids( &self, query: &str, diff --git a/rig-mongodb/tests/integration_tests.rs b/rig-mongodb/tests/integration_tests.rs index 6968ec1a..9b340ae2 100644 --- a/rig-mongodb/tests/integration_tests.rs +++ b/rig-mongodb/tests/integration_tests.rs @@ -71,7 +71,7 @@ async fn vector_search_test() { .await .unwrap(); - sleep(Duration::from_secs(5)).await; + sleep(Duration::from_secs(15)).await; // Query the index let results = index