From ada7f19693fc555f4d37acc996ddbb3e9b53825b Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 4 Dec 2024 16:37:42 -0500 Subject: [PATCH 1/7] feat: add embedding distance calculator module --- Cargo.lock | 300 +++++++++---------- rig-core/Cargo.toml | 18 +- rig-core/src/embeddings/distance.rs | 143 +++++++++ rig-core/src/embeddings/embedding.rs | 15 - rig-core/src/embeddings/mod.rs | 3 + rig-core/src/vector_store/in_memory_store.rs | 4 +- rig-core/src/vector_store/mod.rs | 1 + 7 files changed, 313 insertions(+), 171 deletions(-) create mode 100644 rig-core/src/embeddings/distance.rs diff --git a/Cargo.lock b/Cargo.lock index 2a1539d5..62e304f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 4 +version = 3 [[package]] name = "addr2line" @@ -42,9 +42,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.21" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" [[package]] name = "android-tzdata" @@ -224,7 +224,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.7.0", + "indexmap 2.6.0", "lexical-core", "num", "serde", @@ -361,7 +361,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -383,7 +383,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -394,7 +394,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -915,7 +915,7 @@ dependencies = [ "hyperlocal", "log", "pin-project-lite", - "rustls 0.23.19", + "rustls 0.23.18", "rustls-native-certs 0.7.3", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -953,7 +953,7 @@ dependencies = [ "base64 0.13.1", "bitvec", "hex", - "indexmap 2.7.0", + "indexmap 2.6.0", "js-sys", "once_cell", "rand", @@ -1000,9 +1000,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.9.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" dependencies = [ "serde", ] @@ -1028,9 +1028,9 @@ dependencies = [ [[package]] name = "cargo-platform" -version = "0.1.9" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e35af189006b9c0f00a064685c727031e3ed2d8020f7ba284d78cc2671bd36ea" +checksum = "24b1f0365a6c6bb4020cd05806fd0d33c44d38046b8bd7f0e40814b9763cabfc" dependencies = [ "serde", ] @@ -1050,9 +1050,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.2" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f34d93e62b03caf570cccc334cbc6c2fceca82f39211051345108adcba3eebdc" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", @@ -1327,7 +1327,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -1338,7 +1338,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -1391,7 +1391,7 @@ dependencies = [ "glob", "half", "hashbrown 0.14.5", - "indexmap 2.7.0", + "indexmap 2.6.0", "itertools 0.12.1", "log", "num_cpus", @@ -1551,7 +1551,7 @@ dependencies = [ "datafusion-expr", "datafusion-physical-expr", "hashbrown 0.14.5", - "indexmap 2.7.0", + "indexmap 2.6.0", "itertools 0.12.1", "log", "paste", @@ -1580,7 +1580,7 @@ dependencies = [ "half", "hashbrown 0.14.5", "hex", - "indexmap 2.7.0", + "indexmap 2.6.0", "itertools 0.12.1", "log", "paste", @@ -1626,7 +1626,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.5", - "indexmap 2.7.0", + "indexmap 2.6.0", "itertools 0.12.1", "log", "once_cell", @@ -1742,7 +1742,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -1752,7 +1752,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -1765,7 +1765,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -1814,7 +1814,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -1870,7 +1870,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -1881,12 +1881,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.10" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.52.0", ] [[package]] @@ -2129,7 +2129,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -2233,7 +2233,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.7.0", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -2252,7 +2252,7 @@ dependencies = [ "futures-core", "futures-sink", "http 1.1.0", - "indexmap 2.7.0", + "indexmap 2.6.0", "slab", "tokio", "tokio-util", @@ -2288,9 +2288,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" dependencies = [ "allocator-api2", "equivalent", @@ -2501,7 +2501,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.5.8", + "socket2 0.5.7", "tokio", "tower-service", "tracing", @@ -2570,7 +2570,7 @@ dependencies = [ "http 1.1.0", "hyper 1.5.1", "hyper-util", - "rustls 0.23.19", + "rustls 0.23.18", "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", @@ -2618,7 +2618,7 @@ dependencies = [ "http-body 1.0.1", "hyper 1.5.1", "pin-project-lite", - "socket2 0.5.8", + "socket2 0.5.7", "tokio", "tower-service", "tracing", @@ -2786,7 +2786,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -2855,12 +2855,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown 0.15.1", "serde", ] @@ -2899,7 +2899,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" dependencies = [ - "socket2 0.5.8", + "socket2 0.5.7", "widestring", "windows-sys 0.48.0", "winreg", @@ -2931,9 +2931,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2" [[package]] name = "jobserver" @@ -2946,11 +2946,10 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.74" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a865e038f7f6ed956f788f0d7d60c541fff74c7bd74272c5d4cf15c63743e705" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ - "once_cell", "wasm-bindgen", ] @@ -3457,9 +3456,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.167" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libm" @@ -3538,7 +3537,7 @@ dependencies = [ "chrono", "encoding_rs", "flate2", - "indexmap 2.7.0", + "indexmap 2.6.0", "itoa", "log", "md-5", @@ -3555,7 +3554,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.1", ] [[package]] @@ -3664,10 +3663,11 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.3" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ + "hermit-abi", "libc", "wasi", "windows-sys 0.52.0", @@ -3742,7 +3742,7 @@ dependencies = [ "serde_with", "sha-1", "sha2", - "socket2 0.5.8", + "socket2 0.5.7", "stringprep", "strsim", "take_mut", @@ -3763,7 +3763,7 @@ checksum = "3a6dbc533e93429a71c44a14c04547ac783b56d3f22e6c4f12b1b994cf93844e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -3830,7 +3830,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a0d57c55d2d1dc62a2b1d16a0a1079eb78d67c36bdf468d582ab4482ec7002" dependencies = [ "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -4018,7 +4018,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -4126,7 +4126,7 @@ dependencies = [ "regex", "regex-syntax 0.8.5", "structmeta", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -4178,7 +4178,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.7.0", + "indexmap 2.6.0", ] [[package]] @@ -4236,7 +4236,7 @@ checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -4322,7 +4322,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -4371,7 +4371,7 @@ dependencies = [ "prost 0.12.6", "prost-types 0.12.6", "regex", - "syn 2.0.90", + "syn 2.0.89", "tempfile", ] @@ -4385,7 +4385,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -4398,7 +4398,7 @@ dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -4490,10 +4490,10 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.1.0", - "rustls 0.23.19", - "socket2 0.5.8", - "thiserror 2.0.4", + "rustc-hash 2.0.0", + "rustls 0.23.18", + "socket2 0.5.7", + "thiserror 2.0.3", "tokio", "tracing", ] @@ -4508,11 +4508,11 @@ dependencies = [ "getrandom", "rand", "ring", - "rustc-hash 2.1.0", - "rustls 0.23.19", + "rustc-hash 2.0.0", + "rustls 0.23.18", "rustls-pki-types", "slab", - "thiserror 2.0.4", + "thiserror 2.0.3", "tinyvec", "tracing", "web-time", @@ -4527,7 +4527,7 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.8", + "socket2 0.5.7", "tracing", "windows-sys 0.59.0", ] @@ -4766,7 +4766,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.19", + "rustls 0.23.18", "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -4813,6 +4813,7 @@ dependencies = [ "glob", "lopdf", "ordered-float", + "rayon", "reqwest 0.11.27", "rig-derive", "schemars", @@ -4832,7 +4833,7 @@ dependencies = [ "indoc", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -4910,7 +4911,7 @@ dependencies = [ "tokio-rusqlite", "tracing", "tracing-subscriber", - "zerocopy 0.8.11", + "zerocopy 0.8.12", ] [[package]] @@ -4930,9 +4931,9 @@ dependencies = [ [[package]] name = "roaring" -version = "0.10.7" +version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f81dc953b2244ddd5e7860cb0bb2a790494b898ef321d4aff8e260efab60cc88" +checksum = "8f4b84ba6e838ceb47b41de5194a60244fac43d9fe03b71dbe8c5a201081d6d1" dependencies = [ "bytemuck", "byteorder", @@ -4976,9 +4977,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustc-hash" -version = "2.1.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" [[package]] name = "rustc_version" @@ -5040,9 +5041,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.19" +version = "0.23.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "934b404430bb06b3fae2cba809eb45a1ab1aecd64491213d7c3301b88393f8d1" +checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" dependencies = [ "log", "once_cell", @@ -5198,7 +5199,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5288,7 +5289,7 @@ checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5299,7 +5300,7 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5308,7 +5309,7 @@ version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ - "indexmap 2.7.0", + "indexmap 2.6.0", "itoa", "memchr", "ryu", @@ -5323,7 +5324,7 @@ checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5348,7 +5349,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.7.0", + "indexmap 2.6.0", "serde", "serde_derive", "serde_json", @@ -5365,7 +5366,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5508,9 +5509,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" dependencies = [ "libc", "windows-sys 0.52.0", @@ -5549,7 +5550,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5602,7 +5603,7 @@ dependencies = [ "proc-macro2", "quote", "structmeta-derive", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5613,7 +5614,7 @@ checksum = "152a0b65a590ff6c3da95cabe2353ee04e6167c896b28e3b14478c2636c922fc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5635,7 +5636,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5657,9 +5658,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.90" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", @@ -5689,7 +5690,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5952,11 +5953,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.4" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f49a1853cf82743e3b7950f77e0f4d622ca36cf4317cba00c767838bac8d490" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" dependencies = [ - "thiserror-impl 2.0.4", + "thiserror-impl 2.0.3", ] [[package]] @@ -5967,18 +5968,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] name = "thiserror-impl" -version = "2.0.4" +version = "2.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8381894bb3efe0c4acac3ded651301ceee58a15d47c2e34885ed1908ad667061" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -5993,9 +5994,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.37" +version = "0.3.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" dependencies = [ "deranged", "itoa", @@ -6014,9 +6015,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.19" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" +checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" dependencies = [ "num-conv", "time-core", @@ -6058,9 +6059,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.42.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", @@ -6069,7 +6070,7 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2 0.5.8", + "socket2 0.5.7", "tokio-macros", "windows-sys 0.52.0", ] @@ -6082,7 +6083,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -6121,7 +6122,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.19", + "rustls 0.23.18", "rustls-pki-types", "tokio", ] @@ -6203,7 +6204,7 @@ dependencies = [ "prost 0.13.3", "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", - "socket2 0.5.8", + "socket2 0.5.7", "tokio", "tokio-rustls 0.26.0", "tokio-stream", @@ -6261,9 +6262,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ "pin-project-lite", "tracing-attributes", @@ -6272,20 +6273,20 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.28" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] name = "tracing-core" -version = "0.1.33" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" dependencies = [ "once_cell", "valuable", @@ -6304,9 +6305,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ "matchers", "nu-ansi-term", @@ -6525,9 +6526,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.97" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d15e63b4482863c109d70a7b8706c1e364eb6ea449b201a76c5b89cedcec2d5c" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -6536,37 +6537,36 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.97" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d36ef12e3aaca16ddd3f67922bc63e48e953f126de60bd33ccc0101ef9998cd" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.47" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dfaf8f50e5f293737ee323940c7d8b08a66a95a419223d9f41610ca08b0833d" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", - "once_cell", "wasm-bindgen", "web-sys", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.97" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705440e08b42d3e4b36de7d66c944be628d579796b8090bfa3471478a2260051" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6574,22 +6574,22 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.97" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98c9ae5a76e46f4deecd0f0255cc223cfa18dc9b261213b8aa0c7b36f61b3f1d" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.97" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ee99da9c5ba11bd675621338ef6fa52296b76b83305e9b6e5c77d4c286d6d49" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" @@ -6606,9 +6606,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.74" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a98bc3c33f0fe7e59ad7cd041b89034fa82a7c2d4365ca538dda6cdaf513863c" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", @@ -6937,7 +6937,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", "synstructure", ] @@ -6953,11 +6953,11 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cce3b5629d87654b53a49002acc2ce64aa5aa7255f5c718374a37ac7fd98c218" +checksum = "e031087b26520ba76806365896f191416ce84873ed6c6910a9ab5fe0f98f8ed3" dependencies = [ - "zerocopy-derive 0.8.11", + "zerocopy-derive 0.8.12", ] [[package]] @@ -6968,18 +6968,18 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] name = "zerocopy-derive" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74a82c26c3986af2623ec9eb890ff4aa19c006e30a1133dc9bd1830ec1612e20" +checksum = "568244125ba0fc91ae949b97f2852f82cb1a65c3327bd68e6edadd29e67cca26" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] @@ -6999,7 +6999,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", "synstructure", ] @@ -7028,7 +7028,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.90", + "syn 2.0.89", ] [[package]] diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 978dd1dc..3a967c17 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -26,6 +26,7 @@ thiserror = "1.0.61" rig-derive = { version = "0.1.0", path = "./rig-core-derive", optional = true } glob = "0.3.1" lopdf = { version = "0.34.0", optional = true } +rayon = { version = "1.10.0", optional = true} [dev-dependencies] anyhow = "1.0.75" @@ -35,25 +36,34 @@ tracing-subscriber = "0.3.18" tokio-test = "0.4.4" [features] -all = ["derive", "pdf"] +all = ["derive", "pdf", "embedding-distance"] derive = ["dep:rig-derive"] pdf = ["dep:lopdf"] +embedding-distance = ["dep:rayon"] [[test]] name = "embed_macro" required-features = ["derive"] +[[example]] +name = "calculator_chatbot" +required-features = ["derive", "embedding-distance"] + +[[example]] +name = "rag_dynamic_tools" +required-features = ["derive", "embedding-distance"] + [[example]] name = "rag" -required-features = ["derive"] +required-features = ["derive", "embedding-distance"] [[example]] name = "vector_search" -required-features = ["derive"] +required-features = ["derive", "embedding-distance"] [[example]] name = "vector_search_cohere" -required-features = ["derive"] +required-features = ["derive", "embedding-distance"] [[example]] name = "gemini_embeddings" diff --git a/rig-core/src/embeddings/distance.rs b/rig-core/src/embeddings/distance.rs new file mode 100644 index 00000000..2390aec1 --- /dev/null +++ b/rig-core/src/embeddings/distance.rs @@ -0,0 +1,143 @@ +use crate::embeddings::Embedding; +use rayon::prelude::*; + +pub trait CalculateDistance { + /// Get dot product of two embedding vectors + fn dot_product(&self, other: &Self) -> f64; + + /// Get cosine similarity of two embedding vectors. + /// If `normalized` is true, the dot product is returned. + fn cosine_similarity(&self, other: &Self, normalized: bool) -> f64; + + /// Get angular distance of two embedding vectors. + fn angular_distance(&self, other: &Self, normalized: bool) -> f64; + + /// Get euclidean distance of two embedding vectors. + fn euclidean_distance(&self, other: &Self) -> f64; + + /// Get manhattan distance of two embedding vectors. + fn manhattan_distance(&self, other: &Self) -> f64; + + /// Get chebyshev distance of two embedding vectors. + fn chebyshev_distance(&self, other: &Self) -> f64; +} + +impl CalculateDistance for Embedding { + fn dot_product(&self, other: &Self) -> f64 { + self.vec + .par_iter() + .zip(other.vec.par_iter()) + .map(|(x, y)| x * y) + .sum() + } + + fn cosine_similarity(&self, other: &Self, normalized: bool) -> f64 { + let dot_product = self.dot_product(other); + + if normalized { + dot_product + } else { + let magnitude1: f64 = self.vec.par_iter().map(|x| x.powi(2)).sum::().sqrt(); + let magnitude2: f64 = other.vec.par_iter().map(|x| x.powi(2)).sum::().sqrt(); + + dot_product / (magnitude1 * magnitude2) + } + } + + fn angular_distance(&self, other: &Self, normalized: bool) -> f64 { + let cosine_sim = self.cosine_similarity(other, normalized); + cosine_sim.acos() / std::f64::consts::PI + } + + fn euclidean_distance(&self, other: &Self) -> f64 { + self.vec + .par_iter() + .zip(other.vec.par_iter()) + .map(|(x, y)| (x - y).powi(2)) + .sum::() + .sqrt() + } + + fn manhattan_distance(&self, other: &Self) -> f64 { + self.vec + .par_iter() + .zip(other.vec.par_iter()) + .map(|(x, y)| (x - y).abs()) + .sum() + } + + fn chebyshev_distance(&self, other: &Self) -> f64 { + self.vec + .iter() + .zip(other.vec.iter()) + .map(|(x, y)| (x - y).abs()) + .fold(0.0, f64::max) + } +} + +#[cfg(test)] +mod test { + use super::Embedding; + + fn embeddings() -> (Embedding, Embedding) { + let embedding_1 = Embedding { + document: "test".to_string(), + vec: vec![1.0, 2.0, 3.0], + }; + + let embedding_2 = Embedding { + document: "test".to_string(), + vec: vec![1.0, 5.0, 7.0], + }; + + (embedding_1, embedding_2) + } + + #[test] + fn test_dot_product() { + let (embedding_1, embedding_2) = embeddings(); + + assert_eq!(embedding_1.dot_product(&embedding_2), 32.0) + } + + #[test] + fn test_cosine_similarity() { + let (embedding_1, embedding_2) = embeddings(); + + assert_eq!( + embedding_1.cosine_similarity(&embedding_2, false), + 0.9875414397573881 + ) + } + + #[test] + fn test_angular_distance() { + let (embedding_1, embedding_2) = embeddings(); + + assert_eq!( + embedding_1.angular_distance(&embedding_2, false), + 0.0502980301830343 + ) + } + + #[test] + fn test_euclidean_distance() { + let (embedding_1, embedding_2) = embeddings(); + + assert_eq!(embedding_1.euclidean_distance(&embedding_2), 5.0) + } + + #[test] + fn test_manhattan_distance() { + let (embedding_1, embedding_2) = embeddings(); + + assert_eq!(embedding_1.manhattan_distance(&embedding_2), 7.0) + } + + #[test] + fn test_chebyshev_distance() { + let (embedding_1, embedding_2) = embeddings(); + + assert_eq!(embedding_1.chebyshev_distance(&embedding_2), 4.0) + } +} diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs index 7c8877d9..73cdfc2f 100644 --- a/rig-core/src/embeddings/embedding.rs +++ b/rig-core/src/embeddings/embedding.rs @@ -76,18 +76,3 @@ impl PartialEq for Embedding { } impl Eq for Embedding {} - -impl Embedding { - pub fn distance(&self, other: &Self) -> f64 { - let dot_product: f64 = self - .vec - .iter() - .zip(other.vec.iter()) - .map(|(x, y)| x * y) - .sum(); - - let product_of_lengths = (self.vec.len() * other.vec.len()) as f64; - - dot_product / product_of_lengths - } -} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index 1ae16436..a1fae0b3 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -8,6 +8,9 @@ pub mod embed; pub mod embedding; pub mod tool; +#[cfg(feature = "embedding-distance")] +pub mod distance; + pub use builder::EmbeddingsBuilder; pub use embed::{to_texts, Embed, EmbedError, TextEmbedder}; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index aae0a256..7c93757a 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use super::{VectorStoreError, VectorStoreIndex}; use crate::{ - embeddings::{Embedding, EmbeddingModel}, + embeddings::{distance::CalculateDistance, Embedding, EmbeddingModel}, OneOrMany, }; @@ -77,7 +77,7 @@ impl InMemoryVectorStore { .iter() .map(|embedding| { ( - OrderedFloat(embedding.distance(prompt_embedding)), + OrderedFloat(embedding.cosine_similarity(prompt_embedding, false)), &embedding.document, ) }) diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 3d6e8369..3d68430c 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -4,6 +4,7 @@ use serde_json::Value; use crate::embeddings::EmbeddingError; +#[cfg(feature = "embedding-distance")] pub mod in_memory_store; #[derive(Debug, thiserror::Error)] From 40122222240ef339c935ba74f5bd6b29edb49d73 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 4 Dec 2024 16:48:59 -0500 Subject: [PATCH 2/7] fix(tests): add missing dependency --- rig-core/src/embeddings/distance.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rig-core/src/embeddings/distance.rs b/rig-core/src/embeddings/distance.rs index 2390aec1..07600926 100644 --- a/rig-core/src/embeddings/distance.rs +++ b/rig-core/src/embeddings/distance.rs @@ -76,7 +76,8 @@ impl CalculateDistance for Embedding { } #[cfg(test)] -mod test { +mod tests { + use super::CalculateDistance; use super::Embedding; fn embeddings() -> (Embedding, Embedding) { From 88ba481ad9d25f621a81e353682130710ee8a006 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 4 Dec 2024 17:14:37 -0500 Subject: [PATCH 3/7] fix: failing tests --- rig-core/src/vector_store/in_memory_store.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 7c93757a..fda4126d 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -76,6 +76,7 @@ impl InMemoryVectorStore { if let Some((distance, embed_doc)) = embeddings .iter() .map(|embedding| { + println!("Embedding: {:?}: {:?}", embedding, embedding.cosine_similarity(prompt_embedding, false)); ( OrderedFloat(embedding.cosine_similarity(prompt_embedding, false)), &embedding.document, @@ -419,7 +420,7 @@ mod tests { }) .collect::>(), vec![( - 0.034444444444444444, + 0.9807965956109156, "doc1".to_string(), "glarb-garb".to_string() )] @@ -496,7 +497,7 @@ mod tests { }) .collect::>(), vec![( - 0.034444444444444444, + 0.9807965956109156, "doc1".to_string(), "glarb-garb".to_string() )] From 0e8f492b804c8846d6e11ef63c5c0423d470a3e5 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 4 Dec 2024 17:16:40 -0500 Subject: [PATCH 4/7] fix: cargo fmt --- rig-core/src/vector_store/in_memory_store.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index fda4126d..98b3436d 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -76,7 +76,11 @@ impl InMemoryVectorStore { if let Some((distance, embed_doc)) = embeddings .iter() .map(|embedding| { - println!("Embedding: {:?}: {:?}", embedding, embedding.cosine_similarity(prompt_embedding, false)); + println!( + "Embedding: {:?}: {:?}", + embedding, + embedding.cosine_similarity(prompt_embedding, false) + ); ( OrderedFloat(embedding.cosine_similarity(prompt_embedding, false)), &embedding.document, From 3a2a0a3a27eb5868af7d164f388727487b8d1334 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 4 Dec 2024 17:57:31 -0500 Subject: [PATCH 5/7] feat: rename feature flag, remove from vector store, all non feature implementation --- rig-core/Cargo.toml | 14 ++-- rig-core/src/embeddings/distance.rs | 87 ++++++++++++++++---- rig-core/src/embeddings/mod.rs | 2 - rig-core/src/vector_store/in_memory_store.rs | 7 +- rig-core/src/vector_store/mod.rs | 1 - 5 files changed, 80 insertions(+), 31 deletions(-) diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 3a967c17..0f8cb689 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -36,10 +36,10 @@ tracing-subscriber = "0.3.18" tokio-test = "0.4.4" [features] -all = ["derive", "pdf", "embedding-distance"] +all = ["derive", "pdf", "rayon"] derive = ["dep:rig-derive"] pdf = ["dep:lopdf"] -embedding-distance = ["dep:rayon"] +rayon = ["dep:rayon"] [[test]] name = "embed_macro" @@ -47,23 +47,23 @@ required-features = ["derive"] [[example]] name = "calculator_chatbot" -required-features = ["derive", "embedding-distance"] +required-features = ["derive"] [[example]] name = "rag_dynamic_tools" -required-features = ["derive", "embedding-distance"] +required-features = ["derive"] [[example]] name = "rag" -required-features = ["derive", "embedding-distance"] +required-features = ["derive", "rayon"] [[example]] name = "vector_search" -required-features = ["derive", "embedding-distance"] +required-features = ["derive", "rayon"] [[example]] name = "vector_search_cohere" -required-features = ["derive", "embedding-distance"] +required-features = ["derive", "rayon"] [[example]] name = "gemini_embeddings" diff --git a/rig-core/src/embeddings/distance.rs b/rig-core/src/embeddings/distance.rs index 07600926..a2bc9060 100644 --- a/rig-core/src/embeddings/distance.rs +++ b/rig-core/src/embeddings/distance.rs @@ -1,7 +1,4 @@ -use crate::embeddings::Embedding; -use rayon::prelude::*; - -pub trait CalculateDistance { +pub trait VectorDistance { /// Get dot product of two embedding vectors fn dot_product(&self, other: &Self) -> f64; @@ -22,11 +19,12 @@ pub trait CalculateDistance { fn chebyshev_distance(&self, other: &Self) -> f64; } -impl CalculateDistance for Embedding { +#[cfg(not(feature = "rayon"))] +impl VectorDistance for crate::embeddings::Embedding { fn dot_product(&self, other: &Self) -> f64 { self.vec - .par_iter() - .zip(other.vec.par_iter()) + .iter() + .zip(other.vec.iter()) .map(|(x, y)| x * y) .sum() } @@ -37,8 +35,8 @@ impl CalculateDistance for Embedding { if normalized { dot_product } else { - let magnitude1: f64 = self.vec.par_iter().map(|x| x.powi(2)).sum::().sqrt(); - let magnitude2: f64 = other.vec.par_iter().map(|x| x.powi(2)).sum::().sqrt(); + let magnitude1: f64 = self.vec.iter().map(|x| x.powi(2)).sum::().sqrt(); + let magnitude2: f64 = other.vec.iter().map(|x| x.powi(2)).sum::().sqrt(); dot_product / (magnitude1 * magnitude2) } @@ -51,8 +49,8 @@ impl CalculateDistance for Embedding { fn euclidean_distance(&self, other: &Self) -> f64 { self.vec - .par_iter() - .zip(other.vec.par_iter()) + .iter() + .zip(other.vec.iter()) .map(|(x, y)| (x - y).powi(2)) .sum::() .sqrt() @@ -60,8 +58,8 @@ impl CalculateDistance for Embedding { fn manhattan_distance(&self, other: &Self) -> f64 { self.vec - .par_iter() - .zip(other.vec.par_iter()) + .iter() + .zip(other.vec.iter()) .map(|(x, y)| (x - y).abs()) .sum() } @@ -75,10 +73,69 @@ impl CalculateDistance for Embedding { } } +#[cfg(feature = "rayon")] +mod rayon { + use crate::embeddings::{distance::VectorDistance, Embedding}; + use rayon::prelude::*; + + impl VectorDistance for Embedding { + fn dot_product(&self, other: &Self) -> f64 { + self.vec + .par_iter() + .zip(other.vec.par_iter()) + .map(|(x, y)| x * y) + .sum() + } + + fn cosine_similarity(&self, other: &Self, normalized: bool) -> f64 { + let dot_product = self.dot_product(other); + + if normalized { + dot_product + } else { + let magnitude1: f64 = self.vec.par_iter().map(|x| x.powi(2)).sum::().sqrt(); + let magnitude2: f64 = other.vec.par_iter().map(|x| x.powi(2)).sum::().sqrt(); + + dot_product / (magnitude1 * magnitude2) + } + } + + fn angular_distance(&self, other: &Self, normalized: bool) -> f64 { + let cosine_sim = self.cosine_similarity(other, normalized); + cosine_sim.acos() / std::f64::consts::PI + } + + fn euclidean_distance(&self, other: &Self) -> f64 { + self.vec + .par_iter() + .zip(other.vec.par_iter()) + .map(|(x, y)| (x - y).powi(2)) + .sum::() + .sqrt() + } + + fn manhattan_distance(&self, other: &Self) -> f64 { + self.vec + .par_iter() + .zip(other.vec.par_iter()) + .map(|(x, y)| (x - y).abs()) + .sum() + } + + fn chebyshev_distance(&self, other: &Self) -> f64 { + self.vec + .iter() + .zip(other.vec.iter()) + .map(|(x, y)| (x - y).abs()) + .fold(0.0, f64::max) + } + } +} + #[cfg(test)] mod tests { - use super::CalculateDistance; - use super::Embedding; + use super::VectorDistance; + use crate::embeddings::Embedding; fn embeddings() -> (Embedding, Embedding) { let embedding_1 = Embedding { diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs index a1fae0b3..696634b7 100644 --- a/rig-core/src/embeddings/mod.rs +++ b/rig-core/src/embeddings/mod.rs @@ -8,9 +8,7 @@ pub mod embed; pub mod embedding; pub mod tool; -#[cfg(feature = "embedding-distance")] pub mod distance; - pub use builder::EmbeddingsBuilder; pub use embed::{to_texts, Embed, EmbedError, TextEmbedder}; pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 98b3436d..14080f0f 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use super::{VectorStoreError, VectorStoreIndex}; use crate::{ - embeddings::{distance::CalculateDistance, Embedding, EmbeddingModel}, + embeddings::{distance::VectorDistance, Embedding, EmbeddingModel}, OneOrMany, }; @@ -76,11 +76,6 @@ impl InMemoryVectorStore { if let Some((distance, embed_doc)) = embeddings .iter() .map(|embedding| { - println!( - "Embedding: {:?}: {:?}", - embedding, - embedding.cosine_similarity(prompt_embedding, false) - ); ( OrderedFloat(embedding.cosine_similarity(prompt_embedding, false)), &embedding.document, diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 3d68430c..3d6e8369 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -4,7 +4,6 @@ use serde_json::Value; use crate::embeddings::EmbeddingError; -#[cfg(feature = "embedding-distance")] pub mod in_memory_store; #[derive(Debug, thiserror::Error)] From c6e3f781300e1b844eb859b6902aef55360c2c63 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 4 Dec 2024 18:00:16 -0500 Subject: [PATCH 6/7] fix: remove requires rayon feature on examples --- rig-core/Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 0f8cb689..696bfebe 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -55,15 +55,15 @@ required-features = ["derive"] [[example]] name = "rag" -required-features = ["derive", "rayon"] +required-features = ["derive"] [[example]] name = "vector_search" -required-features = ["derive", "rayon"] +required-features = ["derive"] [[example]] name = "vector_search_cohere" -required-features = ["derive", "rayon"] +required-features = ["derive"] [[example]] name = "gemini_embeddings" From bce65f9fa61faaee966dca713188c110e18533fd Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 4 Dec 2024 18:03:46 -0500 Subject: [PATCH 7/7] fix: fix cargo.toml --- rig-core/Cargo.toml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 696bfebe..0af6b71d 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -45,14 +45,6 @@ rayon = ["dep:rayon"] name = "embed_macro" required-features = ["derive"] -[[example]] -name = "calculator_chatbot" -required-features = ["derive"] - -[[example]] -name = "rag_dynamic_tools" -required-features = ["derive"] - [[example]] name = "rag" required-features = ["derive"]