diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b65da281..3a58d0ed 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -4,7 +4,7 @@ name: Lint & Test on: pull_request: branches: - - main + - "**" workflow_call: env: diff --git a/Cargo.lock b/Cargo.lock index aefa6a32..2d778dde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 4 +version = 3 [[package]] name = "addr2line" @@ -42,9 +42,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.18" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" [[package]] name = "android-tzdata" @@ -63,15 +63,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.8" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anyhow" -version = "1.0.89" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" +checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" [[package]] name = "arc-swap" @@ -361,7 +361,7 @@ checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -383,7 +383,7 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -394,7 +394,7 @@ checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -426,9 +426,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.8" +version = "1.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7198e6f03240fdceba36656d8be440297b6b82270325908c7381f37d826a74f6" +checksum = "9b49afaa341e8dd8577e1a2200468f98956d6eda50bcf4a53246cc00174ba924" dependencies = [ "aws-credential-types", "aws-runtime", @@ -443,7 +443,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.1", + "fastrand 2.2.0", "hex", "http 0.2.12", "ring", @@ -481,7 +481,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.1", + "fastrand 2.2.0", "http 0.2.12", "http-body 0.4.6", "once_cell", @@ -493,9 +493,9 @@ dependencies = [ [[package]] name = "aws-sdk-dynamodb" -version = "1.49.0" +version = "1.54.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab0ade000608877169533a54326badd6b5a707d2faf876cfc3976a7f9d7e5329" +checksum = "8efdda6a491bb4640d35b99b0a4b93f75ce7d6e3a1937c3e902d3cb23d0a179c" dependencies = [ "aws-credential-types", "aws-runtime", @@ -507,7 +507,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "bytes", - "fastrand 2.1.1", + "fastrand 2.2.0", "http 0.2.12", "once_cell", "regex-lite", @@ -516,9 +516,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.45.0" +version = "1.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e33ae899566f3d395cbf42858e433930682cc9c1889fa89318896082fef45efb" +checksum = "09677244a9da92172c8dc60109b4a9658597d4d298b188dd0018b6a66b410ca4" dependencies = [ "aws-credential-types", "aws-runtime", @@ -538,9 +538,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.46.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f39c09e199ebd96b9f860b0fce4b6625f211e064ad7c8693b72ecf7ef03881e0" +checksum = "81fea2f3a8bb3bd10932ae7ad59cc59f65f270fc9183a7e91f501dc5efbef7ee" dependencies = [ "aws-credential-types", "aws-runtime", @@ -560,9 +560,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.45.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d95f93a98130389eb6233b9d615249e543f6c24a68ca1f109af9ca5164a8765" +checksum = "6ada54e5f26ac246dc79727def52f7f8ed38915cb47781e2a72213957dc3a7d5" dependencies = [ "aws-credential-types", "aws-runtime", @@ -583,9 +583,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc8db6904450bafe7473c6ca9123f88cc11089e41a025408f992db4e22d3be68" +checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -656,22 +656,22 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.2" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a065c0fe6fdbdf9f11817eb68582b2ab4aff9e9c39e986ae48f7ec576c6322db" +checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" dependencies = [ "aws-smithy-async", "aws-smithy-http", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", - "fastrand 2.1.1", + "fastrand 2.2.0", "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "http-body 1.0.1", "httparse", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-rustls 0.24.2", "once_cell", "pin-project-lite", @@ -683,9 +683,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.2" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96" +checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" dependencies = [ "aws-smithy-async", "aws-smithy-types", @@ -700,9 +700,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.7" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147100a7bea70fa20ef224a6bad700358305f5dc0f84649c53769761395b355b" +checksum = "4fbd94a32b3a7d55d3806fe27d98d3ad393050439dd05eb53ece36ec5e3d3510" dependencies = [ "base64-simd", "bytes", @@ -749,9 +749,9 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.7" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" dependencies = [ "async-trait", "axum-core", @@ -768,7 +768,7 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "tower 0.5.1", "tower-layer", "tower-service", @@ -789,7 +789,7 @@ dependencies = [ "mime", "pin-project-lite", "rustversion", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "tower-layer", "tower-service", ] @@ -916,9 +916,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +checksum = "1a68f1f47cdf0ec8ee4b941b2eee2a80cb796db73118c0dd09ac63fbe405be22" dependencies = [ "memchr", "serde", @@ -938,9 +938,9 @@ checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" [[package]] name = "bytemuck" -version = "1.18.0" +version = "1.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" +checksum = "8b37c88a63ffd85d15b406896cc343916d7cf57838a847b3a6f2ca5d39a5695a" [[package]] name = "byteorder" @@ -950,9 +950,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.7.2" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" dependencies = [ "serde", ] @@ -1000,9 +1000,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.28" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e80e3b6a3ab07840e1cae9b0666a63970dc28e8ed5ffbcdacbfc760c281bfc1" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", @@ -1021,6 +1021,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.38" @@ -1082,13 +1088,13 @@ dependencies = [ [[package]] name = "comfy-table" -version = "7.1.1" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" +checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" dependencies = [ "strum", "strum_macros", - "unicode-width", + "unicode-width 0.2.0", ] [[package]] @@ -1136,6 +1142,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1144,9 +1160,9 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "cpufeatures" -version = "0.2.14" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" +checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" dependencies = [ "libc", ] @@ -1221,9 +1237,9 @@ dependencies = [ [[package]] name = "csv" -version = "1.3.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +checksum = "acdc4883a9c96732e4733212c01447ebd805833b7275a73ca3ee080fd77afdaf" dependencies = [ "csv-core", "itoa", @@ -1285,7 +1301,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.11.1", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -1307,7 +1323,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core 0.20.10", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -1711,7 +1727,7 @@ dependencies = [ "darling 0.20.10", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -1721,7 +1737,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -1734,7 +1750,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version 0.4.1", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -1775,6 +1791,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "doc-comment" version = "0.3.3" @@ -1801,9 +1828,9 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "encoding_rs" -version = "0.8.34" +version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" dependencies = [ "cfg-if", ] @@ -1864,9 +1891,9 @@ dependencies = [ [[package]] name = "fastdivide" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59668941c55e5c186b8b58c391629af56774ec768f73c08bbcd56f09348eb00b" +checksum = "9afc2bd4d5a73106dd53d10d73d3401c2f32730ba2c0b93ddb888a8983680471" [[package]] name = "fastrand" @@ -1879,9 +1906,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.1.1" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" +checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" [[package]] name = "fixedbitset" @@ -1901,9 +1928,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", "miniz_oxide", @@ -1951,7 +1978,7 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7e180ac76c23b45e767bd7ae9579bc0bb458618c4bc71835926e098e61d15f8" dependencies = [ - "rustix 0.38.37", + "rustix 0.38.41", "windows-sys 0.52.0", ] @@ -2041,7 +2068,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -2154,9 +2181,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" dependencies = [ "atomic-waker", "bytes", @@ -2200,9 +2227,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" dependencies = [ "allocator-api2", "equivalent", @@ -2335,9 +2362,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.30" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a152ddd61dfaec7273fe8419ab357f33aee0d914c5f4efbf0d96fa749eea5ec9" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", @@ -2359,14 +2386,14 @@ dependencies = [ [[package]] name = "hyper" -version = "1.4.1" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -2386,7 +2413,7 @@ checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", "http 0.2.12", - "hyper 0.14.30", + "hyper 0.14.31", "log", "rustls 0.21.12", "rustls-native-certs 0.6.3", @@ -2402,15 +2429,15 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.4.1", + "hyper 1.5.1", "hyper-util", - "rustls 0.23.14", - "rustls-native-certs 0.8.0", + "rustls 0.23.18", + "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", "tower-service", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", ] [[package]] @@ -2419,7 +2446,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" dependencies = [ - "hyper 1.4.1", + "hyper 1.5.1", "hyper-util", "pin-project-lite", "tokio", @@ -2433,7 +2460,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper 0.14.30", + "hyper 0.14.31", "native-tls", "tokio", "tokio-native-tls", @@ -2450,7 +2477,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.4.1", + "hyper 1.5.1", "pin-project-lite", "socket2 0.5.7", "tokio", @@ -2490,6 +2517,124 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -2509,12 +2654,23 @@ dependencies = [ [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] @@ -2551,10 +2707,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.15.0", + "hashbrown 0.15.1", "serde", ] +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + [[package]] name = "instant" version = "0.1.13" @@ -2616,9 +2778,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2" [[package]] name = "jobserver" @@ -2631,9 +2793,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.71" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cb94a0ffd3f3ee755c20f7d8752f45cac88605a4dcf808abcff72873296ec7b" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] @@ -3141,15 +3303,15 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.159" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libredox" @@ -3179,6 +3341,12 @@ version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +[[package]] +name = "litemap" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" + [[package]] name = "lock_api" version = "0.4.12" @@ -3221,7 +3389,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.0", + "hashbrown 0.15.1", ] [[package]] @@ -3366,7 +3534,7 @@ dependencies = [ "skeptic", "smallvec", "tagptr", - "thiserror", + "thiserror 1.0.69", "triomphe", "uuid", ] @@ -3407,7 +3575,7 @@ dependencies = [ "stringprep", "strsim 0.10.0", "take_mut", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-rustls 0.24.1", "tokio-util", @@ -3442,7 +3610,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework", + "security-framework 2.11.1", "security-framework-sys", "tempfile", ] @@ -3468,11 +3636,11 @@ dependencies = [ "rustls-native-certs 0.7.3", "rustls-pemfile 2.2.0", "serde", - "thiserror", + "thiserror 1.0.69", "tokio", "tokio-rustls 0.26.0", "url", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", ] [[package]] @@ -3482,7 +3650,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a0d57c55d2d1dc62a2b1d16a0a1079eb78d67c36bdf468d582ab4482ec7002" dependencies = [ "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -3616,14 +3784,14 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.4.1", + "hyper 1.5.1", "itertools 0.13.0", "md-5", "parking_lot", "percent-encoding", "quick-xml", "rand", - "reqwest 0.12.8", + "reqwest 0.12.9", "ring", "rustls-pemfile 2.2.0", "serde", @@ -3649,9 +3817,9 @@ checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" [[package]] name = "openssl" -version = "0.10.66" +version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ "bitflags 2.6.0", "cfg-if", @@ -3670,7 +3838,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -3681,9 +3849,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" -version = "0.9.103" +version = "0.9.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" dependencies = [ "cc", "libc", @@ -3699,9 +3867,9 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "ordered-float" -version = "4.3.0" +version = "4.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537" +checksum = "c65ee1f9701bf938026630b455d5315f490640234259037edb259798b3bcf85e" dependencies = [ "num-traits", ] @@ -3848,29 +4016,29 @@ dependencies = [ [[package]] name = "pin-project" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf123a161dde1e524adf36f90bc5d8d3462824a9c43553ad07a8183161189ec" +checksum = "be57f64e946e500c8ee36ef6331845d40a93055567ec57e8fae13efd33759b95" dependencies = [ "pin-project-internal", ] [[package]] name = "pin-project-internal" -version = "1.1.6" +version = "1.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4502d8515ca9f32f1fb543d987f63d95a14934883db45bdb48060b6b69257f8" +checksum = "3c0f5fad0874fc7abcd4d750e76917eaebbecaa2c20bde22e1dbeeba8beb758c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -3944,19 +4112,19 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.22" +version = "0.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" +checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" dependencies = [ "proc-macro2", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] name = "proc-macro2" -version = "1.0.87" +version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] @@ -3998,7 +4166,7 @@ dependencies = [ "prost 0.12.6", "prost-types 0.12.6", "regex", - "syn 2.0.79", + "syn 2.0.89", "tempfile", ] @@ -4012,7 +4180,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -4025,7 +4193,7 @@ dependencies = [ "itertools 0.13.0", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -4068,10 +4236,10 @@ dependencies = [ "futures-util", "prost 0.13.3", "prost-types 0.13.3", - "reqwest 0.12.8", + "reqwest 0.12.9", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tonic", ] @@ -4109,45 +4277,49 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" dependencies = [ "bytes", "pin-project-lite", "quinn-proto", "quinn-udp", "rustc-hash 2.0.0", - "rustls 0.23.14", + "rustls 0.23.18", "socket2 0.5.7", - "thiserror", + "thiserror 2.0.3", "tokio", "tracing", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" dependencies = [ "bytes", + "getrandom", "rand", "ring", "rustc-hash 2.0.0", - "rustls 0.23.14", + "rustls 0.23.18", + "rustls-pki-types", "slab", - "thiserror", + "thiserror 2.0.3", "tinyvec", "tracing", + "web-time", ] [[package]] name = "quinn-udp" -version = "0.5.5" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fe68c2e9e1a1234e218683dbdf9f9dfcb094113c5ac2b938dfcb9bab4c4140b" +checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" dependencies = [ + "cfg_aliases", "libc", "once_cell", "socket2 0.5.7", @@ -4262,14 +4434,14 @@ checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", - "thiserror", + "thiserror 1.0.69", ] [[package]] name = "regex" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38200e5ee88914975b69f657f0801b6f6dccafd44fd9326302a4aaeecfacb1d8" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -4279,9 +4451,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -4314,7 +4486,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.30", + "hyper 0.14.31", "hyper-tls", "ipnet", "js-sys", @@ -4342,19 +4514,19 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.12.8" +version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f713147fbe92361e52392c73b8c9e48c04c6625bce969ef54dc901e58e042a7b" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ "base64 0.22.1", "bytes", "futures-core", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.1", "hyper-rustls 0.27.3", "hyper-util", "ipnet", @@ -4365,14 +4537,14 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.14", - "rustls-native-certs 0.8.0", + "rustls 0.23.18", + "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", - "sync_wrapper 1.0.1", + "sync_wrapper 1.0.2", "tokio", "tokio-rustls 0.26.0", "tokio-util", @@ -4382,7 +4554,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots 0.26.6", + "webpki-roots 0.26.7", "windows-registry", ] @@ -4413,15 +4585,27 @@ dependencies = [ "lopdf", "ordered-float", "reqwest 0.11.27", + "rig-derive", "schemars", "serde", "serde_json", - "thiserror", + "thiserror 1.0.69", "tokio", + "tokio-test", "tracing", "tracing-subscriber", ] +[[package]] +name = "rig-derive" +version = "0.1.0" +dependencies = [ + "indoc", + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "rig-lancedb" version = "0.1.2" @@ -4575,9 +4759,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.37" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8acb788b847c24f28525660c4d7758620a7210875711f79e7f663cc152726811" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -4600,9 +4784,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.14" +version = "0.23.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415d9944693cb90382053259f89fbb077ea730ad7273047ec63b19bc9b160ba8" +checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" dependencies = [ "log", "once_cell", @@ -4622,7 +4806,7 @@ dependencies = [ "openssl-probe", "rustls-pemfile 1.0.4", "schannel", - "security-framework", + "security-framework 2.11.1", ] [[package]] @@ -4635,20 +4819,19 @@ dependencies = [ "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 2.11.1", ] [[package]] name = "rustls-native-certs" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" dependencies = [ "openssl-probe", - "rustls-pemfile 2.2.0", "rustls-pki-types", "schannel", - "security-framework", + "security-framework 3.0.1", ] [[package]] @@ -4671,9 +4854,12 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e696e35370c65c9c541198af4543ccd580cf17fc25d8e05c5a242b202488c55" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +dependencies = [ + "web-time", +] [[package]] name = "rustls-webpki" @@ -4698,9 +4884,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "ryu" @@ -4719,9 +4905,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -4756,7 +4942,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -4782,7 +4968,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", - "core-foundation", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8" +dependencies = [ + "bitflags 2.6.0", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -4790,9 +4989,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" dependencies = [ "core-foundation-sys", "libc", @@ -4824,9 +5023,9 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.210" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f" dependencies = [ "serde_derive", ] @@ -4842,13 +5041,13 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.210" +version = "1.0.215" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -4859,14 +5058,14 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] name = "serde_json" -version = "1.0.128" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "indexmap 2.6.0", "itoa", @@ -4936,7 +5135,7 @@ dependencies = [ "darling 0.20.10", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5111,7 +5310,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5180,7 +5379,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5202,9 +5401,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.79" +version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89132cd0bf050864e1d38dc3bbc07a0eb8e7530af26344d3d2bbbef83499f590" +checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", @@ -5219,13 +5418,24 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "system-configuration" version = "0.5.1" @@ -5233,7 +5443,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -5304,7 +5514,7 @@ dependencies = [ "tantivy-stacker", "tantivy-tokenizer-api", "tempfile", - "thiserror", + "thiserror 1.0.69", "time", "uuid", "winapi", @@ -5408,14 +5618,14 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" +checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" dependencies = [ "cfg-if", - "fastrand 2.1.1", + "fastrand 2.2.0", "once_cell", - "rustix 0.38.37", + "rustix 0.38.41", "windows-sys 0.59.0", ] @@ -5443,27 +5653,47 @@ checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" dependencies = [ "smawk", "unicode-linebreak", - "unicode-width", + "unicode-width 0.1.14", ] [[package]] name = "thiserror" -version = "1.0.65" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c006c85c7651b3cf2ada4584faa36773bd07bac24acfb39f3c431b36d7e667aa" +dependencies = [ + "thiserror-impl 2.0.3", ] [[package]] name = "thiserror-impl" -version = "1.0.65" +version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f077553d607adc1caf65430528a576c757a71ed73944b66ebb58ef2bbd243568" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", ] [[package]] @@ -5516,6 +5746,16 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -5533,9 +5773,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.40.0" +version = "1.41.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" +checksum = "22cfb5bee7a6a52939ca9224d6ac897bb669134078daa8735560897f69de4d33" dependencies = [ "backtrace", "bytes", @@ -5557,7 +5797,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5586,7 +5826,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.14", + "rustls 0.23.18", "rustls-pki-types", "tokio", ] @@ -5602,6 +5842,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.12" @@ -5628,17 +5881,17 @@ dependencies = [ "base64 0.22.1", "bytes", "flate2", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.4.1", + "hyper 1.5.1", "hyper-timeout", "hyper-util", "percent-encoding", "pin-project", "prost 0.13.3", - "rustls-native-certs 0.8.0", + "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "socket2 0.5.7", "tokio", @@ -5715,7 +5968,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", ] [[package]] @@ -5778,7 +6031,7 @@ dependencies = [ "log", "rand", "smallvec", - "thiserror", + "thiserror 1.0.69", "tinyvec", "tokio", "url", @@ -5799,7 +6052,7 @@ dependencies = [ "parking_lot", "resolv-conf", "smallvec", - "thiserror", + "thiserror 1.0.69", "tokio", "trust-dns-proto", ] @@ -5839,12 +6092,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicase" -version = "2.7.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" -dependencies = [ - "version_check", -] +checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df" [[package]] name = "unicode-bidi" @@ -5854,9 +6104,9 @@ checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-linebreak" @@ -5891,6 +6141,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-width" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" + [[package]] name = "untrusted" version = "0.9.0" @@ -5899,12 +6155,12 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.2" +version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", - "idna 0.5.0", + "idna 1.0.3", "percent-encoding", ] @@ -5914,17 +6170,29 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + [[package]] name = "utf8-ranges" version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcfc827f90e53a02eaef5e535ee14266c1d569214c6aa70133a624d8a3164ba" +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", "serde", @@ -5987,9 +6255,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef073ced962d62984fb38a36e5fdc1a2b23c9e0e1fa0689bb97afa4202ef6887" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", "once_cell", @@ -5998,24 +6266,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4bfab14ef75323f4eb75fa52ee0a3fb59611977fd3240da19b2cf36ff85030e" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.44" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65471f79c1022ffa5291d33520cbbb53b7687b01c2f8e83b57d102eed7ed479d" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -6025,9 +6293,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7bec9830f60924d9ceb3ef99d55c155be8afa76954edffbb5936ff4509474e7" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6035,28 +6303,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c74f6e152a76a2ad448e223b0fc0b6b5747649c3d769cc6bf45737bf97d0ed6" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.94" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a42f6c679374623f295a8623adfe63d9284091245c3504bde47c17a3ce2777d9" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e072d4e72f700fb3443d8fe94a39315df013eef1104903cdb0a2abd322bbecd" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" dependencies = [ "futures-util", "js-sys", @@ -6067,9 +6335,19 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.71" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44188d185b5bdcae1052d08bcbcf9091a5524038d4572cc4f4f2bb9d5554ddd9" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" dependencies = [ "js-sys", "wasm-bindgen", @@ -6083,9 +6361,9 @@ checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" [[package]] name = "webpki-roots" -version = "0.26.6" +version = "0.26.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841c67bff177718f1d4dfefde8d8f0e78f9b6589319ba88312f567fc5841a958" +checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" dependencies = [ "rustls-pki-types", ] @@ -6330,6 +6608,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "wyz" version = "0.5.1" @@ -6345,6 +6635,30 @@ version = "0.13.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" +[[package]] +name = "yoke" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.35" @@ -6363,7 +6677,28 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.79", + "syn 2.0.89", +] + +[[package]] +name = "zerofrom" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", + "synstructure", ] [[package]] @@ -6372,6 +6707,28 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.89", +] + [[package]] name = "zstd" version = "0.13.2" diff --git a/Cargo.toml b/Cargo.toml index c8d75273..2f6d642c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,5 +3,5 @@ resolver = "2" members = [ "rig-core", "rig-lancedb", "rig-mongodb", "rig-neo4j", - "rig-qdrant", + "rig-qdrant", "rig-core/rig-core-derive" ] diff --git a/rig-core/Cargo.toml b/rig-core/Cargo.toml index 63cd23e5..f8c44e8f 100644 --- a/rig-core/Cargo.toml +++ b/rig-core/Cargo.toml @@ -23,6 +23,7 @@ futures = "0.3.29" ordered-float = "4.2.0" schemars = "0.8.16" thiserror = "1.0.61" +rig-derive = { path = "./rig-core-derive", optional = true } glob = "0.3.1" lopdf = { version = "0.34.0", optional = true } @@ -31,7 +32,33 @@ anyhow = "1.0.75" assert_fs = "1.1.2" tokio = { version = "1.34.0", features = ["full"] } tracing-subscriber = "0.3.18" +tokio-test = "0.4.4" [features] -all = ["pdf"] +all = ["derive", "pdf"] +derive = ["dep:rig-derive"] pdf = ["dep:lopdf"] + +[[test]] +name = "embed_macro" +required-features = ["derive"] + +[[example]] +name = "rag" +required-features = ["derive"] + +[[example]] +name = "vector_search" +required-features = ["derive"] + +[[example]] +name = "vector_search_cohere" +required-features = ["derive"] + +[[example]] +name = "gemini_embeddings" +required-features = ["derive"] + +[[example]] +name = "xai_embeddings" +required-features = ["derive"] diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 04d26dc3..149b1ce4 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -5,7 +5,7 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -247,13 +247,13 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .tools(&toolset)? + .documents(toolset.schemas()?)? .build() .await?; - let mut store = InMemoryVectorStore::default(); - store.add_documents(embeddings).await?; - let index = store.index(embedding_model); + let index = InMemoryVectorStore::default() + .add_documents_with_id(embeddings, |tool| tool.name.clone())? + .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/gemini_embeddings.rs b/rig-core/examples/gemini_embeddings.rs index 4ce24636..6f8badbe 100644 --- a/rig-core/examples/gemini_embeddings.rs +++ b/rig-core/examples/gemini_embeddings.rs @@ -1,4 +1,11 @@ use rig::providers::gemini; +use rig::Embed; + +#[derive(Embed, Debug)] +struct Greetings { + #[embed] + message: String, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -8,8 +15,12 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = client .embeddings(gemini::embedding::EMBEDDING_001) - .simple_document("doc0", "Hello, world!") - .simple_document("doc1", "Goodbye, world!") + .document(Greetings { + message: "Hello, world!".to_string(), + })? + .document(Greetings { + message: "Goodbye, world!".to_string(), + })? .build() .await .expect("Failed to embed documents"); diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 3abd8ee9..376c37db 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -1,34 +1,73 @@ -use std::env; +use std::{env, vec}; use rig::{ completion::Prompt, embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, + Embed, }; +use serde::Serialize; + +// Data to be RAGged. +// A vector search needs to be performed on the `definitions` field, so we derive the `Embed` trait for `WordDefinition` +// and tag that field with `#[embed]`. +#[derive(Embed, Serialize, Clone, Debug, Eq, PartialEq, Default)] +struct WordDefinition { + id: String, + word: String, + #[embed] + definitions: Vec, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { + // Initialize tracing + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .with_target(false) + .init(); + // Create OpenAI client let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Create vector store, compute embeddings and load them in the store - let mut vector_store = InMemoryVectorStore::default(); - + // Generate embeddings for the definitions of all the documents using the specified embedding model. let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(vec![ + WordDefinition { + id: "doc0".to_string(), + word: "flurbo".to_string(), + definitions: vec![ + "1. *flurbo* (name): A flurbo is a green alien that lives on cold planets.".to_string(), + "2. *flurbo* (name): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + ] + }, + WordDefinition { + id: "doc1".to_string(), + word: "glarb-glarb".to_string(), + definitions: vec![ + "1. *glarb-glarb* (noun): A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ] + }, + WordDefinition { + id: "doc2".to_string(), + word: "linglingdong".to_string(), + definitions: vec![ + "1. *linglingdong* (noun): A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + "2. *linglingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() + ] + }, + ])? .build() .await?; - vector_store.add_documents(embeddings).await?; - - // Create vector store index - let index = vector_store.index(embedding_model); + let index = InMemoryVectorStore::default() + .add_documents_with_id(embeddings, |definition| definition.id.clone())? + .index(embedding_model); let rag_agent = openai_client.agent("gpt-4") .preamble(" diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 6e45730b..bc92f7c5 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -4,7 +4,7 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, + vector_store::in_memory_store::InMemoryVectorStore, }; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -137,11 +137,6 @@ async fn main() -> Result<(), anyhow::Error> { .with_max_level(tracing::Level::INFO) // disable printing the name of the module in every log line. .with_target(false) - // this needs to be set to false, otherwise ANSI color codes will - // show up in a confusing manner in CloudWatch logs. - .with_ansi(false) - // disabling time is handy because CloudWatch will add the ingestion time. - .without_time() .init(); // Create OpenAI client @@ -150,23 +145,19 @@ async fn main() -> Result<(), anyhow::Error> { let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Create vector store, compute tool embeddings and load them in the store - let mut vector_store = InMemoryVectorStore::default(); - let toolset = ToolSet::builder() .dynamic_tool(Add) .dynamic_tool(Subtract) .build(); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) - .tools(&toolset)? + .documents(toolset.schemas()?)? .build() .await?; - vector_store.add_documents(embeddings).await?; - - // Create vector store index - let index = vector_store.index(embedding_model); + let index = InMemoryVectorStore::default() + .add_documents_with_id(embeddings, |tool| tool.name.clone())? + .index(embedding_model); // Create RAG agent with a single context prompt and a dynamic tool source let calculator_rag = openai_client diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 1328a3a5..e8cbf894 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -3,8 +3,20 @@ use std::env; use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + Embed, }; +use serde::{Deserialize, Serialize}; + +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +struct WordDefinition { + id: String, + word: String, + #[embed] + definitions: Vec, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -15,25 +27,50 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(vec![ + WordDefinition { + id: "doc0".to_string(), + word: "flurbo".to_string(), + definitions: vec![ + "A green alien that lives on cold planets.".to_string(), + "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + ] + }, + WordDefinition { + id: "doc1".to_string(), + word: "glarb-glarb".to_string(), + definitions: vec![ + "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ] + }, + WordDefinition { + id: "doc2".to_string(), + word: "linglingdong".to_string(), + definitions: vec![ + "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), + "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() + ] + }, + ])? .build() .await?; - let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; + let index = InMemoryVectorStore::default() + .add_documents_with_id(embeddings, |definition| definition.id.clone())? + .index(model); let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::("I need to buy something in a fictional universe. What type of money can I use for this?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) + .map(|(score, id, doc)| (score, id, doc.word)) .collect::>(); println!("Results: {:?}", results); let id_results = index - .top_n_ids("What is a linglingdong?", 1) + .top_n_ids("I need to buy something in a fictional universe. What type of money can I use for this?", 1) .await? .into_iter() .collect::>(); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index a49ac231..aace89fa 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -1,10 +1,22 @@ use std::env; use rig::{ - embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, + embeddings::EmbeddingsBuilder, providers::cohere::{Client, EMBED_ENGLISH_V3}, - vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, + vector_store::{in_memory_store::InMemoryVectorStore, VectorStoreIndex}, + Embed, }; +use serde::{Deserialize, Serialize}; + +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +#[derive(Embed, Clone, Deserialize, Debug, Serialize, Eq, PartialEq, Default)] +struct WordDefinition { + id: String, + word: String, + #[embed] + definitions: Vec, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -15,24 +27,48 @@ async fn main() -> Result<(), anyhow::Error> { let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); - let mut vector_store = InMemoryVectorStore::default(); - - let embeddings = EmbeddingsBuilder::new(document_model) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + let embeddings = EmbeddingsBuilder::new(document_model.clone()) + .documents(vec![ + WordDefinition { + id: "doc0".to_string(), + word: "flurbo".to_string(), + definitions: vec![ + "A green alien that lives on cold planets.".to_string(), + "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + ] + }, + WordDefinition { + id: "doc1".to_string(), + word: "glarb-glarb".to_string(), + definitions: vec![ + "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ] + }, + WordDefinition { + id: "doc2".to_string(), + word: "linglingdong".to_string(), + definitions: vec![ + "A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), + "A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() + ] + }, + ])? .build() .await?; - vector_store.add_documents(embeddings).await?; - - let index = vector_store.index(search_model); + let index = InMemoryVectorStore::default() + .add_documents_with_id(embeddings, |definition| definition.id.clone())? + .index(search_model); let results = index - .top_n::("What is a linglingdong?", 1) + .top_n::( + "Which instrument is found in the Nebulon Mountain Ranges?", + 1, + ) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) + .map(|(score, id, doc)| (score, id, doc.word)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-core/examples/xai_embeddings.rs b/rig-core/examples/xai_embeddings.rs index ba24a9b0..a127c389 100644 --- a/rig-core/examples/xai_embeddings.rs +++ b/rig-core/examples/xai_embeddings.rs @@ -1,4 +1,11 @@ use rig::providers::xai; +use rig::Embed; + +#[derive(Embed, Debug)] +struct Greetings { + #[embed] + message: String, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -7,8 +14,12 @@ async fn main() -> Result<(), anyhow::Error> { let embeddings = client .embeddings(xai::embedding::EMBEDDING_V1) - .simple_document("doc0", "Hello, world!") - .simple_document("doc1", "Goodbye, world!") + .document(Greetings { + message: "Hello, world!".to_string(), + })? + .document(Greetings { + message: "Goodbye, world!".to_string(), + })? .build() .await .expect("Failed to embed documents"); diff --git a/rig-core/rig-core-derive/Cargo.toml b/rig-core/rig-core-derive/Cargo.toml new file mode 100644 index 00000000..1ab5e5ac --- /dev/null +++ b/rig-core/rig-core-derive/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "rig-derive" +version = "0.1.0" +edition = "2021" + +[dependencies] +indoc = "2.0.5" +proc-macro2 = { version = "1.0.87", features = ["proc-macro"] } +quote = "1.0.37" +syn = { version = "2.0.79", features = ["full"]} + +[lib] +proc-macro = true diff --git a/rig-core/rig-core-derive/src/basic.rs b/rig-core/rig-core-derive/src/basic.rs new file mode 100644 index 00000000..b9c1e5c4 --- /dev/null +++ b/rig-core/rig-core-derive/src/basic.rs @@ -0,0 +1,25 @@ +use syn::{parse_quote, Attribute, DataStruct, Meta}; + +use crate::EMBED; + +/// Finds and returns fields with simple `#[embed]` attribute tags only. +pub(crate) fn basic_embed_fields(data_struct: &DataStruct) -> impl Iterator { + data_struct.fields.iter().filter(|field| { + field.attrs.iter().any(|attribute| match attribute { + Attribute { + meta: Meta::Path(path), + .. + } => path.is_ident(EMBED), + _ => false, + }) + }) +} + +/// Adds bounds to where clause that force all fields tagged with `#[embed]` to implement the `Embed` trait. +pub(crate) fn add_struct_bounds(generics: &mut syn::Generics, field_type: &syn::Type) { + let where_clause = generics.make_where_clause(); + + where_clause.predicates.push(parse_quote! { + #field_type: Embed + }); +} diff --git a/rig-core/rig-core-derive/src/custom.rs b/rig-core/rig-core-derive/src/custom.rs new file mode 100644 index 00000000..f29ed20e --- /dev/null +++ b/rig-core/rig-core-derive/src/custom.rs @@ -0,0 +1,123 @@ +use quote::ToTokens; +use syn::{meta::ParseNestedMeta, ExprPath}; + +use crate::EMBED; + +const EMBED_WITH: &str = "embed_with"; + +/// Finds and returns fields with #[embed(embed_with = "...")] attribute tags only. +/// Also returns the "..." part of the tag (ie. the custom function). +pub(crate) fn custom_embed_fields( + data_struct: &syn::DataStruct, +) -> syn::Result> { + data_struct + .fields + .iter() + .filter_map(|field| { + field + .attrs + .iter() + .filter_map(|attribute| match attribute.is_custom() { + Ok(true) => match attribute.expand_tag() { + Ok(path) => Some(Ok((field, path))), + Err(e) => Some(Err(e)), + }, + Ok(false) => None, + Err(e) => Some(Err(e)), + }) + .next() + }) + .collect::, _>>() +} + +trait CustomAttributeParser { + // Determine if field is tagged with an #[embed(embed_with = "...")] attribute. + fn is_custom(&self) -> syn::Result; + + // Get the "..." part of the #[embed(embed_with = "...")] attribute. + // Ex: If attribute is tagged with #[embed(embed_with = "my_embed")], returns "my_embed". + fn expand_tag(&self) -> syn::Result; +} + +impl CustomAttributeParser for syn::Attribute { + fn is_custom(&self) -> syn::Result { + // Check that the attribute is a list. + match &self.meta { + syn::Meta::List(meta) => { + if meta.tokens.is_empty() { + return Ok(false); + } + } + _ => return Ok(false), + }; + + // Check the first attribute tag (the first "embed") + if !self.path().is_ident(EMBED) { + return Ok(false); + } + + self.parse_nested_meta(|meta| { + // Parse the meta attribute as an expression. Need this to compile. + meta.value()?.parse::()?; + + if meta.path.is_ident(EMBED_WITH) { + Ok(()) + } else { + let path = meta.path.to_token_stream().to_string().replace(' ', ""); + Err(syn::Error::new_spanned( + meta.path, + format_args!("unknown embedding field attribute `{}`", path), + )) + } + })?; + + Ok(true) + } + + fn expand_tag(&self) -> syn::Result { + fn function_path(meta: &ParseNestedMeta<'_>) -> syn::Result { + // #[embed(embed_with = "...")] + let expr = meta.value()?.parse::().unwrap(); + let mut value = &expr; + while let syn::Expr::Group(e) = value { + value = &e.expr; + } + let string = if let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = value + { + let suffix = lit_str.suffix(); + if !suffix.is_empty() { + return Err(syn::Error::new_spanned( + lit_str, + format!("unexpected suffix `{}` on string literal", suffix), + )); + } + lit_str.clone() + } else { + return Err(syn::Error::new_spanned( + value, + format!( + "expected {} attribute to be a string: `{} = \"...\"`", + EMBED_WITH, EMBED_WITH + ), + )); + }; + + string.parse() + } + + let mut custom_func_path = None; + + self.parse_nested_meta(|meta| match function_path(&meta) { + Ok(path) => { + custom_func_path = Some(path); + Ok(()) + } + Err(e) => Err(e), + })?; + + Ok(custom_func_path.unwrap()) + } +} diff --git a/rig-core/rig-core-derive/src/embed.rs b/rig-core/rig-core-derive/src/embed.rs new file mode 100644 index 00000000..73b89205 --- /dev/null +++ b/rig-core/rig-core-derive/src/embed.rs @@ -0,0 +1,110 @@ +use proc_macro2::TokenStream; +use quote::quote; +use syn::DataStruct; + +use crate::{ + basic::{add_struct_bounds, basic_embed_fields}, + custom::custom_embed_fields, +}; + +pub(crate) fn expand_derive_embedding(input: &mut syn::DeriveInput) -> syn::Result { + let name = &input.ident; + let data = &input.data; + let generics = &mut input.generics; + + let target_stream = match data { + syn::Data::Struct(data_struct) => { + let (basic_targets, basic_target_size) = data_struct.basic(generics); + let (custom_targets, custom_target_size) = data_struct.custom()?; + + // If there are no fields tagged with `#[embed]` or `#[embed(embed_with = "...")]`, return an empty TokenStream. + // ie. do not implement `Embed` trait for the struct. + if basic_target_size + custom_target_size == 0 { + return Err(syn::Error::new_spanned( + name, + "Add at least one field tagged with #[embed] or #[embed(embed_with = \"...\")].", + )); + } + + quote! { + #basic_targets; + #custom_targets; + } + } + _ => { + return Err(syn::Error::new_spanned( + input, + "Embed derive macro should only be used on structs", + )) + } + }; + + let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); + + let gen = quote! { + // Note: `Embed` trait is imported with the macro. + + impl #impl_generics Embed for #name #ty_generics #where_clause { + fn embed(&self, embedder: &mut rig::embeddings::embed::TextEmbedder) -> Result<(), rig::embeddings::embed::EmbedError> { + #target_stream; + + Ok(()) + } + } + }; + + Ok(gen) +} + +trait StructParser { + // Handles fields tagged with `#[embed]` + fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize); + + // Handles fields tagged with `#[embed(embed_with = "...")]` + fn custom(&self) -> syn::Result<(TokenStream, usize)>; +} + +impl StructParser for DataStruct { + fn basic(&self, generics: &mut syn::Generics) -> (TokenStream, usize) { + let embed_targets = basic_embed_fields(self) + // Iterate over every field tagged with `#[embed]` + .map(|field| { + add_struct_bounds(generics, &field.ty); + + let field_name = &field.ident; + + quote! { + self.#field_name + } + }) + .collect::>(); + + ( + quote! { + #(#embed_targets.embed(embedder)?;)* + }, + embed_targets.len(), + ) + } + + fn custom(&self) -> syn::Result<(TokenStream, usize)> { + let embed_targets = custom_embed_fields(self)? + // Iterate over every field tagged with `#[embed(embed_with = "...")]` + .into_iter() + .map(|(field, custom_func_path)| { + let field_name = &field.ident; + + quote! { + #custom_func_path(embedder, self.#field_name.clone())?; + } + }) + .collect::>(); + + Ok(( + quote! { + #(#embed_targets)* + }, + embed_targets.len(), + )) + } +} diff --git a/rig-core/rig-core-derive/src/lib.rs b/rig-core/rig-core-derive/src/lib.rs new file mode 100644 index 00000000..4ce20cfa --- /dev/null +++ b/rig-core/rig-core-derive/src/lib.rs @@ -0,0 +1,21 @@ +extern crate proc_macro; +use proc_macro::TokenStream; +use syn::{parse_macro_input, DeriveInput}; + +mod basic; +mod custom; +mod embed; + +pub(crate) const EMBED: &str = "embed"; + +/// References: +/// +/// +#[proc_macro_derive(Embed, attributes(embed))] +pub fn derive_embedding_trait(item: TokenStream) -> TokenStream { + let mut input = parse_macro_input!(item as DeriveInput); + + embed::expand_derive_embedding(&mut input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} diff --git a/rig-core/src/completion.rs b/rig-core/src/completion.rs index e766fb27..f13f316b 100644 --- a/rig-core/src/completion.rs +++ b/rig-core/src/completion.rs @@ -82,7 +82,7 @@ pub enum CompletionError { /// Error building the completion request #[error("RequestError: {0}")] - RequestError(#[from] Box), + RequestError(#[from] Box), /// Error parsing the completion response #[error("ResponseError: {0}")] diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs deleted file mode 100644 index eaced08b..00000000 --- a/rig-core/src/embeddings.rs +++ /dev/null @@ -1,335 +0,0 @@ -//! This module provides functionality for working with embeddings and embedding models. -//! Embeddings are numerical representations of documents or other objects, typically used in -//! natural language processing (NLP) tasks such as text classification, information retrieval, -//! and document similarity. -//! -//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can -//! generate embeddings for documents. It also provides an implementation of the [EmbeddingsBuilder] -//! struct, which allows users to build collections of document embeddings using different embedding -//! models and document sources. -//! -//! The module also defines the [Embedding] struct, which represents a single document embedding, -//! and the [DocumentEmbeddings] struct, which represents a document along with its associated -//! embeddings. These structs are used to store and manipulate collections of document embeddings. -//! -//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that -//! can occur during embedding generation or processing. -//! -//! # Example -//! ```rust -//! use rig::providers::openai::{Client, self}; -//! use rig::embeddings::{EmbeddingModel, EmbeddingsBuilder}; -//! -//! // Initialize the OpenAI client -//! let openai = Client::new("your-openai-api-key"); -//! -//! // Create an instance of the `text-embedding-ada-002` model -//! let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_ADA_002); -//! -//! // Create an embeddings builder and add documents -//! let embeddings = EmbeddingsBuilder::new(embedding_model) -//! .simple_document("doc1", "This is the first document.") -//! .simple_document("doc2", "This is the second document.") -//! .build() -//! .await -//! .expect("Failed to build embeddings."); -//! -//! // Use the generated embeddings -//! // ... -//! ``` - -use std::{cmp::max, collections::HashMap}; - -use futures::{stream, StreamExt, TryStreamExt}; -use serde::{Deserialize, Serialize}; - -use crate::tool::{ToolEmbedding, ToolSet, ToolType}; - -#[derive(Debug, thiserror::Error)] -pub enum EmbeddingError { - /// Http error (e.g.: connection error, timeout, etc.) - #[error("HttpError: {0}")] - HttpError(#[from] reqwest::Error), - - /// Json error (e.g.: serialization, deserialization) - #[error("JsonError: {0}")] - JsonError(#[from] serde_json::Error), - - /// Error processing the document for embedding - #[error("DocumentError: {0}")] - DocumentError(String), - - /// Error parsing the completion response - #[error("ResponseError: {0}")] - ResponseError(String), - - /// Error returned by the embedding model provider - #[error("ProviderError: {0}")] - ProviderError(String), -} - -/// Trait for embedding models that can generate embeddings for documents. -pub trait EmbeddingModel: Clone + Sync + Send { - /// The maximum number of documents that can be embedded in a single request. - const MAX_DOCUMENTS: usize; - - /// The number of dimensions in the embedding vector. - fn ndims(&self) -> usize; - - /// Embed a single document - fn embed_document( - &self, - document: &str, - ) -> impl std::future::Future> + Send - where - Self: Sync, - { - async { - Ok(self - .embed_documents(vec![document.to_string()]) - .await? - .first() - .cloned() - .expect("One embedding should be present")) - } - } - - /// Embed multiple documents in a single request - fn embed_documents( - &self, - documents: impl IntoIterator + Send, - ) -> impl std::future::Future, EmbeddingError>> + Send; -} - -/// Struct that holds a single document and its embedding. -#[derive(Clone, Default, Deserialize, Serialize, Debug)] -pub struct Embedding { - /// The document that was embedded - pub document: String, - /// The embedding vector - pub vec: Vec, -} - -impl PartialEq for Embedding { - fn eq(&self, other: &Self) -> bool { - self.document == other.document - } -} - -impl Eq for Embedding {} - -impl Embedding { - pub fn distance(&self, other: &Self) -> f64 { - let dot_product: f64 = self - .vec - .iter() - .zip(other.vec.iter()) - .map(|(x, y)| x * y) - .sum(); - - let product_of_lengths = (self.vec.len() * other.vec.len()) as f64; - - dot_product / product_of_lengths - } -} - -/// Struct that holds a document and its embeddings. -/// -/// The struct is designed to model any kind of documents that can be serialized to JSON -/// (including a simple string). -/// -/// Moreover, it can hold multiple embeddings for the same document, thus allowing a -/// large document to be retrieved from a query that matches multiple smaller and -/// distinct text documents. For example, if the document is a textbook, a summary of -/// each chapter could serve as the book's embeddings. -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] -pub struct DocumentEmbeddings { - #[serde(rename = "_id")] - pub id: String, - pub document: serde_json::Value, - pub embeddings: Vec, -} - -type Embeddings = Vec; - -/// Builder for creating a collection of embeddings -pub struct EmbeddingsBuilder { - model: M, - documents: Vec<(String, serde_json::Value, Vec)>, -} - -impl EmbeddingsBuilder { - /// Create a new embedding builder with the given embedding model - pub fn new(model: M) -> Self { - Self { - model, - documents: vec![], - } - } - - /// Add a simple document to the embedding collection. - /// The provided document string will be used for the embedding. - pub fn simple_document(mut self, id: &str, document: &str) -> Self { - self.documents.push(( - id.to_string(), - serde_json::Value::String(document.to_string()), - vec![document.to_string()], - )); - self - } - - /// Add multiple documents to the embedding collection. - /// Each element of the vector is a tuple of the form (id, document). - pub fn simple_documents(mut self, documents: Vec<(String, String)>) -> Self { - self.documents - .extend(documents.into_iter().map(|(id, document)| { - ( - id, - serde_json::Value::String(document.clone()), - vec![document], - ) - })); - self - } - - /// Add a tool to the embedding collection. - /// The `tool.context()` corresponds to the document being stored while - /// `tool.embedding_docs()` corresponds to the documents that will be used to generate the embeddings. - pub fn tool(mut self, tool: impl ToolEmbedding + 'static) -> Result { - self.documents.push(( - tool.name(), - serde_json::to_value(tool.context())?, - tool.embedding_docs(), - )); - Ok(self) - } - - /// Add the tools from the given toolset to the embedding collection. - pub fn tools(mut self, toolset: &ToolSet) -> Result { - for (name, tool) in toolset.tools.iter() { - if let ToolType::Embedding(tool) = tool { - self.documents.push(( - name.clone(), - tool.context().map_err(|e| { - EmbeddingError::DocumentError(format!( - "Failed to generate context for tool {}: {}", - name, e - )) - })?, - tool.embedding_docs(), - )); - } - } - Ok(self) - } - - /// Add a document to the embedding collection. - /// `embed_documents` are the documents that will be used to generate the embeddings - /// for `document`. - pub fn document( - mut self, - id: &str, - document: T, - embed_documents: Vec, - ) -> Self { - self.documents.push(( - id.to_string(), - serde_json::to_value(document).expect("Document should serialize"), - embed_documents, - )); - self - } - - /// Add multiple documents to the embedding collection. - /// Each element of the vector is a tuple of the form (id, document, embed_documents). - pub fn documents(mut self, documents: Vec<(String, T, Vec)>) -> Self { - self.documents.extend( - documents - .into_iter() - .map(|(id, document, embed_documents)| { - ( - id, - serde_json::to_value(document).expect("Document should serialize"), - embed_documents, - ) - }), - ); - self - } - - /// Add a json document to the embedding collection. - pub fn json_document( - mut self, - id: &str, - document: serde_json::Value, - embed_documents: Vec, - ) -> Self { - self.documents - .push((id.to_string(), document, embed_documents)); - self - } - - /// Add multiple json documents to the embedding collection. - pub fn json_documents( - mut self, - documents: Vec<(String, serde_json::Value, Vec)>, - ) -> Self { - self.documents.extend(documents); - self - } - - /// Generate the embeddings for the given documents - pub async fn build(self) -> Result { - // Create a temporary store for the documents - let documents_map = self - .documents - .into_iter() - .map(|(id, document, docs)| (id, (document, docs))) - .collect::>(); - - let embeddings = stream::iter(documents_map.iter()) - // Flatten the documents - .flat_map(|(id, (_, docs))| { - stream::iter(docs.iter().map(|doc| (id.clone(), doc.clone()))) - }) - // Chunk them into N (the emebdding API limit per request) - .chunks(M::MAX_DOCUMENTS) - // Generate the embeddings - .map(|docs| async { - let (ids, docs): (Vec<_>, Vec<_>) = docs.into_iter().unzip(); - Ok::<_, EmbeddingError>( - ids.into_iter() - .zip(self.model.embed_documents(docs).await?.into_iter()) - .collect::>(), - ) - }) - .boxed() - // Parallelize the embeddings generation over 10 concurrent requests - .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) - .try_fold(vec![], |mut acc, mut embeddings| async move { - Ok({ - acc.append(&mut embeddings); - acc - }) - }) - .await?; - - // Assemble the DocumentEmbeddings - let mut document_embeddings: HashMap = HashMap::new(); - embeddings.into_iter().for_each(|(id, embedding)| { - let (document, _) = documents_map.get(&id).expect("Document not found"); - let document_embedding = - document_embeddings - .entry(id.clone()) - .or_insert_with(|| DocumentEmbeddings { - id: id.clone(), - document: document.clone(), - embeddings: vec![], - }); - - document_embedding.embeddings.push(embedding); - }); - - Ok(document_embeddings.into_values().collect()) - } -} diff --git a/rig-core/src/embeddings/builder.rs b/rig-core/src/embeddings/builder.rs new file mode 100644 index 00000000..f9e80779 --- /dev/null +++ b/rig-core/src/embeddings/builder.rs @@ -0,0 +1,387 @@ +//! The module defines the [EmbeddingsBuilder] struct which accumulates objects to be embedded +//! and batch generates the embeddings for each object when built. +//! Only types that implement the [Embed] trait can be added to the [EmbeddingsBuilder]. + +use std::{cmp::max, collections::HashMap}; + +use futures::{stream, StreamExt}; + +use crate::{ + embeddings::{ + embed::TextEmbedder, Embed, EmbedError, Embedding, EmbeddingError, EmbeddingModel, + }, + OneOrMany, +}; + +/// Builder for creating embeddings from one or more documents of type `T`. +/// Note: `T` can be any type that implements the [Embed] trait. +/// +/// Using the builder is preferred over using [EmbeddingModel::embed_text] directly as +/// it will batch the documents in a single request to the model provider. +/// +/// # Example +/// ```rust +/// use std::env; +/// +/// use rig::{ +/// embeddings::EmbeddingsBuilder, +/// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +/// }; +/// use serde::{Deserialize, Serialize}; +/// +/// // Create OpenAI client +/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); +/// let openai_client = Client::new(&openai_api_key); +/// +/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); +/// +/// let embeddings = EmbeddingsBuilder::new(model.clone()) +/// .documents(vec![ +/// "1. *flurbo* (noun): A green alien that lives on cold planets.".to_string(), +/// "2. *flurbo* (noun): A fictional digital currency that originated in the animated series Rick and Morty.".to_string() +/// "1. *glarb-glarb* (noun): An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), +/// "2. *glarb-glarb* (noun): A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() +/// "1. *linlingdong* (noun): A term used by inhabitants of the sombrero galaxy to describe humans.".to_string(), +/// "2. *linlingdong* (noun): A rare, mystical instrument crafted by the ancient monks of the Nebulon Mountain Ranges on the planet Quarm.".to_string() +/// ])? +/// .build() +/// .await?; +/// ``` +pub struct EmbeddingsBuilder { + model: M, + documents: Vec<(T, Vec)>, +} + +impl EmbeddingsBuilder { + /// Create a new embedding builder with the given embedding model + pub fn new(model: M) -> Self { + Self { + model, + documents: vec![], + } + } + + /// Add a document to be embedded to the builder. `document` must implement the [Embed] trait. + pub fn document(mut self, document: T) -> Result { + let mut embedder = TextEmbedder::default(); + document.embed(&mut embedder)?; + + self.documents.push((document, embedder.texts)); + + Ok(self) + } + + /// Add multiple documents to be embedded to the builder. `documents` must be iteratable + /// with items that implement the [Embed] trait. + pub fn documents(self, documents: impl IntoIterator) -> Result { + let builder = documents + .into_iter() + .try_fold(self, |builder, doc| builder.document(doc))?; + + Ok(builder) + } +} + +impl EmbeddingsBuilder { + /// Generate embeddings for all documents in the builder. + /// Returns a vector of tuples, where the first element is the document and the second element is the embeddings (either one embedding or many). + pub async fn build(self) -> Result)>, EmbeddingError> { + use stream::TryStreamExt; + + // Store the documents and their texts in a HashMap for easy access. + let mut docs = HashMap::new(); + let mut texts = HashMap::new(); + + // Iterate over all documents in the builder and insert their docs and texts into the lookup stores. + for (i, (doc, doc_texts)) in self.documents.into_iter().enumerate() { + docs.insert(i, doc); + texts.insert(i, doc_texts); + } + + // Compute the embeddings. + let mut embeddings = stream::iter(texts.into_iter()) + // Merge the texts of each document into a single list of texts. + .flat_map(|(i, texts)| stream::iter(texts.into_iter().map(move |text| (i, text)))) + // Chunk them into batches. Each batch size is at most the embedding API limit per request. + .chunks(M::MAX_DOCUMENTS) + // Generate the embeddings for each batch. + .map(|text| async { + let (ids, docs): (Vec<_>, Vec<_>) = text.into_iter().unzip(); + + let embeddings = self.model.embed_texts(docs).await?; + Ok::<_, EmbeddingError>(ids.into_iter().zip(embeddings).collect::>()) + }) + // Parallelize the embeddings generation over 10 concurrent requests + .buffer_unordered(max(1, 1024 / M::MAX_DOCUMENTS)) + // Collect the embeddings into a HashMap. + .try_fold( + HashMap::new(), + |mut acc: HashMap<_, OneOrMany>, embeddings| async move { + embeddings.into_iter().for_each(|(i, embedding)| { + acc.entry(i) + .and_modify(|embeddings| embeddings.push(embedding.clone())) + .or_insert(OneOrMany::one(embedding.clone())); + }); + + Ok(acc) + }, + ) + .await?; + + // Merge the embeddings with their respective documents + Ok(docs + .into_iter() + .map(|(i, doc)| { + ( + doc, + embeddings.remove(&i).expect("Document should be present"), + ) + }) + .collect()) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + embeddings::{embed::EmbedError, embed::TextEmbedder, Embedding, EmbeddingModel}, + Embed, + }; + + use super::EmbeddingsBuilder; + + #[derive(Clone)] + struct Model; + + impl EmbeddingModel for Model { + const MAX_DOCUMENTS: usize = 5; + + fn ndims(&self) -> usize { + 10 + } + + async fn embed_texts( + &self, + documents: impl IntoIterator + Send, + ) -> Result, crate::embeddings::EmbeddingError> { + Ok(documents + .into_iter() + .map(|doc| Embedding { + document: doc.to_string(), + vec: vec![0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + }) + .collect()) + } + } + + #[derive(Clone, Debug)] + struct WordDefinition { + id: String, + definitions: Vec, + } + + impl Embed for WordDefinition { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + for definition in &self.definitions { + embedder.embed(definition.clone()); + } + Ok(()) + } + } + + fn definitions_multiple_text() -> Vec { + vec![ + WordDefinition { + id: "doc0".to_string(), + definitions: vec![ + "A green alien that lives on cold planets.".to_string(), + "A fictional digital currency that originated in the animated series Rick and Morty.".to_string() + ] + }, + WordDefinition { + id: "doc1".to_string(), + definitions: vec![ + "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ] + } + ] + } + + fn definitions_multiple_text_2() -> Vec { + vec![ + WordDefinition { + id: "doc2".to_string(), + definitions: vec!["Another fake definitions".to_string()], + }, + WordDefinition { + id: "doc3".to_string(), + definitions: vec!["Some fake definition".to_string()], + }, + ] + } + + #[derive(Clone, Debug)] + struct WordDefinitionSingle { + id: String, + definition: String, + } + + impl Embed for WordDefinitionSingle { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.definition.clone()); + Ok(()) + } + } + + fn definitions_single_text() -> Vec { + vec![ + WordDefinitionSingle { + id: "doc0".to_string(), + definition: "A green alien that lives on cold planets.".to_string(), + }, + WordDefinitionSingle { + id: "doc1".to_string(), + definition: "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + } + ] + } + + #[tokio::test] + async fn test_build_multiple_text() { + let fake_definitions = definitions_multiple_text(); + + let fake_model = Model; + let mut result = EmbeddingsBuilder::new(fake_model) + .documents(fake_definitions) + .unwrap() + .build() + .await + .unwrap(); + + result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { + fake_definition_1.id.cmp(&fake_definition_2.id) + }); + + assert_eq!(result.len(), 2); + + let first_definition = &result[0]; + assert_eq!(first_definition.0.id, "doc0"); + assert_eq!(first_definition.1.len(), 2); + assert_eq!( + first_definition.1.first().document, + "A green alien that lives on cold planets.".to_string() + ); + + let second_definition = &result[1]; + assert_eq!(second_definition.0.id, "doc1"); + assert_eq!(second_definition.1.len(), 2); + assert_eq!( + second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ) + } + + #[tokio::test] + async fn test_build_single_text() { + let fake_definitions = definitions_single_text(); + + let fake_model = Model; + let mut result = EmbeddingsBuilder::new(fake_model) + .documents(fake_definitions) + .unwrap() + .build() + .await + .unwrap(); + + result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { + fake_definition_1.id.cmp(&fake_definition_2.id) + }); + + assert_eq!(result.len(), 2); + + let first_definition = &result[0]; + assert_eq!(first_definition.0.id, "doc0"); + assert_eq!(first_definition.1.len(), 1); + assert_eq!( + first_definition.1.first().document, + "A green alien that lives on cold planets.".to_string() + ); + + let second_definition = &result[1]; + assert_eq!(second_definition.0.id, "doc1"); + assert_eq!(second_definition.1.len(), 1); + assert_eq!( + second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + ) + } + + #[tokio::test] + async fn test_build_multiple_and_single_text() { + let fake_definitions = definitions_multiple_text(); + let fake_definitions_single = definitions_multiple_text_2(); + + let fake_model = Model; + let mut result = EmbeddingsBuilder::new(fake_model) + .documents(fake_definitions) + .unwrap() + .documents(fake_definitions_single) + .unwrap() + .build() + .await + .unwrap(); + + result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { + fake_definition_1.id.cmp(&fake_definition_2.id) + }); + + assert_eq!(result.len(), 4); + + let second_definition = &result[1]; + assert_eq!(second_definition.0.id, "doc1"); + assert_eq!(second_definition.1.len(), 2); + assert_eq!( + second_definition.1.first().document, "An ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string() + ); + + let third_definition = &result[2]; + assert_eq!(third_definition.0.id, "doc2"); + assert_eq!(third_definition.1.len(), 1); + assert_eq!( + third_definition.1.first().document, + "Another fake definitions".to_string() + ) + } + + #[tokio::test] + async fn test_build_string() { + let bindings = definitions_multiple_text(); + let fake_definitions = bindings.iter().map(|def| def.definitions.clone()); + + let fake_model = Model; + let mut result = EmbeddingsBuilder::new(fake_model) + .documents(fake_definitions) + .unwrap() + .build() + .await + .unwrap(); + + result.sort_by(|(fake_definition_1, _), (fake_definition_2, _)| { + fake_definition_1.cmp(&fake_definition_2) + }); + + assert_eq!(result.len(), 2); + + let first_definition = &result[0]; + assert_eq!(first_definition.1.len(), 2); + assert_eq!( + first_definition.1.first().document, + "A green alien that lives on cold planets.".to_string() + ); + + let second_definition = &result[1]; + assert_eq!(second_definition.1.len(), 2); + assert_eq!( + second_definition.1.rest()[0].document, "A fictional creature found in the distant, swampy marshlands of the planet Glibbo in the Andromeda galaxy.".to_string() + ) + } +} diff --git a/rig-core/src/embeddings/embed.rs b/rig-core/src/embeddings/embed.rs new file mode 100644 index 00000000..e8acc613 --- /dev/null +++ b/rig-core/src/embeddings/embed.rs @@ -0,0 +1,190 @@ +//! The module defines the [Embed] trait, which must be implemented for types +//! that can be embedded by the [crate::embeddings::EmbeddingsBuilder]. +//! +//! The module also defines the [EmbedError] struct which is used for when the [Embed::embed] +//! method of the [Embed] trait fails. +//! +//! The module also defines the [TextEmbedder] struct which accumulates string values that need to be embedded. +//! It is used directly with the [Embed] trait. +//! +//! Finally, the module implements [Embed] for many common primitive types. + +/// Error type used for when the [Embed::embed] method fo the [Embed] trait fails. +/// Used by default implementations of [Embed] for common types. +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct EmbedError(#[from] Box); + +impl EmbedError { + pub fn new(error: E) -> Self { + EmbedError(Box::new(error)) + } +} + +/// Derive this trait for objects that need to be converted to vector embeddings. +/// The [Embed::embed] method accumulates string values that need to be embedded by adding them to the [TextEmbedder]. +/// If an error occurs, the method should return [EmbedError]. +/// # Example +/// ```rust +/// use std::env; +/// +/// use serde::{Deserialize, Serialize}; +/// use rig::{Embed, embeddings::{TextEmbedder, EmbedError}}; +/// +/// struct WordDefinition { +/// id: String, +/// word: String, +/// definitions: String, +/// } +/// +/// impl Embed for WordDefinition { +/// fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { +/// // Embeddings only need to be generated for `definition` field. +/// // Split the definitions by comma and collect them into a vector of strings. +/// // That way, different embeddings can be generated for each definition in the `definitions` string. +/// self.definitions +/// .split(",") +/// .for_each(|s| { +/// embedder.embed(s.to_string()); +/// }); +/// +/// Ok(()) +/// } +/// } +/// +/// let fake_definition = WordDefinition { +/// id: "1".to_string(), +/// word: "apple".to_string(), +/// definitions: "a fruit, a tech company".to_string(), +/// }; +/// +/// assert_eq!(embeddings::to_texts(fake_definition).unwrap(), vec!["a fruit", " a tech company"]); +/// ``` +pub trait Embed { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError>; +} + +/// Accumulates string values that need to be embedded. +/// Used by the [Embed] trait. +#[derive(Default)] +pub struct TextEmbedder { + pub(crate) texts: Vec, +} + +impl TextEmbedder { + /// Adds input `text` string to the list of texts in the [TextEmbedder] that need to be embedded. + pub fn embed(&mut self, text: String) { + self.texts.push(text); + } +} + +/// Utility function that returns a vector of strings that need to be embedded for a +/// given object that implements the [Embed] trait. +pub fn to_texts(item: impl Embed) -> Result, EmbedError> { + let mut embedder = TextEmbedder::default(); + item.embed(&mut embedder)?; + Ok(embedder.texts) +} + +// ================================================================ +// Implementations of Embed for common types +// ================================================================ + +impl Embed for String { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.clone()); + Ok(()) + } +} + +impl Embed for &str { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i8 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i16 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i32 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i64 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for i128 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for f32 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for f64 { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for bool { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for char { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(self.to_string()); + Ok(()) + } +} + +impl Embed for serde_json::Value { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(self).map_err(EmbedError::new)?); + Ok(()) + } +} + +impl Embed for &T { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + (*self).embed(embedder) + } +} + +impl Embed for Vec { + fn embed(&self, embedder: &mut TextEmbedder) -> Result<(), EmbedError> { + for item in self { + item.embed(embedder).map_err(EmbedError::new)?; + } + Ok(()) + } +} diff --git a/rig-core/src/embeddings/embedding.rs b/rig-core/src/embeddings/embedding.rs new file mode 100644 index 00000000..7c8877d9 --- /dev/null +++ b/rig-core/src/embeddings/embedding.rs @@ -0,0 +1,93 @@ +//! The module defines the [EmbeddingModel] trait, which represents an embedding model that can +//! generate embeddings for documents. +//! +//! The module also defines the [Embedding] struct, which represents a single document embedding. +//! +//! Finally, the module defines the [EmbeddingError] enum, which represents various errors that +//! can occur during embedding generation or processing. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, thiserror::Error)] +pub enum EmbeddingError { + /// Http error (e.g.: connection error, timeout, etc.) + #[error("HttpError: {0}")] + HttpError(#[from] reqwest::Error), + + /// Json error (e.g.: serialization, deserialization) + #[error("JsonError: {0}")] + JsonError(#[from] serde_json::Error), + + /// Error processing the document for embedding + #[error("DocumentError: {0}")] + DocumentError(Box), + + /// Error parsing the completion response + #[error("ResponseError: {0}")] + ResponseError(String), + + /// Error returned by the embedding model provider + #[error("ProviderError: {0}")] + ProviderError(String), +} + +/// Trait for embedding models that can generate embeddings for documents. +pub trait EmbeddingModel: Clone + Sync + Send { + /// The maximum number of documents that can be embedded in a single request. + const MAX_DOCUMENTS: usize; + + /// The number of dimensions in the embedding vector. + fn ndims(&self) -> usize; + + /// Embed multiple text documents in a single request + fn embed_texts( + &self, + texts: impl IntoIterator + Send, + ) -> impl std::future::Future, EmbeddingError>> + Send; + + /// Embed a single text document. + fn embed_text( + &self, + text: &str, + ) -> impl std::future::Future> + Send { + async { + Ok(self + .embed_texts(vec![text.to_string()]) + .await? + .pop() + .expect("There should be at least one embedding")) + } + } +} + +/// Struct that holds a single document and its embedding. +#[derive(Clone, Default, Deserialize, Serialize, Debug)] +pub struct Embedding { + /// The document that was embedded. Used for debugging. + pub document: String, + /// The embedding vector + pub vec: Vec, +} + +impl PartialEq for Embedding { + fn eq(&self, other: &Self) -> bool { + self.document == other.document + } +} + +impl Eq for Embedding {} + +impl Embedding { + pub fn distance(&self, other: &Self) -> f64 { + let dot_product: f64 = self + .vec + .iter() + .zip(other.vec.iter()) + .map(|(x, y)| x * y) + .sum(); + + let product_of_lengths = (self.vec.len() * other.vec.len()) as f64; + + dot_product / product_of_lengths + } +} diff --git a/rig-core/src/embeddings/mod.rs b/rig-core/src/embeddings/mod.rs new file mode 100644 index 00000000..1ae16436 --- /dev/null +++ b/rig-core/src/embeddings/mod.rs @@ -0,0 +1,14 @@ +//! This module provides functionality for working with embeddings. +//! Embeddings are numerical representations of documents or other objects, typically used in +//! natural language processing (NLP) tasks such as text classification, information retrieval, +//! and document similarity. + +pub mod builder; +pub mod embed; +pub mod embedding; +pub mod tool; + +pub use builder::EmbeddingsBuilder; +pub use embed::{to_texts, Embed, EmbedError, TextEmbedder}; +pub use embedding::{Embedding, EmbeddingError, EmbeddingModel}; +pub use tool::ToolSchema; diff --git a/rig-core/src/embeddings/tool.rs b/rig-core/src/embeddings/tool.rs new file mode 100644 index 00000000..a8441a23 --- /dev/null +++ b/rig-core/src/embeddings/tool.rs @@ -0,0 +1,96 @@ +//! The module defines the [ToolSchema] struct, which is used to embed an object that implements [crate::tool::ToolEmbedding] + +use crate::{tool::ToolEmbeddingDyn, Embed}; +use serde::Serialize; + +use super::embed::EmbedError; + +/// Embeddable document that is used as an intermediate representation of a tool when +/// RAGging tools. +#[derive(Clone, Serialize, Default, Eq, PartialEq)] +pub struct ToolSchema { + pub name: String, + pub context: serde_json::Value, + pub embedding_docs: Vec, +} + +impl Embed for ToolSchema { + fn embed(&self, embedder: &mut super::embed::TextEmbedder) -> Result<(), EmbedError> { + for doc in &self.embedding_docs { + embedder.embed(doc.clone()); + } + Ok(()) + } +} + +impl ToolSchema { + /// Convert item that implements [ToolEmbeddingDyn] to an [ToolSchema]. + /// + /// # Example + /// ```rust + /// use rig::{ + /// completion::ToolDefinition, + /// embeddings::ToolSchema, + /// tool::{Tool, ToolEmbedding, ToolEmbeddingDyn}, + /// }; + /// use serde_json::json; + /// + /// #[derive(Debug, thiserror::Error)] + /// #[error("Math error")] + /// struct NothingError; + /// + /// #[derive(Debug, thiserror::Error)] + /// #[error("Init error")] + /// struct InitError; + /// + /// struct Nothing; + /// impl Tool for Nothing { + /// const NAME: &'static str = "nothing"; + /// + /// type Error = NothingError; + /// type Args = (); + /// type Output = (); + /// + /// async fn definition(&self, _prompt: String) -> ToolDefinition { + /// serde_json::from_value(json!({ + /// "name": "nothing", + /// "description": "nothing", + /// "parameters": {} + /// })) + /// .expect("Tool Definition") + /// } + /// + /// async fn call(&self, args: Self::Args) -> Result { + /// Ok(()) + /// } + /// } + /// + /// impl ToolEmbedding for Nothing { + /// type InitError = InitError; + /// type Context = (); + /// type State = (); + /// + /// fn init(_state: Self::State, _context: Self::Context) -> Result { + /// Ok(Nothing) + /// } + /// + /// fn embedding_docs(&self) -> Vec { + /// vec!["Do nothing.".into()] + /// } + /// + /// fn context(&self) -> Self::Context {} + /// } + /// + /// let tool = ToolSchema::try_from(&Nothing).unwrap(); + /// + /// assert_eq!(tool.name, "nothing".to_string()); + /// assert_eq!(tool.embedding_docs, vec!["Do nothing.".to_string()]); + /// ``` + pub fn try_from(tool: &dyn ToolEmbeddingDyn) -> Result { + Ok(ToolSchema { + name: tool.name(), + context: tool.context().map_err(EmbedError::new)?, + embedding_docs: tool.embedding_docs(), + }) + } +} diff --git a/rig-core/src/lib.rs b/rig-core/src/lib.rs index 9da3abfc..f1b5427b 100644 --- a/rig-core/src/lib.rs +++ b/rig-core/src/lib.rs @@ -50,9 +50,9 @@ //! RAG systems that can be used to answer questions using a knowledge base. //! //! ## Vector stores and indexes -//! Rig defines a common interface for working with vector stores and indexes. Specifically, the library -//! provides the [VectorStore](crate::vector_store::VectorStore) and [VectorStoreIndex](crate::vector_store::VectorStoreIndex) -//! traits, which can be implemented on a given type to define vector stores and indices respectively. +//! Rig provides a common interface for working with vector stores and indexes. Specifically, the library +//! provides the [VectorStoreIndex](crate::vector_store::VectorStoreIndex) +//! trait, which can be implemented to define vector stores and indices respectively. //! Those can then be used as the knowledge base for a RAG enabled [Agent](crate::agent::Agent), or //! as a source of context documents in a custom architecture that use multiple LLMs or agents. //! @@ -85,6 +85,14 @@ pub mod embeddings; pub mod extractor; pub(crate) mod json_utils; pub mod loaders; +pub mod one_or_many; pub mod providers; pub mod tool; pub mod vector_store; + +// Re-export commonly used types and traits +pub use embeddings::Embed; +pub use one_or_many::{EmptyListError, OneOrMany}; + +#[cfg(feature = "derive")] +pub use rig_derive::Embed; diff --git a/rig-core/src/loaders/file.rs b/rig-core/src/loaders/file.rs index 17c2f1f3..84eb1864 100644 --- a/rig-core/src/loaders/file.rs +++ b/rig-core/src/loaders/file.rs @@ -162,7 +162,7 @@ impl<'a, T: 'a> FileLoader<'a, Result> { } } -impl<'a> FileLoader<'a, Result> { +impl FileLoader<'_, Result> { /// Creates a new [FileLoader] using a glob pattern to match files. /// /// # Example @@ -227,7 +227,7 @@ impl<'a, T> IntoIterator for FileLoader<'a, T> { } } -impl<'a, T> Iterator for IntoIter<'a, T> { +impl Iterator for IntoIter<'_, T> { type Item = T; fn next(&mut self) -> Option { diff --git a/rig-core/src/loaders/pdf.rs b/rig-core/src/loaders/pdf.rs index ea18e4e6..410643f8 100644 --- a/rig-core/src/loaders/pdf.rs +++ b/rig-core/src/loaders/pdf.rs @@ -335,7 +335,7 @@ impl<'a, T: 'a> PdfFileLoader<'a, Result> { } } -impl<'a> PdfFileLoader<'a, Result> { +impl PdfFileLoader<'_, Result> { /// Creates a new [PdfFileLoader] using a glob pattern to match files. /// /// # Example @@ -396,7 +396,7 @@ impl<'a, T> IntoIterator for PdfFileLoader<'a, T> { } } -impl<'a, T> Iterator for IntoIter<'a, T> { +impl Iterator for IntoIter<'_, T> { type Item = T; fn next(&mut self) -> Option { diff --git a/rig-core/src/one_or_many.rs b/rig-core/src/one_or_many.rs new file mode 100644 index 00000000..64584603 --- /dev/null +++ b/rig-core/src/one_or_many.rs @@ -0,0 +1,302 @@ +/// Struct containing either a single item or a list of items of type T. +/// If a single item is present, `first` will contain it and `rest` will be empty. +/// If multiple items are present, `first` will contain the first item and `rest` will contain the rest. +/// IMPORTANT: this struct cannot be created with an empty vector. +/// OneOrMany objects can only be created using OneOrMany::from() or OneOrMany::try_from(). +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct OneOrMany { + /// First item in the list. + first: T, + /// Rest of the items in the list. + rest: Vec, +} + +/// Error type for when trying to create a OneOrMany object with an empty vector. +#[derive(Debug, thiserror::Error)] +#[error("Cannot create OneOrMany with an empty vector.")] +pub struct EmptyListError; + +impl OneOrMany { + /// Get the first item in the list. + pub fn first(&self) -> T { + self.first.clone() + } + + /// Get the rest of the items in the list (excluding the first one). + pub fn rest(&self) -> Vec { + self.rest.clone() + } + + /// After `OneOrMany` is created, add an item of type T to the `rest`. + pub fn push(&mut self, item: T) { + self.rest.push(item); + } + + /// Length of all items in `OneOrMany`. + pub fn len(&self) -> usize { + 1 + self.rest.len() + } + + /// If `OneOrMany` is empty. This will always be false because you cannot create an empty `OneOrMany`. + /// This method is required when the method `len` exists. + pub fn is_empty(&self) -> bool { + false + } + + /// Create a OneOrMany object with a single item of any type. + pub fn one(item: T) -> Self { + OneOrMany { + first: item, + rest: vec![], + } + } + + /// Create a OneOrMany object with a vector of items of any type. + pub fn many(items: Vec) -> Result { + let mut iter = items.into_iter(); + Ok(OneOrMany { + first: match iter.next() { + Some(item) => item, + None => return Err(EmptyListError), + }, + rest: iter.collect(), + }) + } + + /// Merge a list of OneOrMany items into a single OneOrMany item. + pub fn merge(one_or_many_items: Vec>) -> Result { + let items = one_or_many_items + .into_iter() + .flat_map(|one_or_many| one_or_many.into_iter()) + .collect::>(); + + OneOrMany::many(items) + } + + pub fn iter(&self) -> Iter { + Iter { + first: Some(&self.first), + rest: self.rest.iter(), + } + } + + pub fn iter_mut(&mut self) -> IterMut<'_, T> { + IterMut { + first: Some(&mut self.first), + rest: self.rest.iter_mut(), + } + } +} + +// ================================================================ +// Implementations of Iterator for OneOrMany +// - OneOrMany::iter() -> iterate over references of T objects +// - OneOrMany::into_iter() -> iterate over owned T objects +// - OneOrMany::iter_mut() -> iterate over mutable references of T objects +// ================================================================ + +/// Struct returned by call to `OneOrMany::iter()`. +pub struct Iter<'a, T> { + // References. + first: Option<&'a T>, + rest: std::slice::Iter<'a, T>, +} + +/// Implement `Iterator` for `Iter`. +/// The Item type of the `Iterator` trait is a reference of `T`. +impl<'a, T> Iterator for Iter<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } + } +} + +/// Struct returned by call to `OneOrMany::into_iter()`. +pub struct IntoIter { + // Owned. + first: Option, + rest: std::vec::IntoIter, +} + +/// Implement `Iterator` for `IntoIter`. +impl IntoIterator for OneOrMany { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter { + first: Some(self.first), + rest: self.rest.into_iter(), + } + } +} + +/// Implement `Iterator` for `IntoIter`. +/// The Item type of the `Iterator` trait is an owned `T`. +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } + } +} + +/// Struct returned by call to `OneOrMany::iter_mut()`. +pub struct IterMut<'a, T> { + // Mutable references. + first: Option<&'a mut T>, + rest: std::slice::IterMut<'a, T>, +} + +// Implement `Iterator` for `IterMut`. +// The Item type of the `Iterator` trait is a mutable reference of `OneOrMany`. +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + + fn next(&mut self) -> Option { + if let Some(first) = self.first.take() { + Some(first) + } else { + self.rest.next() + } + } +} + +#[cfg(test)] +mod test { + use super::OneOrMany; + + #[test] + fn test_single() { + let one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.iter().count(), 1); + + one_or_many.iter().for_each(|i| { + assert_eq!(i, "hello"); + }); + } + + #[test] + fn test() { + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.iter().count(), 2); + + one_or_many.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + }); + } + + #[test] + fn test_one_or_many_into_iter_single() { + let one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.clone().into_iter().count(), 1); + + one_or_many.into_iter().for_each(|i| { + assert_eq!(i, "hello".to_string()); + }); + } + + #[test] + fn test_one_or_many_into_iter() { + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.clone().into_iter().count(), 2); + + one_or_many.into_iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello".to_string()); + } + if i == 1 { + assert_eq!(item, "word".to_string()); + } + }); + } + + #[test] + fn test_one_or_many_merge() { + let one_or_many_1 = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + let one_or_many_2 = OneOrMany::one("sup".to_string()); + + let merged = OneOrMany::merge(vec![one_or_many_1, one_or_many_2]).unwrap(); + + assert_eq!(merged.iter().count(), 3); + + merged.iter().enumerate().for_each(|(i, item)| { + if i == 0 { + assert_eq!(item, "hello"); + } + if i == 1 { + assert_eq!(item, "word"); + } + if i == 2 { + assert_eq!(item, "sup"); + } + }); + } + + #[test] + fn test_mut_single() { + let mut one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.iter_mut().count(), 1); + + one_or_many.iter_mut().for_each(|i| { + assert_eq!(i, "hello"); + }); + } + + #[test] + fn test_mut() { + let mut one_or_many = + OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.iter_mut().count(), 2); + + one_or_many.iter_mut().enumerate().for_each(|(i, item)| { + if i == 0 { + item.push_str(" world"); + assert_eq!(item, "hello world"); + } + if i == 1 { + assert_eq!(item, "word"); + } + }); + } + + #[test] + fn test_one_or_many_error() { + assert!(OneOrMany::::many(vec![]).is_err()) + } + + #[test] + fn test_len_single() { + let one_or_many = OneOrMany::one("hello".to_string()); + + assert_eq!(one_or_many.len(), 1); + } + + #[test] + fn test_len_many() { + let one_or_many = OneOrMany::many(vec!["hello".to_string(), "word".to_string()]).unwrap(); + + assert_eq!(one_or_many.len(), 2); + } +} diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index baa6f1d5..883204d4 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -15,7 +15,7 @@ use crate::{ completion::{self, CompletionError}, embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, + json_utils, Embed, }; use schemars::JsonSchema; @@ -92,7 +92,11 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings(&self, model: &str, input_type: &str) -> EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &str, + input_type: &str, + ) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } @@ -207,7 +211,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { self.ndims } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator, ) -> Result, EmbeddingError> { @@ -238,11 +242,14 @@ impl embeddings::EmbeddingModel for EmbeddingModel { }; if response.embeddings.len() != documents.len() { - return Err(EmbeddingError::DocumentError(format!( - "Expected {} embeddings, got {}", - documents.len(), - response.embeddings.len() - ))); + return Err(EmbeddingError::DocumentError( + format!( + "Expected {} embeddings, got {}", + documents.len(), + response.embeddings.len() + ) + .into(), + )); } Ok(response diff --git a/rig-core/src/providers/gemini/client.rs b/rig-core/src/providers/gemini/client.rs index c22e6996..04316dfe 100644 --- a/rig-core/src/providers/gemini/client.rs +++ b/rig-core/src/providers/gemini/client.rs @@ -2,6 +2,7 @@ use crate::{ agent::AgentBuilder, embeddings::{self}, extractor::ExtractorBuilder, + Embed, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -104,7 +105,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &str, + ) -> embeddings::EmbeddingsBuilder { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/providers/gemini/embedding.rs b/rig-core/src/providers/gemini/embedding.rs index 1249387a..c2b76e02 100644 --- a/rig-core/src/providers/gemini/embedding.rs +++ b/rig-core/src/providers/gemini/embedding.rs @@ -41,7 +41,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator + Send, ) -> Result, EmbeddingError> { diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 33aa7b79..6b5a3079 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -11,9 +11,9 @@ use crate::{ agent::AgentBuilder, completion::{self, CompletionError, CompletionRequest}, - embeddings::{self, EmbeddingError}, + embeddings::{self, EmbeddingError, EmbeddingsBuilder}, extractor::ExtractorBuilder, - json_utils, + json_utils, Embed, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -121,8 +121,8 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder { - embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) + pub fn embeddings(&self, model: &str) -> EmbeddingsBuilder { + EmbeddingsBuilder::new(self.embedding_model(model)) } /// Create a completion model with the given name. @@ -219,7 +219,7 @@ pub struct EmbeddingData { pub index: usize, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize)] pub struct Usage { pub prompt_tokens: usize, pub total_tokens: usize, @@ -229,7 +229,7 @@ impl std::fmt::Display for Usage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Prompt tokens: {}\nTotal tokens: {}", + "Prompt tokens: {} Total tokens: {}", self.prompt_tokens, self.total_tokens ) } @@ -249,7 +249,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { self.ndims } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator, ) -> Result, EmbeddingError> { @@ -535,7 +535,7 @@ impl completion::CompletionModel for CompletionModel { ApiResponse::Ok(response) => { tracing::info!(target: "rig", "OpenAI completion token usage: {:?}", - response.usage + response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string()) ); response.try_into() } diff --git a/rig-core/src/providers/perplexity.rs b/rig-core/src/providers/perplexity.rs index 41be57f7..fa1e34fb 100644 --- a/rig-core/src/providers/perplexity.rs +++ b/rig-core/src/providers/perplexity.rs @@ -155,7 +155,7 @@ impl std::fmt::Display for Usage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Prompt tokens: {}\nCompletion tokens: {}\nTotal tokens: {}", + "Prompt tokens: {}\nCompletion tokens: {} Total tokens: {}", self.prompt_tokens, self.completion_tokens, self.total_tokens ) } diff --git a/rig-core/src/providers/xai/client.rs b/rig-core/src/providers/xai/client.rs index e03c6978..6af7cd31 100644 --- a/rig-core/src/providers/xai/client.rs +++ b/rig-core/src/providers/xai/client.rs @@ -2,6 +2,7 @@ use crate::{ agent::AgentBuilder, embeddings::{self}, extractor::ExtractorBuilder, + Embed, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -113,7 +114,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &str, + ) -> embeddings::EmbeddingsBuilder { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } diff --git a/rig-core/src/providers/xai/embedding.rs b/rig-core/src/providers/xai/embedding.rs index 1c588071..f0ad1c92 100644 --- a/rig-core/src/providers/xai/embedding.rs +++ b/rig-core/src/providers/xai/embedding.rs @@ -69,7 +69,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { self.ndims } - async fn embed_documents( + async fn embed_texts( &self, documents: impl IntoIterator, ) -> Result, EmbeddingError> { diff --git a/rig-core/src/tool.rs b/rig-core/src/tool.rs index 83a27640..e540c918 100644 --- a/rig-core/src/tool.rs +++ b/rig-core/src/tool.rs @@ -14,7 +14,10 @@ use std::{collections::HashMap, pin::Pin}; use futures::Future; use serde::{Deserialize, Serialize}; -use crate::completion::{self, ToolDefinition}; +use crate::{ + completion::{self, ToolDefinition}, + embeddings::{embed::EmbedError, tool::ToolSchema}, +}; #[derive(Debug, thiserror::Error)] pub enum ToolError { @@ -334,6 +337,22 @@ impl ToolSet { } Ok(docs) } + + /// Convert tools in self to objects of type ToolSchema. + /// This is necessary because when adding tools to the EmbeddingBuilder because all + /// documents added to the builder must all be of the same type. + pub fn schemas(&self) -> Result, EmbedError> { + self.tools + .values() + .filter_map(|tool_type| { + if let ToolType::Embedding(tool) = tool_type { + Some(ToolSchema::try_from(&**tool)) + } else { + None + } + }) + .collect::, _>>() + } } #[derive(Default)] diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index eda13ba4..931946eb 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -7,28 +7,32 @@ use std::{ use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; -use super::{VectorStore, VectorStoreError, VectorStoreIndex}; -use crate::embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel, EmbeddingsBuilder}; +use super::{VectorStoreError, VectorStoreIndex}; +use crate::{ + embeddings::{Embedding, EmbeddingModel}, + OneOrMany, +}; /// InMemoryVectorStore is a simple in-memory vector store that stores embeddings /// in-memory using a HashMap. -#[derive(Clone, Default, Deserialize, Serialize)] -pub struct InMemoryVectorStore { - /// The embeddings are stored in a HashMap with the document ID as the key. - embeddings: HashMap, +#[derive(Clone, Default)] +pub struct InMemoryVectorStore { + /// The embeddings are stored in a HashMap. + /// Hashmap key is the document id. + /// Hashmap value is a tuple of the serializable document and its corresponding embeddings. + embeddings: HashMap)>, } -impl InMemoryVectorStore { +impl InMemoryVectorStore { /// Implement vector search on InMemoryVectorStore. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for InMemoryVectorStore. - fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { + fn vector_search(&self, prompt_embedding: &Embedding, n: usize) -> EmbeddingRanking { // Sort documents by best embedding distance - let mut docs: EmbeddingRanking = BinaryHeap::new(); + let mut docs = BinaryHeap::new(); - for (id, doc_embeddings) in self.embeddings.iter() { + for (id, (doc, embeddings)) in self.embeddings.iter() { // Get the best context for the document given the prompt - if let Some((distance, embed_doc)) = doc_embeddings - .embeddings + if let Some((distance, embed_doc)) = embeddings .iter() .map(|embedding| { ( @@ -38,12 +42,7 @@ impl InMemoryVectorStore { }) .max_by(|a, b| a.0.cmp(&b.0)) { - docs.push(Reverse(RankingItem( - distance, - id, - doc_embeddings, - embed_doc, - ))); + docs.push(Reverse(RankingItem(distance, id, doc, embed_doc))); }; // If the heap size exceeds n, pop the least old element. @@ -63,77 +62,72 @@ impl InMemoryVectorStore { docs } -} -/// RankingItem(distance, document_id, document, embed_doc) -#[derive(Eq, PartialEq)] -struct RankingItem<'a>( - OrderedFloat, - &'a String, - &'a DocumentEmbeddings, - &'a String, -); - -impl Ord for RankingItem<'_> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.0.cmp(&other.0) - } -} + /// Add documents to the store. + /// Returns the store with the added documents. + pub fn add_documents( + mut self, + documents: Vec<(String, D, OneOrMany)>, + ) -> Result { + for (id, doc, embeddings) in documents { + self.embeddings.insert(id, (doc, embeddings)); + } -impl PartialOrd for RankingItem<'_> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) + Ok(self) } -} -type EmbeddingRanking<'a> = BinaryHeap>>; - -impl VectorStore for InMemoryVectorStore { - type Q = (); - - async fn add_documents( - &mut self, - documents: Vec, - ) -> Result<(), VectorStoreError> { - for doc in documents { - self.embeddings.insert(doc.id.clone(), doc); + /// Add documents to the store. Define a function that takes as input the reference of the document and returns its id. + /// Returns the store with the added documents. + pub fn add_documents_with_id( + mut self, + documents: Vec<(D, OneOrMany)>, + id_f: fn(&D) -> String, + ) -> Result { + for (doc, embeddings) in documents { + let id = id_f(&doc); + self.embeddings.insert(id, (doc, embeddings)); } - Ok(()) + Ok(self) } - async fn get_document Deserialize<'a>>( + /// Get the document by its id and deserialize it into the given type. + pub fn get_document Deserialize<'a>>( &self, id: &str, ) -> Result, VectorStoreError> { Ok(self .embeddings .get(id) - .map(|document| serde_json::from_value(document.document.clone())) + .map(|(doc, _)| serde_json::from_str(&serde_json::to_string(doc)?)) .transpose()?) } +} - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result, VectorStoreError> { - Ok(self.embeddings.get(id).cloned()) +/// RankingItem(distance, document_id, serializable document, embeddings document) +#[derive(Eq, PartialEq)] +struct RankingItem<'a, D: Serialize>(OrderedFloat, &'a String, &'a D, &'a String); + +impl Ord for RankingItem<'_, D> { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) } +} - async fn get_document_by_query( - &self, - _query: Self::Q, - ) -> Result, VectorStoreError> { - Ok(None) +impl PartialOrd for RankingItem<'_, D> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } } -impl InMemoryVectorStore { - pub fn index(self, model: M) -> InMemoryVectorIndex { +type EmbeddingRanking<'a, D> = BinaryHeap>>; + +impl InMemoryVectorStore { + pub fn index(self, model: M) -> InMemoryVectorIndex { InMemoryVectorIndex::new(model, self) } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator))> { self.embeddings.iter() } @@ -144,54 +138,19 @@ impl InMemoryVectorStore { pub fn is_empty(&self) -> bool { self.embeddings.is_empty() } - - /// Uitilty method to create an InMemoryVectorStore from a list of embeddings. - pub async fn from_embeddings( - embeddings: Vec, - ) -> Result { - let mut store = Self::default(); - store.add_documents(embeddings).await?; - Ok(store) - } - - /// Create an InMemoryVectorStore from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - pub async fn from_documents( - embedding_model: M, - documents: &[(String, T)], - ) -> Result { - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - let store = Self::from_embeddings(embeddings).await?; - Ok(store) - } } -pub struct InMemoryVectorIndex { +pub struct InMemoryVectorIndex { model: M, - pub store: InMemoryVectorStore, + pub store: InMemoryVectorStore, } -impl InMemoryVectorIndex { - pub fn new(model: M, store: InMemoryVectorStore) -> Self { +impl InMemoryVectorIndex { + pub fn new(model: M, store: InMemoryVectorStore) -> Self { Self { model, store } } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator))> { self.store.iter() } @@ -202,66 +161,30 @@ impl InMemoryVectorIndex { pub fn is_empty(&self) -> bool { self.store.is_empty() } - - /// Create an InMemoryVectorIndex from a list of documents. - /// The documents are serialized to JSON and embedded using the provided embedding model. - /// The resulting embeddings are stored in an InMemoryVectorStore created by the method. - /// The InMemoryVectorIndex is then created from the store and the provided query model. - pub async fn from_documents( - embedding_model: M, - query_model: M, - documents: &[(String, T)], - ) -> Result { - let mut store = InMemoryVectorStore::default(); - - let embeddings = documents - .iter() - .fold( - EmbeddingsBuilder::new(embedding_model), - |builder, (id, doc)| { - builder.json_document( - id, - serde_json::to_value(doc).expect("Document should be serializable"), - vec![serde_json::to_string(doc).expect("Document should be serializable")], - ) - }, - ) - .build() - .await?; - - store.add_documents(embeddings).await?; - Ok(store.index(query_model)) - } - - /// Utility method to create an InMemoryVectorIndex from a list of embeddings - /// and an embedding model. - pub async fn from_embeddings( - query_model: M, - embeddings: Vec, - ) -> Result { - let store = InMemoryVectorStore::from_embeddings(embeddings).await?; - Ok(store.index(query_model)) - } } -impl VectorStoreIndex for InMemoryVectorIndex { +impl VectorStoreIndex + for InMemoryVectorIndex +{ async fn top_n Deserialize<'a>>( &self, query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = &self.model.embed_document(query).await?; + let prompt_embedding = &self.model.embed_text(query).await?; let docs = self.store.vector_search(prompt_embedding, n); // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| { + .map(|Reverse(RankingItem(distance, id, doc, _))| { Ok(( distance.0, - doc.id.clone(), - serde_json::from_value(doc.document.clone()) - .map_err(VectorStoreError::JsonError)?, + id.clone(), + serde_json::from_str( + &serde_json::to_string(doc).map_err(VectorStoreError::JsonError)?, + ) + .map_err(VectorStoreError::JsonError)?, )) }) .collect::, _>>() @@ -272,13 +195,159 @@ impl VectorStoreIndex for InMemoryVectorI query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = &self.model.embed_document(query).await?; + let prompt_embedding = &self.model.embed_text(query).await?; let docs = self.store.vector_search(prompt_embedding, n); // Return n best docs.into_iter() - .map(|Reverse(RankingItem(distance, _, doc, _))| Ok((distance.0, doc.id.clone()))) + .map(|Reverse(RankingItem(distance, id, _, _))| Ok((distance.0, id.clone()))) .collect::, _>>() } } + +#[cfg(test)] +mod tests { + use std::cmp::Reverse; + + use crate::{embeddings::embedding::Embedding, OneOrMany}; + + use super::{InMemoryVectorStore, RankingItem}; + + #[test] + fn test_single_embedding() { + let index = InMemoryVectorStore::default() + .add_documents(vec![ + ( + "doc1".to_string(), + "glarb-garb", + OneOrMany::one(Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }), + ), + ( + "doc2".to_string(), + "marble-marble", + OneOrMany::one(Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }), + ), + ( + "doc3".to_string(), + "flumb-flumb", + OneOrMany::one(Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }), + ), + ]) + .unwrap(); + + let ranking = index.vector_search( + &Embedding { + document: "glarby-glarble".to_string(), + vec: vec![0.0, 0.1, 0.6], + }, + 1, + ); + + assert_eq!( + ranking + .into_iter() + .map(|Reverse(RankingItem(distance, id, doc, _))| { + ( + distance.0, + id.clone(), + serde_json::from_str(&serde_json::to_string(doc).unwrap()).unwrap(), + ) + }) + .collect::>(), + vec![( + 0.034444444444444444, + "doc1".to_string(), + "glarb-garb".to_string() + )] + ) + } + + #[test] + fn test_multiple_embeddings() { + let index = InMemoryVectorStore::default() + .add_documents(vec![ + ( + "doc1".to_string(), + "glarb-garb", + OneOrMany::many(vec![ + Embedding { + document: "glarb-garb".to_string(), + vec: vec![0.1, 0.1, 0.5], + }, + Embedding { + document: "don't-choose-me".to_string(), + vec: vec![-0.5, 0.9, 0.1], + }, + ]) + .unwrap(), + ), + ( + "doc2".to_string(), + "marble-marble", + OneOrMany::many(vec![ + Embedding { + document: "marble-marble".to_string(), + vec: vec![0.7, -0.3, 0.0], + }, + Embedding { + document: "sandwich".to_string(), + vec: vec![0.5, 0.5, -0.7], + }, + ]) + .unwrap(), + ), + ( + "doc3".to_string(), + "flumb-flumb", + OneOrMany::many(vec![ + Embedding { + document: "flumb-flumb".to_string(), + vec: vec![0.3, 0.7, 0.1], + }, + Embedding { + document: "banana".to_string(), + vec: vec![0.1, -0.5, -0.5], + }, + ]) + .unwrap(), + ), + ]) + .unwrap(); + + let ranking = index.vector_search( + &Embedding { + document: "glarby-glarble".to_string(), + vec: vec![0.0, 0.1, 0.6], + }, + 1, + ); + + assert_eq!( + ranking + .into_iter() + .map(|Reverse(RankingItem(distance, id, doc, _))| { + ( + distance.0, + id.clone(), + serde_json::from_str(&serde_json::to_string(doc).unwrap()).unwrap(), + ) + }) + .collect::>(), + vec![( + 0.034444444444444444, + "doc1".to_string(), + "glarb-garb".to_string() + )] + ) + } +} diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 405ef83b..3d6e8369 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -2,7 +2,7 @@ use futures::future::BoxFuture; use serde::Deserialize; use serde_json::Value; -use crate::embeddings::{DocumentEmbeddings, EmbeddingError}; +use crate::embeddings::EmbeddingError; pub mod in_memory_store; @@ -16,44 +16,17 @@ pub enum VectorStoreError { JsonError(#[from] serde_json::Error), #[error("Datastore error: {0}")] - DatastoreError(#[from] Box), -} - -/// Trait for vector stores -pub trait VectorStore: Send + Sync { - /// Query type for the vector store - type Q; - - /// Add a list of documents to the vector store - fn add_documents( - &mut self, - documents: Vec, - ) -> impl std::future::Future> + Send; - - /// Get the embeddings of a document by its id - fn get_document_embeddings( - &self, - id: &str, - ) -> impl std::future::Future, VectorStoreError>> + Send; - - /// Get the document by its id and deserialize it into the given type - fn get_document Deserialize<'a>>( - &self, - id: &str, - ) -> impl std::future::Future, VectorStoreError>> + Send; + DatastoreError(#[from] Box), - /// Get the document by a query and deserialize it into the given type - fn get_document_by_query( - &self, - query: Self::Q, - ) -> impl std::future::Future, VectorStoreError>> + Send; + #[error("Missing Id: {0}")] + MissingIdError(String), } /// Trait for vector store indexes pub trait VectorStoreIndex: Send + Sync { /// Get the top n documents based on the distance to the given query. /// The result is a list of tuples of the form (score, id, document) - fn top_n Deserialize<'a> + std::marker::Send>( + fn top_n Deserialize<'a> + Send>( &self, query: &str, n: usize, diff --git a/rig-core/tests/embed_macro.rs b/rig-core/tests/embed_macro.rs new file mode 100644 index 00000000..daf6dcf8 --- /dev/null +++ b/rig-core/tests/embed_macro.rs @@ -0,0 +1,217 @@ +use rig::{ + embeddings::{self, embed::EmbedError, TextEmbedder}, + Embed, +}; +use serde::Serialize; + +#[test] +fn test_custom_embed() { + #[derive(Embed)] + struct WordDefinition { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + word: String, + #[embed(embed_with = "custom_embedding_function")] + definition: Definition, + } + + #[derive(Serialize, Clone)] + struct Definition { + word: String, + link: String, + speech: String, + } + + fn custom_embedding_function( + embedder: &mut TextEmbedder, + definition: Definition, + ) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); + + Ok(()) + } + + let definition = WordDefinition { + id: "doc1".to_string(), + word: "house".to_string(), + definition: Definition { + speech: "noun".to_string(), + word: "a building in which people live; residence for human beings.".to_string(), + link: "https://www.dictionary.com/browse/house".to_string(), + }, + }; + + assert_eq!( + embeddings::to_texts(definition).unwrap(), + vec!["{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string()] + ) +} + +#[test] +fn test_custom_and_basic_embed() { + #[derive(Embed)] + struct WordDefinition { + #[allow(dead_code)] + id: String, + #[embed] + word: String, + #[embed(embed_with = "custom_embedding_function")] + definition: Definition, + } + + #[derive(Serialize, Clone)] + struct Definition { + word: String, + link: String, + speech: String, + } + + fn custom_embedding_function( + embedder: &mut TextEmbedder, + definition: Definition, + ) -> Result<(), EmbedError> { + embedder.embed(serde_json::to_string(&definition).map_err(EmbedError::new)?); + + Ok(()) + } + + let definition = WordDefinition { + id: "doc1".to_string(), + word: "house".to_string(), + definition: Definition { + speech: "noun".to_string(), + word: "a building in which people live; residence for human beings.".to_string(), + link: "https://www.dictionary.com/browse/house".to_string(), + }, + }; + + let texts = embeddings::to_texts(definition).unwrap(); + + assert_eq!( + texts, + vec![ + "house".to_string(), + "{\"word\":\"a building in which people live; residence for human beings.\",\"link\":\"https://www.dictionary.com/browse/house\",\"speech\":\"noun\"}".to_string() + ] + ); +} + +#[test] +fn test_single_embed() { + #[derive(Embed)] + struct WordDefinition { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + word: String, + #[embed] + definition: String, + } + + let definition = "a building in which people live; residence for human beings.".to_string(); + + let word_definition = WordDefinition { + id: "doc1".to_string(), + word: "house".to_string(), + definition: definition.clone(), + }; + + assert_eq!( + embeddings::to_texts(word_definition).unwrap(), + vec![definition] + ) +} + +#[test] +fn test_embed_vec_non_string() { + #[derive(Embed)] + struct Company { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + company: String, + #[embed] + employee_ages: Vec, + } + + let company = Company { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + assert_eq!( + embeddings::to_texts(company).unwrap(), + vec![ + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ] + ); +} + +#[test] +fn test_embed_vec_string() { + #[derive(Embed)] + struct Company { + #[allow(dead_code)] + id: String, + #[allow(dead_code)] + company: String, + #[embed] + employee_names: Vec, + } + + let company = Company { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_names: vec![ + "Alice".to_string(), + "Bob".to_string(), + "Charlie".to_string(), + "David".to_string(), + ], + }; + + assert_eq!( + embeddings::to_texts(company).unwrap(), + vec![ + "Alice".to_string(), + "Bob".to_string(), + "Charlie".to_string(), + "David".to_string() + ] + ); +} + +#[test] +fn test_multiple_embed_tags() { + #[derive(Embed)] + struct Company { + #[allow(dead_code)] + id: String, + #[embed] + company: String, + #[embed] + employee_ages: Vec, + } + + let company = Company { + id: "doc1".to_string(), + company: "Google".to_string(), + employee_ages: vec![25, 30, 35, 40], + }; + + assert_eq!( + embeddings::to_texts(company).unwrap(), + vec![ + "Google".to_string(), + "25".to_string(), + "30".to_string(), + "35".to_string(), + "40".to_string() + ] + ); +} diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 0eb5fdbe..12d51b08 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -18,3 +18,15 @@ futures = "0.3.30" [dev-dependencies] tokio = "1.40.0" anyhow = "1.0.89" + +[[example]] +name = "vector_search_local_ann" +required-features = ["rig-core/derive"] + +[[example]] +name = "vector_search_local_enn" +required-features = ["rig-core/derive"] + +[[example]] +name = "vector_search_s3_ann" +required-features = ["rig-core/derive"] diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs index 9a91432e..94704822 100644 --- a/rig-lancedb/examples/fixtures/lib.rs +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -2,13 +2,39 @@ use std::sync::Arc; use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; -use rig::embeddings::DocumentEmbeddings; +use rig::embeddings::Embedding; +use rig::{Embed, OneOrMany}; +use serde::Deserialize; + +#[derive(Embed, Clone, Deserialize, Debug)] +pub struct WordDefinition { + pub id: String, + #[embed] + pub definition: String, +} + +pub fn word_definitions() -> Vec { + vec![ + WordDefinition { + id: "doc0".to_string(), + definition: "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.".to_string() + }, + WordDefinition { + id: "doc1".to_string(), + definition: "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive.".to_string() + }, + WordDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string() + } + ] +} // Schema of table in LanceDB. pub fn schema(dims: usize) -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), - Field::new("content", DataType::Utf8, false), + Field::new("definition", DataType::Utf8, false), Field::new( "embedding", DataType::FixedSizeList( @@ -20,40 +46,37 @@ pub fn schema(dims: usize) -> Schema { ])) } -// Convert DocumentEmbeddings objects to a RecordBatch. +// Convert WordDefinition objects and their embedding to a RecordBatch. pub fn as_record_batch( - records: Vec, + records: Vec<(WordDefinition, OneOrMany)>, dims: usize, ) -> Result { let id = StringArray::from_iter_values( records .iter() - .flat_map(|record| (0..record.embeddings.len()).map(|i| format!("{}-{i}", record.id))) + .map(|(WordDefinition { id, .. }, _)| id) .collect::>(), ); - let content = StringArray::from_iter_values( + let definition = StringArray::from_iter_values( records .iter() - .flat_map(|record| { - record - .embeddings - .iter() - .map(|embedding| embedding.document.clone()) - }) + .map(|(WordDefinition { definition, .. }, _)| definition) .collect::>(), ); let embedding = FixedSizeListArray::from_iter_primitive::( records .into_iter() - .flat_map(|record| { - record - .embeddings - .into_iter() - .map(|embedding| embedding.vec.into_iter().map(Some).collect::>()) - .map(Some) - .collect::>() + .map(|(_, embeddings)| { + Some( + embeddings + .first() + .vec + .into_iter() + .map(Some) + .collect::>(), + ) }) .collect::>(), dims as i32, @@ -61,7 +84,7 @@ pub fn as_record_batch( RecordBatch::try_from_iter(vec![ ("id", Arc::new(id) as ArrayRef), - ("content", Arc::new(content) as ArrayRef), + ("definition", Arc::new(definition) as ArrayRef), ("embedding", Arc::new(embedding) as ArrayRef), ]) } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 3ecd6b23..03636089 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,25 +1,18 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, schema, word_definitions, WordDefinition}; use lancedb::index::vector::IvfPqIndexBuilder; -use rig::vector_store::VectorStoreIndex; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + vector_store::VectorStoreIndex, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; -use serde::Deserialize; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub content: String, -} - #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). @@ -32,27 +25,27 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - // Set up test data for RAG demo - let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - let definitions = vec![definition; 256]; - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") - .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .documents(word_definitions())? + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + .documents( + (0..256) + .map(|i| WordDefinition { + id: format!("doc{}", i), + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() + }) + )? .build() .await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; @@ -68,11 +61,11 @@ async fn main() -> Result<(), anyhow::Error> { // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default(); - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store_index = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index - let results = vector_store - .top_n::("My boss says I zindle too much, what does that mean?", 1) + let results = vector_store_index + .top_n::("My boss says I zindle too much, what does that mean?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 5932dcd0..0244d33e 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,13 +1,13 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, schema, word_definitions}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndexDyn, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; @@ -21,10 +21,9 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .documents(word_definitions())? .build() .await?; @@ -34,17 +33,18 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index let results = vector_store diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 70f0c8c5..f296d1d7 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,25 +1,18 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; -use fixture::{as_record_batch, schema}; +use fixture::{as_record_batch, schema, word_definitions, WordDefinition}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, }; -use rig_lancedb::{LanceDbVectorStore, SearchParams}; -use serde::Deserialize; +use rig_lancedb::{LanceDbVectorIndex, SearchParams}; #[path = "./fixtures/lib.rs"] mod fixture; -#[derive(Deserialize, Debug)] -pub struct VectorSearchResult { - pub id: String, - pub content: String, -} - // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. // https://lancedb.github.io/lancedb/guides/storage/ #[tokio::main] @@ -38,27 +31,27 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; - // Set up test data for RAG demo - let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - - // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. - let definitions = vec![definition; 256]; - // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") - .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") - .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") - .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .documents(word_definitions())? + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + .documents( + (0..256) + .map(|i| WordDefinition { + id: format!("doc{}", i), + definition: "Definition of *flumbuzzle (noun)*: A sudden, inexplicable urge to rearrange or reorganize small objects, such as desk items or books, for no apparent reason.".to_string() + }) + )? .build() .await?; - // Create table with embeddings. - let record_batch = as_record_batch(embeddings, model.ndims()); let table = db .create_table( "definitions", - RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + RecordBatchIterator::new( + vec![as_record_batch(embeddings, model.ndims())], + Arc::new(schema(model.ndims())), + ), ) .execute() .await?; @@ -80,11 +73,11 @@ async fn main() -> Result<(), anyhow::Error> { // Define search_params params that will be used by the vector store to perform the vector search. let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + let vector_store = LanceDbVectorIndex::new(table, model, "id", search_params).await?; // Query the index let results = vector_store - .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 13009bf9..7567bc60 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -3,7 +3,7 @@ use lancedb::{ DistanceType, }; use rig::{ - embeddings::EmbeddingModel, + embeddings::embedding::EmbeddingModel, vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::Deserialize; @@ -20,75 +20,17 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { VectorStoreError::JsonError(e) } +/// Type on which vector searches can be performed for a lanceDb table. /// # Example /// ``` -/// use std::{env, sync::Arc}; +/// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; +/// use rig::embeddings::EmbeddingModel; /// -/// use arrow_array::RecordBatchIterator; -/// use fixture::{as_record_batch, schema}; -/// use rig::{ -/// embeddings::{EmbeddingModel, EmbeddingsBuilder}, -/// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, -/// vector_store::VectorStoreIndexDyn, -/// }; -/// use rig_lancedb::{LanceDbVectorStore, SearchParams}; -/// use serde::Deserialize; -/// -/// #[derive(Deserialize, Debug)] -/// pub struct VectorSearchResult { -/// pub id: String, -/// pub content: String, -/// } -/// -/// // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). -/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); -/// let openai_client = Client::new(&openai_api_key); -/// -/// // Select the embedding model and generate our embeddings -/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); -/// -/// let embeddings = EmbeddingsBuilder::new(model.clone()) -/// .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") -/// .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") -/// .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") -/// .build() -/// .await?; -/// -/// // Define search_params params that will be used by the vector store to perform the vector search. -/// let search_params = SearchParams::default(); -/// -/// // Initialize LanceDB locally. -/// let db = lancedb::connect("data/lancedb-store").execute().await?; -/// -/// // Create table with embeddings. -/// let record_batch = as_record_batch(embeddings, model.ndims()); -/// let table = db -/// .create_table( -/// "definitions", -/// RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), -/// ) -/// .execute() -/// .await?; -/// -/// let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; -/// -/// // Query the index -/// let results = vector_store -/// .top_n("My boss says I zindle too much, what does that mean?", 1) -/// .await? -/// .into_iter() -/// .map(|(score, id, doc)| { -/// anyhow::Ok(( -/// score, -/// id, -/// serde_json::from_value::(doc)?, -/// )) -/// }) -/// .collect::, _>>()?; -/// -/// println!("Results: {:?}", results); +/// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. +/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. +/// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; /// ``` -pub struct LanceDbVectorStore { +pub struct LanceDbVectorIndex { /// Defines which model is used to generate embeddings for the vector store. model: M, /// LanceDB table containing embeddings. @@ -99,7 +41,24 @@ pub struct LanceDbVectorStore { search_params: SearchParams, } -impl LanceDbVectorStore { +impl LanceDbVectorIndex { + /// Create an instance of `LanceDbVectorIndex` with an existing table and model. + /// Define the id field name of the table. + /// Define search parameters that will be used to perform vector searches on the table. + pub async fn new( + table: lancedb::Table, + model: M, + id_field: &str, + search_params: SearchParams, + ) -> Result { + Ok(Self { + table, + model, + id_field: id_field.to_string(), + search_params, + }) + } + /// Apply the search_params to the vector query. /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait. fn build_query(&self, mut query: VectorQuery) -> VectorQuery { @@ -151,6 +110,10 @@ pub enum SearchType { } /// Parameters used to perform a vector search on a LanceDb table. +/// # Example +/// ``` +/// let search_params = SearchParams::default().distance_type(DistanceType::Cosine); +/// ``` #[derive(Debug, Clone, Default)] pub struct SearchParams { distance_type: Option, @@ -211,32 +174,28 @@ impl SearchParams { } } -impl LanceDbVectorStore { - /// Create an instance of `LanceDbVectorStore` with an existing table and model. - /// Define the id field name of the table. - /// Define search parameters that will be used to perform vector searches on the table. - pub async fn new( - table: lancedb::Table, - model: M, - id_field: &str, - search_params: SearchParams, - ) -> Result { - Ok(Self { - table, - model, - id_field: id_field.to_string(), - search_params, - }) - } -} - -impl VectorStoreIndex for LanceDbVectorStore { - async fn top_n Deserialize<'a> + std::marker::Send>( +impl VectorStoreIndex for LanceDbVectorIndex { + /// Implement the `top_n` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. + /// # Example + /// ``` + /// use rig_lancedb::{LanceDbVectorIndex, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// let table: lancedb::Table = db.create_table("fake_definitions"); // <-- Replace with your lancedb table here. + /// let model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// + /// // Query the index + /// let result = vector_store_index + /// .top_n::("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// ``` + async fn top_n Deserialize<'a> + Send>( &self, query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let query = self .table @@ -268,12 +227,24 @@ impl VectorStoreIndex for LanceDbV .collect() } + /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `LanceDbVectorIndex`. + /// # Example + /// ``` + /// let table: table: lancedb::Table = db.create_table(""); // <-- Replace with your lancedb table here. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- Replace with your embedding model here. + /// let vector_store_index = LanceDbVectorIndex::new(table, model, "id", SearchParams::default()).await?; + /// + /// // Query the index + /// let result = vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// ``` async fn top_n_ids( &self, query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let query = self .table diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index fd890280..7f9d1d7d 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -356,9 +356,9 @@ fn type_matcher(column: &Arc) -> Result, VectorStoreError> } } -/////////////////////////////////////////////////////////////////////////////////// -/// Everything below includes helpers for the recursive function `type_matcher`./// -/////////////////////////////////////////////////////////////////////////////////// +// ================================================================ +// Everything below includes helpers for the recursive function `type_matcher` +// ================================================================ /// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. trait DeserializePrimitiveArray { diff --git a/rig-mongodb/Cargo.toml b/rig-mongodb/Cargo.toml index edb755e0..f52f8630 100644 --- a/rig-mongodb/Cargo.toml +++ b/rig-mongodb/Cargo.toml @@ -20,3 +20,7 @@ tracing = "0.1.40" [dev-dependencies] anyhow = "1.0.86" tokio = { version = "1.38.0", features = ["macros"] } + +[[example]] +name = "vector_search_mongodb" +required-features = ["rig-core/derive"] \ No newline at end of file diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 5943ac3f..5d0ed81b 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,20 +1,36 @@ -use mongodb::bson; -use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; -use rig::vector_store::VectorStore; -use rig::{ - embeddings::EmbeddingsBuilder, - providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::VectorStoreIndex, -}; -use rig_mongodb::{MongoDbVectorStore, SearchParams}; +use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; +use rig::providers::openai::TEXT_EMBEDDING_ADA_002; use serde::{Deserialize, Serialize}; use std::env; -#[derive(Clone, Eq, PartialEq, Serialize, Deserialize, Debug)] -pub struct DocumentResponse { +use rig::{ + embeddings::EmbeddingsBuilder, providers::openai::Client, vector_store::VectorStoreIndex, Embed, +}; +use rig_mongodb::{MongoDbVectorIndex, SearchParams}; + +// Shape of data that needs to be RAG'ed. +// The definition field will be used to generate embeddings. +#[derive(Embed, Clone, Deserialize, Debug)] +struct WordDefinition { + #[serde(rename = "_id")] + id: String, + #[embed] + definition: String, +} + +#[derive(Clone, Deserialize, Debug, Serialize)] +struct Link { + word: String, + link: String, +} + +// Shape of the document to be stored in MongoDB, with embeddings. +#[derive(Serialize, Debug)] +struct Document { #[serde(rename = "_id")] - pub id: String, - pub document: serde_json::Value, + id: String, + definition: String, + embedding: Vec, } #[tokio::main] @@ -34,41 +50,59 @@ async fn main() -> Result<(), anyhow::Error> { MongoClient::with_options(options).expect("MongoDB client options should be valid"); // Initialize MongoDB vector store - let collection: Collection = mongodb_client + let collection: Collection = mongodb_client .database("knowledgebase") .collection("context"); - let mut vector_store = MongoDbVectorStore::new(collection); - // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let fake_definitions = vec![ + WordDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + }, + WordDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + }, + WordDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + } + ]; + let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .documents(fake_definitions)? .build() .await?; - // Add embeddings to vector store - match vector_store.add_documents(embeddings).await { + let mongo_documents = embeddings + .iter() + .map( + |(WordDefinition { id, definition, .. }, embedding)| Document { + id: id.clone(), + definition: definition.clone(), + embedding: embedding.first().vec.clone(), + }, + ) + .collect::>(); + + match collection.insert_many(mongo_documents, None).await { Ok(_) => println!("Documents added successfully"), Err(e) => println!("Error adding documents: {:?}", e), - } + }; - // Create a vector index on our vector store + // Create a vector index on our vector store. + // Note: a vector index called "vector_index" must exist on the MongoDB collection you are querying. // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store - .index(model, "vector_index", SearchParams::default()) - .await?; + let index = + MongoDbVectorIndex::new(collection, model, "vector_index", SearchParams::new()).await?; // Query the index let results = index - .top_n::("What is a linglingdong?", 1) - .await? - .into_iter() - .map(|(score, id, doc)| (score, id, doc.document)) - .collect::>(); + .top_n::("What is a linglingdong?", 1) + .await?; println!("Results: {:?}", results); diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 56c0009d..813818e9 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -2,16 +2,11 @@ use futures::StreamExt; use mongodb::bson::{self, doc}; use rig::{ - embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, - vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, + embeddings::embedding::{Embedding, EmbeddingModel}, + vector_store::{VectorStoreError, VectorStoreIndex}, }; use serde::{Deserialize, Serialize}; -/// A MongoDB vector store. -pub struct MongoDbVectorStore { - collection: mongodb::Collection, -} - #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] struct SearchIndex { @@ -25,8 +20,8 @@ struct SearchIndex { } impl SearchIndex { - async fn get_search_index( - collection: mongodb::Collection, + async fn get_search_index( + collection: mongodb::Collection, index_name: &str, ) -> Result { collection @@ -61,100 +56,38 @@ fn mongodb_to_rig_error(e: mongodb::error::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -impl VectorStore for MongoDbVectorStore { - type Q = mongodb::bson::Document; - - async fn add_documents( - &mut self, - documents: Vec, - ) -> Result<(), VectorStoreError> { - self.collection - .clone_with_type::() - .insert_many(documents, None) - .await - .map_err(mongodb_to_rig_error)?; - Ok(()) - } - - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result, VectorStoreError> { - self.collection - .clone_with_type::() - .find_one(doc! { "_id": id }, None) - .await - .map_err(mongodb_to_rig_error) - } - - async fn get_document serde::Deserialize<'a>>( - &self, - id: &str, - ) -> Result, VectorStoreError> { - Ok(self - .collection - .clone_with_type::() - .aggregate( - [ - doc! {"$match": { "_id": id}}, - doc! {"$project": { "document": 1 }}, - doc! {"$replaceRoot": { "newRoot": "$document" }}, - ], - None, - ) - .await - .map_err(mongodb_to_rig_error)? - .with_type::() - .next() - .await - .transpose() - .map_err(mongodb_to_rig_error)? - .map(|doc| serde_json::from_str(&doc)) - .transpose()?) - } - - async fn get_document_by_query( - &self, - query: Self::Q, - ) -> Result, VectorStoreError> { - self.collection - .clone_with_type::() - .find_one(query, None) - .await - .map_err(mongodb_to_rig_error) - } -} - -impl MongoDbVectorStore { - /// Create a new `MongoDbVectorStore` from a MongoDB collection. - pub fn new(collection: mongodb::Collection) -> Self { - Self { collection } - } - - /// Create a new `MongoDbVectorIndex` from an existing `MongoDbVectorStore`. - /// - /// The index (of type "vector") must already exist for the MongoDB collection. - /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. - pub async fn index( - &self, - model: M, - index_name: &str, - search_params: SearchParams, - ) -> Result, VectorStoreError> { - MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params).await - } -} - /// A vector index for a MongoDB collection. -pub struct MongoDbVectorIndex { - collection: mongodb::Collection, +/// # Example +/// ``` +/// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; +/// use rig::embeddings::EmbeddingModel; +/// +/// #[derive(serde::Serialize, Debug)] +/// struct Document { +/// #[serde(rename = "_id")] +/// id: String, +/// definition: String, +/// embedding: Vec, +/// } +/// +/// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. +/// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. +/// let index = MongoDbVectorIndex::new( +/// collection, +/// model, +/// "vector_index", // <-- replace with the name of the index in your mongodb collection. +/// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. +/// ); +/// ``` +pub struct MongoDbVectorIndex { + collection: mongodb::Collection, model: M, index_name: String, embedded_field: String, search_params: SearchParams, } -impl MongoDbVectorIndex { +impl MongoDbVectorIndex { /// Vector search stage of aggregation pipeline of mongoDB collection. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document { @@ -188,9 +121,13 @@ impl MongoDbVectorIndex { } } -impl MongoDbVectorIndex { +impl MongoDbVectorIndex { + /// Create a new `MongoDbVectorIndex`. + /// + /// The index (of type "vector") must already exist for the MongoDB collection. + /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. pub async fn new( - collection: mongodb::Collection, + collection: mongodb::Collection, model: M, index_name: &str, search_params: SearchParams, @@ -226,6 +163,7 @@ impl MongoDbVectorIndex { /// See [MongoDB Vector Search](`https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/`) for more information /// on each of the fields +#[derive(Default)] pub struct SearchParams { filter: mongodb::bson::Document, exact: Option, @@ -268,19 +206,51 @@ impl SearchParams { } } -impl Default for SearchParams { - fn default() -> Self { - Self::new() - } -} - -impl VectorStoreIndex for MongoDbVectorIndex { - async fn top_n Deserialize<'a> + std::marker::Send>( +impl VectorStoreIndex + for MongoDbVectorIndex +{ + /// Implement the `top_n` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. + /// # Example + /// ``` + /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// #[derive(serde::Serialize, Debug)] + /// struct Document { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// embedding: Vec, + /// } + /// + /// #[derive(serde::Deserialize, Debug)] + /// struct Definition { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// } + /// + /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. + /// + /// let vector_store_index = MongoDbVectorIndex::new( + /// collection, + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n::("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// ``` + async fn top_n Deserialize<'a> + Send>( &self, query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let mut cursor = self .collection @@ -322,12 +292,40 @@ impl VectorStoreIndex for MongoDbV Ok(results) } + /// Implement the `top_n_ids` method of the `VectorStoreIndex` trait for `MongoDbVectorIndex`. + /// # Example + /// ``` + /// use rig_mongodb::{MongoDbVectorIndex, SearchParams}; + /// use rig::embeddings::EmbeddingModel; + /// + /// #[derive(serde::Serialize, Debug)] + /// struct Document { + /// #[serde(rename = "_id")] + /// id: String, + /// definition: String, + /// embedding: Vec, + /// } + /// + /// let collection: collection: mongodb::Collection = mongodb_client.collection(""); // <-- replace with your mongodb collection. + /// let model: model: EmbeddingModel = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // <-- replace with your embedding model. + /// let vector_store_index = MongoDbVectorIndex::new( + /// collection, + /// model, + /// "vector_index", // <-- replace with the name of the index in your mongodb collection. + /// SearchParams::new("embedding"), // <-- field name in `Document` that contains the embeddings. + /// ); + /// + /// // Query the index + /// vector_store_index + /// .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + /// .await?; + /// ``` async fn top_n_ids( &self, query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.model.embed_document(query).await?; + let prompt_embedding = self.model.embed_text(query).await?; let mut cursor = self .collection diff --git a/rig-neo4j/Cargo.toml b/rig-neo4j/Cargo.toml index a0a94633..aae275c7 100644 --- a/rig-neo4j/Cargo.toml +++ b/rig-neo4j/Cargo.toml @@ -22,3 +22,7 @@ anyhow = "1.0.86" tokio = { version = "1.38.0", features = ["macros"] } textwrap = { version = "0.16.1"} term_size = { version = "0.3.2"} + +[[example]] +name = "vector_search_simple" +required-features = ["rig-core/derive"] \ No newline at end of file diff --git a/rig-neo4j/examples/vector_search_simple.rs b/rig-neo4j/examples/vector_search_simple.rs index 2cb0030d..fca43d27 100644 --- a/rig-neo4j/examples/vector_search_simple.rs +++ b/rig-neo4j/examples/vector_search_simple.rs @@ -13,12 +13,20 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex as _, + Embed, }; use rig_neo4j::{ vector_index::{IndexConfig, SearchParams}, Neo4jClient, ToBoltType, }; +#[derive(Embed, Clone, Debug)] +pub struct WordDefinition { + pub id: String, + #[embed] + pub definition: String, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client @@ -36,9 +44,18 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) - .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("doc1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("doc2", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .document(WordDefinition { + id: "doc0".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + })? + .document(WordDefinition { + id: "doc1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + })? + .document(WordDefinition { + id: "doc2".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + })? .build() .await?; @@ -54,7 +71,7 @@ async fn main() -> Result<(), anyhow::Error> { } let create_nodes = futures::stream::iter(embeddings) - .map(|doc| { + .map(|(doc, embeddings)| { neo4j_client.graph.run( neo4rs::query( " @@ -68,8 +85,8 @@ async fn main() -> Result<(), anyhow::Error> { .param("id", doc.id) // Here we use the first embedding but we could use any of them. // Neo4j only takes primitive types or arrays as properties. - .param("embedding", doc.embeddings[0].vec.clone()) - .param("document", doc.document.to_bolt_type()), + .param("embedding", embeddings.first().vec.clone()) + .param("document", doc.definition.to_bolt_type()), ) }) .buffer_unordered(3) diff --git a/rig-neo4j/src/vector_index.rs b/rig-neo4j/src/vector_index.rs index db6fd3df..bf39644d 100644 --- a/rig-neo4j/src/vector_index.rs +++ b/rig-neo4j/src/vector_index.rs @@ -259,7 +259,7 @@ impl VectorStoreIndex for Neo4jVec query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.embedding_model.embed_document(query).await?; + let prompt_embedding = self.embedding_model.embed_text(query).await?; let query = self.build_vector_search_query(prompt_embedding, true, n); let rows = self.execute_and_collect::>(query).await?; @@ -279,7 +279,7 @@ impl VectorStoreIndex for Neo4jVec query: &str, n: usize, ) -> Result, VectorStoreError> { - let prompt_embedding = self.embedding_model.embed_document(query).await?; + let prompt_embedding = self.embedding_model.embed_text(query).await?; let query = self.build_vector_search_query(prompt_embedding, false, n); diff --git a/rig-qdrant/Cargo.toml b/rig-qdrant/Cargo.toml index 4a7360a9..35c4b96e 100644 --- a/rig-qdrant/Cargo.toml +++ b/rig-qdrant/Cargo.toml @@ -16,3 +16,7 @@ qdrant-client = "1.12.1" [dev-dependencies] tokio = { version = "1.40.0", features = ["rt-multi-thread"] } anyhow = "1.0.89" + +[[example]] +name = "qdrant_vector_search" +required-features = ["rig-core/derive"] \ No newline at end of file diff --git a/rig-qdrant/examples/qdrant_vector_search.rs b/rig-qdrant/examples/qdrant_vector_search.rs index c9148d69..b1a91349 100644 --- a/rig-qdrant/examples/qdrant_vector_search.rs +++ b/rig-qdrant/examples/qdrant_vector_search.rs @@ -19,10 +19,18 @@ use rig::{ embeddings::EmbeddingsBuilder, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, + Embed, }; use rig_qdrant::QdrantVectorStore; use serde_json::json; +#[derive(Embed)] +struct WordDefinition { + id: String, + #[embed] + definition: String, +} + #[tokio::main] async fn main() -> Result<(), anyhow::Error> { const COLLECTION_NAME: &str = "rig-collection"; @@ -49,21 +57,30 @@ async fn main() -> Result<(), anyhow::Error> { let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let documents = EmbeddingsBuilder::new(model.clone()) - .simple_document("0981d983-a5f8-49eb-89ea-f7d3b2196d2e", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") - .simple_document("62a36d43-80b6-4fd6-990c-f75bb02287d1", "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.") - .simple_document("f9e17d59-32e5-440c-be02-b2759a654824", "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.") + .document(WordDefinition { + id: "0981d983-a5f8-49eb-89ea-f7d3b2196d2e".to_string(), + definition: "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets".to_string(), + })? + .document(WordDefinition { + id: "62a36d43-80b6-4fd6-990c-f75bb02287d1".to_string(), + definition: "Definition of a *glarb-glarb*: A glarb-glarb is a ancient tool used by the ancestors of the inhabitants of planet Jiro to farm the land.".to_string(), + })? + .document(WordDefinition { + id: "f9e17d59-32e5-440c-be02-b2759a654824".to_string(), + definition: "Definition of a *linglingdong*: A term used by inhabitants of the far side of the moon to describe humans.".to_string(), + })? .build() .await?; let points: Vec = documents .into_iter() - .map(|d| { - let vec: Vec = d.embeddings[0].vec.iter().map(|&x| x as f32).collect(); + .map(|(d, embeddings)| { + let vec: Vec = embeddings.first().vec.iter().map(|&x| x as f32).collect(); PointStruct::new( d.id, vec, Payload::try_from(json!({ - "document": d.document, + "document": d.definition, })) .unwrap(), ) diff --git a/rig-qdrant/src/lib.rs b/rig-qdrant/src/lib.rs index e88878cc..666d0a4f 100644 --- a/rig-qdrant/src/lib.rs +++ b/rig-qdrant/src/lib.rs @@ -36,7 +36,7 @@ impl QdrantVectorStore { /// Embed query based on `QdrantVectorStore` model and modify the vector in the required format. async fn generate_query_vector(&self, query: &str) -> Result, VectorStoreError> { - let embedding = self.model.embed_document(query).await?; + let embedding = self.model.embed_text(query).await?; Ok(embedding.vec.iter().map(|&x| x as f32).collect()) }