From ccea2d72e141d6c06b2b490eac88965f6f0dda66 Mon Sep 17 00:00:00 2001 From: Garance Buricatu Date: Mon, 2 Dec 2024 15:11:21 -0500 Subject: [PATCH 01/11] setup: integration test --- Cargo.lock | 251 ++++++++++++++++++++++++++- rig-neo4j/Cargo.toml | 5 + rig-neo4j/tests/integration_tests.rs | 72 ++++++++ 3 files changed, 327 insertions(+), 1 deletion(-) create mode 100644 rig-neo4j/tests/integration_tests.rs diff --git a/Cargo.lock b/Cargo.lock index 2d778dde..abbddf11 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -893,6 +893,56 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bollard" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d41711ad46fda47cd701f6908e59d1bd6b9a2b7464c0d0aeab95c6d37096ff8a" +dependencies = [ + "base64 0.22.1", + "bollard-stubs", + "bytes", + "futures-core", + "futures-util", + "hex", + "home", + "http 1.1.0", + "http-body-util", + "hyper 1.5.1", + "hyper-named-pipe", + "hyper-rustls 0.27.3", + "hyper-util", + "hyperlocal", + "log", + "pin-project-lite", + "rustls 0.23.18", + "rustls-native-certs 0.7.3", + "rustls-pemfile 2.2.0", + "rustls-pki-types", + "serde", + "serde_derive", + "serde_json", + "serde_repr", + "serde_urlencoded", + "thiserror 1.0.69", + "tokio", + "tokio-util", + "tower-service", + "url", + "winapi", +] + +[[package]] +name = "bollard-stubs" +version = "1.45.0-rc.26.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d7c5415e3a6bc6d3e99eff6268e488fd4ee25e7b28c10f08fa6760bd9de16e4" +dependencies = [ + "serde", + "serde_repr", + "serde_with 3.11.0", +] + [[package]] name = "bson" version = "2.13.0" @@ -1808,6 +1858,17 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" +[[package]] +name = "docker_credential" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31951f49556e34d90ed28342e1df7e1cb7a229c4cab0aecc627b5d91edd41d07" +dependencies = [ + "base64 0.21.7", + "serde", + "serde_json", +] + [[package]] name = "downcast-rs" version = "1.2.1" @@ -1872,6 +1933,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + [[package]] name = "event-listener" version = "2.5.3" @@ -1910,6 +1982,18 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +[[package]] +name = "filetime" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35c0522e981e68cbfa8c3f978441a5f34b30b96e146b33cd3359176b50fe8586" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.59.0", +] + [[package]] name = "fixedbitset" version = "0.4.2" @@ -2269,6 +2353,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "hostname" version = "0.3.1" @@ -2405,6 +2498,21 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-named-pipe" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73b7d8abf35697b81a825e386fc151e0d503e8cb5fcb93cc8669c376dfd6f278" +dependencies = [ + "hex", + "hyper 1.5.1", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", + "winapi", +] + [[package]] name = "hyper-rustls" version = "0.24.2" @@ -2485,6 +2593,21 @@ dependencies = [ "tracing", ] +[[package]] +name = "hyperlocal" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "986c5ce3b994526b3cd75578e62554abd09f0899d6206de48b3e96ab34ccc8c7" +dependencies = [ + "hex", + "http-body-util", + "hyper 1.5.1", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyperloglogplus" version = "0.4.1" @@ -3321,6 +3444,7 @@ checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ "bitflags 2.6.0", "libc", + "redox_syscall 0.5.7", ] [[package]] @@ -3919,11 +4043,36 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.7", "smallvec", "windows-targets 0.52.6", ] +[[package]] +name = "parse-display" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "914a1c2265c98e2446911282c6ac86d8524f495792c38c5bd884f80499c7538a" +dependencies = [ + "parse-display-derive", + "regex", + "regex-syntax", +] + +[[package]] +name = "parse-display-derive" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ae7800a4c974efd12df917266338e79a7a74415173caf7e70aa0a0707345281" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "regex-syntax", + "structmeta", + "syn 2.0.89", +] + [[package]] name = "parse-zoneinfo" version = "0.3.1" @@ -4417,6 +4566,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.5.7" @@ -4645,6 +4803,7 @@ dependencies = [ "serde", "serde_json", "term_size", + "testcontainers", "textwrap", "tokio", "tracing", @@ -5074,6 +5233,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_repr" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -5360,6 +5530,29 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "structmeta" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e1575d8d40908d70f6fd05537266b90ae71b15dbbe7a8b7dffa2b759306d329" +dependencies = [ + "proc-macro2", + "quote", + "structmeta-derive", + "syn 2.0.89", +] + +[[package]] +name = "structmeta-derive" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "strum" version = "0.26.3" @@ -5645,6 +5838,35 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" +[[package]] +name = "testcontainers" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f40cc2bd72e17f328faf8ca7687fe337e61bccd8acf9674fa78dd3792b045e1" +dependencies = [ + "async-trait", + "bollard", + "bollard-stubs", + "bytes", + "docker_credential", + "either", + "etcetera", + "futures", + "log", + "memchr", + "parse-display", + "pin-project-lite", + "serde", + "serde_json", + "serde_with 3.11.0", + "thiserror 1.0.69", + "tokio", + "tokio-stream", + "tokio-tar", + "tokio-util", + "url", +] + [[package]] name = "textwrap" version = "0.16.1" @@ -5842,6 +6064,21 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tar" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d5714c010ca3e5c27114c1cdeb9d14641ace49874aa5626d7149e47aedace75" +dependencies = [ + "filetime", + "futures-core", + "libc", + "redox_syscall 0.3.5", + "tokio", + "tokio-stream", + "xattr", +] + [[package]] name = "tokio-test" version = "0.4.4" @@ -6162,6 +6399,7 @@ dependencies = [ "form_urlencoded", "idna 1.0.3", "percent-encoding", + "serde", ] [[package]] @@ -6629,6 +6867,17 @@ dependencies = [ "tap", ] +[[package]] +name = "xattr" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" +dependencies = [ + "libc", + "linux-raw-sys 0.4.14", + "rustix 0.38.41", +] + [[package]] name = "xmlparser" version = "0.13.6" diff --git a/rig-neo4j/Cargo.toml b/rig-neo4j/Cargo.toml index aae275c7..2a2c2fe2 100644 --- a/rig-neo4j/Cargo.toml +++ b/rig-neo4j/Cargo.toml @@ -22,7 +22,12 @@ anyhow = "1.0.86" tokio = { version = "1.38.0", features = ["macros"] } textwrap = { version = "0.16.1"} term_size = { version = "0.3.2"} +testcontainers = "0.23.1" [[example]] name = "vector_search_simple" +required-features = ["rig-core/derive"] + +[[test]] +name = "integration_tests" required-features = ["rig-core/derive"] \ No newline at end of file diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs new file mode 100644 index 00000000..c17e405d --- /dev/null +++ b/rig-neo4j/tests/integration_tests.rs @@ -0,0 +1,72 @@ +use testcontainers::{ + core::{IntoContainerPort, Mount, WaitFor}, + runners::AsyncRunner, + GenericImage, ImageExt, +}; + +use neo4rs::{ConfigBuilder, Graph}; +use rig_neo4j::{ + vector_index::{IndexConfig, SearchParams}, + Neo4jClient, ToBoltType, +}; +use rig::embeddings::EmbeddingsBuilder; + +const BOLT_PORT: u16 = 7687; +const HTTP_PORT: u16 = 7474; + +#[derive(Embed, Clone, serde::Deserialize, Debug, PartialEq)] +struct FakeDefinition { + id: String, + #[embed] + definition: String, +} + +#[tokio::test] +async fn vector_search_test() { + let mount = Mount::volume_mount("data", "./data"); + // Setup a local MongoDB Atlas container for testing. NOTE: docker service must be running. + let container = GenericImage::new("neo4j", "latest") + .with_wait_for(WaitFor::Duration { + length: std::time::Duration::from_secs(5), + }) + .with_exposed_port(BOLT_PORT.tcp()) + .with_exposed_port(HTTP_PORT.tcp()) + .with_mount(mount) + .with_env_var("NEO4J_AUTH", "none") + .start() + .await + .expect("Failed to start MongoDB Atlas container"); + + let port = container.get_host_port_ipv4(BOLT_PORT).await.unwrap(); + + let config = ConfigBuilder::default() + .uri(format!("neo4j://localhost:{port}")) + .build() + .unwrap(); + + let neo4j_client = Neo4jClient { + graph: Graph::connect(config).await.unwrap(), + }; + + // Initialize OpenAI client + let openai_client = openai::Client::from_env(); + + // Select the embedding model and generate our embeddings + let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); + + let embeddings = EmbeddingsBuilder::new(model.clone()) + .document(WordDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + })? + .document(WordDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + })? + .document(WordDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + })? + .build() + .await?; +} From c97b403722b6bdf6469cd9f00d4e5b69b8ad4eb8 Mon Sep 17 00:00:00 2001 From: Garance Buricatu Date: Mon, 2 Dec 2024 15:35:19 -0500 Subject: [PATCH 02/11] test: add to integration test --- rig-neo4j/tests/integration_tests.rs | 131 +++++++++++++++++++++++---- 1 file changed, 114 insertions(+), 17 deletions(-) diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index c17e405d..15ae103e 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -5,22 +5,33 @@ use testcontainers::{ }; use neo4rs::{ConfigBuilder, Graph}; +use rig::{embeddings::{EmbeddingsBuilder, Embedding}, Embed, providers::openai, OneOrMany}; +use rig::vector_store::VectorStoreIndex; use rig_neo4j::{ vector_index::{IndexConfig, SearchParams}, Neo4jClient, ToBoltType, }; -use rig::embeddings::EmbeddingsBuilder; +use futures::StreamExt; const BOLT_PORT: u16 = 7687; const HTTP_PORT: u16 = 7474; -#[derive(Embed, Clone, serde::Deserialize, Debug, PartialEq)] +#[derive(Embed, Clone, serde::Deserialize, Debug)] struct FakeDefinition { id: String, #[embed] definition: String, } +#[derive(serde::Deserialize)] +struct Document { + #[allow(dead_code)] + id: String, + document: String, + #[allow(dead_code)] + embedding: Vec, +} + #[tokio::test] async fn vector_search_test() { let mount = Mount::volume_mount("data", "./data"); @@ -54,19 +65,105 @@ async fn vector_search_test() { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(openai::TEXT_EMBEDDING_ADA_002); - let embeddings = EmbeddingsBuilder::new(model.clone()) - .document(WordDefinition { - id: "doc0".to_string(), - definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), - })? - .document(WordDefinition { - id: "doc1".to_string(), - definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), - })? - .document(WordDefinition { - id: "doc2".to_string(), - definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), - })? - .build() - .await?; + let embeddings = create_embeddings(model.clone()).await; + + let create_nodes = futures::stream::iter(embeddings) + .map(|(doc, embeddings)| { + neo4j_client.graph.run( + neo4rs::query( + " + CREATE + (document:DocumentEmbeddings { + id: $id, + document: $document, + embedding: $embedding}) + RETURN document", + ) + .param("id", doc.id) + // Here we use the first embedding but we could use any of them. + // Neo4j only takes primitive types or arrays as properties. + .param("embedding", embeddings.first().vec.clone()) + .param("document", doc.definition.to_bolt_type()), + ) + }) + .buffer_unordered(3) + .collect::>() + .await; + + // Unwrap the results in the vector _create_nodes + for result in create_nodes { + result.unwrap(); // or handle the error appropriately + } + + // Create a vector index on our vector store + println!("Creating vector index..."); + neo4j_client + .graph + .run(neo4rs::query( + "CREATE VECTOR INDEX vector_index IF NOT EXISTS + FOR (m:DocumentEmbeddings) + ON m.embedding + OPTIONS { indexConfig: { + `vector.dimensions`: 1536, + `vector.similarity_function`: 'cosine' + }}", + )) + .await.unwrap(); + + // ℹ️ The index name must be unique among both indexes and constraints. + // A newly created index is not immediately available but is created in the background. + + // Check if the index exists with db.awaitIndex(), the call timeouts if the index is not ready + let index_exists = neo4j_client + .graph + .run(neo4rs::query("CALL db.awaitIndex('vector_index')")) + .await; + if index_exists.is_err() { + println!("Index not ready, waiting for index..."); + std::thread::sleep(std::time::Duration::from_secs(5)); + } + + println!("Index exists: {:?}", index_exists); + + // Create a vector index on our vector store + // IMPORTANT: Reuse the same model that was used to generate the embeddings + let index = neo4j_client.index( + model, + IndexConfig::new("vector_index"), + SearchParams::default(), + ); + + // Query the index + let results = index + .top_n::("What is a glarb?", 1) + .await.unwrap(); + + println!("Results: {:?}", results); + + assert!(false) } + + +async fn create_embeddings(model: openai::EmbeddingModel) -> Vec<(FakeDefinition, OneOrMany)> { + let fake_definitions = vec![ + FakeDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + }, + FakeDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + }, + FakeDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + } + ]; + + EmbeddingsBuilder::new(model) + .documents(fake_definitions) + .unwrap() + .build() + .await + .unwrap() +} \ No newline at end of file From e6577a5573da754f69ab4787ee4015b13c93b484 Mon Sep 17 00:00:00 2001 From: Garance Buricatu Date: Mon, 2 Dec 2024 15:51:03 -0500 Subject: [PATCH 03/11] fix(test): fix bugs while testing --- rig-neo4j/tests/integration_tests.rs | 37 ++++++++++++++++------------ 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index 15ae103e..b14e95c3 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -4,14 +4,18 @@ use testcontainers::{ GenericImage, ImageExt, }; +use futures::StreamExt; use neo4rs::{ConfigBuilder, Graph}; -use rig::{embeddings::{EmbeddingsBuilder, Embedding}, Embed, providers::openai, OneOrMany}; use rig::vector_store::VectorStoreIndex; +use rig::{ + embeddings::{Embedding, EmbeddingsBuilder}, + providers::openai, + Embed, OneOrMany, +}; use rig_neo4j::{ vector_index::{IndexConfig, SearchParams}, Neo4jClient, ToBoltType, }; -use futures::StreamExt; const BOLT_PORT: u16 = 7687; const HTTP_PORT: u16 = 7474; @@ -34,7 +38,10 @@ struct Document { #[tokio::test] async fn vector_search_test() { - let mount = Mount::volume_mount("data", "./data"); + let mount = Mount::volume_mount( + "data", + "/home/garance/Documents/playgrounds_repos/rig/rig-neo4j/data", + ); // Setup a local MongoDB Atlas container for testing. NOTE: docker service must be running. let container = GenericImage::new("neo4j", "latest") .with_wait_for(WaitFor::Duration { @@ -49,15 +56,10 @@ async fn vector_search_test() { .expect("Failed to start MongoDB Atlas container"); let port = container.get_host_port_ipv4(BOLT_PORT).await.unwrap(); + let host = container.get_host().await.unwrap().to_string(); - let config = ConfigBuilder::default() - .uri(format!("neo4j://localhost:{port}")) - .build() - .unwrap(); - - let neo4j_client = Neo4jClient { - graph: Graph::connect(config).await.unwrap(), - }; + let neo4j_client = + Neo4jClient::connect(&format!("neo4j://{host}:{port}"), "", "").await.unwrap(); // Initialize OpenAI client let openai_client = openai::Client::from_env(); @@ -108,7 +110,8 @@ async fn vector_search_test() { `vector.similarity_function`: 'cosine' }}", )) - .await.unwrap(); + .await + .unwrap(); // ℹ️ The index name must be unique among both indexes and constraints. // A newly created index is not immediately available but is created in the background. @@ -136,15 +139,17 @@ async fn vector_search_test() { // Query the index let results = index .top_n::("What is a glarb?", 1) - .await.unwrap(); + .await + .unwrap(); println!("Results: {:?}", results); assert!(false) } - -async fn create_embeddings(model: openai::EmbeddingModel) -> Vec<(FakeDefinition, OneOrMany)> { +async fn create_embeddings( + model: openai::EmbeddingModel, +) -> Vec<(FakeDefinition, OneOrMany)> { let fake_definitions = vec![ FakeDefinition { id: "doc0".to_string(), @@ -166,4 +171,4 @@ async fn create_embeddings(model: openai::EmbeddingModel) -> Vec<(FakeDefinition .build() .await .unwrap() -} \ No newline at end of file +} From f526551adfbddb11c5eebedf66148895bb536c67 Mon Sep 17 00:00:00 2001 From: Garance Buricatu Date: Mon, 2 Dec 2024 16:06:00 -0500 Subject: [PATCH 04/11] fix: fix bugs after merging with main --- rig-neo4j/tests/integration_tests.rs | 42 ++++++++++++---------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index b14e95c3..ccd04ab3 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -5,17 +5,13 @@ use testcontainers::{ }; use futures::StreamExt; -use neo4rs::{ConfigBuilder, Graph}; use rig::vector_store::VectorStoreIndex; use rig::{ embeddings::{Embedding, EmbeddingsBuilder}, providers::openai, Embed, OneOrMany, }; -use rig_neo4j::{ - vector_index::{IndexConfig, SearchParams}, - Neo4jClient, ToBoltType, -}; +use rig_neo4j::{vector_index::SearchParams, Neo4jClient, ToBoltType}; const BOLT_PORT: u16 = 7687; const HTTP_PORT: u16 = 7474; @@ -27,20 +23,11 @@ struct FakeDefinition { definition: String, } -#[derive(serde::Deserialize)] -struct Document { - #[allow(dead_code)] - id: String, - document: String, - #[allow(dead_code)] - embedding: Vec, -} - #[tokio::test] async fn vector_search_test() { let mount = Mount::volume_mount( "data", - "/home/garance/Documents/playgrounds_repos/rig/rig-neo4j/data", + std::env::var("GITHUB_WORKSPACE").unwrap(), ); // Setup a local MongoDB Atlas container for testing. NOTE: docker service must be running. let container = GenericImage::new("neo4j", "latest") @@ -58,8 +45,9 @@ async fn vector_search_test() { let port = container.get_host_port_ipv4(BOLT_PORT).await.unwrap(); let host = container.get_host().await.unwrap().to_string(); - let neo4j_client = - Neo4jClient::connect(&format!("neo4j://{host}:{port}"), "", "").await.unwrap(); + let neo4j_client = Neo4jClient::connect(&format!("neo4j://{host}:{port}"), "", "") + .await + .unwrap(); // Initialize OpenAI client let openai_client = openai::Client::from_env(); @@ -130,11 +118,10 @@ async fn vector_search_test() { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = neo4j_client.index( - model, - IndexConfig::new("vector_index"), - SearchParams::default(), - ); + let index = neo4j_client + .get_index(model, "vector_index", SearchParams::default()) + .await + .unwrap(); // Query the index let results = index @@ -142,9 +129,16 @@ async fn vector_search_test() { .await .unwrap(); - println!("Results: {:?}", results); + let (_, _, value) = &results.first().unwrap(); - assert!(false) + assert_eq!( + value, + &serde_json::json!({ + "id": "doc1", + "document": "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.", + "embedding": serde_json::Value::Null + }) + ) } async fn create_embeddings( From f13593b9cc426e1bba1f4e0955b89bef95fdef35 Mon Sep 17 00:00:00 2001 From: Garance Buricatu Date: Mon, 2 Dec 2024 16:08:32 -0500 Subject: [PATCH 05/11] cargo fmt --- rig-neo4j/tests/integration_tests.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index ccd04ab3..ed9ff365 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -25,10 +25,7 @@ struct FakeDefinition { #[tokio::test] async fn vector_search_test() { - let mount = Mount::volume_mount( - "data", - std::env::var("GITHUB_WORKSPACE").unwrap(), - ); + let mount = Mount::volume_mount("data", std::env::var("GITHUB_WORKSPACE").unwrap()); // Setup a local MongoDB Atlas container for testing. NOTE: docker service must be running. let container = GenericImage::new("neo4j", "latest") .with_wait_for(WaitFor::Duration { From 5f8a555f934cc3f48363729c7690be8f8149ed82 Mon Sep 17 00:00:00 2001 From: Garance Buricatu Date: Mon, 2 Dec 2024 16:15:18 -0500 Subject: [PATCH 06/11] fix: fix error in comment --- rig-neo4j/tests/integration_tests.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index ed9ff365..e21e35bc 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -26,7 +26,7 @@ struct FakeDefinition { #[tokio::test] async fn vector_search_test() { let mount = Mount::volume_mount("data", std::env::var("GITHUB_WORKSPACE").unwrap()); - // Setup a local MongoDB Atlas container for testing. NOTE: docker service must be running. + // Setup a local Neo 4J container for testing. NOTE: docker service must be running. let container = GenericImage::new("neo4j", "latest") .with_wait_for(WaitFor::Duration { length: std::time::Duration::from_secs(5), @@ -37,7 +37,7 @@ async fn vector_search_test() { .with_env_var("NEO4J_AUTH", "none") .start() .await - .expect("Failed to start MongoDB Atlas container"); + .expect("Failed to start Neo 4J container"); let port = container.get_host_port_ipv4(BOLT_PORT).await.unwrap(); let host = container.get_host().await.unwrap().to_string(); From b9da61fe4e4eddf7bd95014faceafbdfc54c63a8 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 3 Dec 2024 11:59:10 -0500 Subject: [PATCH 07/11] refactor: rename fakedefinition to definition --- rig-neo4j/examples/vector_search_simple.rs | 8 ++++---- rig-neo4j/tests/integration_tests.rs | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/rig-neo4j/examples/vector_search_simple.rs b/rig-neo4j/examples/vector_search_simple.rs index ecce7389..100abd5b 100644 --- a/rig-neo4j/examples/vector_search_simple.rs +++ b/rig-neo4j/examples/vector_search_simple.rs @@ -18,7 +18,7 @@ use rig::{ use rig_neo4j::{vector_index::SearchParams, Neo4jClient, ToBoltType}; #[derive(Embed, Clone, Debug)] -pub struct WordDefinition { +pub struct Definition { pub id: String, #[embed] pub definition: String, @@ -41,15 +41,15 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .document(WordDefinition { + .document(Definition { id: "doc0".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), })? - .document(WordDefinition { + .document(Definition { id: "doc1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), })? - .document(WordDefinition { + .document(Definition { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), })? diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index e21e35bc..64877926 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -17,7 +17,7 @@ const BOLT_PORT: u16 = 7687; const HTTP_PORT: u16 = 7474; #[derive(Embed, Clone, serde::Deserialize, Debug)] -struct FakeDefinition { +struct Definition { id: String, #[embed] definition: String, @@ -140,17 +140,17 @@ async fn vector_search_test() { async fn create_embeddings( model: openai::EmbeddingModel, -) -> Vec<(FakeDefinition, OneOrMany)> { +) -> Vec<(Definition, OneOrMany)> { let fake_definitions = vec![ - FakeDefinition { + Definition { id: "doc0".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), }, - FakeDefinition { + Definition { id: "doc1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), }, - FakeDefinition { + Definition { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), } From 6adf5e75766449e803160c09ad8b2274bf1475cd Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 3 Dec 2024 12:15:54 -0500 Subject: [PATCH 08/11] fix: make PR requested change --- rig-neo4j/examples/vector_search_simple.rs | 9 ++------- rig-neo4j/src/lib.rs | 4 ++-- rig-neo4j/tests/integration_tests.rs | 9 ++------- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/rig-neo4j/examples/vector_search_simple.rs b/rig-neo4j/examples/vector_search_simple.rs index 100abd5b..6287666b 100644 --- a/rig-neo4j/examples/vector_search_simple.rs +++ b/rig-neo4j/examples/vector_search_simple.rs @@ -8,7 +8,7 @@ //! 5. Returns the results use std::env; -use futures::StreamExt; +use futures::{StreamExt, TryStreamExt}; use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, @@ -76,14 +76,9 @@ async fn main() -> Result<(), anyhow::Error> { ) }) .buffer_unordered(3) - .collect::>() + .try_collect::>() .await; - // Unwrap the results in the vector _create_nodes - for result in create_nodes { - result.unwrap(); // or handle the error appropriately - } - // Create a vector index on our vector store println!("Creating vector index..."); neo4j_client diff --git a/rig-neo4j/src/lib.rs b/rig-neo4j/src/lib.rs index 830e329b..b13650bc 100644 --- a/rig-neo4j/src/lib.rs +++ b/rig-neo4j/src/lib.rs @@ -149,14 +149,14 @@ where } impl Neo4jClient { - const GET_INDEX_QUERY: &str = " + const GET_INDEX_QUERY: &'static str = " SHOW VECTOR INDEXES YIELD name, properties, options WHERE name=$index_name RETURN name, properties, options "; - const SHOW_INDEXES_QUERY: &str = "SHOW VECTOR INDEXES YIELD name RETURN name"; + const SHOW_INDEXES_QUERY: &'static str = "SHOW VECTOR INDEXES YIELD name RETURN name"; pub fn new(graph: Graph) -> Self { Self { graph } diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index 64877926..385a958c 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -4,7 +4,7 @@ use testcontainers::{ GenericImage, ImageExt, }; -use futures::StreamExt; +use futures::{StreamExt, TryStreamExt}; use rig::vector_store::VectorStoreIndex; use rig::{ embeddings::{Embedding, EmbeddingsBuilder}, @@ -74,14 +74,9 @@ async fn vector_search_test() { ) }) .buffer_unordered(3) - .collect::>() + .try_collect::>() .await; - // Unwrap the results in the vector _create_nodes - for result in create_nodes { - result.unwrap(); // or handle the error appropriately - } - // Create a vector index on our vector store println!("Creating vector index..."); neo4j_client From f05b5de4ed439c589c3c32db1ad15c564b824e67 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 3 Dec 2024 12:21:59 -0500 Subject: [PATCH 09/11] fixL cargo clippy --- rig-neo4j/examples/vector_search_simple.rs | 5 +++-- rig-neo4j/src/lib.rs | 8 ++++---- rig-neo4j/tests/integration_tests.rs | 5 +++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/rig-neo4j/examples/vector_search_simple.rs b/rig-neo4j/examples/vector_search_simple.rs index 6287666b..9d886eca 100644 --- a/rig-neo4j/examples/vector_search_simple.rs +++ b/rig-neo4j/examples/vector_search_simple.rs @@ -56,7 +56,7 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - let create_nodes = futures::stream::iter(embeddings) + futures::stream::iter(embeddings) .map(|(doc, embeddings)| { neo4j_client.graph.run( neo4rs::query( @@ -77,7 +77,8 @@ async fn main() -> Result<(), anyhow::Error> { }) .buffer_unordered(3) .try_collect::>() - .await; + .await + .unwrap(); // Create a vector index on our vector store println!("Creating vector index..."); diff --git a/rig-neo4j/src/lib.rs b/rig-neo4j/src/lib.rs index b13650bc..09616938 100644 --- a/rig-neo4j/src/lib.rs +++ b/rig-neo4j/src/lib.rs @@ -70,11 +70,11 @@ //! .await //! .unwrap(); //! -//! let index = client.index( +//! let index = client.get_index( //! model, -//! IndexConfig::new("moviePlotsEmbedding"), -//! SearchParams::default(), -//! ); +//! "moviePlotsEmbedding", +//! SearchParams::default() +//! ).await.unwrap(); //! //! #[derive(Debug, Deserialize)] //! struct Movie { diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index 385a958c..38784abf 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -54,7 +54,7 @@ async fn vector_search_test() { let embeddings = create_embeddings(model.clone()).await; - let create_nodes = futures::stream::iter(embeddings) + futures::stream::iter(embeddings) .map(|(doc, embeddings)| { neo4j_client.graph.run( neo4rs::query( @@ -75,7 +75,8 @@ async fn vector_search_test() { }) .buffer_unordered(3) .try_collect::>() - .await; + .await + .unwrap(); // Create a vector index on our vector store println!("Creating vector index..."); From 97a3bf360d94fbeb83907a26fc4ea2ae5c6a49af Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 3 Dec 2024 12:29:39 -0500 Subject: [PATCH 10/11] fix: rename struct to Word --- rig-neo4j/examples/vector_search_simple.rs | 8 ++++---- rig-neo4j/tests/integration_tests.rs | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/rig-neo4j/examples/vector_search_simple.rs b/rig-neo4j/examples/vector_search_simple.rs index 9d886eca..f68096e6 100644 --- a/rig-neo4j/examples/vector_search_simple.rs +++ b/rig-neo4j/examples/vector_search_simple.rs @@ -18,7 +18,7 @@ use rig::{ use rig_neo4j::{vector_index::SearchParams, Neo4jClient, ToBoltType}; #[derive(Embed, Clone, Debug)] -pub struct Definition { +pub struct Word { pub id: String, #[embed] pub definition: String, @@ -41,15 +41,15 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .document(Definition { + .document(Word { id: "doc0".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), })? - .document(Definition { + .document(Word { id: "doc1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), })? - .document(Definition { + .document(Word { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), })? diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index 38784abf..878866cd 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -17,7 +17,7 @@ const BOLT_PORT: u16 = 7687; const HTTP_PORT: u16 = 7474; #[derive(Embed, Clone, serde::Deserialize, Debug)] -struct Definition { +struct Word { id: String, #[embed] definition: String, @@ -136,24 +136,24 @@ async fn vector_search_test() { async fn create_embeddings( model: openai::EmbeddingModel, -) -> Vec<(Definition, OneOrMany)> { - let fake_definitions = vec![ - Definition { +) -> Vec<(Word, OneOrMany)> { + let words = vec![ + Word { id: "doc0".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), }, - Definition { + Word { id: "doc1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), }, - Definition { + Word { id: "doc2".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), } ]; EmbeddingsBuilder::new(model) - .documents(fake_definitions) + .documents(words) .unwrap() .build() .await From ef95d87d6f521944144d4ec31c15ecb2dac2b165 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 3 Dec 2024 12:31:05 -0500 Subject: [PATCH 11/11] fix: rename struct to Word --- rig-neo4j/tests/integration_tests.rs | 4 +--- rig-qdrant/examples/qdrant_vector_search.rs | 10 +++++----- rig-qdrant/tests/integration_tests.rs | 12 ++++++------ 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/rig-neo4j/tests/integration_tests.rs b/rig-neo4j/tests/integration_tests.rs index 878866cd..3db19f15 100644 --- a/rig-neo4j/tests/integration_tests.rs +++ b/rig-neo4j/tests/integration_tests.rs @@ -134,9 +134,7 @@ async fn vector_search_test() { ) } -async fn create_embeddings( - model: openai::EmbeddingModel, -) -> Vec<(Word, OneOrMany)> { +async fn create_embeddings(model: openai::EmbeddingModel) -> Vec<(Word, OneOrMany)> { let words = vec![ Word { id: "doc0".to_string(), diff --git a/rig-qdrant/examples/qdrant_vector_search.rs b/rig-qdrant/examples/qdrant_vector_search.rs index 3ea0e99d..7ce9679e 100644 --- a/rig-qdrant/examples/qdrant_vector_search.rs +++ b/rig-qdrant/examples/qdrant_vector_search.rs @@ -24,7 +24,7 @@ use rig::{ use rig_qdrant::QdrantVectorStore; #[derive(Embed, serde::Deserialize, serde::Serialize, Debug)] -struct Definition { +struct Word { id: String, #[embed] definition: String, @@ -56,15 +56,15 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let documents = EmbeddingsBuilder::new(model.clone()) - .document(Definition { + .document(Word { id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), })? - .document(Definition { + .document(Word { id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), })? - .document(Definition { + .document(Word { id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), })? @@ -91,7 +91,7 @@ async fn main() -> Result<(), anyhow::Error> { let vector_store = QdrantVectorStore::new(client, model, query_params.build()); let results = vector_store - .top_n::("What is a linglingdong?", 1) + .top_n::("What is a linglingdong?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-qdrant/tests/integration_tests.rs b/rig-qdrant/tests/integration_tests.rs index 4de1f4f4..f967076d 100644 --- a/rig-qdrant/tests/integration_tests.rs +++ b/rig-qdrant/tests/integration_tests.rs @@ -21,7 +21,7 @@ const QDRANT_PORT_SECONDARY: u16 = 6334; const COLLECTION_NAME: &str = "rig-collection"; #[derive(Embed, Clone, serde::Deserialize, serde::Serialize, Debug)] -struct Definition { +struct Word { id: String, #[embed] definition: String, @@ -95,23 +95,23 @@ async fn vector_search_test() { } async fn create_points(model: openai::EmbeddingModel) -> Vec { - let fake_definitions = vec![ - Definition { + let words = vec![ + Word { id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(), definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), }, - Definition { + Word { id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(), definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), }, - Definition { + Word { id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(), definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), } ]; let documents = EmbeddingsBuilder::new(model) - .documents(fake_definitions) + .documents(words) .unwrap() .build() .await