From 9d7cbb778b7162525855858f9067a382f0b094b9 Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 16 Sep 2024 16:43:50 -0400 Subject: [PATCH 01/39] feat: start implementing VectorStore trait for lancedb --- Cargo.lock | 4347 +++++++++++++++++++++++++++++++++++----- Cargo.toml | 2 +- rig-lancedb/Cargo.toml | 13 + rig-lancedb/src/lib.rs | 182 ++ 4 files changed, 4089 insertions(+), 455 deletions(-) create mode 100644 rig-lancedb/Cargo.toml create mode 100644 rig-lancedb/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 76e03fd9..ab13cda6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,12 +24,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", + "const-random", "getrandom", "once_cell", "version_check", "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "allocator-api2" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -52,704 +68,2795 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" [[package]] -name = "async-trait" -version = "0.1.80" +name = "arc-swap" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.65", -] +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" [[package]] -name = "autocfg" -version = "1.3.0" +name = "arrow" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "05048a8932648b63f21c37d88b552ccc8a65afb6dfe9fc9f30ce79174c2e7a85" +dependencies = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-csv", + "arrow-data", + "arrow-ipc", + "arrow-json", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", +] [[package]] -name = "backtrace" -version = "0.3.71" +name = "arrow-arith" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +checksum = "1d8a57966e43bfe9a3277984a14c24ec617ad874e4c0e1d2a1b083a39cfbf22c" dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "num", ] [[package]] -name = "base64" -version = "0.13.1" +name = "arrow-array" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" +checksum = "16f4a9468c882dc66862cef4e1fd8423d47e67972377d85d80e022786427768c" +dependencies = [ + "ahash", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "chrono-tz", + "half", + "hashbrown 0.14.5", + "num", +] [[package]] -name = "base64" -version = "0.21.7" +name = "arrow-buffer" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +checksum = "c975484888fc95ec4a632cdc98be39c085b1bb518531b0c80c5d462063e5daa1" +dependencies = [ + "bytes", + "half", + "num", +] [[package]] -name = "bitflags" -version = "1.3.2" +name = "arrow-cast" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +checksum = "da26719e76b81d8bc3faad1d4dbdc1bcc10d14704e63dc17fc9f3e7e1e567c8e" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "atoi", + "base64 0.22.1", + "chrono", + "comfy-table", + "half", + "lexical-core", + "num", + "ryu", +] [[package]] -name = "bitflags" -version = "2.5.0" +name = "arrow-csv" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "c13c36dc5ddf8c128df19bab27898eea64bf9da2b555ec1cd17a8ff57fba9ec2" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "csv", + "csv-core", + "lazy_static", + "lexical-core", + "regex", +] [[package]] -name = "bitvec" -version = "1.0.1" +name = "arrow-data" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +checksum = "dd9d6f18c65ef7a2573ab498c374d8ae364b4a4edf67105357491c031f716ca5" dependencies = [ - "funty", - "radium", - "tap", - "wyz", + "arrow-buffer", + "arrow-schema", + "half", + "num", ] [[package]] -name = "block-buffer" -version = "0.10.4" +name = "arrow-ipc" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +checksum = "e786e1cdd952205d9a8afc69397b317cfbb6e0095e445c69cda7e8da5c1eeb0f" dependencies = [ - "generic-array", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "flatbuffers", + "lz4_flex", + "zstd", ] [[package]] -name = "bson" -version = "2.10.0" +name = "arrow-json" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d43b38e074cc0de2957f10947e376a1d88b9c4dbab340b590800cc1b2e066b2" +checksum = "fb22284c5a2a01d73cebfd88a33511a3234ab45d66086b2ca2d1228c3498e445" dependencies = [ - "ahash", - "base64 0.13.1", - "bitvec", - "hex", - "indexmap", - "js-sys", - "once_cell", - "rand", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "indexmap 2.2.6", + "lexical-core", + "num", "serde", - "serde_bytes", "serde_json", - "time", - "uuid", ] [[package]] -name = "bumpalo" -version = "3.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" - -[[package]] -name = "bytes" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" - -[[package]] -name = "cc" -version = "1.0.98" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" - -[[package]] -name = "cfg-if" -version = "1.0.0" +name = "arrow-ord" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "42745f86b1ab99ef96d1c0bcf49180848a64fe2c7a7a0d945bc64fa2b21ba9bc" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "half", + "num", +] [[package]] -name = "chrono" -version = "0.4.38" +name = "arrow-row" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "4cd09a518c602a55bd406bcc291a967b284cfa7a63edfbf8b897ea4748aad23c" dependencies = [ - "android-tzdata", - "iana-time-zone", - "num-traits", - "windows-targets 0.52.5", + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "half", ] [[package]] -name = "convert_case" -version = "0.4.0" +name = "arrow-schema" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +checksum = "9e972cd1ff4a4ccd22f86d3e53e835c2ed92e0eea6a3e8eadb72b4f1ac802cf8" +dependencies = [ + "bitflags 2.5.0", +] [[package]] -name = "core-foundation" -version = "0.9.4" +name = "arrow-select" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +checksum = "600bae05d43483d216fb3494f8c32fdbefd8aa4e1de237e790dbb3d9f44690a3" dependencies = [ - "core-foundation-sys", - "libc", + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "num", ] [[package]] -name = "core-foundation-sys" -version = "0.8.6" +name = "arrow-string" +version = "52.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "f0dc1985b67cb45f6606a248ac2b4a288849f196bab8c657ea5589f47cdd55e6" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "memchr", + "num", + "regex", + "regex-syntax", +] [[package]] -name = "cpufeatures" -version = "0.2.12" +name = "async-io" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" dependencies = [ - "libc", + "async-lock", + "autocfg", + "cfg-if", + "concurrent-queue", + "futures-lite", + "log", + "parking", + "polling", + "rustix 0.37.27", + "slab", + "socket2 0.4.10", + "waker-fn", ] [[package]] -name = "crypto-common" -version = "0.1.6" +name = "async-lock" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +checksum = "287272293e9d8c41773cec55e365490fe034813a2f172f502d6ddcf75b2f582b" dependencies = [ - "generic-array", - "typenum", + "event-listener 2.5.3", ] [[package]] -name = "darling" -version = "0.13.4" +name = "async-priority-channel" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a01d95850c592940db9b8194bc39f4bc0e89dee5c4265e4b1807c34a9aba453c" +checksum = "acde96f444d31031f760c5c43dc786b97d3e1cb2ee49dd06898383fe9a999758" dependencies = [ - "darling_core", - "darling_macro", + "event-listener 4.0.3", ] [[package]] -name = "darling_core" -version = "0.13.4" +name = "async-recursion" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "859d65a907b6852c9361e3185c862aae7fafd2887876799fa55f5f99dc40d610" +checksum = "3b43422f69d8ff38f95f1b2bb76517c91589a924d1559a0e935d7c8ce0274c11" dependencies = [ - "fnv", - "ident_case", "proc-macro2", "quote", - "strsim", - "syn 1.0.109", + "syn 2.0.77", ] [[package]] -name = "darling_macro" -version = "0.13.4" +name = "async-trait" +version = "0.1.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c972679f83bdf9c42bd905396b6c3588a843a17f0f16dfcfa3e2c5d57441835" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" dependencies = [ - "darling_core", + "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.77", ] [[package]] -name = "data-encoding" -version = "2.6.0" +name = "async_cell" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" +checksum = "834eee9ce518130a3b4d5af09ecc43e9d6b57ee76613f227a1ddd6b77c7a62bc" [[package]] -name = "deranged" -version = "0.3.11" +name = "atoi" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" dependencies = [ - "powerfmt", + "num-traits", ] [[package]] -name = "derivative" -version = "2.2.0" +name = "atomic-waker" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] -name = "derive_more" -version = "0.99.17" +name = "autocfg" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "aws-config" +version = "1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "848d7b9b605720989929279fa644ce8f244d0ce3146fcca5b70e4eb7b3c020fc" dependencies = [ - "convert_case", - "proc-macro2", - "quote", - "rustc_version 0.4.0", - "syn 1.0.109", + "aws-credential-types", + "aws-runtime", + "aws-sdk-sso", + "aws-sdk-ssooidc", + "aws-sdk-sts", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand 2.1.0", + "hex", + "http 0.2.12", + "ring", + "time", + "tokio", + "tracing", + "url", + "zeroize", ] [[package]] -name = "digest" -version = "0.10.7" +name = "aws-credential-types" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +checksum = "60e8f6b615cb5fc60a98132268508ad104310f0cfb25a1c22eee76efdf9154da" dependencies = [ - "block-buffer", - "crypto-common", - "subtle", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "zeroize", ] [[package]] -name = "dyn-clone" -version = "1.0.17" +name = "aws-runtime" +version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +checksum = "a10d5c055aa540164d9561a0e2e74ad30f0dcf7393c3a92f6733ddf9c5762468" +dependencies = [ + "aws-credential-types", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand 2.1.0", + "http 0.2.12", + "http-body 0.4.6", + "once_cell", + "percent-encoding", + "pin-project-lite", + "tracing", + "uuid", +] [[package]] -name = "encoding_rs" -version = "0.8.34" +name = "aws-sdk-dynamodb" +version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +checksum = "c7f3d9e807092149e3df266e3f4d9760dac439b90f82d8438e5b2c0bbe62007f" dependencies = [ - "cfg-if", + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand 2.1.0", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", ] [[package]] -name = "enum-as-inner" -version = "0.4.0" +name = "aws-sdk-sso" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21cdad81446a7f7dc43f6a77409efeb9733d2fa65553efef6018ef257c959b73" +checksum = "27bf24cd0d389daa923e974b0e7c38daf308fc21e963c049f57980235017175e" dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn 1.0.109", + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", ] [[package]] -name = "equivalent" -version = "1.0.1" +name = "aws-sdk-ssooidc" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +checksum = "3b43b3220f1c46ac0e9dcc0a97d94b93305dacb36d1dd393996300c6b9b74364" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", +] [[package]] -name = "errno" -version = "0.3.9" +name = "aws-sdk-sts" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "d1c46924fb1add65bba55636e12812cae2febf68c0f37361766f627ddcca91ce" dependencies = [ - "libc", - "windows-sys 0.52.0", + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "http 0.2.12", + "once_cell", + "regex-lite", + "tracing", ] [[package]] -name = "fastrand" -version = "2.1.0" +name = "aws-sigv4" +version = "1.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" +checksum = "cc8db6904450bafe7473c6ca9123f88cc11089e41a025408f992db4e22d3be68" +dependencies = [ + "aws-credential-types", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "form_urlencoded", + "hex", + "hmac", + "http 0.2.12", + "http 1.1.0", + "once_cell", + "percent-encoding", + "sha2", + "time", + "tracing", +] [[package]] -name = "finl_unicode" -version = "1.2.0" +name = "aws-smithy-async" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" +checksum = "62220bc6e97f946ddd51b5f1361f78996e704677afc518a4ff66b7a72ea1378c" +dependencies = [ + "futures-util", + "pin-project-lite", + "tokio", +] [[package]] -name = "fnv" -version = "1.0.7" +name = "aws-smithy-http" +version = "0.60.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +checksum = "5c8bc3e8fdc6b8d07d976e301c02fe553f72a39b7a9fea820e023268467d7ab6" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http-body 0.4.6", + "once_cell", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] [[package]] -name = "foreign-types" -version = "0.3.2" +name = "aws-smithy-json" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +checksum = "4683df9469ef09468dad3473d129960119a0d3593617542b7d52086c8486f2d6" dependencies = [ - "foreign-types-shared", + "aws-smithy-types", ] [[package]] -name = "foreign-types-shared" -version = "0.1.1" +name = "aws-smithy-query" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" +dependencies = [ + "aws-smithy-types", + "urlencoding", +] [[package]] -name = "form_urlencoded" -version = "1.2.1" +name = "aws-smithy-runtime" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" dependencies = [ - "percent-encoding", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "fastrand 2.1.0", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "http-body 1.0.1", + "httparse", + "hyper 0.14.28", + "hyper-rustls 0.24.2", + "once_cell", + "pin-project-lite", + "pin-utils", + "rustls 0.21.12", + "tokio", + "tracing", ] [[package]] -name = "funty" -version = "2.0.0" +name = "aws-smithy-runtime-api" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" +checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.1.0", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", +] [[package]] -name = "futures" -version = "0.3.30" +name = "aws-smithy-types" +version = "1.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "03701449087215b5369c7ea17fef0dd5d24cb93439ec5af0c7615f58c3f22605" dependencies = [ - "futures-channel", + "base64-simd", + "bytes", + "bytes-utils", "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", + "http 0.2.12", + "http 1.1.0", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", + "tokio", + "tokio-util", ] [[package]] -name = "futures-channel" -version = "0.3.30" +name = "aws-smithy-xml" +version = "0.60.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" dependencies = [ - "futures-core", - "futures-sink", + "xmlparser", ] [[package]] -name = "futures-core" -version = "0.3.30" +name = "aws-types" +version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "5221b91b3e441e6675310829fd8984801b772cb1546ef6c0e54dec9f1ac13fef" +dependencies = [ + "aws-credential-types", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "rustc_version 0.4.0", + "tracing", +] [[package]] -name = "futures-executor" -version = "0.3.30" +name = "backtrace" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ - "futures-core", - "futures-task", - "futures-util", + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", ] [[package]] -name = "futures-io" -version = "0.3.30" +name = "base64" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] -name = "futures-macro" -version = "0.3.30" +name = "base64" +version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.65", + "outref", + "vsimd", ] [[package]] -name = "futures-sink" -version = "0.3.30" +name = "bitflags" +version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] -name = "futures-task" -version = "0.3.30" +name = "bitflags" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] -name = "futures-util" -version = "0.3.30" +name = "bitpacking" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "4c1d3e2bfd8d06048a179f7b17afc3188effa10385e7b00dc65af6aae732ea92" dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-macro", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "pin-utils", - "slab", + "crunchy", ] [[package]] -name = "generic-array" -version = "0.14.7" +name = "bitvec" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" dependencies = [ - "typenum", - "version_check", + "funty", + "radium", + "tap", + "wyz", ] [[package]] -name = "getrandom" -version = "0.2.15" +name = "block-buffer" +version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ - "cfg-if", - "libc", - "wasi", + "generic-array", ] [[package]] -name = "gimli" -version = "0.28.1" +name = "bson" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +checksum = "4d43b38e074cc0de2957f10947e376a1d88b9c4dbab340b590800cc1b2e066b2" +dependencies = [ + "ahash", + "base64 0.13.1", + "bitvec", + "hex", + "indexmap 2.2.6", + "js-sys", + "once_cell", + "rand", + "serde", + "serde_bytes", + "serde_json", + "time", + "uuid", +] [[package]] -name = "h2" -version = "0.3.26" +name = "bumpalo" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http", - "indexmap", - "slab", - "tokio", - "tokio-util", - "tracing", -] +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] -name = "hashbrown" -version = "0.14.5" +name = "bytecount" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +checksum = "5ce89b21cab1437276d2650d57e971f9d548a2d9037cc231abdc0562b97498ce" [[package]] -name = "heck" -version = "0.4.1" +name = "bytemuck" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "94bbb0ad554ad961ddc5da507a12a29b14e4ae5bda06b19f575a3e6079d2e2ae" [[package]] -name = "hermit-abi" -version = "0.3.9" +name = "byteorder" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] -name = "hex" -version = "0.4.3" +name = "bytes" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] -name = "hmac" -version = "0.12.1" +name = "bytes-utils" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" dependencies = [ - "digest", + "bytes", + "either", +] + +[[package]] +name = "camino" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b96ec4966b5813e2c0507c1f86115c8c5abaadc3980879c3424042a02fd1ad3" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo-platform" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24b1f0365a6c6bb4020cd05806fd0d33c44d38046b8bd7f0e40814b9763cabfc" +dependencies = [ + "serde", +] + +[[package]] +name = "cargo_metadata" +version = "0.14.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4acbb09d9ee8e23699b9634375c72795d095bf268439da88562cf9b501f181fa" +dependencies = [ + "camino", + "cargo-platform", + "semver 1.0.23", + "serde", + "serde_json", +] + +[[package]] +name = "cc" +version = "1.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" +dependencies = [ + "jobserver", + "libc", + "once_cell", +] + +[[package]] +name = "census" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f4c707c6a209cbe82d10abd08e1ea8995e9ea937d2550646e02798948992be0" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-targets 0.52.5", +] + +[[package]] +name = "chrono-tz" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93698b29de5e97ad0ae26447b344c482a7284c737d9ddc5f9e52b74a336671bb" +dependencies = [ + "chrono", + "chrono-tz-build", + "phf", +] + +[[package]] +name = "chrono-tz-build" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c088aee841df9c3041febbb73934cfc39708749bf96dc827e3359cd39ef11b1" +dependencies = [ + "parse-zoneinfo", + "phf", + "phf_codegen", +] + +[[package]] +name = "comfy-table" +version = "7.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" +dependencies = [ + "strum", + "strum_macros", + "unicode-width", +] + +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom", + "once_cell", + "tiny-keccak", +] + +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + +[[package]] +name = "cpufeatures" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + +[[package]] +name = "darling" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a01d95850c592940db9b8194bc39f4bc0e89dee5c4265e4b1807c34a9aba453c" +dependencies = [ + "darling_core 0.13.4", + "darling_macro 0.13.4", +] + +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core 0.20.10", + "darling_macro 0.20.10", +] + +[[package]] +name = "darling_core" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "859d65a907b6852c9361e3185c862aae7fafd2887876799fa55f5f99dc40d610" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.10.0", + "syn 1.0.109", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.77", +] + +[[package]] +name = "darling_macro" +version = "0.13.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c972679f83bdf9c42bd905396b6c3588a843a17f0f16dfcfa3e2c5d57441835" +dependencies = [ + "darling_core 0.13.4", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core 0.20.10", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + +[[package]] +name = "datafusion" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab9d55a9cd2634818953809f75ebe5248b00dd43c3227efb2a51a2d5feaad54e" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-ipc", + "arrow-schema", + "async-trait", + "bytes", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-common-runtime", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", + "datafusion-functions-array", + "datafusion-optimizer", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-plan", + "datafusion-sql", + "futures", + "glob", + "half", + "hashbrown 0.14.5", + "indexmap 2.2.6", + "itertools 0.12.1", + "log", + "num_cpus", + "object_store", + "parking_lot", + "paste", + "pin-project-lite", + "rand", + "sqlparser", + "tempfile", + "tokio", + "url", + "uuid", +] + +[[package]] +name = "datafusion-common" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "def66b642959e7f96f5d2da22e1f43d3bd35598f821e5ce351a0553e0f1b7367" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema", + "chrono", + "half", + "hashbrown 0.14.5", + "instant", + "libc", + "num_cpus", + "object_store", + "sqlparser", +] + +[[package]] +name = "datafusion-common-runtime" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f104bb9cb44c06c9badf8a0d7e0855e5f7fa5e395b887d7f835e8a9457dc1352" +dependencies = [ + "tokio", +] + +[[package]] +name = "datafusion-execution" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ac0fd8b5d80bbca3fc3b6f40da4e9f6907354824ec3b18bbd83fee8cf5c3c3e" +dependencies = [ + "arrow", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-expr", + "futures", + "hashbrown 0.14.5", + "log", + "object_store", + "parking_lot", + "rand", + "tempfile", + "url", +] + +[[package]] +name = "datafusion-expr" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2103d2cc16fb11ef1fa993a6cac57ed5cb028601db4b97566c90e5fa77aa1e68" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-buffer", + "chrono", + "datafusion-common", + "paste", + "serde_json", + "sqlparser", + "strum", + "strum_macros", +] + +[[package]] +name = "datafusion-functions" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a369332afd0ef5bd565f6db2139fb9f1dfdd0afa75a7f70f000b74208d76994f" +dependencies = [ + "arrow", + "base64 0.22.1", + "chrono", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "hashbrown 0.14.5", + "hex", + "itertools 0.12.1", + "log", + "rand", + "regex", + "unicode-segmentation", + "uuid", +] + +[[package]] +name = "datafusion-functions-aggregate" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92718db1aff70c47e5abf9fc975768530097059e5db7c7b78cd64b5e9a11fc77" +dependencies = [ + "ahash", + "arrow", + "arrow-schema", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr-common", + "log", + "paste", + "sqlparser", +] + +[[package]] +name = "datafusion-functions-array" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30bb80f46ff3dcf4bb4510209c2ba9b8ce1b716ac8b7bf70c6bf7dca6260c831" +dependencies = [ + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-ord", + "arrow-schema", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions", + "datafusion-functions-aggregate", + "itertools 0.12.1", + "log", + "paste", +] + +[[package]] +name = "datafusion-optimizer" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82f34692011bec4fdd6fc18c264bf8037b8625d801e6dd8f5111af15cb6d71d3" +dependencies = [ + "arrow", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-expr", + "hashbrown 0.14.5", + "indexmap 2.2.6", + "itertools 0.12.1", + "log", + "paste", + "regex-syntax", +] + +[[package]] +name = "datafusion-physical-expr" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45538630defedb553771434a437f7ca8f04b9b3e834344aafacecb27dc65d5e5" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-ord", + "arrow-schema", + "arrow-string", + "base64 0.22.1", + "chrono", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr-common", + "half", + "hashbrown 0.14.5", + "hex", + "indexmap 2.2.6", + "itertools 0.12.1", + "log", + "paste", + "petgraph", + "regex", +] + +[[package]] +name = "datafusion-physical-expr-common" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d8a72b0ca908e074aaeca52c14ddf5c28d22361e9cb6bc79bb733cd6661b536" +dependencies = [ + "ahash", + "arrow", + "datafusion-common", + "datafusion-expr", + "hashbrown 0.14.5", + "rand", +] + +[[package]] +name = "datafusion-physical-plan" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b504eae6107a342775e22e323e9103f7f42db593ec6103b28605b7b7b1405c4a" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-ord", + "arrow-schema", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-common-runtime", + "datafusion-execution", + "datafusion-expr", + "datafusion-functions-aggregate", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "futures", + "half", + "hashbrown 0.14.5", + "indexmap 2.2.6", + "itertools 0.12.1", + "log", + "once_cell", + "parking_lot", + "pin-project-lite", + "rand", + "tokio", +] + +[[package]] +name = "datafusion-sql" +version = "40.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5db33f323f41b95ae201318ba654a9bf11113e58a51a1dff977b1a836d3d889" +dependencies = [ + "arrow", + "arrow-array", + "arrow-schema", + "datafusion-common", + "datafusion-expr", + "log", + "regex", + "sqlparser", + "strum", +] + +[[package]] +name = "deepsize" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cdb987ec36f6bf7bfbea3f928b75590b736fc42af8e54d97592481351b2b96c" +dependencies = [ + "deepsize_derive", +] + +[[package]] +name = "deepsize_derive" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990101d41f3bc8c1a45641024377ee284ecc338e5ecf3ea0f0e236d897c72796" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", + "serde", +] + +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version 0.4.0", + "syn 1.0.109", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + +[[package]] +name = "downcast-rs" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" + +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "encoding_rs" +version = "0.8.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "enum-as-inner" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21cdad81446a7f7dc43f6a77409efeb9733d2fa65553efef6018ef257c959b73" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "errno" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "error-chain" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d2f06b9cac1506ece98fe3231e3cc9c4410ec3d5b1f24ae1c8946f0742cdefc" +dependencies = [ + "version_check", +] + +[[package]] +name = "event-listener" +version = "2.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" + +[[package]] +name = "event-listener" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "fastdivide" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59668941c55e5c186b8b58c391629af56774ec768f73c08bbcd56f09348eb00b" + +[[package]] +name = "fastrand" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" +dependencies = [ + "instant", +] + +[[package]] +name = "fastrand" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" + +[[package]] +name = "finl_unicode" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + +[[package]] +name = "flatbuffers" +version = "24.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" +dependencies = [ + "bitflags 1.3.2", + "rustc_version 0.4.0", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "fs4" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7e180ac76c23b45e767bd7ae9579bc0bb458618c4bc71835926e098e61d15f8" +dependencies = [ + "rustix 0.38.34", + "windows-sys 0.52.0", +] + +[[package]] +name = "fsst" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26212c1db7eee3ec0808bd99107cf62ba4d3edd3489df601e2d0c73c5d739aec" +dependencies = [ + "rand", +] + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-lite" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" +dependencies = [ + "fastrand 1.9.0", + "futures-core", + "futures-io", + "memchr", + "parking", + "pin-project-lite", + "waker-fn", +] + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "js-sys", + "libc", + "wasi", + "wasm-bindgen", +] + +[[package]] +name = "gimli" +version = "0.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "h2" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap 2.2.6", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.1.0", + "indexmap 2.2.6", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", + "allocator-api2", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] +name = "hostname" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" +dependencies = [ + "libc", + "match_cfg", + "winapi", +] + +[[package]] +name = "htmlescape" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9025058dae765dee5070ec375f591e2ba14638c63feff74f13805a72e523163" + +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.1.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "hyper" +version = "0.14.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2 0.5.7", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "hyper" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.28", + "log", + "rustls 0.21.12", + "rustls-native-certs 0.6.3", + "tokio", + "tokio-rustls 0.24.1", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +dependencies = [ + "futures-util", + "http 1.1.0", + "hyper 1.4.1", + "hyper-util", + "rustls 0.23.13", + "rustls-native-certs 0.8.0", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.0", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper 0.14.28", + "native-tls", + "tokio", + "tokio-native-tls", +] + +[[package]] +name = "hyper-util" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "hyper 1.4.1", + "pin-project-lite", + "socket2 0.5.7", + "tokio", + "tower", + "tower-service", + "tracing", +] + +[[package]] +name = "hyperloglogplus" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "621debdf94dcac33e50475fdd76d34d5ea9c0362a834b9db08c3024696c1fbe3" +dependencies = [ + "serde", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" +dependencies = [ + "matches", + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", + "serde", +] + +[[package]] +name = "indexmap" +version = "2.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +dependencies = [ + "equivalent", + "hashbrown 0.14.5", + "serde", +] + +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "io-lifetimes" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "ipconfig" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" +dependencies = [ + "socket2 0.5.7", + "widestring", + "windows-sys 0.48.0", + "winreg 0.50.0", +] + +[[package]] +name = "ipnet" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lance" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a427160737dd74d2d4f566f3111027edc63927106541d173459d010209371c42" +dependencies = [ + "arrow", + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "async-recursion", + "async-trait", + "async_cell", + "aws-credential-types", + "aws-sdk-dynamodb", + "byteorder", + "bytes", + "chrono", + "dashmap", + "datafusion", + "datafusion-functions", + "datafusion-physical-expr", + "deepsize", + "futures", + "half", + "itertools 0.12.1", + "lance-arrow", + "lance-core", + "lance-datafusion", + "lance-encoding", + "lance-file", + "lance-index", + "lance-io", + "lance-linalg", + "lance-table", + "lazy_static", + "log", + "moka", + "object_store", + "pin-project", + "prost", + "prost-build", + "rand", + "roaring", + "serde", + "serde_json", + "snafu", + "tantivy", + "tempfile", + "tokio", + "tracing", + "url", + "uuid", +] + +[[package]] +name = "lance-arrow" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2f1cfebe08c64b1edabe9b6ccd6f8ea1bc6349d0870d47f2db8cdadf02ab8e2" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "arrow-select", + "getrandom", + "half", + "num-traits", + "rand", +] + +[[package]] +name = "lance-core" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "243aa2323dee6fcab6bb9bb3a21ae8f040c98a5de9bbfb7ab8484a036176185a" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-schema", + "async-trait", + "byteorder", + "bytes", + "chrono", + "datafusion-common", + "datafusion-sql", + "deepsize", + "futures", + "lance-arrow", + "lazy_static", + "libc", + "log", + "mock_instant", + "moka", + "num_cpus", + "object_store", + "pin-project", + "prost", + "rand", + "roaring", + "serde_json", + "snafu", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", + "url", ] [[package]] -name = "hostname" -version = "0.3.1" +name = "lance-datafusion" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" +checksum = "c0a69d039f93a43477245b51a8f1ce58a1f41485f8ded946f53031a11ded8c97" dependencies = [ - "libc", - "match_cfg", - "winapi", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-ord", + "arrow-schema", + "arrow-select", + "async-trait", + "datafusion", + "datafusion-common", + "datafusion-functions", + "datafusion-physical-expr", + "futures", + "lance-arrow", + "lance-core", + "lazy_static", + "log", + "prost", + "snafu", + "tokio", ] [[package]] -name = "http" -version = "0.2.12" +name = "lance-encoding" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +checksum = "3b713e49ce6039d0ca0f88e8ded66ee64d89c42f85107bc9e684fbff41386a65" dependencies = [ + "arrow", + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "arrow-select", "bytes", - "fnv", - "itoa", + "fsst", + "futures", + "hex", + "hyperloglogplus", + "itertools 0.12.1", + "lance-arrow", + "lance-core", + "log", + "num-traits", + "prost", + "prost-build", + "prost-types", + "rand", + "snafu", + "tokio", + "tracing", + "zstd", ] [[package]] -name = "http-body" -version = "0.4.6" +name = "lance-file" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +checksum = "4e0eaebd40c77f8f06a0cbadcd07f9344aea616dcd4d8712f6cad81c2eda14d5" dependencies = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "async-recursion", + "async-trait", + "byteorder", "bytes", - "http", - "pin-project-lite", + "datafusion-common", + "deepsize", + "futures", + "lance-arrow", + "lance-core", + "lance-encoding", + "lance-io", + "log", + "num-traits", + "object_store", + "prost", + "prost-build", + "prost-types", + "roaring", + "snafu", + "tempfile", + "tokio", + "tracing", ] [[package]] -name = "httparse" -version = "1.8.0" +name = "lance-index" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +checksum = "d1ad5a42b9a4909749ee62fc94c64d19259c2aadca7c8446f42ee8e9c7a097d3" +dependencies = [ + "arrow", + "arrow-array", + "arrow-ord", + "arrow-schema", + "arrow-select", + "async-recursion", + "async-trait", + "bitvec", + "bytes", + "crossbeam-queue", + "datafusion", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-expr", + "datafusion-sql", + "deepsize", + "futures", + "half", + "itertools 0.12.1", + "lance-arrow", + "lance-core", + "lance-datafusion", + "lance-encoding", + "lance-file", + "lance-io", + "lance-linalg", + "lance-table", + "lazy_static", + "log", + "moka", + "num-traits", + "object_store", + "prost", + "prost-build", + "rand", + "rayon", + "roaring", + "serde", + "serde_json", + "snafu", + "tantivy", + "tempfile", + "tokio", + "tracing", + "uuid", +] [[package]] -name = "httpdate" -version = "1.0.3" +name = "lance-io" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +checksum = "d0f334f2c279f80f19803141cf7f98c6b82e6ace3c7f75c8740f1df7a73bb720" +dependencies = [ + "arrow", + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "arrow-select", + "async-priority-channel", + "async-recursion", + "async-trait", + "aws-config", + "aws-credential-types", + "byteorder", + "bytes", + "chrono", + "deepsize", + "futures", + "lance-arrow", + "lance-core", + "lazy_static", + "log", + "object_store", + "path_abs", + "pin-project", + "prost", + "prost-build", + "rand", + "shellexpand", + "snafu", + "tokio", + "tracing", + "url", +] [[package]] -name = "hyper" -version = "0.14.28" +name = "lance-linalg" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" +checksum = "8fa019770a0afb287360a4ea919cff482371ad43318607d1e797534c819bf356" dependencies = [ - "bytes", - "futures-channel", - "futures-core", - "futures-util", - "h2", - "http", - "http-body", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "socket2 0.5.7", + "arrow-array", + "arrow-ord", + "arrow-schema", + "bitvec", + "cc", + "deepsize", + "futures", + "half", + "lance-arrow", + "lance-core", + "lazy_static", + "log", + "num-traits", + "rand", + "rayon", "tokio", - "tower-service", "tracing", - "want", ] [[package]] -name = "hyper-tls" -version = "0.5.0" +name = "lance-table" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +checksum = "99fa39bede133d578431db3f77b0c2c63fcaff12a088000648014c27266830a2" dependencies = [ + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-ipc", + "arrow-schema", + "async-trait", + "aws-credential-types", + "aws-sdk-dynamodb", + "byteorder", "bytes", - "hyper", - "native-tls", + "chrono", + "deepsize", + "futures", + "lance-arrow", + "lance-core", + "lance-file", + "lance-io", + "lazy_static", + "log", + "object_store", + "prost", + "prost-build", + "prost-types", + "rand", + "rangemap", + "roaring", + "serde", + "serde_json", + "snafu", "tokio", - "tokio-native-tls", + "tracing", + "url", + "uuid", ] [[package]] -name = "iana-time-zone" -version = "0.1.60" +name = "lance-testing" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "fb4eae0993cda6130cfd75d3d9d11830dd4ec8d4c66cf81def939837c419d4bc" dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "wasm-bindgen", - "windows-core", + "arrow-array", + "arrow-schema", + "lance-arrow", + "num-traits", + "rand", ] [[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" +name = "lancedb" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +checksum = "d867f115c1c23e77dc46a967664e3073630d53e3d61399f0b984c4a4753fd3a7" dependencies = [ - "cc", + "arrow", + "arrow-array", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-ord", + "arrow-schema", + "async-trait", + "bytes", + "chrono", + "datafusion-physical-plan", + "futures", + "half", + "lance", + "lance-datafusion", + "lance-encoding", + "lance-index", + "lance-linalg", + "lance-table", + "lance-testing", + "lazy_static", + "log", + "num-traits", + "object_store", + "pin-project", + "regex", + "serde", + "serde_json", + "serde_with 3.9.0", + "snafu", + "tokio", + "url", ] [[package]] -name = "ident_case" -version = "1.0.1" +name = "lazy_static" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] -name = "idna" -version = "0.2.3" +name = "levenshtein_automata" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" -dependencies = [ - "matches", - "unicode-bidi", - "unicode-normalization", -] +checksum = "0c2cdeb66e45e9f36bfad5bbdb4d2384e70936afbee843c6f6543f0c551ebb25" [[package]] -name = "idna" -version = "0.5.0" +name = "lexical-core" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", ] [[package]] -name = "indexmap" -version = "2.2.6" +name = "lexical-parse-float" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" dependencies = [ - "equivalent", - "hashbrown", + "lexical-parse-integer", + "lexical-util", + "static_assertions", ] [[package]] -name = "ipconfig" -version = "0.3.2" +name = "lexical-parse-integer" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b58db92f96b720de98181bbbe63c831e87005ab460c1bf306eb2622b4707997f" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" dependencies = [ - "socket2 0.5.7", - "widestring", - "windows-sys 0.48.0", - "winreg", + "lexical-util", + "static_assertions", ] [[package]] -name = "ipnet" -version = "2.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" - -[[package]] -name = "itoa" -version = "1.0.11" +name = "lexical-util" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +dependencies = [ + "static_assertions", +] [[package]] -name = "js-sys" -version = "0.3.69" +name = "lexical-write-float" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" dependencies = [ - "wasm-bindgen", + "lexical-util", + "lexical-write-integer", + "static_assertions", ] [[package]] -name = "lazy_static" -version = "1.4.0" +name = "lexical-write-integer" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +dependencies = [ + "lexical-util", + "static_assertions", +] [[package]] name = "libc" @@ -757,12 +2864,34 @@ version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags 2.5.0", + "libc", +] + [[package]] name = "linked-hash-map" version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +[[package]] +name = "linux-raw-sys" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" + [[package]] name = "linux-raw-sys" version = "0.4.14" @@ -785,6 +2914,15 @@ version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +[[package]] +name = "lru" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37ee39891760e7d94734f6f63fedc29a2e4a152f836120753a72503f09fcf904" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "lru-cache" version = "0.1.2" @@ -794,6 +2932,24 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "lz4_flex" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" +dependencies = [ + "twox-hash", +] + +[[package]] +name = "mach2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b955cdeb2a02b9117f121ce63aa52d08ade45de53e48fe6a38b39c10f6f709" +dependencies = [ + "libc", +] + [[package]] name = "match_cfg" version = "0.1.0" @@ -816,18 +2972,43 @@ dependencies = [ "digest", ] +[[package]] +name = "measure_time" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbefd235b0aadd181626f281e1d684e116972988c14c264e42069d5e8a5775cc" +dependencies = [ + "instant", + "log", +] + [[package]] name = "memchr" version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +[[package]] +name = "memmap2" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" +dependencies = [ + "libc", +] + [[package]] name = "mime" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.7.3" @@ -838,14 +3019,49 @@ dependencies = [ ] [[package]] -name = "mio" -version = "0.8.11" +name = "mio" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +dependencies = [ + "hermit-abi", + "libc", + "wasi", + "windows-sys 0.52.0", +] + +[[package]] +name = "mock_instant" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9366861eb2a2c436c20b12c8dbec5f798cea6b47ad99216be0282942e2c81ea0" +dependencies = [ + "once_cell", +] + +[[package]] +name = "moka" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +checksum = "fa6e72583bf6830c956235bff0d5afec8cf2952f579ebad18ae7821a917d950f" dependencies = [ - "libc", - "wasi", - "windows-sys 0.48.0", + "async-io", + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "futures-util", + "once_cell", + "parking_lot", + "quanta", + "rustc_version 0.4.0", + "scheduled-thread-pool", + "skeptic", + "smallvec", + "tagptr", + "thiserror", + "triomphe", + "uuid", ] [[package]] @@ -873,20 +3089,20 @@ dependencies = [ "percent-encoding", "rand", "rustc_version_runtime", - "rustls", - "rustls-pemfile", + "rustls 0.21.12", + "rustls-pemfile 1.0.4", "serde", "serde_bytes", - "serde_with", + "serde_with 1.14.0", "sha-1", "sha2", "socket2 0.4.10", "stringprep", - "strsim", + "strsim 0.10.0", "take_mut", "thiserror", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", "tokio-util", "trust-dns-proto", "trust-dns-resolver", @@ -895,6 +3111,18 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "multimap" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" + +[[package]] +name = "murmurhash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2195bf6aa996a481483b29d62a7663eed3fe39600c460e323f8ff41e90bdd89b" + [[package]] name = "native-tls" version = "0.2.11" @@ -913,6 +3141,16 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -923,12 +3161,76 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -936,6 +3238,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -957,12 +3260,49 @@ dependencies = [ "memchr", ] +[[package]] +name = "object_store" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bytes", + "chrono", + "futures", + "humantime", + "hyper 1.4.1", + "itertools 0.13.0", + "md-5", + "parking_lot", + "percent-encoding", + "quick-xml", + "rand", + "reqwest 0.12.5", + "ring", + "rustls-pemfile 2.1.3", + "serde", + "serde_json", + "snafu", + "tokio", + "tracing", + "url", + "walkdir", +] + [[package]] name = "once_cell" version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "oneshot" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e296cf87e61c9cfc1a61c3c63a0f7f286ed4554e0e22be84e8a38e1d264a2a29" + [[package]] name = "openssl" version = "0.10.64" @@ -986,7 +3326,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.77", ] [[package]] @@ -1007,6 +3347,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "ordered-float" version = "4.2.0" @@ -1016,12 +3362,33 @@ dependencies = [ "num-traits", ] +[[package]] +name = "outref" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a" + [[package]] name = "overload" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +[[package]] +name = "ownedbytes" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3a059efb063b8f425b948e042e6b9bd85edfe60e913630ed727b23e2dfcc558" +dependencies = [ + "stable_deref_trait", +] + +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.2" @@ -1045,6 +3412,33 @@ dependencies = [ "windows-targets 0.52.5", ] +[[package]] +name = "parse-zoneinfo" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f2a05b18d44e2957b88f96ba460715e295bc1d7510468a2f3d3b44535d26c24" +dependencies = [ + "regex", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "path_abs" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05ef02f6342ac01d8a93b65f96db53fe68a92a15f41144f97fb00a9e669633c3" +dependencies = [ + "serde", + "serde_derive", + "std_prelude", + "stfu8", +] + [[package]] name = "pbkdf2" version = "0.11.0" @@ -1060,6 +3454,74 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.2.6", +] + +[[package]] +name = "phf" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_codegen" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +dependencies = [ + "phf_generator", + "phf_shared", +] + +[[package]] +name = "phf_generator" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +dependencies = [ + "phf_shared", + "rand", +] + +[[package]] +name = "phf_shared" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + [[package]] name = "pin-project-lite" version = "0.2.14" @@ -1078,6 +3540,22 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +[[package]] +name = "polling" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b2d323e8ca7996b3e23126511a523f7e62924d93ecd5ae73b333815b0eb3dce" +dependencies = [ + "autocfg", + "bitflags 1.3.2", + "cfg-if", + "concurrent-queue", + "libc", + "log", + "pin-project-lite", + "windows-sys 0.48.0", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1091,19 +3569,167 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] -name = "proc-macro2" -version = "1.0.83" +name = "prettyplease" +version = "0.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479cf940fbbb3426c32c5d5176f62ad57549a0bb84773423ba8be9d089f5faba" +dependencies = [ + "proc-macro2", + "syn 2.0.77", +] + +[[package]] +name = "proc-macro2" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b33eb56c327dec362a9e55b3ad14f9d2f0904fb5a5b03b513ab5465399e9f43" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "prost" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" +dependencies = [ + "bytes", + "heck 0.5.0", + "itertools 0.12.1", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn 2.0.77", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" +dependencies = [ + "anyhow", + "itertools 0.12.1", + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "prost-types" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" +dependencies = [ + "prost", +] + +[[package]] +name = "pulldown-cmark" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57206b407293d2bcd3af849ce869d52068623f19e1b5ff8e8778e3309439682b" +dependencies = [ + "bitflags 2.5.0", + "memchr", + "unicase", +] + +[[package]] +name = "quanta" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab" +dependencies = [ + "crossbeam-utils", + "libc", + "mach2", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + +[[package]] +name = "quick-xml" +version = "0.36.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96a05e2e8efddfa51a84ca47cec303fac86c8541b686d37cac5efc0e094417bc" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "quinn" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +dependencies = [ + "bytes", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash 2.0.0", + "rustls 0.23.13", + "socket2 0.5.7", + "thiserror", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b33eb56c327dec362a9e55b3ad14f9d2f0904fb5a5b03b513ab5465399e9f43" +checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" dependencies = [ - "unicode-ident", + "bytes", + "rand", + "ring", + "rustc-hash 2.0.0", + "rustls 0.23.13", + "slab", + "thiserror", + "tinyvec", + "tracing", ] [[package]] -name = "quick-error" -version = "1.2.3" +name = "quinn-udp" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" +checksum = "8bffec3605b73c6f1754535084a85229fa8a30f86014e6c81aeec4abb68b0285" +dependencies = [ + "libc", + "once_cell", + "socket2 0.5.7", + "tracing", + "windows-sys 0.52.0", +] [[package]] name = "quote" @@ -1150,6 +3776,51 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rangemap" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60fcc7d6849342eff22c4350c8b9a989ee8ceabc4b481253e8946b9fe83d684" + +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.1" @@ -1159,6 +3830,52 @@ dependencies = [ "bitflags 2.5.0", ] +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom", + "libredox", + "thiserror", +] + +[[package]] +name = "regex" +version = "1.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + +[[package]] +name = "regex-syntax" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" + [[package]] name = "reqwest" version = "0.11.27" @@ -1170,10 +3887,10 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "hyper 0.14.28", "hyper-tls", "ipnet", "js-sys", @@ -1183,11 +3900,11 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls-pemfile", + "rustls-pemfile 1.0.4", "serde", "serde_json", "serde_urlencoded", - "sync_wrapper", + "sync_wrapper 0.1.2", "system-configuration", "tokio", "tokio-native-tls", @@ -1196,7 +3913,52 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "winreg", + "winreg 0.50.0", +] + +[[package]] +name = "reqwest" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-core", + "futures-util", + "h2 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", + "hyper-rustls 0.27.3", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "once_cell", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls 0.23.13", + "rustls-native-certs 0.7.3", + "rustls-pemfile 2.1.3", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "tokio", + "tokio-rustls 0.26.0", + "tokio-util", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "winreg 0.52.0", ] [[package]] @@ -1216,7 +3978,7 @@ dependencies = [ "anyhow", "futures", "ordered-float", - "reqwest", + "reqwest 0.11.27", "schemars", "serde", "serde_json", @@ -1226,6 +3988,19 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "rig-lancedb" +version = "0.1.0" +dependencies = [ + "arrow-array", + "futures", + "lancedb", + "rig-core", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "rig-mongodb" version = "0.1.0" @@ -1255,12 +4030,44 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "roaring" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f4b84ba6e838ceb47b41de5194a60244fac43d9fe03b71dbe8c5a201081d6d1" +dependencies = [ + "bytemuck", + "byteorder", +] + +[[package]] +name = "rust-stemmers" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e46a2036019fdb888131db7a4c847a1063a7493f971ed94ea82c67eada63ca54" +dependencies = [ + "serde", + "serde_derive", +] + [[package]] name = "rustc-demangle" version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + +[[package]] +name = "rustc-hash" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" + [[package]] name = "rustc_version" version = "0.2.3" @@ -1289,6 +4096,20 @@ dependencies = [ "semver 0.9.0", ] +[[package]] +name = "rustix" +version = "0.37.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea8ca367a3a01fe35e6943c400addf443c0f57670e6ec51196f71a4b8762dd2" +dependencies = [ + "bitflags 1.3.2", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys 0.3.8", + "windows-sys 0.48.0", +] + [[package]] name = "rustix" version = "0.38.34" @@ -1298,7 +4119,7 @@ dependencies = [ "bitflags 2.5.0", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.4.14", "windows-sys 0.52.0", ] @@ -1310,10 +4131,62 @@ checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" dependencies = [ "log", "ring", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.23.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.8", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile 1.0.4", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-native-certs" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.3", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.1.3", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -1323,6 +4196,22 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-pemfile" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" +dependencies = [ + "base64 0.22.1", + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -1333,12 +4222,38 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "ryu" version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.23" @@ -1348,6 +4263,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot", +] + [[package]] name = "schemars" version = "0.8.20" @@ -1369,7 +4293,7 @@ dependencies = [ "proc-macro2", "quote", "serde_derive_internals", - "syn 2.0.65", + "syn 2.0.77", ] [[package]] @@ -1425,6 +4349,9 @@ name = "semver" version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +dependencies = [ + "serde", +] [[package]] name = "semver-parser" @@ -1434,9 +4361,9 @@ checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3" [[package]] name = "serde" -version = "1.0.203" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] @@ -1452,13 +4379,13 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.77", ] [[package]] @@ -1469,17 +4396,18 @@ checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.77", ] [[package]] name = "serde_json" -version = "1.0.117" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "455182ea6142b14f93f4bc5320a2b31c1f266b66a4a5c858b013302a5d8cbfc3" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ - "indexmap", + "indexmap 2.2.6", "itoa", + "memchr", "ryu", "serde", ] @@ -1503,7 +4431,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "678b5a069e50bf00ecd22d0cd8ddf7c236f68581b03db652061ed5eb13a312ff" dependencies = [ "serde", - "serde_with_macros", + "serde_with_macros 1.5.2", +] + +[[package]] +name = "serde_with" +version = "3.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" +dependencies = [ + "base64 0.22.1", + "chrono", + "hex", + "indexmap 1.9.3", + "indexmap 2.2.6", + "serde", + "serde_derive", + "serde_json", + "serde_with_macros 3.9.0", + "time", ] [[package]] @@ -1512,12 +4458,24 @@ version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e182d6ec6f05393cc0e5ed1bf81ad6db3a8feedf8ee515ecdd369809bcce8082" dependencies = [ - "darling", + "darling 0.13.4", "proc-macro2", "quote", "syn 1.0.109", ] +[[package]] +name = "serde_with_macros" +version = "3.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" +dependencies = [ + "darling 0.20.10", + "proc-macro2", + "quote", + "syn 2.0.77", +] + [[package]] name = "sha-1" version = "0.10.1" @@ -1549,6 +4507,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shellexpand" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da03fa3b94cc19e3ebfc88c4229c49d8f08cdbd1228870a45f0ffdf84988e14b" +dependencies = [ + "dirs", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -1558,6 +4525,36 @@ dependencies = [ "libc", ] +[[package]] +name = "siphasher" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" + +[[package]] +name = "skeptic" +version = "0.13.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16d23b015676c90a0f01c197bfdc786c20342c73a0afdda9025adb0bc42940a8" +dependencies = [ + "bytecount", + "cargo_metadata", + "error-chain", + "glob", + "pulldown-cmark", + "tempfile", + "walkdir", +] + +[[package]] +name = "sketches-ddsketch" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c" +dependencies = [ + "serde", +] + [[package]] name = "slab" version = "0.4.9" @@ -1573,6 +4570,28 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "socket2" version = "0.4.10" @@ -1599,6 +4618,51 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "sqlparser" +version = "0.47.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "295e9930cd7a97e58ca2a070541a3ca502b17f5d1fa7157376d0fabd85324f25" +dependencies = [ + "log", + "sqlparser_derive", +] + +[[package]] +name = "sqlparser_derive" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.77", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "std_prelude" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8207e78455ffdf55661170876f88daf85356e4edd54e0a3dbc79586ca1e50cbe" + +[[package]] +name = "stfu8" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51f1e89f093f99e7432c491c382b88a6860a5adbe6bf02574bf0a08efff1978" + [[package]] name = "stringprep" version = "0.1.4" @@ -1616,6 +4680,34 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.77", +] + [[package]] name = "subtle" version = "2.5.0" @@ -1635,9 +4727,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.65" +version = "2.0.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2863d96a84c6439701d7a38f9de935ec562c8832cc55d1dde0f513b52fad106" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" dependencies = [ "proc-macro2", "quote", @@ -1650,6 +4742,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "sync_wrapper" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" + [[package]] name = "system-configuration" version = "0.5.1" @@ -1671,12 +4769,159 @@ dependencies = [ "libc", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "take_mut" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f764005d11ee5f36500a149ace24e00e3da98b0158b3e2d53a7495660d3f4d60" +[[package]] +name = "tantivy" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8d0582f186c0a6d55655d24543f15e43607299425c5ad8352c242b914b31856" +dependencies = [ + "aho-corasick", + "arc-swap", + "base64 0.22.1", + "bitpacking", + "byteorder", + "census", + "crc32fast", + "crossbeam-channel", + "downcast-rs", + "fastdivide", + "fnv", + "fs4", + "htmlescape", + "itertools 0.12.1", + "levenshtein_automata", + "log", + "lru", + "lz4_flex", + "measure_time", + "memmap2", + "num_cpus", + "once_cell", + "oneshot", + "rayon", + "regex", + "rust-stemmers", + "rustc-hash 1.1.0", + "serde", + "serde_json", + "sketches-ddsketch", + "smallvec", + "tantivy-bitpacker", + "tantivy-columnar", + "tantivy-common", + "tantivy-fst", + "tantivy-query-grammar", + "tantivy-stacker", + "tantivy-tokenizer-api", + "tempfile", + "thiserror", + "time", + "uuid", + "winapi", +] + +[[package]] +name = "tantivy-bitpacker" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "284899c2325d6832203ac6ff5891b297fc5239c3dc754c5bc1977855b23c10df" +dependencies = [ + "bitpacking", +] + +[[package]] +name = "tantivy-columnar" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12722224ffbe346c7fec3275c699e508fd0d4710e629e933d5736ec524a1f44e" +dependencies = [ + "downcast-rs", + "fastdivide", + "itertools 0.12.1", + "serde", + "tantivy-bitpacker", + "tantivy-common", + "tantivy-sstable", + "tantivy-stacker", +] + +[[package]] +name = "tantivy-common" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8019e3cabcfd20a1380b491e13ff42f57bb38bf97c3d5fa5c07e50816e0621f4" +dependencies = [ + "async-trait", + "byteorder", + "ownedbytes", + "serde", + "time", +] + +[[package]] +name = "tantivy-fst" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d60769b80ad7953d8a7b2c70cdfe722bbcdcac6bccc8ac934c40c034d866fc18" +dependencies = [ + "byteorder", + "regex-syntax", + "utf8-ranges", +] + +[[package]] +name = "tantivy-query-grammar" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "847434d4af57b32e309f4ab1b4f1707a6c566656264caa427ff4285c4d9d0b82" +dependencies = [ + "nom", +] + +[[package]] +name = "tantivy-sstable" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c69578242e8e9fc989119f522ba5b49a38ac20f576fc778035b96cc94f41f98e" +dependencies = [ + "tantivy-bitpacker", + "tantivy-common", + "tantivy-fst", + "zstd", +] + +[[package]] +name = "tantivy-stacker" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56d6ff5591fc332739b3ce7035b57995a3ce29a93ffd6012660e0949c956ea8" +dependencies = [ + "murmurhash32", + "rand_distr", + "tantivy-common", +] + +[[package]] +name = "tantivy-tokenizer-api" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a0dcade25819a89cfe6f17d932c9cedff11989936bf6dd4f336d50392053b04" +dependencies = [ + "serde", +] + [[package]] name = "tap" version = "1.0.1" @@ -1690,8 +4935,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" dependencies = [ "cfg-if", - "fastrand", - "rustix", + "fastrand 2.1.0", + "rustix 0.38.34", "windows-sys 0.52.0", ] @@ -1712,7 +4957,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.77", ] [[package]] @@ -1756,6 +5001,15 @@ dependencies = [ "time-core", ] +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + [[package]] name = "tinyvec" version = "1.6.0" @@ -1773,32 +5027,31 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.38.0" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", "bytes", "libc", "mio", - "num_cpus", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2 0.5.7", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "tokio-macros" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.77", ] [[package]] @@ -1817,7 +5070,29 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls", + "rustls 0.21.12", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls 0.23.13", + "rustls-pki-types", + "tokio", +] + +[[package]] +name = "tokio-stream" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" +dependencies = [ + "futures-core", + "pin-project-lite", "tokio", ] @@ -1835,6 +5110,27 @@ dependencies = [ "tokio", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "pin-project", + "pin-project-lite", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + [[package]] name = "tower-service" version = "0.3.2" @@ -1860,7 +5156,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.77", ] [[package]] @@ -1898,6 +5194,12 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "triomphe" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6631e42e10b40c0690bf92f404ebcfe6e1fdb480391d15f17cc8e96eeed5369" + [[package]] name = "trust-dns-proto" version = "0.21.2" @@ -1949,6 +5251,16 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + [[package]] name = "typed-builder" version = "0.10.0" @@ -1966,6 +5278,15 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unicase" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7d2d4dafb69621809a81864c9c1b864479e1235c0dd4e199924b9742439ed89" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.15" @@ -1987,6 +5308,18 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-width" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" + [[package]] name = "untrusted" version = "0.9.0" @@ -2004,6 +5337,18 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + +[[package]] +name = "utf8-ranges" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcfc827f90e53a02eaef5e535ee14266c1d569214c6aa70133a624d8a3164ba" + [[package]] name = "uuid" version = "1.8.0" @@ -2032,6 +5377,28 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + +[[package]] +name = "waker-fn" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "317211a0dc0ceedd78fb2ca9a44aed3d7b9b26f81870d485c07122b4350673b7" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2068,7 +5435,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.77", "wasm-bindgen-shared", ] @@ -2102,7 +5469,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.77", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2113,6 +5480,19 @@ version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +[[package]] +name = "wasm-streams" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.69" @@ -2151,6 +5531,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -2315,6 +5704,16 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "winreg" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5" +dependencies = [ + "cfg-if", + "windows-sys 0.48.0", +] + [[package]] name = "wyz" version = "0.5.1" @@ -2324,6 +5723,12 @@ dependencies = [ "tap", ] +[[package]] +name = "xmlparser" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" + [[package]] name = "zerocopy" version = "0.7.34" @@ -2341,5 +5746,39 @@ checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.65", + "syn 2.0.77", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", ] diff --git a/Cargo.toml b/Cargo.toml index 2501b86f..33e2782f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" members = [ - "rig-core", + "rig-core", "rig-lancedb", "rig-mongodb", ] diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml new file mode 100644 index 00000000..c97b9b78 --- /dev/null +++ b/rig-lancedb/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "rig-lancedb" +version = "0.1.0" +edition = "2021" + +[dependencies] +lancedb = "0.10.0" +tokio = "1.40.0" +rig-core = { path = "../rig-core", version = "0.1.0" } +arrow-array = "52.2.0" +serde_json = "1.0.128" +serde = "1.0.210" +futures = "0.3.30" diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs new file mode 100644 index 00000000..c10cb0b1 --- /dev/null +++ b/rig-lancedb/src/lib.rs @@ -0,0 +1,182 @@ +use std::sync::Arc; + +use arrow_array::{ + builder::{Float64Builder, ListBuilder, StringBuilder, StructBuilder}, + cast::AsArray, + RecordBatch, RecordBatchIterator, StringArray, +}; +use futures::StreamExt; +use lancedb::{ + arrow::arrow_schema::{DataType, Field, Fields, Schema}, + query::{ExecutableQuery, QueryBase}, +}; +use rig::{ + embeddings::DocumentEmbeddings, + vector_store::{VectorStore, VectorStoreError}, +}; + +pub struct LanceDbVectorStore { + table: lancedb::Table, +} + +fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + +impl VectorStore for LanceDbVectorStore { + type Q = lancedb::query::Query; + + async fn add_documents( + &mut self, + documents: Vec, + ) -> Result<(), VectorStoreError> { + let id = StringArray::from_iter_values(documents.clone().into_iter().map(|doc| doc.id)); + let document = StringArray::from_iter_values( + documents + .clone() + .into_iter() + .map(|doc| serde_json::to_string(&doc.document).unwrap()), + ); + + let mut list_builder = ListBuilder::new(StructBuilder::from_fields(self.schema(), 0)); + documents.into_iter().map(|doc| { + let struct_builder = list_builder.values(); + + doc.embeddings.into_iter().for_each(|embedding| { + struct_builder + .field_builder::(0) + .unwrap() + .append_value(embedding.document); + struct_builder + .field_builder::>(1) + .unwrap() + .append_value(embedding.vec.into_iter().map(Some).collect::>()); + struct_builder.append(true); // Append the first struct + }); + + list_builder.append(true) + }); + let embeddings = list_builder.finish(); + + let batches = RecordBatchIterator::new( + vec![RecordBatch::try_new( + Arc::new(Schema::new(self.schema())), + vec![Arc::new(id), Arc::new(document), Arc::new(embeddings)], + ) + .unwrap()] + .into_iter() + .map(Ok), + Arc::new(Schema::new(self.schema())), + ); + + self.table + .add(batches) + .execute() + .await + .map_err(lancedb_to_rig_error)?; + + Ok(()) + } + + async fn get_document_embeddings( + &self, + id: &str, + ) -> Result, VectorStoreError> { + let mut stream = self + .table + .query() + .only_if(format!("id = {id}")) + .execute() + .await + .map_err(lancedb_to_rig_error)?; + + // let record_batches = stream.try_collect::>().await.map_err(lancedb_to_rig_error)?; + + stream.next().await.map(|maybe_record_batch| { + let record_batch = maybe_record_batch?; + + Ok::<(), lancedb::Error>(()) + }); + + todo!() + } + + async fn get_document serde::Deserialize<'a>>( + &self, + id: &str, + ) -> Result, VectorStoreError> { + todo!() + } + + async fn get_document_by_query( + &self, + query: Self::Q, + ) -> Result, VectorStoreError> { + query.execute().await.map_err(lancedb_to_rig_error)?; + + todo!() + } +} + +pub fn to_document_embeddings(record_batch: arrow_array::RecordBatch) -> Vec { + let columns = record_batch.columns().into_iter(); + + let ids = match columns.next() { + Some(column) => match column.data_type() { + DataType::Utf8 => column.as_string::().into_iter().collect(), + _ => vec![], + }, + None => vec![], + }; + let documents = match columns.next() { + Some(column) => match column.data_type() { + DataType::Utf8 => column.as_string::().into_iter().collect(), + _ => vec![], + }, + None => vec![], + }; + + let embeddings = match columns.next() { + Some(column) => match column.data_type() { + DataType::List(embeddings_list) => match embeddings_list.data_type() { + DataType::Struct(embedding_fields) => match embedding_fields.into_iter().next() { + Some(field) => match field.data_type() { + DataType::Utf8 => {} + _ => vec![], + }, + None => vec![], + }, + _ => vec![], + }, + _ => vec![], + }, + None => vec![], + }; + + todo!() +} + +impl LanceDbVectorStore { + pub fn schema(&self) -> Vec { + vec![ + Field::new("id", DataType::Utf8, false), + Field::new("document", DataType::Utf8, false), + Field::new( + "embeddings", + DataType::List(Arc::new(Field::new( + "embedding_item", + DataType::Struct(Fields::from(vec![ + Arc::new(Field::new("document", DataType::Utf8, false)), + Arc::new(Field::new( + "vec", + DataType::List(Arc::new(Field::new("float", DataType::Float64, false))), + false, + )), + ])), + false, + ))), + false, + ), + ] + } +} From 1d9fd646bee925945bf9f1adc632ae0d852d4921 Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 16 Sep 2024 17:14:29 -0400 Subject: [PATCH 02/39] refactor: create wrapper for vec for from/tryfrom traits --- .../src/conversions/document_embeddings.rs | 43 ++++++ rig-lancedb/src/conversions/mod.rs | 42 ++++++ rig-lancedb/src/conversions/record_batch.rs | 56 ++++++++ rig-lancedb/src/lib.rs | 126 +++--------------- 4 files changed, 157 insertions(+), 110 deletions(-) create mode 100644 rig-lancedb/src/conversions/document_embeddings.rs create mode 100644 rig-lancedb/src/conversions/mod.rs create mode 100644 rig-lancedb/src/conversions/record_batch.rs diff --git a/rig-lancedb/src/conversions/document_embeddings.rs b/rig-lancedb/src/conversions/document_embeddings.rs new file mode 100644 index 00000000..64bb3f6e --- /dev/null +++ b/rig-lancedb/src/conversions/document_embeddings.rs @@ -0,0 +1,43 @@ +use arrow_array::cast::AsArray; +use lancedb::arrow::arrow_schema::DataType; + +impl From for super::DocumentEmbeddings { + fn from(record_batch: arrow_array::RecordBatch) -> Self { + let columns = record_batch.columns().into_iter(); + + let ids = match columns.next() { + Some(column) => match column.data_type() { + DataType::Utf8 => column.as_string::().into_iter().collect(), + _ => vec![], + }, + None => vec![], + }; + let documents = match columns.next() { + Some(column) => match column.data_type() { + DataType::Utf8 => column.as_string::().into_iter().collect(), + _ => vec![], + }, + None => vec![], + }; + + let embeddings = match columns.next() { + Some(column) => match column.data_type() { + DataType::List(embeddings_list) => match embeddings_list.data_type() { + DataType::Struct(embedding_fields) => match embedding_fields.into_iter().next() + { + Some(field) => match field.data_type() { + DataType::Utf8 => {} + _ => vec![], + }, + None => vec![], + }, + _ => vec![], + }, + _ => vec![], + }, + None => vec![], + }; + + todo!() + } +} diff --git a/rig-lancedb/src/conversions/mod.rs b/rig-lancedb/src/conversions/mod.rs new file mode 100644 index 00000000..305eb005 --- /dev/null +++ b/rig-lancedb/src/conversions/mod.rs @@ -0,0 +1,42 @@ +use std::sync::Arc; + +use lancedb::arrow::arrow_schema::{DataType, Field, Fields}; + +pub mod document_embeddings; +pub mod record_batch; + +#[derive(Clone)] +pub struct DocumentEmbeddings(pub Vec); + +impl DocumentEmbeddings { + pub fn new(documents: Vec) -> Self { + Self(documents) + } + + pub fn as_iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn schema(&self) -> Vec { + vec![ + Field::new("id", DataType::Utf8, false), + Field::new("document", DataType::Utf8, false), + Field::new( + "embeddings", + DataType::List(Arc::new(Field::new( + "embedding_item", + DataType::Struct(Fields::from(vec![ + Arc::new(Field::new("document", DataType::Utf8, false)), + Arc::new(Field::new( + "vec", + DataType::List(Arc::new(Field::new("float", DataType::Float64, false))), + false, + )), + ])), + false, + ))), + false, + ), + ] + } +} diff --git a/rig-lancedb/src/conversions/record_batch.rs b/rig-lancedb/src/conversions/record_batch.rs new file mode 100644 index 00000000..247e57bf --- /dev/null +++ b/rig-lancedb/src/conversions/record_batch.rs @@ -0,0 +1,56 @@ +use std::sync::Arc; + +use arrow_array::{ + builder::{Float64Builder, ListBuilder, StringBuilder, StructBuilder}, + RecordBatch, StringArray, +}; +use lancedb::arrow::arrow_schema::{ArrowError, Schema}; +use rig::vector_store::VectorStoreError; + +pub fn arrow_to_rig_error(e: lancedb::arrow::arrow_schema::ArrowError) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + +impl TryFrom for RecordBatch { + type Error = ArrowError; + + fn try_from(documents: super::DocumentEmbeddings) -> Result { + let id = StringArray::from_iter_values(documents.as_iter().map(|doc| doc.id.clone())); + let document = StringArray::from_iter_values( + documents + .as_iter() + .map(|doc| serde_json::to_string(&doc.document).unwrap()), + ); + + let mut list_builder = ListBuilder::new(StructBuilder::from_fields(documents.schema(), 0)); + documents.as_iter().map(|doc| { + let struct_builder = list_builder.values(); + + doc.embeddings.clone().into_iter().for_each(|embedding| { + struct_builder + .field_builder::(0) + .unwrap() + .append_value(embedding.document); + struct_builder + .field_builder::>(1) + .unwrap() + .append_value(embedding.vec.into_iter().map(Some).collect::>()); + struct_builder.append(true); // Append the first struct + }); + + list_builder.append(true) + }); + let embeddings = list_builder.finish(); + + RecordBatch::try_new( + Arc::new(Schema::new(documents.schema())), + vec![Arc::new(id), Arc::new(document), Arc::new(embeddings)], + ) + } +} + +#[cfg(test)] +mod tests { + #[tokio::test] + async fn record_batch_deserialization() {} +} diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index c10cb0b1..3837bb03 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -1,19 +1,15 @@ use std::sync::Arc; -use arrow_array::{ - builder::{Float64Builder, ListBuilder, StringBuilder, StructBuilder}, - cast::AsArray, - RecordBatch, RecordBatchIterator, StringArray, -}; +use arrow_array::{cast::AsArray, RecordBatch, RecordBatchIterator}; +use conversions::{record_batch::arrow_to_rig_error, DocumentEmbeddings}; use futures::StreamExt; use lancedb::{ - arrow::arrow_schema::{DataType, Field, Fields, Schema}, + arrow::arrow_schema::{DataType, Schema}, query::{ExecutableQuery, QueryBase}, }; -use rig::{ - embeddings::DocumentEmbeddings, - vector_store::{VectorStore, VectorStoreError}, -}; +use rig::vector_store::{VectorStore, VectorStoreError}; + +mod conversions; pub struct LanceDbVectorStore { table: lancedb::Table, @@ -28,45 +24,18 @@ impl VectorStore for LanceDbVectorStore { async fn add_documents( &mut self, - documents: Vec, + documents: Vec, ) -> Result<(), VectorStoreError> { - let id = StringArray::from_iter_values(documents.clone().into_iter().map(|doc| doc.id)); - let document = StringArray::from_iter_values( - documents - .clone() - .into_iter() - .map(|doc| serde_json::to_string(&doc.document).unwrap()), - ); + let document_embeddings = DocumentEmbeddings::new(documents); - let mut list_builder = ListBuilder::new(StructBuilder::from_fields(self.schema(), 0)); - documents.into_iter().map(|doc| { - let struct_builder = list_builder.values(); - - doc.embeddings.into_iter().for_each(|embedding| { - struct_builder - .field_builder::(0) - .unwrap() - .append_value(embedding.document); - struct_builder - .field_builder::>(1) - .unwrap() - .append_value(embedding.vec.into_iter().map(Some).collect::>()); - struct_builder.append(true); // Append the first struct - }); - - list_builder.append(true) - }); - let embeddings = list_builder.finish(); + let record_batch = document_embeddings + .clone() + .try_into() + .map_err(arrow_to_rig_error)?; let batches = RecordBatchIterator::new( - vec![RecordBatch::try_new( - Arc::new(Schema::new(self.schema())), - vec![Arc::new(id), Arc::new(document), Arc::new(embeddings)], - ) - .unwrap()] - .into_iter() - .map(Ok), - Arc::new(Schema::new(self.schema())), + vec![record_batch].into_iter().map(Ok), + Arc::new(Schema::new(document_embeddings.schema())), ); self.table @@ -81,7 +50,7 @@ impl VectorStore for LanceDbVectorStore { async fn get_document_embeddings( &self, id: &str, - ) -> Result, VectorStoreError> { + ) -> Result, VectorStoreError> { let mut stream = self .table .query() @@ -111,72 +80,9 @@ impl VectorStore for LanceDbVectorStore { async fn get_document_by_query( &self, query: Self::Q, - ) -> Result, VectorStoreError> { + ) -> Result, VectorStoreError> { query.execute().await.map_err(lancedb_to_rig_error)?; todo!() } } - -pub fn to_document_embeddings(record_batch: arrow_array::RecordBatch) -> Vec { - let columns = record_batch.columns().into_iter(); - - let ids = match columns.next() { - Some(column) => match column.data_type() { - DataType::Utf8 => column.as_string::().into_iter().collect(), - _ => vec![], - }, - None => vec![], - }; - let documents = match columns.next() { - Some(column) => match column.data_type() { - DataType::Utf8 => column.as_string::().into_iter().collect(), - _ => vec![], - }, - None => vec![], - }; - - let embeddings = match columns.next() { - Some(column) => match column.data_type() { - DataType::List(embeddings_list) => match embeddings_list.data_type() { - DataType::Struct(embedding_fields) => match embedding_fields.into_iter().next() { - Some(field) => match field.data_type() { - DataType::Utf8 => {} - _ => vec![], - }, - None => vec![], - }, - _ => vec![], - }, - _ => vec![], - }, - None => vec![], - }; - - todo!() -} - -impl LanceDbVectorStore { - pub fn schema(&self) -> Vec { - vec![ - Field::new("id", DataType::Utf8, false), - Field::new("document", DataType::Utf8, false), - Field::new( - "embeddings", - DataType::List(Arc::new(Field::new( - "embedding_item", - DataType::Struct(Fields::from(vec![ - Arc::new(Field::new("document", DataType::Utf8, false)), - Arc::new(Field::new( - "vec", - DataType::List(Arc::new(Field::new("float", DataType::Float64, false))), - false, - )), - ])), - false, - ))), - false, - ), - ] - } -} From e9d18c52d6e16729ebf8482df5097132d838bae9 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 17 Sep 2024 18:15:11 -0400 Subject: [PATCH 03/39] feat: implement add_documents on VectorStore trait --- .../src/conversions/document_embeddings.rs | 43 ------- rig-lancedb/src/conversions/mod.rs | 112 ++++++++++++------ rig-lancedb/src/conversions/record_batch.rs | 56 --------- rig-lancedb/src/lib.rs | 63 +++++----- 4 files changed, 111 insertions(+), 163 deletions(-) delete mode 100644 rig-lancedb/src/conversions/document_embeddings.rs delete mode 100644 rig-lancedb/src/conversions/record_batch.rs diff --git a/rig-lancedb/src/conversions/document_embeddings.rs b/rig-lancedb/src/conversions/document_embeddings.rs deleted file mode 100644 index 64bb3f6e..00000000 --- a/rig-lancedb/src/conversions/document_embeddings.rs +++ /dev/null @@ -1,43 +0,0 @@ -use arrow_array::cast::AsArray; -use lancedb::arrow::arrow_schema::DataType; - -impl From for super::DocumentEmbeddings { - fn from(record_batch: arrow_array::RecordBatch) -> Self { - let columns = record_batch.columns().into_iter(); - - let ids = match columns.next() { - Some(column) => match column.data_type() { - DataType::Utf8 => column.as_string::().into_iter().collect(), - _ => vec![], - }, - None => vec![], - }; - let documents = match columns.next() { - Some(column) => match column.data_type() { - DataType::Utf8 => column.as_string::().into_iter().collect(), - _ => vec![], - }, - None => vec![], - }; - - let embeddings = match columns.next() { - Some(column) => match column.data_type() { - DataType::List(embeddings_list) => match embeddings_list.data_type() { - DataType::Struct(embedding_fields) => match embedding_fields.into_iter().next() - { - Some(field) => match field.data_type() { - DataType::Utf8 => {} - _ => vec![], - }, - None => vec![], - }, - _ => vec![], - }, - _ => vec![], - }, - None => vec![], - }; - - todo!() - } -} diff --git a/rig-lancedb/src/conversions/mod.rs b/rig-lancedb/src/conversions/mod.rs index 305eb005..99f153ae 100644 --- a/rig-lancedb/src/conversions/mod.rs +++ b/rig-lancedb/src/conversions/mod.rs @@ -1,42 +1,88 @@ use std::sync::Arc; -use lancedb::arrow::arrow_schema::{DataType, Field, Fields}; +use arrow_array::{ + builder::{Float64Builder, ListBuilder}, + RecordBatch, StringArray, +}; +use lancedb::arrow::arrow_schema::{ArrowError, DataType, Field, Fields, Schema}; +use rig::embeddings::DocumentEmbeddings; -pub mod document_embeddings; -pub mod record_batch; +pub fn document_schema() -> Schema { + Schema::new(Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("document", DataType::Utf8, false), + ])) +} + +pub fn embedding_schema() -> Schema { + Schema::new(Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("document_id", DataType::Utf8, false), + Field::new("content", DataType::Utf8, false), + Field::new( + "embedding", + DataType::List(Arc::new(Field::new("float", DataType::Float64, false))), + false, + ), + ])) +} + +pub fn document_records(documents: &Vec) -> Result { + let id = StringArray::from_iter_values(documents.iter().map(|doc| doc.id.clone())); + let document = StringArray::from_iter_values( + documents + .iter() + .map(|doc| serde_json::to_string(&doc.document.clone()).unwrap()), + ); + + RecordBatch::try_new( + Arc::new(document_schema()), + vec![Arc::new(id), Arc::new(document)], + ) +} + +struct EmbeddingRecord { + id: String, + document_id: String, + content: String, + embedding: Vec, +} -#[derive(Clone)] -pub struct DocumentEmbeddings(pub Vec); +pub fn embedding_records(documents: &Vec) -> Result { + let embedding_records = documents.into_iter().flat_map(|document| { + document + .embeddings.clone() + .into_iter() + .map(move |embedding| EmbeddingRecord { + id: "".to_string(), + document_id: document.id.clone(), + content: embedding.document, + embedding: embedding.vec, + }) + }); -impl DocumentEmbeddings { - pub fn new(documents: Vec) -> Self { - Self(documents) - } + let id = StringArray::from_iter_values(embedding_records.clone().map(|record| record.id)); + let document_id = + StringArray::from_iter_values(embedding_records.clone().map(|record| record.document_id)); + let content = + StringArray::from_iter_values(embedding_records.clone().map(|record| record.content)); - pub fn as_iter(&self) -> impl Iterator { - self.0.iter() - } + let mut builder = ListBuilder::new(Float64Builder::new()); + embedding_records.for_each(|record| { + record + .embedding + .iter() + .for_each(|value| builder.values().append_value(*value)); + builder.append(true); + }); - pub fn schema(&self) -> Vec { + RecordBatch::try_new( + Arc::new(document_schema()), vec![ - Field::new("id", DataType::Utf8, false), - Field::new("document", DataType::Utf8, false), - Field::new( - "embeddings", - DataType::List(Arc::new(Field::new( - "embedding_item", - DataType::Struct(Fields::from(vec![ - Arc::new(Field::new("document", DataType::Utf8, false)), - Arc::new(Field::new( - "vec", - DataType::List(Arc::new(Field::new("float", DataType::Float64, false))), - false, - )), - ])), - false, - ))), - false, - ), - ] - } + Arc::new(id), + Arc::new(document_id), + Arc::new(content), + Arc::new(builder.finish()), + ], + ) } diff --git a/rig-lancedb/src/conversions/record_batch.rs b/rig-lancedb/src/conversions/record_batch.rs deleted file mode 100644 index 247e57bf..00000000 --- a/rig-lancedb/src/conversions/record_batch.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::sync::Arc; - -use arrow_array::{ - builder::{Float64Builder, ListBuilder, StringBuilder, StructBuilder}, - RecordBatch, StringArray, -}; -use lancedb::arrow::arrow_schema::{ArrowError, Schema}; -use rig::vector_store::VectorStoreError; - -pub fn arrow_to_rig_error(e: lancedb::arrow::arrow_schema::ArrowError) -> VectorStoreError { - VectorStoreError::DatastoreError(Box::new(e)) -} - -impl TryFrom for RecordBatch { - type Error = ArrowError; - - fn try_from(documents: super::DocumentEmbeddings) -> Result { - let id = StringArray::from_iter_values(documents.as_iter().map(|doc| doc.id.clone())); - let document = StringArray::from_iter_values( - documents - .as_iter() - .map(|doc| serde_json::to_string(&doc.document).unwrap()), - ); - - let mut list_builder = ListBuilder::new(StructBuilder::from_fields(documents.schema(), 0)); - documents.as_iter().map(|doc| { - let struct_builder = list_builder.values(); - - doc.embeddings.clone().into_iter().for_each(|embedding| { - struct_builder - .field_builder::(0) - .unwrap() - .append_value(embedding.document); - struct_builder - .field_builder::>(1) - .unwrap() - .append_value(embedding.vec.into_iter().map(Some).collect::>()); - struct_builder.append(true); // Append the first struct - }); - - list_builder.append(true) - }); - let embeddings = list_builder.finish(); - - RecordBatch::try_new( - Arc::new(Schema::new(documents.schema())), - vec![Arc::new(id), Arc::new(document), Arc::new(embeddings)], - ) - } -} - -#[cfg(test)] -mod tests { - #[tokio::test] - async fn record_batch_deserialization() {} -} diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 3837bb03..bad7c1f9 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -1,18 +1,15 @@ use std::sync::Arc; -use arrow_array::{cast::AsArray, RecordBatch, RecordBatchIterator}; -use conversions::{record_batch::arrow_to_rig_error, DocumentEmbeddings}; -use futures::StreamExt; -use lancedb::{ - arrow::arrow_schema::{DataType, Schema}, - query::{ExecutableQuery, QueryBase}, -}; +use arrow_array::RecordBatchIterator; +use conversions::{document_records, document_schema, embedding_records, embedding_schema}; +use lancedb::{arrow::arrow_schema::{ArrowError, Schema}, query::ExecutableQuery}; use rig::vector_store::{VectorStore, VectorStoreError}; mod conversions; pub struct LanceDbVectorStore { - table: lancedb::Table, + document_table: lancedb::Table, + embedding_table: lancedb::Table, } fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { @@ -26,20 +23,24 @@ impl VectorStore for LanceDbVectorStore { &mut self, documents: Vec, ) -> Result<(), VectorStoreError> { - let document_embeddings = DocumentEmbeddings::new(documents); - - let record_batch = document_embeddings - .clone() - .try_into() - .map_err(arrow_to_rig_error)?; + let document_batches = RecordBatchIterator::new( + vec![document_records(&documents)], + Arc::new(document_schema()), + ); - let batches = RecordBatchIterator::new( - vec![record_batch].into_iter().map(Ok), - Arc::new(Schema::new(document_embeddings.schema())), + let embedding_batches = RecordBatchIterator::new( + vec![embedding_records(&documents)], + Arc::new(embedding_schema()), ); - self.table - .add(batches) + self.document_table + .add(document_batches) + .execute() + .await + .map_err(lancedb_to_rig_error)?; + + self.embedding_table + .add(embedding_batches) .execute() .await .map_err(lancedb_to_rig_error)?; @@ -51,21 +52,21 @@ impl VectorStore for LanceDbVectorStore { &self, id: &str, ) -> Result, VectorStoreError> { - let mut stream = self - .table - .query() - .only_if(format!("id = {id}")) - .execute() - .await - .map_err(lancedb_to_rig_error)?; + // let mut stream = self + // .table + // .query() + // .only_if(format!("id = {id}")) + // .execute() + // .await + // .map_err(lancedb_to_rig_error)?; - // let record_batches = stream.try_collect::>().await.map_err(lancedb_to_rig_error)?; + // // let record_batches = stream.try_collect::>().await.map_err(lancedb_to_rig_error)?; - stream.next().await.map(|maybe_record_batch| { - let record_batch = maybe_record_batch?; + // stream.next().await.map(|maybe_record_batch| { + // let record_batch = maybe_record_batch?; - Ok::<(), lancedb::Error>(()) - }); + // Ok::<(), lancedb::Error>(()) + // }); todo!() } From d3c7f9ac13aba9076dfe1779e4887b414ae8dcef Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 18 Sep 2024 16:22:14 -0400 Subject: [PATCH 04/39] feat: implement search by id for VectorStore trait --- Cargo.lock | 12 +- rig-lancedb/src/conversions/mod.rs | 88 -------- rig-lancedb/src/lib.rs | 82 ++++--- rig-lancedb/src/table_schemas/document.rs | 177 +++++++++++++++ rig-lancedb/src/table_schemas/embedding.rs | 243 +++++++++++++++++++++ rig-lancedb/src/table_schemas/mod.rs | 42 ++++ rig-lancedb/src/utils/mod.rs | 70 ++++++ 7 files changed, 587 insertions(+), 127 deletions(-) delete mode 100644 rig-lancedb/src/conversions/mod.rs create mode 100644 rig-lancedb/src/table_schemas/document.rs create mode 100644 rig-lancedb/src/table_schemas/embedding.rs create mode 100644 rig-lancedb/src/table_schemas/mod.rs create mode 100644 rig-lancedb/src/utils/mod.rs diff --git a/Cargo.lock b/Cargo.lock index ab13cda6..cbddf0c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -63,9 +63,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" [[package]] name = "arc-swap" @@ -4942,18 +4942,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.61" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", diff --git a/rig-lancedb/src/conversions/mod.rs b/rig-lancedb/src/conversions/mod.rs deleted file mode 100644 index 99f153ae..00000000 --- a/rig-lancedb/src/conversions/mod.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::sync::Arc; - -use arrow_array::{ - builder::{Float64Builder, ListBuilder}, - RecordBatch, StringArray, -}; -use lancedb::arrow::arrow_schema::{ArrowError, DataType, Field, Fields, Schema}; -use rig::embeddings::DocumentEmbeddings; - -pub fn document_schema() -> Schema { - Schema::new(Fields::from(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("document", DataType::Utf8, false), - ])) -} - -pub fn embedding_schema() -> Schema { - Schema::new(Fields::from(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("document_id", DataType::Utf8, false), - Field::new("content", DataType::Utf8, false), - Field::new( - "embedding", - DataType::List(Arc::new(Field::new("float", DataType::Float64, false))), - false, - ), - ])) -} - -pub fn document_records(documents: &Vec) -> Result { - let id = StringArray::from_iter_values(documents.iter().map(|doc| doc.id.clone())); - let document = StringArray::from_iter_values( - documents - .iter() - .map(|doc| serde_json::to_string(&doc.document.clone()).unwrap()), - ); - - RecordBatch::try_new( - Arc::new(document_schema()), - vec![Arc::new(id), Arc::new(document)], - ) -} - -struct EmbeddingRecord { - id: String, - document_id: String, - content: String, - embedding: Vec, -} - -pub fn embedding_records(documents: &Vec) -> Result { - let embedding_records = documents.into_iter().flat_map(|document| { - document - .embeddings.clone() - .into_iter() - .map(move |embedding| EmbeddingRecord { - id: "".to_string(), - document_id: document.id.clone(), - content: embedding.document, - embedding: embedding.vec, - }) - }); - - let id = StringArray::from_iter_values(embedding_records.clone().map(|record| record.id)); - let document_id = - StringArray::from_iter_values(embedding_records.clone().map(|record| record.document_id)); - let content = - StringArray::from_iter_values(embedding_records.clone().map(|record| record.content)); - - let mut builder = ListBuilder::new(Float64Builder::new()); - embedding_records.for_each(|record| { - record - .embedding - .iter() - .for_each(|value| builder.values().append_value(*value)); - builder.append(true); - }); - - RecordBatch::try_new( - Arc::new(document_schema()), - vec![ - Arc::new(id), - Arc::new(document_id), - Arc::new(content), - Arc::new(builder.finish()), - ], - ) -} diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index bad7c1f9..4eab3e99 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -1,11 +1,17 @@ use std::sync::Arc; use arrow_array::RecordBatchIterator; -use conversions::{document_records, document_schema, embedding_records, embedding_schema}; -use lancedb::{arrow::arrow_schema::{ArrowError, Schema}, query::ExecutableQuery}; +use lancedb::query::QueryBase; use rig::vector_store::{VectorStore, VectorStoreError}; +use table_schemas::{ + document::{document_schema, DocumentRecords}, + embedding::{embedding_schema, EmbeddingRecordsBatch}, + merge, +}; +use utils::Query; -mod conversions; +mod table_schemas; +mod utils; pub struct LanceDbVectorStore { document_table: lancedb::Table, @@ -16,6 +22,10 @@ fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } +fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + impl VectorStore for LanceDbVectorStore { type Q = lancedb::query::Query; @@ -23,24 +33,25 @@ impl VectorStore for LanceDbVectorStore { &mut self, documents: Vec, ) -> Result<(), VectorStoreError> { - let document_batches = RecordBatchIterator::new( - vec![document_records(&documents)], - Arc::new(document_schema()), - ); - - let embedding_batches = RecordBatchIterator::new( - vec![embedding_records(&documents)], - Arc::new(embedding_schema()), - ); + let document_records = + DocumentRecords::try_from(documents.clone()).map_err(serde_to_rig_error)?; self.document_table - .add(document_batches) + .add(RecordBatchIterator::new( + vec![document_records.try_into()], + Arc::new(document_schema()), + )) .execute() .await .map_err(lancedb_to_rig_error)?; + let embedding_records = EmbeddingRecordsBatch::from(documents); + self.embedding_table - .add(embedding_batches) + .add(RecordBatchIterator::new( + embedding_records.record_batch_iter(), + Arc::new(embedding_schema()), + )) .execute() .await .map_err(lancedb_to_rig_error)?; @@ -52,23 +63,21 @@ impl VectorStore for LanceDbVectorStore { &self, id: &str, ) -> Result, VectorStoreError> { - // let mut stream = self - // .table - // .query() - // .only_if(format!("id = {id}")) - // .execute() - // .await - // .map_err(lancedb_to_rig_error)?; - - // // let record_batches = stream.try_collect::>().await.map_err(lancedb_to_rig_error)?; - - // stream.next().await.map(|maybe_record_batch| { - // let record_batch = maybe_record_batch?; - - // Ok::<(), lancedb::Error>(()) - // }); - - todo!() + let documents: DocumentRecords = self + .document_table + .query() + .only_if(format!("id = {id}")) + .execute_query() + .await?; + + let embeddings: EmbeddingRecordsBatch = self + .embedding_table + .query() + .only_if(format!("document_id = {id}")) + .execute_query() + .await?; + + Ok(merge(documents, embeddings)?.into_iter().next()) } async fn get_document serde::Deserialize<'a>>( @@ -82,8 +91,15 @@ impl VectorStore for LanceDbVectorStore { &self, query: Self::Q, ) -> Result, VectorStoreError> { - query.execute().await.map_err(lancedb_to_rig_error)?; + let documents: DocumentRecords = query.execute_query().await?; - todo!() + let embeddings: EmbeddingRecordsBatch = self + .embedding_table + .query() + .only_if(format!("document_id IN [{}]", documents.ids().join(","))) + .execute_query() + .await?; + + Ok(merge(documents, embeddings)?.into_iter().next()) } } diff --git a/rig-lancedb/src/table_schemas/document.rs b/rig-lancedb/src/table_schemas/document.rs new file mode 100644 index 00000000..977e7722 --- /dev/null +++ b/rig-lancedb/src/table_schemas/document.rs @@ -0,0 +1,177 @@ +use std::sync::Arc; + +use arrow_array::{RecordBatch, StringArray}; +use lancedb::arrow::arrow_schema::{ArrowError, DataType, Field, Fields, Schema}; +use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; + +use crate::utils::DeserializeArrow; + +/// Schema of `documents` table in LanceDB defined as a struct. +#[derive(Clone, Debug)] +pub struct DocumentRecord { + pub id: String, + pub document: String, +} + +/// Schema of `documents` table in LanceDB defined in `Schema` terms. +pub fn document_schema() -> Schema { + Schema::new(Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("document", DataType::Utf8, false), + ])) +} + +/// Wrapper around `Vec` +#[derive(Debug)] +pub struct DocumentRecords(Vec); + +impl DocumentRecords { + fn new() -> Self { + Self(Vec::new()) + } + + pub fn as_iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn ids(&self) -> Vec { + self.as_iter().map(|doc| doc.id.clone()).collect() + } + + fn add_records(&mut self, records: Vec) { + self.0.extend(records); + } +} + +/// Converts a `DocumentEmbeddings` object to a `DocumentRecord` object. +/// The `DocumentRecord` contains the correct schema required by the `documents` table. +impl TryFrom for DocumentRecord { + type Error = serde_json::Error; + + fn try_from(document: DocumentEmbeddings) -> Result { + Ok(DocumentRecord { + id: document.id, + document: serde_json::to_string(&document.document)?, + }) + } +} + +/// Converts a list of `DocumentEmbeddings` objects to a list of `DocumentRecord` objects. +/// This is useful when we need to write many `DocumentEmbeddings` items to the `documents` table at once. +impl TryFrom> for DocumentRecords { + type Error = serde_json::Error; + + fn try_from(documents: Vec) -> Result { + Ok(Self( + documents + .into_iter() + .map(DocumentRecord::try_from) + .collect::, _>>()?, + )) + } +} + +/// Convert a list of `DocumentRecord` objects to a `RecordBatch` object. +/// All data written to a lanceDB table must be a `RecordBatch` object. +impl TryFrom for RecordBatch { + type Error = ArrowError; + + fn try_from(document_records: DocumentRecords) -> Result { + let id = + StringArray::from_iter_values(document_records.as_iter().map(|doc| doc.id.clone())); + let document = StringArray::from_iter_values( + document_records.as_iter().map(|doc| doc.document.clone()), + ); + + RecordBatch::try_new( + Arc::new(document_schema()), + vec![Arc::new(id), Arc::new(document)], + ) + } +} + +/// Convert a `RecordBatch` object, read from a lanceDb table, to a list of `DocumentRecord` objects. +/// This allows us to convert the query result to our data format. +impl TryFrom for DocumentRecords { + type Error = ArrowError; + + fn try_from(record_batch: RecordBatch) -> Result { + let ids = record_batch.deserialize_str_column(0)?; + let documents = record_batch.deserialize_str_column(1)?; + + Ok(DocumentRecords( + ids.into_iter() + .zip(documents) + .map(|(id, document)| DocumentRecord { + id: id.to_string(), + document: document.to_string(), + }) + .collect(), + )) + } +} + +/// Convert a list of `RecordBatch` objects, read from a lanceDb table, to a list of `DocumentRecord` objects. +impl TryFrom> for DocumentRecords { + type Error = VectorStoreError; + + fn try_from(record_batches: Vec) -> Result { + let documents = record_batches + .into_iter() + .map(DocumentRecords::try_from) + .collect::, _>>() + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + Ok(documents + .into_iter() + .fold(DocumentRecords::new(), |mut acc, document| { + acc.add_records(document.0); + acc + })) + } +} + +#[cfg(test)] +mod tests { + use arrow_array::RecordBatch; + + use crate::table_schemas::document::{DocumentRecord, DocumentRecords}; + + #[tokio::test] + async fn test_record_batch_deserialize() { + let document_records = DocumentRecords(vec![ + DocumentRecord { + id: "ABC".to_string(), + document: serde_json::json!({ + "title": "Hello world", + "body": "Greetings", + }) + .to_string(), + }, + DocumentRecord { + id: "DEF".to_string(), + document: serde_json::json!({ + "title": "Sup dog", + "body": "Greetings", + }) + .to_string(), + }, + ]); + + let record_batch = RecordBatch::try_from(document_records).unwrap(); + + let deserialized_record_batch = DocumentRecords::try_from(record_batch).unwrap(); + + assert_eq!(deserialized_record_batch.0.len(), 2); + + assert_eq!(deserialized_record_batch.0[0].id, "ABC"); + assert_eq!( + deserialized_record_batch.0[0].document, + serde_json::json!({ + "title": "Hello world", + "body": "Greetings", + }) + .to_string() + ); + } +} diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs new file mode 100644 index 00000000..7807ca08 --- /dev/null +++ b/rig-lancedb/src/table_schemas/embedding.rs @@ -0,0 +1,243 @@ +use std::{collections::HashMap, sync::Arc}; + +use arrow_array::{ + builder::{Float64Builder, ListBuilder}, + RecordBatch, StringArray, +}; +use lancedb::arrow::arrow_schema::{ArrowError, DataType, Field, Fields, Schema}; +use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; + +use crate::utils::DeserializeArrow; + +// Data format in the LanceDB table `embeddings` +#[derive(Clone, Debug, PartialEq)] +pub struct EmbeddingRecord { + pub id: String, + pub document_id: String, + pub content: String, + pub embedding: Vec, +} + +#[derive(Clone, Debug)] +pub struct EmbeddingRecords(Vec); + +impl EmbeddingRecords { + fn new(records: Vec) -> Self { + EmbeddingRecords(records) + } + + pub fn as_iter(&self) -> impl Iterator { + self.0.iter() + } + + fn add_record(&mut self, record: EmbeddingRecord) { + self.0.push(record); + } +} + +impl From for EmbeddingRecords { + fn from(document: DocumentEmbeddings) -> Self { + EmbeddingRecords( + document + .embeddings + .clone() + .into_iter() + .map(move |embedding| EmbeddingRecord { + id: "".to_string(), + document_id: document.id.clone(), + content: embedding.document, + embedding: embedding.vec, + }) + .collect(), + ) + } +} + +impl From> for EmbeddingRecordsBatch { + fn from(documents: Vec) -> Self { + EmbeddingRecordsBatch( + documents + .into_iter() + .fold(HashMap::new(), |mut acc, document| { + acc.insert(document.id.clone(), EmbeddingRecords::from(document)); + acc + }), + ) + } +} + +impl TryFrom for RecordBatch { + fn try_from(embedding_records: EmbeddingRecords) -> Result { + let id = StringArray::from_iter_values( + embedding_records.as_iter().map(|record| record.id.clone()), + ); + let document_id = StringArray::from_iter_values( + embedding_records + .as_iter() + .map(|record| record.document_id.clone()), + ); + let content = StringArray::from_iter_values( + embedding_records + .as_iter() + .map(|record| record.content.clone()), + ); + + let mut builder = ListBuilder::new(Float64Builder::new()); + embedding_records.as_iter().for_each(|record| { + record + .embedding + .iter() + .for_each(|value| builder.values().append_value(*value)); + builder.append(true); + }); + + RecordBatch::try_new( + Arc::new(embedding_schema()), + vec![ + Arc::new(id), + Arc::new(document_id), + Arc::new(content), + Arc::new(builder.finish()), + ], + ) + } + + type Error = ArrowError; +} + +pub struct EmbeddingRecordsBatch(HashMap); +impl EmbeddingRecordsBatch { + fn as_iter(&self) -> impl Iterator { + self.0.clone().into_values().collect::>().into_iter() + } + + pub fn record_batch_iter(&self) -> impl Iterator> { + self.as_iter().map(RecordBatch::try_from) + } + + pub fn get_by_id(&self, id: &str) -> Option { + self.0.get(id).cloned() + } +} + +pub fn embedding_schema() -> Schema { + Schema::new(Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("document_id", DataType::Utf8, false), + Field::new("content", DataType::Utf8, false), + Field::new( + "embedding", + DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), + false, + ), + ])) +} + +impl TryFrom for EmbeddingRecords { + type Error = ArrowError; + + fn try_from(record_batch: RecordBatch) -> Result { + let ids = record_batch.deserialize_str_column(0)?; + let document_ids = record_batch.deserialize_str_column(1)?; + let contents = record_batch.deserialize_str_column(2)?; + let embeddings = record_batch.deserialize_list_column(3)?; + + Ok(EmbeddingRecords( + ids.into_iter() + .zip(document_ids) + .zip(contents) + .zip(embeddings) + .map( + |(((id, document_id), content), embedding)| EmbeddingRecord { + id: id.to_string(), + document_id: document_id.to_string(), + content: content.to_string(), + embedding, + }, + ) + .collect(), + )) + } +} + +impl TryFrom> for EmbeddingRecordsBatch { + type Error = VectorStoreError; + + fn try_from(record_batches: Vec) -> Result { + let embedding_records = record_batches + .into_iter() + .map(EmbeddingRecords::try_from) + .collect::, _>>() + .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; + + let grouped_records = + embedding_records + .into_iter() + .fold(HashMap::new(), |mut acc, records| { + records.as_iter().for_each(|record| { + acc.entry(record.document_id.clone()) + .and_modify(|item: &mut EmbeddingRecords| { + item.add_record(record.clone()) + }) + .or_insert(EmbeddingRecords::new(vec![record.clone()])); + }); + acc + }); + + Ok(EmbeddingRecordsBatch(grouped_records)) + } +} + +#[cfg(test)] +mod tests { + use arrow_array::RecordBatch; + + use crate::table_schemas::embedding::{EmbeddingRecord, EmbeddingRecords}; + + #[tokio::test] + async fn test_record_batch_deserialize() { + let embedding_records = EmbeddingRecords(vec![ + EmbeddingRecord { + id: "some_id".to_string(), + document_id: "ABC".to_string(), + content: serde_json::json!({ + "title": "Hello world", + "body": "Greetings", + }) + .to_string(), + embedding: vec![1.0, 2.0, 3.0], + }, + EmbeddingRecord { + id: "another_id".to_string(), + document_id: "DEF".to_string(), + content: serde_json::json!({ + "title": "Sup dog", + "body": "Greetings", + }) + .to_string(), + embedding: vec![4.0, 5.0, 6.0], + }, + ]); + + let record_batch = RecordBatch::try_from(embedding_records).unwrap(); + + let deserialized_record_batch = EmbeddingRecords::try_from(record_batch).unwrap(); + + assert_eq!(deserialized_record_batch.as_iter().count(), 2); + assert_eq!( + deserialized_record_batch.as_iter().nth(0).unwrap().clone(), + EmbeddingRecord { + id: "some_id".to_string(), + document_id: "ABC".to_string(), + content: serde_json::json!({ + "title": "Hello world", + "body": "Greetings", + }) + .to_string(), + embedding: vec![1.0, 2.0, 3.0], + } + ); + + assert!(false) + } +} diff --git a/rig-lancedb/src/table_schemas/mod.rs b/rig-lancedb/src/table_schemas/mod.rs new file mode 100644 index 00000000..9f6f1ba5 --- /dev/null +++ b/rig-lancedb/src/table_schemas/mod.rs @@ -0,0 +1,42 @@ +use document::{DocumentRecord, DocumentRecords}; +use embedding::{EmbeddingRecord, EmbeddingRecordsBatch}; +use rig::{ + embeddings::{DocumentEmbeddings, Embedding}, + vector_store::VectorStoreError, +}; + +use crate::serde_to_rig_error; + +pub mod document; +pub mod embedding; + +pub fn merge( + documents: DocumentRecords, + embeddings: EmbeddingRecordsBatch, +) -> Result, VectorStoreError> { + documents + .as_iter() + .map(|DocumentRecord { id, document }| { + let emebedding_records = embeddings.get_by_id(id); + + Ok::<_, VectorStoreError>(DocumentEmbeddings { + id: id.to_string(), + document: serde_json::from_str(document).map_err(serde_to_rig_error)?, + embeddings: match emebedding_records { + Some(records) => records + .as_iter() + .map( + |EmbeddingRecord { + content, embedding, .. + }| Embedding { + document: content.to_string(), + vec: embedding.to_vec(), + }, + ) + .collect::>(), + None => vec![], + }, + }) + }) + .collect::, _>>() +} diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs new file mode 100644 index 00000000..c1cfa5f0 --- /dev/null +++ b/rig-lancedb/src/utils/mod.rs @@ -0,0 +1,70 @@ +use arrow_array::{Array, Float64Array, ListArray, RecordBatch, StringArray}; +use futures::TryStreamExt; +use lancedb::{arrow::arrow_schema::ArrowError, query::ExecutableQuery}; +use rig::vector_store::VectorStoreError; + +use crate::lancedb_to_rig_error; + +pub trait DeserializeArrow { + fn deserialize_str_column(&self, i: usize) -> Result, ArrowError>; + fn deserialize_list_column(&self, i: usize) -> Result>, ArrowError>; +} + +impl DeserializeArrow for RecordBatch { + fn deserialize_str_column(&self, i: usize) -> Result, ArrowError> { + let column = self.column(i); + match column.as_any().downcast_ref::() { + Some(str_array) => Ok((0..str_array.len()) + .map(|j| str_array.value(j)) + .collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast column {i} to string array" + ))), + } + } + + fn deserialize_list_column(&self, i: usize) -> Result>, ArrowError> { + let column = self.column(i); + match column.as_any().downcast_ref::() { + Some(list_array) => (0..list_array.len()) + .map( + |j| match list_array.value(j).as_any().downcast_ref::() { + Some(float_array) => Ok((0..float_array.len()) + .map(|k| float_array.value(k)) + .collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast value at index {j} to float array" + ))), + }, + ) + .collect::, _>>(), + None => Err(ArrowError::CastError(format!( + "Can't cast column {i} to list array" + ))), + } + } +} + +pub trait Query +where + T: TryFrom, Error = VectorStoreError>, +{ + async fn execute_query(&self) -> Result; +} + +impl Query for lancedb::query::Query +where + T: TryFrom, Error = VectorStoreError>, +{ + async fn execute_query(&self) -> Result { + let record_batches = self + .execute() + .await + .map_err(lancedb_to_rig_error)? + .try_collect::>() + .await + .map_err(lancedb_to_rig_error)?; + + T::try_from(record_batches) + } +} From 7be1f854bbe334db6696f9f8831152e17bbf2495 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 19 Sep 2024 10:32:27 -0400 Subject: [PATCH 05/39] feat: implement get_document method of VectorStore trait --- rig-lancedb/src/lib.rs | 65 ++++++++++++++-------- rig-lancedb/src/table_schemas/document.rs | 31 +++++++---- rig-lancedb/src/table_schemas/embedding.rs | 10 ++-- rig-lancedb/src/utils/mod.rs | 23 +++++++- 4 files changed, 90 insertions(+), 39 deletions(-) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 4eab3e99..97f509a2 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -1,21 +1,17 @@ -use std::sync::Arc; - -use arrow_array::RecordBatchIterator; -use lancedb::query::QueryBase; -use rig::vector_store::{VectorStore, VectorStoreError}; -use table_schemas::{ - document::{document_schema, DocumentRecords}, - embedding::{embedding_schema, EmbeddingRecordsBatch}, - merge, -}; -use utils::Query; +use lancedb::{arrow::arrow_schema::Schema, query::QueryBase}; +use rig::vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}; +use table_schemas::{document::DocumentRecords, embedding::EmbeddingRecordsBatch, merge}; +use utils::{Insert, Query}; mod table_schemas; mod utils; pub struct LanceDbVectorStore { document_table: lancedb::Table, + document_schema: Schema, + embedding_table: lancedb::Table, + embedding_schema: Schema, } fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { @@ -23,7 +19,7 @@ fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { } fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { - VectorStoreError::DatastoreError(Box::new(e)) + VectorStoreError::JsonError(e) } impl VectorStore for LanceDbVectorStore { @@ -37,22 +33,14 @@ impl VectorStore for LanceDbVectorStore { DocumentRecords::try_from(documents.clone()).map_err(serde_to_rig_error)?; self.document_table - .add(RecordBatchIterator::new( - vec![document_records.try_into()], - Arc::new(document_schema()), - )) - .execute() + .insert(document_records, self.document_schema.clone()) .await .map_err(lancedb_to_rig_error)?; let embedding_records = EmbeddingRecordsBatch::from(documents); self.embedding_table - .add(RecordBatchIterator::new( - embedding_records.record_batch_iter(), - Arc::new(embedding_schema()), - )) - .execute() + .insert(embedding_records, self.embedding_schema.clone()) .await .map_err(lancedb_to_rig_error)?; @@ -84,7 +72,20 @@ impl VectorStore for LanceDbVectorStore { &self, id: &str, ) -> Result, VectorStoreError> { - todo!() + let documents: DocumentRecords = self + .document_table + .query() + .only_if(format!("id = {id}")) + .execute_query() + .await?; + + let document = documents + .as_iter() + .next() + .map(|document| serde_json::from_str(&document.document).map_err(serde_to_rig_error)) + .transpose(); + + document } async fn get_document_by_query( @@ -103,3 +104,21 @@ impl VectorStore for LanceDbVectorStore { Ok(merge(documents, embeddings)?.into_iter().next()) } } + +impl VectorStoreIndex for LanceDbVectorStore { + fn top_n_from_query( + &self, + query: &str, + n: usize, + ) -> impl std::future::Future, VectorStoreError>> + Send { + todo!() + } + + fn top_n_from_embedding( + &self, + prompt_embedding: &rig::embeddings::Embedding, + n: usize, + ) -> impl std::future::Future, VectorStoreError>> + Send { + todo!() + } +} \ No newline at end of file diff --git a/rig-lancedb/src/table_schemas/document.rs b/rig-lancedb/src/table_schemas/document.rs index 977e7722..c9ce1639 100644 --- a/rig-lancedb/src/table_schemas/document.rs +++ b/rig-lancedb/src/table_schemas/document.rs @@ -30,16 +30,24 @@ impl DocumentRecords { Self(Vec::new()) } - pub fn as_iter(&self) -> impl Iterator { - self.0.iter() + fn records(&self) -> Vec { + self.0.clone() + } + + fn add_records(&mut self, records: Vec) { + self.0.extend(records); + } + + fn documents(&self) -> Vec { + self.as_iter().map(|doc| doc.document.clone()).collect() } pub fn ids(&self) -> Vec { self.as_iter().map(|doc| doc.id.clone()).collect() } - fn add_records(&mut self, records: Vec) { - self.0.extend(records); + pub fn as_iter(&self) -> impl Iterator { + self.0.iter() } } @@ -77,11 +85,8 @@ impl TryFrom for RecordBatch { type Error = ArrowError; fn try_from(document_records: DocumentRecords) -> Result { - let id = - StringArray::from_iter_values(document_records.as_iter().map(|doc| doc.id.clone())); - let document = StringArray::from_iter_values( - document_records.as_iter().map(|doc| doc.document.clone()), - ); + let id = StringArray::from_iter_values(document_records.ids()); + let document = StringArray::from_iter_values(document_records.documents()); RecordBatch::try_new( Arc::new(document_schema()), @@ -90,6 +95,12 @@ impl TryFrom for RecordBatch { } } +impl From for Vec> { + fn from(documents: DocumentRecords) -> Self { + vec![RecordBatch::try_from(documents)] + } +} + /// Convert a `RecordBatch` object, read from a lanceDb table, to a list of `DocumentRecord` objects. /// This allows us to convert the query result to our data format. impl TryFrom for DocumentRecords { @@ -125,7 +136,7 @@ impl TryFrom> for DocumentRecords { Ok(documents .into_iter() .fold(DocumentRecords::new(), |mut acc, document| { - acc.add_records(document.0); + acc.add_records(document.records()); acc })) } diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs index 7807ca08..a6e34f0f 100644 --- a/rig-lancedb/src/table_schemas/embedding.rs +++ b/rig-lancedb/src/table_schemas/embedding.rs @@ -111,15 +111,17 @@ impl EmbeddingRecordsBatch { self.0.clone().into_values().collect::>().into_iter() } - pub fn record_batch_iter(&self) -> impl Iterator> { - self.as_iter().map(RecordBatch::try_from) - } - pub fn get_by_id(&self, id: &str) -> Option { self.0.get(id).cloned() } } +impl From for Vec> { + fn from(embeddings: EmbeddingRecordsBatch) -> Self { + embeddings.as_iter().map(RecordBatch::try_from).collect() + } +} + pub fn embedding_schema() -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index c1cfa5f0..2b5bb1d5 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,6 +1,11 @@ -use arrow_array::{Array, Float64Array, ListArray, RecordBatch, StringArray}; +use std::sync::Arc; + +use arrow_array::{Array, Float64Array, ListArray, RecordBatch, RecordBatchIterator, StringArray}; use futures::TryStreamExt; -use lancedb::{arrow::arrow_schema::ArrowError, query::ExecutableQuery}; +use lancedb::{ + arrow::arrow_schema::{ArrowError, Schema}, + query::ExecutableQuery, +}; use rig::vector_store::VectorStoreError; use crate::lancedb_to_rig_error; @@ -68,3 +73,17 @@ where T::try_from(record_batches) } } + +pub trait Insert { + async fn insert(&self, data: T, schema: Schema) -> Result<(), lancedb::Error>; +} + +impl>>> Insert for lancedb::Table { + async fn insert(&self, data: T, schema: Schema) -> Result<(), lancedb::Error> { + self.add(RecordBatchIterator::new(data.into(), Arc::new(schema))) + .execute() + .await?; + + Ok(()) + } +} From 168f53e798b6aaac0ff47757f3f17b53b040146c Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 19 Sep 2024 16:01:16 -0400 Subject: [PATCH 06/39] feat: start implementing top_n_from_query for trait VectorStoreIndex --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/src/rag.rs | 37 ++++-- rig-core/src/vector_store/in_memory_store.rs | 7 +- rig-core/src/vector_store/mod.rs | 24 +++- rig-lancedb/src/lib.rs | 113 ++++++++++++++++-- rig-lancedb/src/table_schemas/embedding.rs | 4 + rig-lancedb/src/utils/mod.rs | 17 +++ rig-mongodb/examples/vector_search_mongodb.rs | 6 +- rig-mongodb/src/lib.rs | 50 +++++--- 13 files changed, 220 insertions(+), 48 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index 0fc540f2..380c0277 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -272,7 +272,7 @@ async fn main() -> Result<(), anyhow::Error> { ) // Add a dynamic tool source with a sample rate of 1 (i.e.: only // 1 additional tool will be added to prompts) - .dynamic_tools(4, index, toolset) + .dynamic_tools(4, index, toolset, ()) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index d6929b52..8d9d7a92 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -35,7 +35,7 @@ async fn main() -> Result<(), anyhow::Error> { You are a dictionary assistant here to assist the user in understanding the meaning of words. You will find additional non-standard word definitions that could be useful below. ") - .dynamic_context(1, index) + .dynamic_context(1, index, ()) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 348c0fcf..e426e1ab 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -174,7 +174,7 @@ async fn main() -> Result<(), anyhow::Error> { .preamble("You are a calculator here to help the user perform arithmetic operations.") // Add a dynamic tool source with a sample rate of 1 (i.e.: only // 1 additional tool will be added to prompts) - .dynamic_tools(1, index, toolset) + .dynamic_tools(1, index, toolset, ()) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index f9ac6cf0..ebf9feff 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -28,7 +28,7 @@ async fn main() -> Result<(), anyhow::Error> { let index = vector_store.index(model); let results = index - .top_n_from_query("What is a linglingdong?", 1) + .top_n_from_query("What is a linglingdong?", 1, &()) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 7e9226af..e0c0ab0e 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -29,7 +29,7 @@ async fn main() -> Result<(), anyhow::Error> { let index = vector_store.index(search_model); let results = index - .top_n_from_query("What is a linglingdong?", 1) + .top_n_from_query("What is a linglingdong?", 1, &()) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-core/src/rag.rs b/rig-core/src/rag.rs index 89bb21d1..31a1f380 100644 --- a/rig-core/src/rag.rs +++ b/rig-core/src/rag.rs @@ -87,9 +87,9 @@ pub struct RagAgent, /// List of vector store, with the sample number - dynamic_context: Vec<(usize, C)>, + dynamic_context: Vec<(usize, C, C::S)>, /// Dynamic tools - dynamic_tools: Vec<(usize, T)>, + dynamic_tools: Vec<(usize, T, T::S)>, /// Actual tool implementations pub tools: ToolSet, } @@ -106,10 +106,10 @@ impl Completion chat_history: Vec, ) -> Result, CompletionError> { let dynamic_context = stream::iter(self.dynamic_context.iter()) - .then(|(num_sample, index)| async { + .then(|(num_sample, index, search_params)| async { Ok::<_, VectorStoreError>( index - .top_n_from_query(prompt, *num_sample) + .top_n_from_query(prompt, *num_sample, search_params) .await? .into_iter() .map(|(_, doc)| { @@ -133,10 +133,10 @@ impl Completion .map_err(|e| CompletionError::RequestError(Box::new(e)))?; let dynamic_tools = stream::iter(self.dynamic_tools.iter()) - .then(|(num_sample, index)| async { + .then(|(num_sample, index, search_params)| async { Ok::<_, VectorStoreError>( index - .top_n_ids_from_query(prompt, *num_sample) + .top_n_ids_from_query(prompt, *num_sample, search_params) .await? .into_iter() .map(|(_, doc)| doc) @@ -242,9 +242,9 @@ pub struct RagAgentBuilder, /// List of vector store, with the sample number - dynamic_context: Vec<(usize, C)>, + dynamic_context: Vec<(usize, C, C::S)>, /// Dynamic tools - dynamic_tools: Vec<(usize, T)>, + dynamic_tools: Vec<(usize, T, T::S)>, /// Temperature of the model temperature: Option, /// Actual tool implementations @@ -292,15 +292,28 @@ impl RagAgentBuild /// Add some dynamic context to the RAG agent. On each prompt, `sample` documents from the /// dynamic context will be inserted in the request. - pub fn dynamic_context(mut self, sample: usize, dynamic_context: C) -> Self { - self.dynamic_context.push((sample, dynamic_context)); + pub fn dynamic_context( + mut self, + sample: usize, + dynamic_context: C, + search_params: C::S, + ) -> Self { + self.dynamic_context + .push((sample, dynamic_context, search_params)); self } /// Add some dynamic tools to the RAG agent. On each prompt, `sample` tools from the /// dynamic toolset will be inserted in the request. - pub fn dynamic_tools(mut self, sample: usize, dynamic_tools: T, toolset: ToolSet) -> Self { - self.dynamic_tools.push((sample, dynamic_tools)); + pub fn dynamic_tools( + mut self, + sample: usize, + dynamic_tools: T, + toolset: ToolSet, + search_params: T::S, + ) -> Self { + self.dynamic_tools + .push((sample, dynamic_tools, search_params)); self.tools.add_tools(toolset); self } diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index 494b339b..0d4a93a3 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -162,19 +162,24 @@ impl InMemoryVectorIndex { } impl VectorStoreIndex for InMemoryVectorIndex { + type S = (); + async fn top_n_from_query( &self, query: &str, n: usize, + search_params: &Self::S, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n).await + self.top_n_from_embedding(&prompt_embedding, n, search_params) + .await } async fn top_n_from_embedding( &self, query_embedding: &Embedding, n: usize, + _search_params: &Self::S, ) -> Result, VectorStoreError> { // Sort documents by best embedding distance let mut docs: EmbeddingRanking = BinaryHeap::new(); diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 0caea0f0..8c3fc183 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -49,6 +49,8 @@ pub trait VectorStore: Send + Sync { /// Trait for vector store indexes pub trait VectorStoreIndex: Send + Sync { + type S: Send + Sync; + /// Get the top n documents based on the distance to the given embedding. /// The distance is calculated as the cosine distance between the prompt and /// the document embedding. @@ -57,6 +59,7 @@ pub trait VectorStoreIndex: Send + Sync { &self, query: &str, n: usize, + search_params: &Self::S, ) -> impl std::future::Future, VectorStoreError>> + Send; /// Same as `top_n_from_query` but returns the documents without its embeddings. @@ -65,9 +68,10 @@ pub trait VectorStoreIndex: Send + Sync { &self, query: &str, n: usize, + search_params: &Self::S, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self.top_n_from_query(query, n).await?; + let documents = self.top_n_from_query(query, n, search_params).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) @@ -80,10 +84,11 @@ pub trait VectorStoreIndex: Send + Sync { &self, query: &str, n: usize, + search_params: &Self::S, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self.top_n_from_query(query, n).await?; + let documents = self.top_n_from_query(query, n, search_params).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -99,6 +104,7 @@ pub trait VectorStoreIndex: Send + Sync { &self, prompt_embedding: &Embedding, n: usize, + search_params: &Self::S, ) -> impl std::future::Future, VectorStoreError>> + Send; /// Same as `top_n_from_embedding` but returns the documents without its embeddings. @@ -107,9 +113,12 @@ pub trait VectorStoreIndex: Send + Sync { &self, prompt_embedding: &Embedding, n: usize, + search_params: &Self::S, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self.top_n_from_embedding(prompt_embedding, n).await?; + let documents = self + .top_n_from_embedding(prompt_embedding, n, search_params) + .await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) @@ -122,10 +131,13 @@ pub trait VectorStoreIndex: Send + Sync { &self, prompt_embedding: &Embedding, n: usize, + search_params: &Self::S, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self.top_n_from_embedding(prompt_embedding, n).await?; + let documents = self + .top_n_from_embedding(prompt_embedding, n, search_params) + .await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -137,10 +149,13 @@ pub trait VectorStoreIndex: Send + Sync { pub struct NoIndex; impl VectorStoreIndex for NoIndex { + type S = (); + async fn top_n_from_query( &self, _query: &str, _n: usize, + _search_params: &Self::S, ) -> Result, VectorStoreError> { Ok(vec![]) } @@ -149,6 +164,7 @@ impl VectorStoreIndex for NoIndex { &self, _prompt_embedding: &Embedding, _n: usize, + _search_params: &Self::S, ) -> Result, VectorStoreError> { Ok(vec![]) } diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 97f509a2..e3effe7f 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -1,5 +1,8 @@ -use lancedb::{arrow::arrow_schema::Schema, query::QueryBase}; -use rig::vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}; +use lancedb::{arrow::arrow_schema::Schema, query::QueryBase, DistanceType}; +use rig::{ + embeddings::EmbeddingModel, + vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, +}; use table_schemas::{document::DocumentRecords, embedding::EmbeddingRecordsBatch, merge}; use utils::{Insert, Query}; @@ -105,20 +108,114 @@ impl VectorStore for LanceDbVectorStore { } } -impl VectorStoreIndex for LanceDbVectorStore { - fn top_n_from_query( +/// A vector index for a MongoDB collection. +pub struct LanceDbVectorIndex { + model: M, + embedding_table: lancedb::Table, + document_table: lancedb::Table, +} + +impl LanceDbVectorIndex { + pub fn new(model: M, embedding_table: lancedb::Table, document_table: lancedb::Table) -> Self { + Self { + model, + embedding_table, + document_table, + } + } +} + +/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. +pub enum SearchType { + // Flat search, also called ENN or kNN. + Flat, + /// Approximal Nearest Neighbor search, also called ANN. + Approximate, +} + +pub struct SearchParams { + /// Always set the distance_type to match the value used to train the index + distance_type: DistanceType, + /// By default, ANN will be used if there is an index on the table. + /// By default, kNN will be used if there is NO index on the table. + /// To use defaults, set to None. + search_type: Option, + /// Set this value only when search type is ANN. + /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information + nprobes: Option, + /// Set this value only when search type is ANN. + /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information + refine_factor: Option, + /// If set to true, filtering will happen after the vector search instead of before + /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information + post_filter: Option, +} + +impl VectorStoreIndex for LanceDbVectorIndex { + async fn top_n_from_query( &self, query: &str, n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send { + search_params: &Self::S, + ) -> Result, VectorStoreError> { + let prompt_embedding = self.model.embed_document(query).await?; + + let SearchParams { + distance_type, + search_type, + nprobes, + refine_factor, + post_filter, + } = search_params; + + let query = self + .embedding_table + .vector_search(prompt_embedding.vec) + .map_err(lancedb_to_rig_error)? + .distance_type(*distance_type) + .limit(n); + + if let Some(SearchType::Flat) = &search_type { + query.clone().bypass_vector_index(); + } + + if let Some(SearchType::Approximate) = &search_type { + if let Some(nprobes) = nprobes { + query.clone().nprobes(*nprobes); + } + if let Some(refine_factor) = refine_factor { + query.clone().refine_factor(*refine_factor); + } + } + + if let Some(true) = &post_filter { + query.clone().postfilter(); + } + + let embeddings: EmbeddingRecordsBatch = query.execute_query().await?; + + let documents: DocumentRecords = self + .document_table + .query() + .only_if(format!("id IN [{}]", embeddings.document_ids().join(","))) + .execute_query() + .await?; + + // Todo: get distances for each returned vector + + merge(documents, embeddings)?; + todo!() } - fn top_n_from_embedding( + async fn top_n_from_embedding( &self, prompt_embedding: &rig::embeddings::Embedding, n: usize, - ) -> impl std::future::Future, VectorStoreError>> + Send { + search_params: &Self::S, + ) -> Result, VectorStoreError> { todo!() } -} \ No newline at end of file + + type S = SearchParams; +} diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs index a6e34f0f..363c613d 100644 --- a/rig-lancedb/src/table_schemas/embedding.rs +++ b/rig-lancedb/src/table_schemas/embedding.rs @@ -114,6 +114,10 @@ impl EmbeddingRecordsBatch { pub fn get_by_id(&self, id: &str) -> Option { self.0.get(id).cloned() } + + pub fn document_ids(&self) -> Vec { + self.0.clone().into_keys().collect() + } } impl From for Vec> { diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index 2b5bb1d5..92532612 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -74,6 +74,23 @@ where } } +impl Query for lancedb::query::VectorQuery +where + T: TryFrom, Error = VectorStoreError>, +{ + async fn execute_query(&self) -> Result { + let record_batches = self + .execute() + .await + .map_err(lancedb_to_rig_error)? + .try_collect::>() + .await + .map_err(lancedb_to_rig_error)?; + + T::try_from(record_batches) + } +} + pub trait Insert { async fn insert(&self, data: T, schema: Schema) -> Result<(), lancedb::Error>; } diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index a39e7c93..0eb79ca9 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -6,7 +6,7 @@ use rig::{ providers::openai::Client, vector_store::{VectorStore, VectorStoreIndex}, }; -use rig_mongodb::MongoDbVectorStore; +use rig_mongodb::{MongoDbVectorStore, SearchParams}; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -49,11 +49,11 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "context_vector_index", doc! {}); + let index = vector_store.index(model, "context_vector_index"); // Query the index let results = index - .top_n_from_query("What is a linglingdong?", 1) + .top_n_from_query("What is a linglingdong?", 1, &SearchParams::new()) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 7d85201f..96a1e03c 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -89,13 +89,8 @@ impl MongoDbVectorStore { /// /// An additional filter can be provided to further restrict the documents that are /// considered in the search. - pub fn index( - &self, - model: M, - index_name: &str, - filter: mongodb::bson::Document, - ) -> MongoDbVectorIndex { - MongoDbVectorIndex::new(self.collection.clone(), model, index_name, filter) + pub fn index(&self, model: M, index_name: &str) -> MongoDbVectorIndex { + MongoDbVectorIndex::new(self.collection.clone(), model, index_name) } } @@ -104,7 +99,6 @@ pub struct MongoDbVectorIndex { collection: mongodb::Collection, model: M, index_name: String, - filter: mongodb::bson::Document, } impl MongoDbVectorIndex { @@ -112,13 +106,32 @@ impl MongoDbVectorIndex { collection: mongodb::Collection, model: M, index_name: &str, - filter: mongodb::bson::Document, ) -> Self { Self { collection, model, index_name: index_name.to_string(), - filter, + } + } +} + +/// See [MongoDB Vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information +/// on each of the fields +pub struct SearchParams { + filter: mongodb::bson::Document, + /// Whether to use ANN or ENN search + exact: Option, + /// Only set this field if exact is set to false + /// Number of nearest neighbors to use during the search + num_candidates: Option, +} + +impl SearchParams { + pub fn new() -> Self { + Self { + filter: doc! {}, + exact: None, + num_candidates: None, } } } @@ -128,15 +141,18 @@ impl VectorStoreIndex for MongoDbV &self, query: &str, n: usize, + search_params: &Self::S, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n).await + self.top_n_from_embedding(&prompt_embedding, n, search_params) + .await } async fn top_n_from_embedding( &self, prompt_embedding: &Embedding, n: usize, + search_params: &Self::S, ) -> Result, VectorStoreError> { let mut cursor = self .collection @@ -144,12 +160,13 @@ impl VectorStoreIndex for MongoDbV [ doc! { "$vectorSearch": { + "queryVector": &prompt_embedding.vec, "index": &self.index_name, + "exact": search_params.exact.unwrap_or(false), "path": "embeddings.vec", - "queryVector": &prompt_embedding.vec, - "numCandidates": (n * 10) as u32, + "numCandidates": search_params.num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, - "filter": &self.filter, + "filter": &search_params.filter, } }, doc! { @@ -168,7 +185,8 @@ impl VectorStoreIndex for MongoDbV while let Some(doc) = cursor.next().await { let doc = doc.map_err(mongodb_to_rig_error)?; let score = doc.get("score").expect("score").as_f64().expect("f64"); - let document: DocumentEmbeddings = serde_json::from_value(doc).expect("document"); + let document: DocumentEmbeddings = + serde_json::from_value(doc).map_err(VectorStoreError::JsonError)?; results.push((score, document)); } @@ -182,4 +200,6 @@ impl VectorStoreIndex for MongoDbV Ok(results) } + + type S = SearchParams; } From 1e52e6a1c30be0e84697e11700c7a2089ced1cf9 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 19 Sep 2024 16:51:32 -0400 Subject: [PATCH 07/39] docs: add doc string to mongodb search params struct --- rig-lancedb/src/lib.rs | 19 +++++++++---------- rig-mongodb/src/lib.rs | 1 + 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 25d3eebb..97acd523 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -162,7 +162,15 @@ impl VectorStoreIndex for LanceDbV search_params: Self::SearchParams, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; + self.top_n_from_embedding(&prompt_embedding, n, search_params).await + } + async fn top_n_from_embedding( + &self, + prompt_embedding: &rig::embeddings::Embedding, + n: usize, + search_params: Self::SearchParams, + ) -> Result, VectorStoreError> { let SearchParams { distance_type, search_type, @@ -173,7 +181,7 @@ impl VectorStoreIndex for LanceDbV let query = self .embedding_table - .vector_search(prompt_embedding.vec) + .vector_search(prompt_embedding.vec.clone()) .map_err(lancedb_to_rig_error)? .distance_type(distance_type) .limit(n); @@ -211,14 +219,5 @@ impl VectorStoreIndex for LanceDbV todo!() } - async fn top_n_from_embedding( - &self, - prompt_embedding: &rig::embeddings::Embedding, - n: usize, - search_params: Self::SearchParams, - ) -> Result, VectorStoreError> { - todo!() - } - type SearchParams = SearchParams; } diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 0fa51853..8a8b4467 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -120,6 +120,7 @@ impl MongoDbVectorIndex { /// on each of the fields #[derive(Deserialize)] pub struct SearchParams { + /// Pre-filter filter: mongodb::bson::Document, /// Whether to use ANN or ENN search exact: Option, From bb5c76767ed61ad4a896558b115995778f94f352 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 19 Sep 2024 17:20:28 -0400 Subject: [PATCH 08/39] docs: Add doc strings to utility methods --- rig-lancedb/src/table_schemas/embedding.rs | 2 +- rig-lancedb/src/utils/mod.rs | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs index 363c613d..109ede40 100644 --- a/rig-lancedb/src/table_schemas/embedding.rs +++ b/rig-lancedb/src/table_schemas/embedding.rs @@ -146,7 +146,7 @@ impl TryFrom for EmbeddingRecords { let ids = record_batch.deserialize_str_column(0)?; let document_ids = record_batch.deserialize_str_column(1)?; let contents = record_batch.deserialize_str_column(2)?; - let embeddings = record_batch.deserialize_list_column(3)?; + let embeddings = record_batch.deserialize_float_list_column(3)?; Ok(EmbeddingRecords( ids.into_iter() diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index 92532612..0d642ab9 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -10,9 +10,15 @@ use rig::vector_store::VectorStoreError; use crate::lancedb_to_rig_error; +/// Trait used to "deserialize" a column of a RecordBatch object into a list o primitive types pub trait DeserializeArrow { + /// Define the column number that contains strings, i. + /// For each item in the column, convert it to a string and collect the result in a vector of strings. fn deserialize_str_column(&self, i: usize) -> Result, ArrowError>; - fn deserialize_list_column(&self, i: usize) -> Result>, ArrowError>; + /// Define the column number that contains the list of floats, i. + /// For each item in the column, convert it to a list and for each item in the list, convert it to a float. + /// Collect the result as a vector of vectors of floats. + fn deserialize_float_list_column(&self, i: usize) -> Result>, ArrowError>; } impl DeserializeArrow for RecordBatch { @@ -28,7 +34,7 @@ impl DeserializeArrow for RecordBatch { } } - fn deserialize_list_column(&self, i: usize) -> Result>, ArrowError> { + fn deserialize_float_list_column(&self, i: usize) -> Result>, ArrowError> { let column = self.column(i); match column.as_any().downcast_ref::() { Some(list_array) => (0..list_array.len()) @@ -50,6 +56,10 @@ impl DeserializeArrow for RecordBatch { } } +/// Trait that facilitates the conversion of columnar data returned by a lanceDb query to the desired struct. +/// Used whenever a lanceDb table is queried. +/// First, execute the query and get the result as a list of RecordBatches (columnar data). +/// Then, convert the record batches to the desired type using the try_from trait. pub trait Query where T: TryFrom, Error = VectorStoreError>, @@ -74,6 +84,8 @@ where } } +/// Same as the above trait but for the VectorQuery type. +/// Used whenever a lanceDb table vector search is executed. impl Query for lancedb::query::VectorQuery where T: TryFrom, Error = VectorStoreError>, @@ -91,6 +103,7 @@ where } } +/// Trait that facilitate inserting data defined as Rust structs into lanceDB table which contains columnar data. pub trait Insert { async fn insert(&self, data: T, schema: Schema) -> Result<(), lancedb::Error>; } From b788cd5fc1e1fedea56f5c417f620265015af514 Mon Sep 17 00:00:00 2001 From: Garance Date: Fri, 20 Sep 2024 17:11:28 -0400 Subject: [PATCH 09/39] feat: implement ANN search example --- Cargo.lock | 1 + rig-lancedb/Cargo.toml | 5 +- .../examples/vector_search_local_ann.rs | 84 ++++++++ rig-lancedb/src/lib.rs | 165 ++++++++++++--- rig-lancedb/src/table_schemas/document.rs | 28 +-- rig-lancedb/src/table_schemas/embedding.rs | 193 +++++++++++------- rig-lancedb/src/table_schemas/mod.rs | 4 +- rig-lancedb/src/utils/mod.rs | 24 ++- rig-mongodb/src/lib.rs | 3 - 9 files changed, 377 insertions(+), 130 deletions(-) create mode 100644 rig-lancedb/examples/vector_search_local_ann.rs diff --git a/Cargo.lock b/Cargo.lock index cbddf0c3..f6305eb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3992,6 +3992,7 @@ dependencies = [ name = "rig-lancedb" version = "0.1.0" dependencies = [ + "anyhow", "arrow-array", "futures", "lancedb", diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index c97b9b78..3ec378f7 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -5,9 +5,12 @@ edition = "2021" [dependencies] lancedb = "0.10.0" -tokio = "1.40.0" rig-core = { path = "../rig-core", version = "0.1.0" } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" futures = "0.3.30" + +[dev-dependencies] +tokio = "1.40.0" +anyhow = "1.0.89" \ No newline at end of file diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs new file mode 100644 index 00000000..ec02a704 --- /dev/null +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -0,0 +1,84 @@ +use std::env; + +use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; +use rig::{ + completion::Prompt, + embeddings::EmbeddingsBuilder, + providers::openai::Client, + vector_store::{VectorStore, VectorStoreIndexDyn}, +}; +use rig_lancedb::{LanceDbVectorStore, SearchParams}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store").execute().await?; + let mut vector_store = LanceDbVectorStore::new(&db, 1536).await?; + + // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_client = Client::new(&openai_api_key); + + // Generate test data for RAG demo + let agent = openai_client + .agent("gpt-4o") + .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") + .build(); + let response = agent + .prompt("Invent at least 150 words and their definitions") + .await?; + let mut definitions: Vec = serde_json::from_str(&response)?; + + // Note: need at least 256 rows in order to create an index on a table but OpenAi limits the output size + // so we duplicate the vector for testing purposes. + definitions.extend(definitions.clone()); + + // Select the embedding model and generate our embeddings + let model = openai_client.embedding_model("text-embedding-ada-002"); + + let embeddings = EmbeddingsBuilder::new(model.clone()) + .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") + .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") + .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .build() + .await?; + + // Add embeddings to vector store + vector_store.add_documents(embeddings).await?; + + // Create a vector index on our vector store + // IMPORTANT: Reuse the same model that was used to generate the embeddings + let index = vector_store.index(model); + + // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information + index + .create_index(lancedb::index::Index::IvfPq( + IvfPqIndexBuilder::default() + // This overrides the default distance type of L2 + .distance_type(DistanceType::Cosine), + )) + .await?; + + // Query the index + let results = index + .top_n_from_query( + "My boss says I zindle too much, what does that mean?", + 1, + &serde_json::to_string(&SearchParams::new( + Some(DistanceType::Cosine), + None, + None, + None, + None, + ))?, + ) + .await? + .into_iter() + .map(|(score, doc)| (score, doc.id, doc.document)) + .collect::>(); + + println!("Results: {:?}", results); + + Ok(()) +} diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 97acd523..300f65f8 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -1,9 +1,16 @@ -use lancedb::{arrow::arrow_schema::Schema, query::QueryBase, DistanceType}; +use std::sync::Arc; + +use lancedb::{ + arrow::arrow_schema::{DataType, Field, Fields, Schema}, + index::Index, + query::QueryBase, + DistanceType, +}; use rig::{ embeddings::EmbeddingModel, vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use table_schemas::{document::DocumentRecords, embedding::EmbeddingRecordsBatch, merge}; use utils::{Insert, Query}; @@ -12,10 +19,63 @@ mod utils; pub struct LanceDbVectorStore { document_table: lancedb::Table, - document_schema: Schema, - embedding_table: lancedb::Table, - embedding_schema: Schema, + embedding_dimension: i32, +} + +impl LanceDbVectorStore { + /// Note: Tables are created inside the new function rather than created outside and passed as reference to new function. + /// This is because a specific schema needs to be enforced on the tables and this is done at creation time. + pub async fn new( + db: &lancedb::Connection, + embedding_dimension: i32, + ) -> Result { + Ok(Self { + document_table: db + .create_empty_table("documents", Arc::new(Self::document_schema())) + .execute() + .await?, + embedding_table: db + .create_empty_table( + "embeddings", + Arc::new(Self::embedding_schema(embedding_dimension)), + ) + .execute() + .await?, + embedding_dimension, + }) + } + + fn document_schema() -> Schema { + Schema::new(Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("document", DataType::Utf8, false), + ])) + } + + fn embedding_schema(dimension: i32) -> Schema { + Schema::new(Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("document_id", DataType::Utf8, false), + Field::new("content", DataType::Utf8, false), + Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float64, true)), + dimension, + ), + false, + ), + ])) + } + + pub fn index(&self, model: M) -> LanceDbVectorIndex { + LanceDbVectorIndex::new( + model, + self.embedding_table.clone(), + self.document_table.clone(), + ) + } } fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { @@ -37,14 +97,17 @@ impl VectorStore for LanceDbVectorStore { DocumentRecords::try_from(documents.clone()).map_err(serde_to_rig_error)?; self.document_table - .insert(document_records, self.document_schema.clone()) + .insert(document_records, Self::document_schema()) .await .map_err(lancedb_to_rig_error)?; let embedding_records = EmbeddingRecordsBatch::from(documents); self.embedding_table - .insert(embedding_records, self.embedding_schema.clone()) + .insert( + embedding_records, + Self::embedding_schema(self.embedding_dimension), + ) .await .map_err(lancedb_to_rig_error)?; @@ -69,7 +132,7 @@ impl VectorStore for LanceDbVectorStore { .execute_query() .await?; - Ok(merge(documents, embeddings)?.into_iter().next()) + Ok(merge(&documents, &embeddings)?.into_iter().next()) } async fn get_document serde::Deserialize<'a>>( @@ -105,7 +168,7 @@ impl VectorStore for LanceDbVectorStore { .execute_query() .await?; - Ok(merge(documents, embeddings)?.into_iter().next()) + Ok(merge(&documents, &embeddings)?.into_iter().next()) } } @@ -124,10 +187,19 @@ impl LanceDbVectorIndex { document_table, } } + + pub async fn create_index(&self, index: Index) -> Result<(), lancedb::Error> { + self.embedding_table + .create_index(&["embedding"], index) + .execute() + .await?; + + Ok(()) + } } /// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. -#[derive(Deserialize)] +#[derive(Deserialize, Serialize, Debug, Clone)] pub enum SearchType { // Flat search, also called ENN or kNN. Flat, @@ -135,10 +207,11 @@ pub enum SearchType { Approximate, } -#[derive(Deserialize)] +#[derive(Deserialize, Serialize, Debug, Clone)] pub struct SearchParams { /// Always set the distance_type to match the value used to train the index - distance_type: DistanceType, + /// By default, set to L2 + distance_type: Option, /// By default, ANN will be used if there is an index on the table. /// By default, kNN will be used if there is NO index on the table. /// To use defaults, set to None. @@ -154,6 +227,24 @@ pub struct SearchParams { post_filter: Option, } +impl SearchParams { + pub fn new( + distance_type: Option, + search_type: Option, + nprobes: Option, + refine_factor: Option, + post_filter: Option, + ) -> Self { + Self { + distance_type, + search_type, + nprobes, + refine_factor, + post_filter, + } + } +} + impl VectorStoreIndex for LanceDbVectorIndex { async fn top_n_from_query( &self, @@ -162,7 +253,8 @@ impl VectorStoreIndex for LanceDbV search_params: Self::SearchParams, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n, search_params).await + self.top_n_from_embedding(&prompt_embedding, n, search_params) + .await } async fn top_n_from_embedding( @@ -177,20 +269,23 @@ impl VectorStoreIndex for LanceDbV nprobes, refine_factor, post_filter, - } = search_params; + } = search_params.clone(); let query = self .embedding_table .vector_search(prompt_embedding.vec.clone()) .map_err(lancedb_to_rig_error)? - .distance_type(distance_type) .limit(n); - if let Some(SearchType::Flat) = &search_type { + if let Some(distance_type) = distance_type { + query.clone().distance_type(distance_type); + } + + if let Some(SearchType::Flat) = search_type { query.clone().bypass_vector_index(); } - if let Some(SearchType::Approximate) = &search_type { + if let Some(SearchType::Approximate) = search_type { if let Some(nprobes) = nprobes { query.clone().nprobes(nprobes); } @@ -199,7 +294,7 @@ impl VectorStoreIndex for LanceDbV } } - if let Some(true) = &post_filter { + if let Some(true) = post_filter { query.clone().postfilter(); } @@ -208,15 +303,37 @@ impl VectorStoreIndex for LanceDbV let documents: DocumentRecords = self .document_table .query() - .only_if(format!("id IN [{}]", embeddings.document_ids().join(","))) + .only_if(format!( + "id IN ({})", + embeddings + .document_ids() + .iter() + .map(|id| format!("'{id}'")) + .collect::>() + .join(",") + )) .execute_query() .await?; - // Todo: get distances for each returned vector - - merge(documents, embeddings)?; - - todo!() + let document_embeddings = merge(&documents, &embeddings)?; + + Ok(document_embeddings + .into_iter() + .map(|doc| { + let distance = embeddings + .get_by_id(&doc.id) + .map(|records| { + records + .as_iter() + .next() + .map(|record| record.distance.unwrap_or(0.0)) + .unwrap_or(0.0) + }) + .unwrap_or(0.0); + + (distance as f64, doc) + }) + .collect()) } type SearchParams = SearchParams; diff --git a/rig-lancedb/src/table_schemas/document.rs b/rig-lancedb/src/table_schemas/document.rs index c9ce1639..084518c1 100644 --- a/rig-lancedb/src/table_schemas/document.rs +++ b/rig-lancedb/src/table_schemas/document.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use arrow_array::{RecordBatch, StringArray}; -use lancedb::arrow::arrow_schema::{ArrowError, DataType, Field, Fields, Schema}; +use arrow_array::{ArrayRef, RecordBatch, StringArray}; +use lancedb::arrow::arrow_schema::ArrowError; use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; use crate::utils::DeserializeArrow; @@ -13,14 +13,6 @@ pub struct DocumentRecord { pub document: String, } -/// Schema of `documents` table in LanceDB defined in `Schema` terms. -pub fn document_schema() -> Schema { - Schema::new(Fields::from(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("document", DataType::Utf8, false), - ])) -} - /// Wrapper around `Vec` #[derive(Debug)] pub struct DocumentRecords(Vec); @@ -79,19 +71,17 @@ impl TryFrom> for DocumentRecords { } } -/// Convert a list of `DocumentRecord` objects to a `RecordBatch` object. -/// All data written to a lanceDB table must be a `RecordBatch` object. +/// Convert a list of documents (`DocumentRecords`) to a `RecordBatch`, the data structure that needs ot be written to LanceDB. +/// All documents will be written to the database as part of the same batch. impl TryFrom for RecordBatch { type Error = ArrowError; fn try_from(document_records: DocumentRecords) -> Result { - let id = StringArray::from_iter_values(document_records.ids()); - let document = StringArray::from_iter_values(document_records.documents()); + let id = Arc::new(StringArray::from_iter_values(document_records.ids())) as ArrayRef; + let document = + Arc::new(StringArray::from_iter_values(document_records.documents())) as ArrayRef; - RecordBatch::try_new( - Arc::new(document_schema()), - vec![Arc::new(id), Arc::new(document)], - ) + RecordBatch::try_from_iter(vec![("id", id), ("document", document)]) } } @@ -149,7 +139,7 @@ mod tests { use crate::table_schemas::document::{DocumentRecord, DocumentRecords}; #[tokio::test] - async fn test_record_batch_deserialize() { + async fn test_record_batch_conversion() { let document_records = DocumentRecords(vec![ DocumentRecord { id: "ABC".to_string(), diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs index 109ede40..d66c6555 100644 --- a/rig-lancedb/src/table_schemas/embedding.rs +++ b/rig-lancedb/src/table_schemas/embedding.rs @@ -1,58 +1,95 @@ use std::{collections::HashMap, sync::Arc}; use arrow_array::{ - builder::{Float64Builder, ListBuilder}, - RecordBatch, StringArray, + builder::{FixedSizeListBuilder, Float64Builder}, + ArrayRef, RecordBatch, StringArray, }; -use lancedb::arrow::arrow_schema::{ArrowError, DataType, Field, Fields, Schema}; +use lancedb::arrow::arrow_schema::ArrowError; use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; use crate::utils::DeserializeArrow; -// Data format in the LanceDB table `embeddings` +/// Data format in the LanceDB table `embeddings` #[derive(Clone, Debug, PartialEq)] pub struct EmbeddingRecord { pub id: String, pub document_id: String, pub content: String, pub embedding: Vec, + /// Distance from prompt. + /// This value is only present after vector search executes and determines the distance + pub distance: Option, } +/// Group of EmbeddingRecord objects. This represents the list of embedding objects in a `DocumentEmbeddings` object. #[derive(Clone, Debug)] -pub struct EmbeddingRecords(Vec); +pub struct EmbeddingRecords { + records: Vec, + dimension: i32, +} impl EmbeddingRecords { - fn new(records: Vec) -> Self { - EmbeddingRecords(records) + fn new(records: Vec, dimension: i32) -> Self { + EmbeddingRecords { records, dimension } + } + + fn add_record(&mut self, record: EmbeddingRecord) { + self.records.push(record); } pub fn as_iter(&self) -> impl Iterator { - self.0.iter() + self.records.iter() } +} - fn add_record(&mut self, record: EmbeddingRecord) { - self.0.push(record); +/// HashMap where the key is the `DocumentEmbeddings` id +/// and the value is the`EmbeddingRecords` object that corresponds to the document. +#[derive(Debug)] +pub struct EmbeddingRecordsBatch(HashMap); + +impl EmbeddingRecordsBatch { + fn as_iter(&self) -> impl Iterator { + self.0.clone().into_values().collect::>().into_iter() + } + + pub fn get_by_id(&self, id: &str) -> Option { + self.0.get(id).cloned() + } + + pub fn document_ids(&self) -> Vec { + self.0.clone().into_keys().collect() } } +/// Convert from a `DocumentEmbeddings` to an `EmbeddingRecords` object (a list of `EmbeddingRecord` objects) impl From for EmbeddingRecords { fn from(document: DocumentEmbeddings) -> Self { - EmbeddingRecords( + EmbeddingRecords::new( document .embeddings .clone() .into_iter() - .map(move |embedding| EmbeddingRecord { - id: "".to_string(), + .enumerate() + .map(move |(i, embedding)| EmbeddingRecord { + id: format!("{}-{i}", document.id), document_id: document.id.clone(), content: embedding.document, embedding: embedding.vec, + distance: None, }) .collect(), + document + .embeddings + .first() + .map(|embedding| embedding.vec.len() as i32) + .unwrap_or(0), ) } } +/// Convert from a list of `DocumentEmbeddings` to an `EmbeddingRecordsBatch` object +/// For each `DocumentEmbeddings`, we create an `EmbeddingRecords` and add it to the +/// hashmap with its corresponding `DocumentEmbeddings` id. impl From> for EmbeddingRecordsBatch { fn from(documents: Vec) -> Self { EmbeddingRecordsBatch( @@ -66,6 +103,8 @@ impl From> for EmbeddingRecordsBatch { } } +/// Convert a list of embeddings (`EmbeddingRecords`) to a `RecordBatch`, the data structure that needs ot be written to LanceDB. +/// All embeddings related to a document will be written to the database as part of the same batch. impl TryFrom for RecordBatch { fn try_from(embedding_records: EmbeddingRecords) -> Result { let id = StringArray::from_iter_values( @@ -82,7 +121,8 @@ impl TryFrom for RecordBatch { .map(|record| record.content.clone()), ); - let mut builder = ListBuilder::new(Float64Builder::new()); + let mut builder = + FixedSizeListBuilder::new(Float64Builder::new(), embedding_records.dimension); embedding_records.as_iter().for_each(|record| { record .embedding @@ -91,54 +131,23 @@ impl TryFrom for RecordBatch { builder.append(true); }); - RecordBatch::try_new( - Arc::new(embedding_schema()), - vec![ - Arc::new(id), - Arc::new(document_id), - Arc::new(content), - Arc::new(builder.finish()), - ], - ) + RecordBatch::try_from_iter(vec![ + ("id", Arc::new(id) as ArrayRef), + ("document_id", Arc::new(document_id) as ArrayRef), + ("content", Arc::new(content) as ArrayRef), + ("embedding", Arc::new(builder.finish()) as ArrayRef), + ]) } type Error = ArrowError; } -pub struct EmbeddingRecordsBatch(HashMap); -impl EmbeddingRecordsBatch { - fn as_iter(&self) -> impl Iterator { - self.0.clone().into_values().collect::>().into_iter() - } - - pub fn get_by_id(&self, id: &str) -> Option { - self.0.get(id).cloned() - } - - pub fn document_ids(&self) -> Vec { - self.0.clone().into_keys().collect() - } -} - impl From for Vec> { fn from(embeddings: EmbeddingRecordsBatch) -> Self { embeddings.as_iter().map(RecordBatch::try_from).collect() } } -pub fn embedding_schema() -> Schema { - Schema::new(Fields::from(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("document_id", DataType::Utf8, false), - Field::new("content", DataType::Utf8, false), - Field::new( - "embedding", - DataType::List(Arc::new(Field::new("item", DataType::Float64, true))), - false, - ), - ])) -} - impl TryFrom for EmbeddingRecords { type Error = ArrowError; @@ -148,20 +157,39 @@ impl TryFrom for EmbeddingRecords { let contents = record_batch.deserialize_str_column(2)?; let embeddings = record_batch.deserialize_float_list_column(3)?; - Ok(EmbeddingRecords( + // There is a `_distance` field in the response if the executed query was a VectorQuery + // Otherwise, for normal queries, the `_distance` field is not present in the response. + let distances = if record_batch.num_columns() == 5 { + record_batch + .deserialize_float32_column(4)? + .into_iter() + .map(Some) + .collect() + } else { + vec![None; record_batch.num_rows()] + }; + + Ok(EmbeddingRecords::new( ids.into_iter() .zip(document_ids) .zip(contents) - .zip(embeddings) + .zip(embeddings.clone()) + .zip(distances) .map( - |(((id, document_id), content), embedding)| EmbeddingRecord { + |((((id, document_id), content), embedding), distance)| EmbeddingRecord { id: id.to_string(), document_id: document_id.to_string(), content: content.to_string(), embedding, + distance, }, ) .collect(), + embeddings + .iter() + .map(|embedding| embedding.len() as i32) + .next() + .unwrap_or(0), )) } } @@ -185,7 +213,10 @@ impl TryFrom> for EmbeddingRecordsBatch { .and_modify(|item: &mut EmbeddingRecords| { item.add_record(record.clone()) }) - .or_insert(EmbeddingRecords::new(vec![record.clone()])); + .or_insert(EmbeddingRecords::new( + vec![record.clone()], + record.embedding.len() as i32, + )); }); acc }); @@ -201,29 +232,34 @@ mod tests { use crate::table_schemas::embedding::{EmbeddingRecord, EmbeddingRecords}; #[tokio::test] - async fn test_record_batch_deserialize() { - let embedding_records = EmbeddingRecords(vec![ - EmbeddingRecord { - id: "some_id".to_string(), - document_id: "ABC".to_string(), - content: serde_json::json!({ - "title": "Hello world", - "body": "Greetings", - }) - .to_string(), - embedding: vec![1.0, 2.0, 3.0], - }, - EmbeddingRecord { - id: "another_id".to_string(), - document_id: "DEF".to_string(), - content: serde_json::json!({ - "title": "Sup dog", - "body": "Greetings", - }) - .to_string(), - embedding: vec![4.0, 5.0, 6.0], - }, - ]); + async fn test_record_batch_conversion() { + let embedding_records = EmbeddingRecords::new( + vec![ + EmbeddingRecord { + id: "some_id".to_string(), + document_id: "ABC".to_string(), + content: serde_json::json!({ + "title": "Hello world", + "body": "Greetings", + }) + .to_string(), + embedding: vec![1.0, 2.0, 3.0], + distance: None, + }, + EmbeddingRecord { + id: "another_id".to_string(), + document_id: "DEF".to_string(), + content: serde_json::json!({ + "title": "Sup dog", + "body": "Greetings", + }) + .to_string(), + embedding: vec![4.0, 5.0, 6.0], + distance: None, + }, + ], + 3, + ); let record_batch = RecordBatch::try_from(embedding_records).unwrap(); @@ -241,6 +277,7 @@ mod tests { }) .to_string(), embedding: vec![1.0, 2.0, 3.0], + distance: None } ); diff --git a/rig-lancedb/src/table_schemas/mod.rs b/rig-lancedb/src/table_schemas/mod.rs index 9f6f1ba5..e880ef87 100644 --- a/rig-lancedb/src/table_schemas/mod.rs +++ b/rig-lancedb/src/table_schemas/mod.rs @@ -11,8 +11,8 @@ pub mod document; pub mod embedding; pub fn merge( - documents: DocumentRecords, - embeddings: EmbeddingRecordsBatch, + documents: &DocumentRecords, + embeddings: &EmbeddingRecordsBatch, ) -> Result, VectorStoreError> { documents .as_iter() diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index 0d642ab9..4ee956f1 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,6 +1,9 @@ use std::sync::Arc; -use arrow_array::{Array, Float64Array, ListArray, RecordBatch, RecordBatchIterator, StringArray}; +use arrow_array::{ + Array, FixedSizeListArray, Float32Array, Float64Array, RecordBatch, RecordBatchIterator, + StringArray, +}; use futures::TryStreamExt; use lancedb::{ arrow::arrow_schema::{ArrowError, Schema}, @@ -15,6 +18,9 @@ pub trait DeserializeArrow { /// Define the column number that contains strings, i. /// For each item in the column, convert it to a string and collect the result in a vector of strings. fn deserialize_str_column(&self, i: usize) -> Result, ArrowError>; + /// Define the column number that contains float32's, i. + /// For each item in the column, convert it to a float32 and collect the result in a vector of float32. + fn deserialize_float32_column(&self, i: usize) -> Result, ArrowError>; /// Define the column number that contains the list of floats, i. /// For each item in the column, convert it to a list and for each item in the list, convert it to a float. /// Collect the result as a vector of vectors of floats. @@ -34,9 +40,21 @@ impl DeserializeArrow for RecordBatch { } } + fn deserialize_float32_column(&self, i: usize) -> Result, ArrowError> { + let column = self.column(i); + match column.as_any().downcast_ref::() { + Some(float_array) => Ok((0..float_array.len()) + .map(|j| float_array.value(j)) + .collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast column {i} to string array" + ))), + } + } + fn deserialize_float_list_column(&self, i: usize) -> Result>, ArrowError> { let column = self.column(i); - match column.as_any().downcast_ref::() { + match column.as_any().downcast_ref::() { Some(list_array) => (0..list_array.len()) .map( |j| match list_array.value(j).as_any().downcast_ref::() { @@ -50,7 +68,7 @@ impl DeserializeArrow for RecordBatch { ) .collect::, _>>(), None => Err(ArrowError::CastError(format!( - "Can't cast column {i} to list array" + "Can't cast column {i} to fixed size list array" ))), } } diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 8a8b4467..67a04664 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -87,9 +87,6 @@ impl MongoDbVectorStore { /// /// The index (of type "vector") must already exist for the MongoDB collection. /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. - /// - /// An additional filter can be provided to further restrict the documents that are - /// considered in the search. pub fn index(&self, model: M, index_name: &str) -> MongoDbVectorIndex { MongoDbVectorIndex::new(self.collection.clone(), model, index_name) } From d62bbbfff84efa2ecaf8afb61a7a9dab1b312a9b Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 23 Sep 2024 11:13:48 -0400 Subject: [PATCH 10/39] refactor: conversions from arrow types to primitive types --- .../examples/vector_search_local_ann.rs | 2 +- .../examples/vector_search_local_enn.rs | 51 +++++++++++++++ rig-lancedb/src/lib.rs | 12 +--- rig-lancedb/src/table_schemas/document.rs | 4 +- rig-lancedb/src/table_schemas/embedding.rs | 23 ++++--- rig-lancedb/src/table_schemas/mod.rs | 17 +++-- rig-lancedb/src/utils/mod.rs | 63 ++++++++++--------- 7 files changed, 111 insertions(+), 61 deletions(-) create mode 100644 rig-lancedb/examples/vector_search_local_enn.rs diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index ec02a704..41d8ccca 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -25,7 +25,7 @@ async fn main() -> Result<(), anyhow::Error> { .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") .build(); let response = agent - .prompt("Invent at least 150 words and their definitions") + .prompt("Invent at least 175 words and their definitions") .await?; let mut definitions: Vec = serde_json::from_str(&response)?; diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs new file mode 100644 index 00000000..8da18ef5 --- /dev/null +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -0,0 +1,51 @@ +use std::env; + +use rig::{ + embeddings::EmbeddingsBuilder, + providers::openai::Client, + vector_store::{VectorStore, VectorStoreIndexDyn}, +}; +use rig_lancedb::{LanceDbVectorStore, SearchParams}; + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store").execute().await?; + let mut vector_store = LanceDbVectorStore::new(&db, 1536).await?; + + // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_client = Client::new(&openai_api_key); + + // Select the embedding model and generate our embeddings + let model = openai_client.embedding_model("text-embedding-ada-002"); + + let embeddings = EmbeddingsBuilder::new(model.clone()) + .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") + .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") + .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .build() + .await?; + + // Add embeddings to vector store + vector_store.add_documents(embeddings).await?; + + // Create a vector index on our vector store + let index = vector_store.index(model); + + // Query the index + let results = index + .top_n_from_query( + "My boss says I zindle too much, what does that mean?", + 1, + &serde_json::to_string(&SearchParams::new(None, None, None, None, None))?, + ) + .await? + .into_iter() + .map(|(score, doc)| (score, doc.id, doc.document)) + .collect::>(); + + println!("Results: {:?}", results); + + Ok(()) +} diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 300f65f8..477fc58f 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -172,7 +172,7 @@ impl VectorStore for LanceDbVectorStore { } } -/// A vector index for a MongoDB collection. +/// A vector index for a LanceDB collection. pub struct LanceDbVectorIndex { model: M, embedding_table: lancedb::Table, @@ -303,15 +303,7 @@ impl VectorStoreIndex for LanceDbV let documents: DocumentRecords = self .document_table .query() - .only_if(format!( - "id IN ({})", - embeddings - .document_ids() - .iter() - .map(|id| format!("'{id}'")) - .collect::>() - .join(",") - )) + .only_if(format!("id IN ({})", embeddings.document_ids())) .execute_query() .await?; diff --git a/rig-lancedb/src/table_schemas/document.rs b/rig-lancedb/src/table_schemas/document.rs index 084518c1..56b19bef 100644 --- a/rig-lancedb/src/table_schemas/document.rs +++ b/rig-lancedb/src/table_schemas/document.rs @@ -97,8 +97,8 @@ impl TryFrom for DocumentRecords { type Error = ArrowError; fn try_from(record_batch: RecordBatch) -> Result { - let ids = record_batch.deserialize_str_column(0)?; - let documents = record_batch.deserialize_str_column(1)?; + let ids = record_batch.to_str(0)?; + let documents = record_batch.to_str(1)?; Ok(DocumentRecords( ids.into_iter() diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs index d66c6555..c73d4e53 100644 --- a/rig-lancedb/src/table_schemas/embedding.rs +++ b/rig-lancedb/src/table_schemas/embedding.rs @@ -2,12 +2,13 @@ use std::{collections::HashMap, sync::Arc}; use arrow_array::{ builder::{FixedSizeListBuilder, Float64Builder}, + types::{Float32Type, Float64Type}, ArrayRef, RecordBatch, StringArray, }; use lancedb::arrow::arrow_schema::ArrowError; use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; -use crate::utils::DeserializeArrow; +use crate::utils::{DeserializeArrow, DeserializePrimitiveArray}; /// Data format in the LanceDB table `embeddings` #[derive(Clone, Debug, PartialEq)] @@ -56,8 +57,13 @@ impl EmbeddingRecordsBatch { self.0.get(id).cloned() } - pub fn document_ids(&self) -> Vec { - self.0.clone().into_keys().collect() + pub fn document_ids(&self) -> String { + self.0 + .clone() + .into_keys() + .map(|id| format!("'{id}'")) + .collect::>() + .join(",") } } @@ -152,16 +158,17 @@ impl TryFrom for EmbeddingRecords { type Error = ArrowError; fn try_from(record_batch: RecordBatch) -> Result { - let ids = record_batch.deserialize_str_column(0)?; - let document_ids = record_batch.deserialize_str_column(1)?; - let contents = record_batch.deserialize_str_column(2)?; - let embeddings = record_batch.deserialize_float_list_column(3)?; + let ids = record_batch.to_str(0)?; + let document_ids = record_batch.to_str(1)?; + let contents = record_batch.to_str(2)?; + let embeddings = record_batch.to_float_list::(3)?; // There is a `_distance` field in the response if the executed query was a VectorQuery // Otherwise, for normal queries, the `_distance` field is not present in the response. let distances = if record_batch.num_columns() == 5 { record_batch - .deserialize_float32_column(4)? + .column(4) + .to_float::()? .into_iter() .map(Some) .collect() diff --git a/rig-lancedb/src/table_schemas/mod.rs b/rig-lancedb/src/table_schemas/mod.rs index e880ef87..5c175fd4 100644 --- a/rig-lancedb/src/table_schemas/mod.rs +++ b/rig-lancedb/src/table_schemas/mod.rs @@ -1,27 +1,26 @@ use document::{DocumentRecord, DocumentRecords}; use embedding::{EmbeddingRecord, EmbeddingRecordsBatch}; -use rig::{ - embeddings::{DocumentEmbeddings, Embedding}, - vector_store::VectorStoreError, -}; - -use crate::serde_to_rig_error; +use rig::embeddings::{DocumentEmbeddings, Embedding}; pub mod document; pub mod embedding; +/// Merge an `DocumentRecords` object with an `EmbeddingRecordsBatch` object. +/// These objects contain document and embedding data, respectively, read from LanceDB. +/// For each document in `DocumentRecords` find the embeddings from `EmbeddingRecordsBatch` that correspond to that document, +/// using the document_id as reference. pub fn merge( documents: &DocumentRecords, embeddings: &EmbeddingRecordsBatch, -) -> Result, VectorStoreError> { +) -> Result, serde_json::Error> { documents .as_iter() .map(|DocumentRecord { id, document }| { let emebedding_records = embeddings.get_by_id(id); - Ok::<_, VectorStoreError>(DocumentEmbeddings { + Ok(DocumentEmbeddings { id: id.to_string(), - document: serde_json::from_str(document).map_err(serde_to_rig_error)?, + document: serde_json::from_str(document)?, embeddings: match emebedding_records { Some(records) => records .as_iter() diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index 4ee956f1..a8ef758d 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use arrow_array::{ - Array, FixedSizeListArray, Float32Array, Float64Array, RecordBatch, RecordBatchIterator, - StringArray, + Array, ArrowPrimitiveType, FixedSizeListArray, PrimitiveArray, RecordBatch, + RecordBatchIterator, StringArray, }; use futures::TryStreamExt; use lancedb::{ @@ -13,22 +13,41 @@ use rig::vector_store::VectorStoreError; use crate::lancedb_to_rig_error; +pub trait DeserializePrimitiveArray { + fn to_float( + &self, + ) -> Result::Native>, ArrowError>; +} + +impl DeserializePrimitiveArray for &Arc { + fn to_float( + &self, + ) -> Result::Native>, ArrowError> { + match self.as_any().downcast_ref::>() { + Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast array: {self:?} to float array" + ))), + } + } +} + /// Trait used to "deserialize" a column of a RecordBatch object into a list o primitive types pub trait DeserializeArrow { /// Define the column number that contains strings, i. /// For each item in the column, convert it to a string and collect the result in a vector of strings. - fn deserialize_str_column(&self, i: usize) -> Result, ArrowError>; - /// Define the column number that contains float32's, i. - /// For each item in the column, convert it to a float32 and collect the result in a vector of float32. - fn deserialize_float32_column(&self, i: usize) -> Result, ArrowError>; + fn to_str(&self, i: usize) -> Result, ArrowError>; /// Define the column number that contains the list of floats, i. /// For each item in the column, convert it to a list and for each item in the list, convert it to a float. /// Collect the result as a vector of vectors of floats. - fn deserialize_float_list_column(&self, i: usize) -> Result>, ArrowError>; + fn to_float_list( + &self, + i: usize, + ) -> Result::Native>>, ArrowError>; } impl DeserializeArrow for RecordBatch { - fn deserialize_str_column(&self, i: usize) -> Result, ArrowError> { + fn to_str(&self, i: usize) -> Result, ArrowError> { let column = self.column(i); match column.as_any().downcast_ref::() { Some(str_array) => Ok((0..str_array.len()) @@ -40,32 +59,14 @@ impl DeserializeArrow for RecordBatch { } } - fn deserialize_float32_column(&self, i: usize) -> Result, ArrowError> { - let column = self.column(i); - match column.as_any().downcast_ref::() { - Some(float_array) => Ok((0..float_array.len()) - .map(|j| float_array.value(j)) - .collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast column {i} to string array" - ))), - } - } - - fn deserialize_float_list_column(&self, i: usize) -> Result>, ArrowError> { + fn to_float_list( + &self, + i: usize, + ) -> Result::Native>>, ArrowError> { let column = self.column(i); match column.as_any().downcast_ref::() { Some(list_array) => (0..list_array.len()) - .map( - |j| match list_array.value(j).as_any().downcast_ref::() { - Some(float_array) => Ok((0..float_array.len()) - .map(|k| float_array.value(k)) - .collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast value at index {j} to float array" - ))), - }, - ) + .map(|j| (&list_array.value(j)).to_float::()) .collect::, _>>(), None => Err(ArrowError::CastError(format!( "Can't cast column {i} to fixed size list array" From 22b43ba99e45b59c90f0a6263e8cf5b40b11ecc0 Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 23 Sep 2024 11:35:09 -0400 Subject: [PATCH 11/39] feat: add vector_search_s3_ann example --- rig-lancedb/examples/vector_search_s3_ann.rs | 91 ++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 rig-lancedb/examples/vector_search_s3_ann.rs diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs new file mode 100644 index 00000000..482db617 --- /dev/null +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -0,0 +1,91 @@ +use std::env; + +use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; +use rig::{ + completion::Prompt, + embeddings::EmbeddingsBuilder, + providers::openai::Client, + vector_store::{VectorStore, VectorStoreIndexDyn}, +}; +use rig_lancedb::{LanceDbVectorStore, SearchParams}; + +// Note: see docs to deploy LanceDB on other cloud providers such as google and azure. +// https://lancedb.github.io/lancedb/guides/storage/ + +#[tokio::main] +async fn main() -> Result<(), anyhow::Error> { + // Initialize LanceDB on S3. + // Note: see below docs for more options and IAM permission required to read/write to S3. + // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 + let db = lancedb::connect("s3://lancedb-test-829666124233") + .execute() + .await?; + let mut vector_store = LanceDbVectorStore::new(&db, 1536).await?; + + // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_client = Client::new(&openai_api_key); + + // Generate test data for RAG demo + let agent = openai_client + .agent("gpt-4o") + .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") + .build(); + let response = agent + .prompt("Invent at least 175 words and their definitions") + .await?; + let mut definitions: Vec = serde_json::from_str(&response)?; + + // Note: need at least 256 rows in order to create an index on a table but OpenAi limits the output size + // so we duplicate the vector for testing purposes. + definitions.extend(definitions.clone()); + + // Select the embedding model and generate our embeddings + let model = openai_client.embedding_model("text-embedding-ada-002"); + + let embeddings = EmbeddingsBuilder::new(model.clone()) + .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") + .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") + .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") + .simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect()) + .build() + .await?; + + // Add embeddings to vector store + vector_store.add_documents(embeddings).await?; + + // Create a vector index on our vector store + // IMPORTANT: Reuse the same model that was used to generate the embeddings + let index = vector_store.index(model); + + // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information + index + .create_index(lancedb::index::Index::IvfPq( + IvfPqIndexBuilder::default() + // This overrides the default distance type of L2 + .distance_type(DistanceType::Cosine), + )) + .await?; + + // Query the index + let results = index + .top_n_from_query( + "My boss says I zindle too much, what does that mean?", + 1, + &serde_json::to_string(&SearchParams::new( + Some(DistanceType::Cosine), + None, + None, + None, + None, + ))?, + ) + .await? + .into_iter() + .map(|(score, doc)| (score, doc.id, doc.document)) + .collect::>(); + + println!("Results: {:?}", results); + + Ok(()) +} From ad4569030c31f1f65dadab500820e10be63f0d27 Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 23 Sep 2024 13:48:17 -0400 Subject: [PATCH 12/39] feat: create enum for embedding models --- rig-core/examples/calculator_chatbot.rs | 4 +- rig-core/examples/rag.rs | 4 +- rig-core/examples/rag_dynamic_tools.rs | 4 +- rig-core/examples/vector_search.rs | 4 +- rig-core/examples/vector_search_cohere.rs | 8 +- rig-core/src/embeddings.rs | 7 ++ rig-core/src/providers/cohere.rs | 87 ++++++++++++++----- rig-core/src/providers/openai.rs | 83 +++++++++++++++--- .../examples/vector_search_local_ann.rs | 24 +++-- .../examples/vector_search_local_enn.rs | 17 ++-- rig-lancedb/examples/vector_search_s3_ann.rs | 29 +++---- rig-lancedb/src/lib.rs | 61 ++++--------- rig-mongodb/examples/vector_search_mongodb.rs | 4 +- 13 files changed, 207 insertions(+), 129 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index b7c5cdcc..bba21914 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -3,7 +3,7 @@ use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, embeddings::EmbeddingsBuilder, - providers::openai::Client, + providers::openai::{Client, OpenAIEmbeddingModel}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; @@ -245,7 +245,7 @@ async fn main() -> Result<(), anyhow::Error> { .dynamic_tool(Divide) .build(); - let embedding_model = openai_client.embedding_model("text-embedding-ada-002"); + let embedding_model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .tools(&toolset)? .build() diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index f7ec9d27..3a390b2a 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -3,7 +3,7 @@ use std::env; use rig::{ completion::Prompt, embeddings::EmbeddingsBuilder, - providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, + providers::openai::{Client, OpenAIEmbeddingModel}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; @@ -13,7 +13,7 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + let embedding_model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); // Create vector store, compute embeddings and load them in the store let mut vector_store = InMemoryVectorStore::default(); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 824bdc9f..777c75ce 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -2,7 +2,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, embeddings::EmbeddingsBuilder, - providers::openai::Client, + providers::openai::{Client, OpenAIEmbeddingModel}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; @@ -148,7 +148,7 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - let embedding_model = openai_client.embedding_model("text-embedding-ada-002"); + let embedding_model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); // Create vector store, compute tool embeddings and load them in the store let mut vector_store = InMemoryVectorStore::default(); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 41532a27..df6363d6 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -2,7 +2,7 @@ use std::env; use rig::{ embeddings::EmbeddingsBuilder, - providers::openai::Client, + providers::openai::{Client, OpenAIEmbeddingModel}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; @@ -12,7 +12,7 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - let model = openai_client.embedding_model("text-embedding-ada-002"); + let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); let mut vector_store = InMemoryVectorStore::default(); diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 5bac6ff8..579a298d 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -2,7 +2,7 @@ use std::env; use rig::{ embeddings::EmbeddingsBuilder, - providers::cohere::Client, + providers::cohere::{Client, CohereEmbeddingModel}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; @@ -12,8 +12,10 @@ async fn main() -> Result<(), anyhow::Error> { let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"); let cohere_client = Client::new(&cohere_api_key); - let document_model = cohere_client.embedding_model("embed-english-v3.0", "search_document"); - let search_model = cohere_client.embedding_model("embed-english-v3.0", "search_query"); + let document_model = + cohere_client.embedding_model(&CohereEmbeddingModel::EmbedEnglishV3, "search_document"); + let search_model = + cohere_client.embedding_model(&CohereEmbeddingModel::EmbedEnglishV3, "search_query"); let mut vector_store = InMemoryVectorStore::default(); diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 2d40bbc5..321b05a2 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -66,6 +66,10 @@ pub enum EmbeddingError { /// Error returned by the embedding model provider #[error("ProviderError: {0}")] ProviderError(String), + + /// Http error (e.g.: connection error, timeout, etc.) + #[error("BadModel: {0}")] + BadModel(String), } /// Trait for embedding models that can generate embeddings for documents. @@ -73,6 +77,9 @@ pub trait EmbeddingModel: Clone + Sync + Send { /// The maximum number of documents that can be embedded in a single request. const MAX_DOCUMENTS: usize; + /// The number of dimensions in the embedding vector. + fn ndims(&self) -> usize; + /// Embed a single document fn embed_document( &self, diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 5f5fc397..7030d87c 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -62,11 +62,19 @@ impl Client { self.http_client.post(url) } - pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel { + pub fn embedding_model( + &self, + model: &CohereEmbeddingModel, + input_type: &str, + ) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model, input_type) } - pub fn embeddings(&self, model: &str, input_type: &str) -> EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &CohereEmbeddingModel, + input_type: &str, + ) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } @@ -133,20 +141,47 @@ enum ApiResponse { // ================================================================ // Cohere Embedding API // ================================================================ -/// `embed-english-v3.0` embedding model -pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0"; -/// `embed-english-light-v3.0` embedding model -pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0"; -/// `embed-multilingual-v3.0` embedding model -pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0"; -/// `embed-multilingual-light-v3.0` embedding model -pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0"; -/// `embed-english-v2.0` embedding model -pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0"; -/// `embed-english-light-v2.0` embedding model -pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0"; -/// `embed-multilingual-v2.0` embedding model -pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0"; +#[derive(Debug, Clone)] +pub enum CohereEmbeddingModel { + EmbedEnglishV3, + EmbedEnglishLightV3, + EmbedMultilingualV3, + EmbedMultilingualLightV3, + EmbedEnglishV2, + EmbedEnglishLightV2, + EmbedMultilingualV2, +} + +impl std::str::FromStr for CohereEmbeddingModel { + type Err = EmbeddingError; + + fn from_str(s: &str) -> std::result::Result { + match s { + "embed-english-v3.0" => Ok(Self::EmbedEnglishV3), + "embed-english-light-v3.0" => Ok(Self::EmbedEnglishLightV3), + "embed-multilingual-v3.0" => Ok(Self::EmbedMultilingualV3), + "embed-multilingual-light-v3.0" => Ok(Self::EmbedMultilingualLightV3), + "embed-english-v2.0" => Ok(Self::EmbedEnglishV2), + "embed-english-light-v2.0" => Ok(Self::EmbedEnglishLightV2), + "embed-multilingual-v2.0" => Ok(Self::EmbedMultilingualV2), + _ => Err(EmbeddingError::BadModel(s.to_string())), + } + } +} + +impl std::fmt::Display for CohereEmbeddingModel { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::EmbedEnglishLightV3 => write!(f, "embed-english-light-v3.0"), + Self::EmbedEnglishV3 => write!(f, "embed-english-v3.0"), + Self::EmbedMultilingualLightV3 => write!(f, "embed-multilingual-light-v3.0"), + Self::EmbedMultilingualV3 => write!(f, "embed-multilingual-v3.0"), + Self::EmbedEnglishV2 => write!(f, "embed-english-v2.0"), + Self::EmbedEnglishLightV2 => write!(f, "embed-english-light-v2.0"), + Self::EmbedMultilingualV2 => write!(f, "embed-multilingual-v2.0"), + } + } +} #[derive(Deserialize)] pub struct EmbeddingResponse { @@ -191,13 +226,25 @@ pub struct BilledUnits { #[derive(Clone)] pub struct EmbeddingModel { client: Client, - pub model: String, + pub model: CohereEmbeddingModel, pub input_type: String, } impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 96; + fn ndims(&self) -> usize { + match self.model { + CohereEmbeddingModel::EmbedEnglishV3 => 1024, + CohereEmbeddingModel::EmbedEnglishLightV3 => 384, + CohereEmbeddingModel::EmbedMultilingualV3 => 1024, + CohereEmbeddingModel::EmbedMultilingualLightV3 => 384, + CohereEmbeddingModel::EmbedEnglishV2 => 4096, + CohereEmbeddingModel::EmbedEnglishLightV2 => 1024, + CohereEmbeddingModel::EmbedMultilingualV2 => 768, + } + } + async fn embed_documents( &self, documents: Vec, @@ -206,7 +253,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .client .post("/v1/embed") .json(&json!({ - "model": self.model, + "model": self.model.to_string(), "texts": documents, "input_type": self.input_type, })) @@ -242,10 +289,10 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } impl EmbeddingModel { - pub fn new(client: Client, model: &str, input_type: &str) -> Self { + pub fn new(client: Client, model: &CohereEmbeddingModel, input_type: &str) -> Self { Self { client, - model: model.to_string(), + model: model.clone(), input_type: input_type.to_string(), } } diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 87c3f557..a2e63727 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -79,7 +79,7 @@ impl Client { /// /// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE); /// ``` - pub fn embedding_model(&self, model: &str) -> EmbeddingModel { + pub fn embedding_model(&self, model: &OpenAIEmbeddingModel) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model) } @@ -99,7 +99,10 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder { + pub fn embeddings( + &self, + model: &OpenAIEmbeddingModel, + ) -> embeddings::EmbeddingsBuilder { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } @@ -205,12 +208,60 @@ enum ApiResponse { // ================================================================ // OpenAI Embedding API // ================================================================ -/// `text-embedding-3-large` embedding model -pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large"; -/// `text-embedding-3-small` embedding model -pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small"; -/// `text-embedding-ada-002` embedding model -pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002"; +#[derive(Debug, Clone)] +pub enum OpenAIEmbeddingModel { + TextEmbedding3Large, + TextEmbedding3Small, + TextEmbeddingAda002, +} + +impl std::str::FromStr for OpenAIEmbeddingModel { + type Err = EmbeddingError; + + fn from_str(s: &str) -> std::result::Result { + match s { + "text-embedding-3-large" => Ok(Self::TextEmbedding3Large), + "text-embedding-3-small" => Ok(Self::TextEmbedding3Small), + "text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002), + _ => Err(EmbeddingError::BadModel(s.to_string())), + } + } +} + +impl std::fmt::Display for OpenAIEmbeddingModel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OpenAIEmbeddingModel::TextEmbedding3Large => write!(f, "text-embedding-3-large"), + OpenAIEmbeddingModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"), + OpenAIEmbeddingModel::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingRequest { + pub model: String, + pub input: Vec, + pub encoding_format: String, +} + +impl EmbeddingRequest { + pub fn new(model: &str, input: Vec) -> Self { + Self { + model: model.to_string(), + input, + encoding_format: "float".to_string(), + } + } +} + +#[derive(Debug, Deserialize)] +pub struct EmbeddingResponseError { + pub code: String, + pub message: String, + pub param: Option, + pub type_: String, +} #[derive(Debug, Deserialize)] pub struct EmbeddingResponse { @@ -251,12 +302,20 @@ pub struct Usage { #[derive(Clone)] pub struct EmbeddingModel { client: Client, - pub model: String, + pub model: OpenAIEmbeddingModel, } impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 1024; + fn ndims(&self) -> usize { + match self.model { + OpenAIEmbeddingModel::TextEmbedding3Large => 3072, + OpenAIEmbeddingModel::TextEmbedding3Small => 1536, + OpenAIEmbeddingModel::TextEmbeddingAda002 => 1536, + } + } + async fn embed_documents( &self, documents: Vec, @@ -265,7 +324,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .client .post("/v1/embeddings") .json(&json!({ - "model": self.model, + "model": self.model.to_string(), "input": documents, })) .send() @@ -298,10 +357,10 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } impl EmbeddingModel { - pub fn new(client: Client, model: &str) -> Self { + pub fn new(client: Client, model: &OpenAIEmbeddingModel) -> Self { Self { client, - model: model.to_string(), + model: model.clone(), } } } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 41d8ccca..0b1a39e6 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -4,21 +4,24 @@ use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ completion::Prompt, embeddings::EmbeddingsBuilder, - providers::openai::Client, + providers::openai::{Client, OpenAIEmbeddingModel}, vector_store::{VectorStore, VectorStoreIndexDyn}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - // Initialize LanceDB locally. - let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, 1536).await?; - // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); + // Select the embedding model and generate our embeddings + let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store").execute().await?; + let mut vector_store = LanceDbVectorStore::new(&db, &model).await?; + // Generate test data for RAG demo let agent = openai_client .agent("gpt-4o") @@ -33,9 +36,6 @@ async fn main() -> Result<(), anyhow::Error> { // so we duplicate the vector for testing purposes. definitions.extend(definitions.clone()); - // Select the embedding model and generate our embeddings - let model = openai_client.embedding_model("text-embedding-ada-002"); - let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") @@ -47,12 +47,8 @@ async fn main() -> Result<(), anyhow::Error> { // Add embeddings to vector store vector_store.add_documents(embeddings).await?; - // Create a vector index on our vector store - // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model); - // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information - index + vector_store .create_index(lancedb::index::Index::IvfPq( IvfPqIndexBuilder::default() // This overrides the default distance type of L2 @@ -61,7 +57,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; // Query the index - let results = index + let results = vector_store .top_n_from_query( "My boss says I zindle too much, what does that mean?", 1, diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 8da18ef5..125ea73c 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -2,23 +2,23 @@ use std::env; use rig::{ embeddings::EmbeddingsBuilder, - providers::openai::Client, + providers::openai::{Client, OpenAIEmbeddingModel}, vector_store::{VectorStore, VectorStoreIndexDyn}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { - // Initialize LanceDB locally. - let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, 1536).await?; - // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); // Select the embedding model and generate our embeddings - let model = openai_client.embedding_model("text-embedding-ada-002"); + let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store").execute().await?; + let mut vector_store = LanceDbVectorStore::new(&db, &model).await?; let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") @@ -30,11 +30,8 @@ async fn main() -> Result<(), anyhow::Error> { // Add embeddings to vector store vector_store.add_documents(embeddings).await?; - // Create a vector index on our vector store - let index = vector_store.index(model); - // Query the index - let results = index + let results = vector_store .top_n_from_query( "My boss says I zindle too much, what does that mean?", 1, diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 482db617..73d3a222 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,7 +4,7 @@ use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ completion::Prompt, embeddings::EmbeddingsBuilder, - providers::openai::Client, + providers::openai::{Client, OpenAIEmbeddingModel}, vector_store::{VectorStore, VectorStoreIndexDyn}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; @@ -14,17 +14,20 @@ use rig_lancedb::{LanceDbVectorStore, SearchParams}; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { + // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). + let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); + let openai_client = Client::new(&openai_api_key); + + // Select the embedding model and generate our embeddings + let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + // Initialize LanceDB on S3. // Note: see below docs for more options and IAM permission required to read/write to S3. // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 let db = lancedb::connect("s3://lancedb-test-829666124233") .execute() .await?; - let mut vector_store = LanceDbVectorStore::new(&db, 1536).await?; - - // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). - let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); - let openai_client = Client::new(&openai_api_key); + let mut vector_store = LanceDbVectorStore::new(&db, &model).await?; // Generate test data for RAG demo let agent = openai_client @@ -40,10 +43,7 @@ async fn main() -> Result<(), anyhow::Error> { // so we duplicate the vector for testing purposes. definitions.extend(definitions.clone()); - // Select the embedding model and generate our embeddings - let model = openai_client.embedding_model("text-embedding-ada-002"); - - let embeddings = EmbeddingsBuilder::new(model.clone()) + let embeddings: Vec = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") @@ -54,12 +54,8 @@ async fn main() -> Result<(), anyhow::Error> { // Add embeddings to vector store vector_store.add_documents(embeddings).await?; - // Create a vector index on our vector store - // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model); - // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information - index + vector_store .create_index(lancedb::index::Index::IvfPq( IvfPqIndexBuilder::default() // This overrides the default distance type of L2 @@ -68,11 +64,12 @@ async fn main() -> Result<(), anyhow::Error> { .await?; // Query the index - let results = index + let results = vector_store .top_n_from_query( "My boss says I zindle too much, what does that mean?", 1, &serde_json::to_string(&SearchParams::new( + // Important: use the same same distance type that was used to train the index. Some(DistanceType::Cosine), None, None, diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 477fc58f..6e938066 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -17,19 +17,17 @@ use utils::{Insert, Query}; mod table_schemas; mod utils; -pub struct LanceDbVectorStore { +pub struct LanceDbVectorStore { + model: M, document_table: lancedb::Table, embedding_table: lancedb::Table, - embedding_dimension: i32, } -impl LanceDbVectorStore { +impl LanceDbVectorStore { /// Note: Tables are created inside the new function rather than created outside and passed as reference to new function. /// This is because a specific schema needs to be enforced on the tables and this is done at creation time. - pub async fn new( - db: &lancedb::Connection, - embedding_dimension: i32, - ) -> Result { + pub async fn new(db: &lancedb::Connection, model: &M) -> Result { + // db.embedding_registry().register(name, function) Ok(Self { document_table: db .create_empty_table("documents", Arc::new(Self::document_schema())) @@ -38,11 +36,11 @@ impl LanceDbVectorStore { embedding_table: db .create_empty_table( "embeddings", - Arc::new(Self::embedding_schema(embedding_dimension)), + Arc::new(Self::embedding_schema(model.ndims() as i32)), ) .execute() .await?, - embedding_dimension, + model: model.clone(), }) } @@ -69,12 +67,13 @@ impl LanceDbVectorStore { ])) } - pub fn index(&self, model: M) -> LanceDbVectorIndex { - LanceDbVectorIndex::new( - model, - self.embedding_table.clone(), - self.document_table.clone(), - ) + pub async fn create_index(&self, index: Index) -> Result<(), lancedb::Error> { + self.embedding_table + .create_index(&["embedding"], index) + .execute() + .await?; + + Ok(()) } } @@ -86,7 +85,7 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { VectorStoreError::JsonError(e) } -impl VectorStore for LanceDbVectorStore { +impl VectorStore for LanceDbVectorStore { type Q = lancedb::query::Query; async fn add_documents( @@ -106,7 +105,7 @@ impl VectorStore for LanceDbVectorStore { self.embedding_table .insert( embedding_records, - Self::embedding_schema(self.embedding_dimension), + Self::embedding_schema(self.model.ndims() as i32), ) .await .map_err(lancedb_to_rig_error)?; @@ -172,32 +171,6 @@ impl VectorStore for LanceDbVectorStore { } } -/// A vector index for a LanceDB collection. -pub struct LanceDbVectorIndex { - model: M, - embedding_table: lancedb::Table, - document_table: lancedb::Table, -} - -impl LanceDbVectorIndex { - pub fn new(model: M, embedding_table: lancedb::Table, document_table: lancedb::Table) -> Self { - Self { - model, - embedding_table, - document_table, - } - } - - pub async fn create_index(&self, index: Index) -> Result<(), lancedb::Error> { - self.embedding_table - .create_index(&["embedding"], index) - .execute() - .await?; - - Ok(()) - } -} - /// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. #[derive(Deserialize, Serialize, Debug, Clone)] pub enum SearchType { @@ -245,7 +218,7 @@ impl SearchParams { } } -impl VectorStoreIndex for LanceDbVectorIndex { +impl VectorStoreIndex for LanceDbVectorStore { async fn top_n_from_query( &self, query: &str, diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index dbfa83cc..d51d6d48 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -3,7 +3,7 @@ use std::env; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, - providers::openai::Client, + providers::openai::{Client, OpenAIEmbeddingModel}, vector_store::{VectorStore, VectorStoreIndex}, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; @@ -32,7 +32,7 @@ async fn main() -> Result<(), anyhow::Error> { let mut vector_store = MongoDbVectorStore::new(collection); // Select the embedding model and generate our embeddings - let model = openai_client.embedding_model("text-embedding-ada-002"); + let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") From e22d778d4c0d9c1c9f2f427fb17e9d0d2efa76ff Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 23 Sep 2024 14:15:10 -0400 Subject: [PATCH 13/39] ci: makes the protoc compiler available on github workflows --- .github/workflows/ci.yaml | 3 +++ rig-core/src/agent.rs | 16 ++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index a098a14b..584532ca 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -49,5 +49,8 @@ jobs: with: components: clippy + - name: Install Protoc + uses: arduino/setup-protoc@v3 + - name: Run clippy action uses: clechasseur/rs-clippy-check@v3 diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 94851f4f..0647815a 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -167,10 +167,10 @@ impl Completion for Agent { chat_history: Vec, ) -> Result, CompletionError> { let dynamic_context = stream::iter(self.dynamic_context.iter()) - .then(|(num_sample, index, search_query)| async { + .then(|(num_sample, index, search_params)| async { Ok::<_, VectorStoreError>( index - .top_n_from_query(prompt, *num_sample, search_query) + .top_n_from_query(prompt, *num_sample, search_params) .await? .into_iter() .map(|(_, doc)| { @@ -195,10 +195,10 @@ impl Completion for Agent { .map_err(|e| CompletionError::RequestError(Box::new(e)))?; let dynamic_tools = stream::iter(self.dynamic_tools.iter()) - .then(|(num_sample, index, search_query)| async { + .then(|(num_sample, index, search_params)| async { Ok::<_, VectorStoreError>( index - .top_n_ids_from_query(prompt, *num_sample, search_query) + .top_n_ids_from_query(prompt, *num_sample, search_params) .await? .into_iter() .map(|(_, doc)| doc) @@ -360,10 +360,10 @@ impl AgentBuilder { mut self, sample: usize, dynamic_context: impl VectorStoreIndexDyn + 'static, - params: String, + search_params: String, ) -> Self { self.dynamic_context - .push((sample, Box::new(dynamic_context), params)); + .push((sample, Box::new(dynamic_context), search_params)); self } @@ -374,10 +374,10 @@ impl AgentBuilder { sample: usize, dynamic_tools: impl VectorStoreIndexDyn + 'static, toolset: ToolSet, - params: String, + search_params: String, ) -> Self { self.dynamic_tools - .push((sample, Box::new(dynamic_tools), params)); + .push((sample, Box::new(dynamic_tools), search_params)); self.tools.add_tools(toolset); self } From 4debc0edef0005d36bf98ce6d1f16a6fff96b8d7 Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 23 Sep 2024 14:46:48 -0400 Subject: [PATCH 14/39] fix: reduce opanai generated content in ANN examples --- rig-lancedb/examples/vector_search_local_ann.rs | 7 ++++--- rig-lancedb/examples/vector_search_s3_ann.rs | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 0b1a39e6..3317a8df 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -28,12 +28,13 @@ async fn main() -> Result<(), anyhow::Error> { .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") .build(); let response = agent - .prompt("Invent at least 175 words and their definitions") + .prompt("Invent at least 100 words and their definitions") .await?; let mut definitions: Vec = serde_json::from_str(&response)?; - // Note: need at least 256 rows in order to create an index on a table but OpenAi limits the output size - // so we duplicate the vector for testing purposes. + // Note: need at least 256 rows in order to create an index on a table but OpenAI limits the output size + // so we triplicate the vector for testing purposes. + definitions.extend(definitions.clone()); definitions.extend(definitions.clone()); let embeddings = EmbeddingsBuilder::new(model.clone()) diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 73d3a222..6ff25d94 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -35,12 +35,13 @@ async fn main() -> Result<(), anyhow::Error> { .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") .build(); let response = agent - .prompt("Invent at least 175 words and their definitions") + .prompt("Invent at least 100 words and their definitions") .await?; let mut definitions: Vec = serde_json::from_str(&response)?; - // Note: need at least 256 rows in order to create an index on a table but OpenAi limits the output size - // so we duplicate the vector for testing purposes. + // Note: need at least 256 rows in order to create an index on a table but OpenAI limits the output size + // so we triplicate the vector for testing purposes. + definitions.extend(definitions.clone()); definitions.extend(definitions.clone()); let embeddings: Vec = EmbeddingsBuilder::new(model.clone()) From 7b71aa1ae15aa2abbc0b91c9cc3062b66bf832cb Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 23 Sep 2024 15:03:16 -0400 Subject: [PATCH 15/39] feat: add indexes and tables for simple search --- rig-lancedb/src/lib.rs | 43 +++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 6e938066..815344ae 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use lancedb::{ arrow::arrow_schema::{DataType, Field, Fields, Schema}, - index::Index, + index::{Index}, query::QueryBase, DistanceType, }; @@ -27,19 +27,22 @@ impl LanceDbVectorStore { /// Note: Tables are created inside the new function rather than created outside and passed as reference to new function. /// This is because a specific schema needs to be enforced on the tables and this is done at creation time. pub async fn new(db: &lancedb::Connection, model: &M) -> Result { - // db.embedding_registry().register(name, function) + let document_table = db + .create_empty_table("documents", Arc::new(Self::document_schema())) + .execute() + .await?; + + let embedding_table = db + .create_empty_table( + "embeddings", + Arc::new(Self::embedding_schema(model.ndims() as i32)), + ) + .execute() + .await?; + Ok(Self { - document_table: db - .create_empty_table("documents", Arc::new(Self::document_schema())) - .execute() - .await?, - embedding_table: db - .create_empty_table( - "embeddings", - Arc::new(Self::embedding_schema(model.ndims() as i32)), - ) - .execute() - .await?, + document_table, + embedding_table, model: model.clone(), }) } @@ -67,6 +70,20 @@ impl LanceDbVectorStore { ])) } + pub async fn create_document_index(&self, index: Index) -> Result<(), lancedb::Error>{ + self.document_table + .create_index(&["id"], index) + .execute() + .await + } + + pub async fn create_embedding_index(&self, index: Index) -> Result<(), lancedb::Error>{ + self.embedding_table + .create_index(&["id", "document_id"], index) + .execute() + .await + } + pub async fn create_index(&self, index: Index) -> Result<(), lancedb::Error> { self.embedding_table .create_index(&["embedding"], index) From e63d5a1fd4975e754a416d4b257afcfe3d687474 Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 23 Sep 2024 15:53:09 -0400 Subject: [PATCH 16/39] style: cargo fmt --- rig-lancedb/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 815344ae..f81a643c 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use lancedb::{ arrow::arrow_schema::{DataType, Field, Fields, Schema}, - index::{Index}, + index::Index, query::QueryBase, DistanceType, }; @@ -70,14 +70,14 @@ impl LanceDbVectorStore { ])) } - pub async fn create_document_index(&self, index: Index) -> Result<(), lancedb::Error>{ + pub async fn create_document_index(&self, index: Index) -> Result<(), lancedb::Error> { self.document_table .create_index(&["id"], index) .execute() .await } - pub async fn create_embedding_index(&self, index: Index) -> Result<(), lancedb::Error>{ + pub async fn create_embedding_index(&self, index: Index) -> Result<(), lancedb::Error> { self.embedding_table .create_index(&["id", "document_id"], index) .execute() From 4d28f615b26813e2112573a29cc95eb85b6645e9 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 24 Sep 2024 10:21:48 -0400 Subject: [PATCH 17/39] refactor: remove associated type on VectorStoreIndex trait --- rig-core/examples/calculator_chatbot.rs | 2 +- rig-core/examples/rag.rs | 2 +- rig-core/examples/rag_dynamic_tools.rs | 2 +- rig-core/examples/vector_search.rs | 2 +- rig-core/examples/vector_search_cohere.rs | 2 +- rig-core/src/agent.rs | 23 +-- rig-core/src/vector_store/in_memory_store.rs | 7 +- rig-core/src/vector_store/mod.rs | 53 +----- .../examples/vector_search_local_ann.rs | 19 +- .../examples/vector_search_local_enn.rs | 8 +- rig-lancedb/examples/vector_search_s3_ann.rs | 20 +- rig-lancedb/src/lib.rs | 174 ++++++++++-------- rig-lancedb/src/table_schemas/document.rs | 19 +- rig-lancedb/src/table_schemas/embedding.rs | 18 +- rig-lancedb/src/utils/mod.rs | 48 +++-- rig-mongodb/examples/vector_search_mongodb.rs | 4 +- rig-mongodb/src/lib.rs | 46 +++-- 17 files changed, 221 insertions(+), 228 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index bba21914..c096857d 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -272,7 +272,7 @@ async fn main() -> Result<(), anyhow::Error> { ) // Add a dynamic tool source with a sample rate of 1 (i.e.: only // 1 additional tool will be added to prompts) - .dynamic_tools(4, index, toolset, "".to_string()) + .dynamic_tools(4, index, toolset) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index 3a390b2a..fae0b91d 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -35,7 +35,7 @@ async fn main() -> Result<(), anyhow::Error> { You are a dictionary assistant here to assist the user in understanding the meaning of words. You will find additional non-standard word definitions that could be useful below. ") - .dynamic_context(1, index, "".to_string()) + .dynamic_context(1, index) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index 777c75ce..cb7a955d 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -174,7 +174,7 @@ async fn main() -> Result<(), anyhow::Error> { .preamble("You are a calculator here to help the user perform arithmetic operations.") // Add a dynamic tool source with a sample rate of 1 (i.e.: only // 1 additional tool will be added to prompts) - .dynamic_tools(1, index, toolset, "".to_string()) + .dynamic_tools(1, index, toolset) .build(); // Prompt the agent and print the response diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 4e2e5870..04ba2ab3 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -24,7 +24,7 @@ async fn main() -> Result<(), anyhow::Error> { let index = InMemoryVectorIndex::from_embeddings(model, embeddings).await?; let results = index - .top_n_from_query("What is a linglingdong?", 1, ()) + .top_n_from_query("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index 579a298d..a8c2d163 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -31,7 +31,7 @@ async fn main() -> Result<(), anyhow::Error> { let index = vector_store.index(search_model); let results = index - .top_n_from_query("What is a linglingdong?", 1, ()) + .top_n_from_query("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-core/src/agent.rs b/rig-core/src/agent.rs index 0647815a..8648a9eb 100644 --- a/rig-core/src/agent.rs +++ b/rig-core/src/agent.rs @@ -153,9 +153,9 @@ pub struct Agent { /// Additional parameters to be passed to the model additional_params: Option, /// List of vector store, with the sample number - dynamic_context: Vec<(usize, Box, String)>, + dynamic_context: Vec<(usize, Box)>, /// Dynamic tools - dynamic_tools: Vec<(usize, Box, String)>, + dynamic_tools: Vec<(usize, Box)>, /// Actual tool implementations pub tools: ToolSet, } @@ -167,10 +167,10 @@ impl Completion for Agent { chat_history: Vec, ) -> Result, CompletionError> { let dynamic_context = stream::iter(self.dynamic_context.iter()) - .then(|(num_sample, index, search_params)| async { + .then(|(num_sample, index)| async { Ok::<_, VectorStoreError>( index - .top_n_from_query(prompt, *num_sample, search_params) + .top_n_from_query(prompt, *num_sample) .await? .into_iter() .map(|(_, doc)| { @@ -195,10 +195,10 @@ impl Completion for Agent { .map_err(|e| CompletionError::RequestError(Box::new(e)))?; let dynamic_tools = stream::iter(self.dynamic_tools.iter()) - .then(|(num_sample, index, search_params)| async { + .then(|(num_sample, index)| async { Ok::<_, VectorStoreError>( index - .top_n_ids_from_query(prompt, *num_sample, search_params) + .top_n_ids_from_query(prompt, *num_sample) .await? .into_iter() .map(|(_, doc)| doc) @@ -296,9 +296,9 @@ pub struct AgentBuilder { /// Additional parameters to be passed to the model additional_params: Option, /// List of vector store, with the sample number - dynamic_context: Vec<(usize, Box, String)>, + dynamic_context: Vec<(usize, Box)>, /// Dynamic tools - dynamic_tools: Vec<(usize, Box, String)>, + dynamic_tools: Vec<(usize, Box)>, /// Temperature of the model temperature: Option, /// Actual tool implementations @@ -360,10 +360,9 @@ impl AgentBuilder { mut self, sample: usize, dynamic_context: impl VectorStoreIndexDyn + 'static, - search_params: String, ) -> Self { self.dynamic_context - .push((sample, Box::new(dynamic_context), search_params)); + .push((sample, Box::new(dynamic_context))); self } @@ -374,10 +373,8 @@ impl AgentBuilder { sample: usize, dynamic_tools: impl VectorStoreIndexDyn + 'static, toolset: ToolSet, - search_params: String, ) -> Self { - self.dynamic_tools - .push((sample, Box::new(dynamic_tools), search_params)); + self.dynamic_tools.push((sample, Box::new(dynamic_tools))); self.tools.add_tools(toolset); self } diff --git a/rig-core/src/vector_store/in_memory_store.rs b/rig-core/src/vector_store/in_memory_store.rs index c2239a23..02c19cf8 100644 --- a/rig-core/src/vector_store/in_memory_store.rs +++ b/rig-core/src/vector_store/in_memory_store.rs @@ -198,24 +198,19 @@ impl InMemoryVectorIndex { } impl VectorStoreIndex for InMemoryVectorIndex { - type SearchParams = (); - async fn top_n_from_query( &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n, search_params) - .await + self.top_n_from_embedding(&prompt_embedding, n).await } async fn top_n_from_embedding( &self, query_embedding: &Embedding, n: usize, - _search_params: Self::SearchParams, ) -> Result, VectorStoreError> { // Sort documents by best embedding distance let mut docs: EmbeddingRanking = BinaryHeap::new(); diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index 0e2d9983..c042f0c3 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -50,8 +50,6 @@ pub trait VectorStore: Send + Sync { /// Trait for vector store indexes pub trait VectorStoreIndex: Send + Sync { - type SearchParams: for<'a> Deserialize<'a> + Send + Sync; - /// Get the top n documents based on the distance to the given embedding. /// The distance is calculated as the cosine distance between the prompt and /// the document embedding. @@ -60,7 +58,6 @@ pub trait VectorStoreIndex: Send + Sync { &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send; /// Same as `top_n_from_query` but returns the documents without its embeddings. @@ -69,10 +66,9 @@ pub trait VectorStoreIndex: Send + Sync { &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self.top_n_from_query(query, n, search_params).await?; + let documents = self.top_n_from_query(query, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) @@ -85,11 +81,10 @@ pub trait VectorStoreIndex: Send + Sync { &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self.top_n_from_query(query, n, search_params).await?; + let documents = self.top_n_from_query(query, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -105,7 +100,6 @@ pub trait VectorStoreIndex: Send + Sync { &self, prompt_embedding: &Embedding, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send; /// Same as `top_n_from_embedding` but returns the documents without its embeddings. @@ -114,12 +108,9 @@ pub trait VectorStoreIndex: Send + Sync { &self, prompt_embedding: &Embedding, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self - .top_n_from_embedding(prompt_embedding, n, search_params) - .await?; + let documents = self.top_n_from_embedding(prompt_embedding, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, serde_json::from_value(doc.document).unwrap())) @@ -132,13 +123,10 @@ pub trait VectorStoreIndex: Send + Sync { &self, prompt_embedding: &Embedding, n: usize, - search_params: Self::SearchParams, ) -> impl std::future::Future, VectorStoreError>> + Send { async move { - let documents = self - .top_n_from_embedding(prompt_embedding, n, search_params) - .await?; + let documents = self.top_n_from_embedding(prompt_embedding, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -152,17 +140,15 @@ pub trait VectorStoreIndexDyn: Send + Sync { &'a self, query: &'a str, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>>; fn top_n_ids_from_query<'a>( &'a self, query: &'a str, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>> { Box::pin(async move { - let documents = self.top_n_from_query(query, n, search_params).await?; + let documents = self.top_n_from_query(query, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -174,19 +160,15 @@ pub trait VectorStoreIndexDyn: Send + Sync { &'a self, prompt_embedding: &'a Embedding, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>>; fn top_n_ids_from_embedding<'a>( &'a self, prompt_embedding: &'a Embedding, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>> { Box::pin(async move { - let documents = self - .top_n_from_embedding(prompt_embedding, n, search_params) - .await?; + let documents = self.top_n_from_embedding(prompt_embedding, n).await?; Ok(documents .into_iter() .map(|(distance, doc)| (distance, doc.id)) @@ -200,44 +182,26 @@ impl VectorStoreIndexDyn for I { &'a self, query: &'a str, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { - match serde_json::from_str(search_params) { - Ok(search_params) => self.top_n_from_query(query, n, search_params).await, - Err(e) => Err(VectorStoreError::JsonError(e)), - } - }) + Box::pin(async move { self.top_n_from_query(query, n).await }) } fn top_n_from_embedding<'a>( &'a self, prompt_embedding: &'a Embedding, n: usize, - search_params: &'a str, ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { - match serde_json::from_str(search_params) { - Ok(search_params) => { - self.top_n_from_embedding(prompt_embedding, n, search_params) - .await - } - Err(e) => Err(VectorStoreError::JsonError(e)), - } - }) + Box::pin(async move { self.top_n_from_embedding(prompt_embedding, n).await }) } } pub struct NoIndex; impl VectorStoreIndex for NoIndex { - type SearchParams = (); - async fn top_n_from_query( &self, _query: &str, _n: usize, - _search_params: Self::SearchParams, ) -> Result, VectorStoreError> { Ok(vec![]) } @@ -246,7 +210,6 @@ impl VectorStoreIndex for NoIndex { &self, _prompt_embedding: &Embedding, _n: usize, - _search_params: Self::SearchParams, ) -> Result, VectorStoreError> { Ok(vec![]) } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 3317a8df..4a6a518a 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -18,9 +18,11 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let search_params = SearchParams::default().distance_type(DistanceType::Cosine); + // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model).await?; + let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?; // Generate test data for RAG demo let agent = openai_client @@ -52,24 +54,15 @@ async fn main() -> Result<(), anyhow::Error> { vector_store .create_index(lancedb::index::Index::IvfPq( IvfPqIndexBuilder::default() - // This overrides the default distance type of L2 + // This overrides the default distance type of L2. + // Needs to be the same distance type as the one used in search params. .distance_type(DistanceType::Cosine), )) .await?; // Query the index let results = vector_store - .top_n_from_query( - "My boss says I zindle too much, what does that mean?", - 1, - &serde_json::to_string(&SearchParams::new( - Some(DistanceType::Cosine), - None, - None, - None, - None, - ))?, - ) + .top_n_from_query("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 125ea73c..94d6aaf8 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -18,7 +18,7 @@ async fn main() -> Result<(), anyhow::Error> { // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model).await?; + let mut vector_store = LanceDbVectorStore::new(&db, &model, &SearchParams::default()).await?; let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") @@ -32,11 +32,7 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n_from_query( - "My boss says I zindle too much, what does that mean?", - 1, - &serde_json::to_string(&SearchParams::new(None, None, None, None, None))?, - ) + .top_n_from_query("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 6ff25d94..1ec4f94a 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -21,13 +21,15 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let search_params = SearchParams::default().distance_type(DistanceType::Cosine); + // Initialize LanceDB on S3. // Note: see below docs for more options and IAM permission required to read/write to S3. // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 let db = lancedb::connect("s3://lancedb-test-829666124233") .execute() .await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model).await?; + let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?; // Generate test data for RAG demo let agent = openai_client @@ -59,25 +61,15 @@ async fn main() -> Result<(), anyhow::Error> { vector_store .create_index(lancedb::index::Index::IvfPq( IvfPqIndexBuilder::default() - // This overrides the default distance type of L2 + // This overrides the default distance type of L2. + // Needs to be the same distance type as the one used in search params. .distance_type(DistanceType::Cosine), )) .await?; // Query the index let results = vector_store - .top_n_from_query( - "My boss says I zindle too much, what does that mean?", - 1, - &serde_json::to_string(&SearchParams::new( - // Important: use the same same distance type that was used to train the index. - Some(DistanceType::Cosine), - None, - None, - None, - None, - ))?, - ) + .top_n_from_query("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index f81a643c..a8463d2b 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -17,16 +17,90 @@ use utils::{Insert, Query}; mod table_schemas; mod utils; +fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + +fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { + VectorStoreError::JsonError(e) +} + pub struct LanceDbVectorStore { + /// Defines which model is used to generate embeddings for the vector store model: M, + /// Table containing documents only document_table: lancedb::Table, + /// Table containing embeddings only. + /// Foreign key references the document in document table. embedding_table: lancedb::Table, + /// Vector search params that are used during vector search operations. + search_params: SearchParams, +} + +/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. +#[derive(Deserialize, Serialize, Debug, Clone)] +pub enum SearchType { + // Flat search, also called ENN or kNN. + Flat, + /// Approximal Nearest Neighbor search, also called ANN. + Approximate, +} + +#[derive(Deserialize, Serialize, Debug, Clone, Default)] +pub struct SearchParams { + /// Always set the distance_type to match the value used to train the index + /// By default, set to L2 + distance_type: Option, + /// By default, ANN will be used if there is an index on the table. + /// By default, kNN will be used if there is NO index on the table. + /// To use defaults, set to None. + search_type: Option, + /// Set this value only when search type is ANN. + /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information + nprobes: Option, + /// Set this value only when search type is ANN. + /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information + refine_factor: Option, + /// If set to true, filtering will happen after the vector search instead of before + /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information + post_filter: Option, +} + +impl SearchParams { + pub fn distance_type(mut self, distance_type: DistanceType) -> Self { + self.distance_type = Some(distance_type); + self + } + + pub fn search_type(mut self, search_type: SearchType) -> Self { + self.search_type = Some(search_type); + self + } + + pub fn nprobes(mut self, nprobes: usize) -> Self { + self.nprobes = Some(nprobes); + self + } + + pub fn refine_factor(mut self, refine_factor: u32) -> Self { + self.refine_factor = Some(refine_factor); + self + } + + pub fn post_filter(mut self, post_filter: bool) -> Self { + self.post_filter = Some(post_filter); + self + } } impl LanceDbVectorStore { /// Note: Tables are created inside the new function rather than created outside and passed as reference to new function. /// This is because a specific schema needs to be enforced on the tables and this is done at creation time. - pub async fn new(db: &lancedb::Connection, model: &M) -> Result { + pub async fn new( + db: &lancedb::Connection, + model: &M, + search_params: &SearchParams, + ) -> Result { let document_table = db .create_empty_table("documents", Arc::new(Self::document_schema())) .execute() @@ -44,9 +118,11 @@ impl LanceDbVectorStore { document_table, embedding_table, model: model.clone(), + search_params: search_params.clone(), }) } + /// Schema of records in document table. fn document_schema() -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), @@ -54,6 +130,8 @@ impl LanceDbVectorStore { ])) } + /// Schema of records in embeddings table. + /// Every embedding vector in the table must have the same size. fn embedding_schema(dimension: i32) -> Schema { Schema::new(Fields::from(vec![ Field::new("id", DataType::Utf8, false), @@ -70,6 +148,7 @@ impl LanceDbVectorStore { ])) } + /// Define index on document table `id` field for search optimization. pub async fn create_document_index(&self, index: Index) -> Result<(), lancedb::Error> { self.document_table .create_index(&["id"], index) @@ -77,6 +156,7 @@ impl LanceDbVectorStore { .await } + /// Define index on embedding table `id` and `document_id` fields for search optimization. pub async fn create_embedding_index(&self, index: Index) -> Result<(), lancedb::Error> { self.embedding_table .create_index(&["id", "document_id"], index) @@ -84,6 +164,7 @@ impl LanceDbVectorStore { .await } + /// Define index on embedding table `embedding` fields for vector search optimization. pub async fn create_index(&self, index: Index) -> Result<(), lancedb::Error> { self.embedding_table .create_index(&["embedding"], index) @@ -94,14 +175,6 @@ impl LanceDbVectorStore { } } -fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { - VectorStoreError::DatastoreError(Box::new(e)) -} - -fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { - VectorStoreError::JsonError(e) -} - impl VectorStore for LanceDbVectorStore { type Q = lancedb::query::Query; @@ -137,14 +210,14 @@ impl VectorStore for LanceDbVector let documents: DocumentRecords = self .document_table .query() - .only_if(format!("id = {id}")) + .only_if(format!("id = '{id}'")) .execute_query() .await?; let embeddings: EmbeddingRecordsBatch = self .embedding_table .query() - .only_if(format!("document_id = {id}")) + .only_if(format!("document_id = '{id}'")) .execute_query() .await?; @@ -158,7 +231,7 @@ impl VectorStore for LanceDbVector let documents: DocumentRecords = self .document_table .query() - .only_if(format!("id = {id}")) + .only_if(format!("id = '{id}'")) .execute_query() .await?; @@ -180,7 +253,14 @@ impl VectorStore for LanceDbVector let embeddings: EmbeddingRecordsBatch = self .embedding_table .query() - .only_if(format!("document_id IN [{}]", documents.ids().join(","))) + .only_if(format!( + "document_id IN ({})", + documents + .ids() + .map(|id| format!("'{id}'")) + .collect::>() + .join(",") + )) .execute_query() .await?; @@ -188,84 +268,34 @@ impl VectorStore for LanceDbVector } } -/// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. -#[derive(Deserialize, Serialize, Debug, Clone)] -pub enum SearchType { - // Flat search, also called ENN or kNN. - Flat, - /// Approximal Nearest Neighbor search, also called ANN. - Approximate, -} - -#[derive(Deserialize, Serialize, Debug, Clone)] -pub struct SearchParams { - /// Always set the distance_type to match the value used to train the index - /// By default, set to L2 - distance_type: Option, - /// By default, ANN will be used if there is an index on the table. - /// By default, kNN will be used if there is NO index on the table. - /// To use defaults, set to None. - search_type: Option, - /// Set this value only when search type is ANN. - /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information - nprobes: Option, - /// Set this value only when search type is ANN. - /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information - refine_factor: Option, - /// If set to true, filtering will happen after the vector search instead of before - /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information - post_filter: Option, -} - -impl SearchParams { - pub fn new( - distance_type: Option, - search_type: Option, - nprobes: Option, - refine_factor: Option, - post_filter: Option, - ) -> Self { - Self { - distance_type, - search_type, - nprobes, - refine_factor, - post_filter, - } - } -} - impl VectorStoreIndex for LanceDbVectorStore { async fn top_n_from_query( &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n, search_params) - .await + self.top_n_from_embedding(&prompt_embedding, n).await } async fn top_n_from_embedding( &self, prompt_embedding: &rig::embeddings::Embedding, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { + let query = self + .embedding_table + .vector_search(prompt_embedding.vec.clone()) + .map_err(lancedb_to_rig_error)? + .limit(n); + let SearchParams { distance_type, search_type, nprobes, refine_factor, post_filter, - } = search_params.clone(); - - let query = self - .embedding_table - .vector_search(prompt_embedding.vec.clone()) - .map_err(lancedb_to_rig_error)? - .limit(n); + } = self.search_params.clone(); if let Some(distance_type) = distance_type { query.clone().distance_type(distance_type); @@ -317,6 +347,4 @@ impl VectorStoreIndex for LanceDbV }) .collect()) } - - type SearchParams = SearchParams; } diff --git a/rig-lancedb/src/table_schemas/document.rs b/rig-lancedb/src/table_schemas/document.rs index 56b19bef..384eb4bf 100644 --- a/rig-lancedb/src/table_schemas/document.rs +++ b/rig-lancedb/src/table_schemas/document.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use arrow_array::{ArrayRef, RecordBatch, StringArray}; +use arrow_array::{types::Utf8Type, ArrayRef, RecordBatch, StringArray}; use lancedb::arrow::arrow_schema::ArrowError; use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; -use crate::utils::DeserializeArrow; +use crate::utils::DeserializeByteArray; /// Schema of `documents` table in LanceDB defined as a struct. #[derive(Clone, Debug)] @@ -30,12 +30,12 @@ impl DocumentRecords { self.0.extend(records); } - fn documents(&self) -> Vec { - self.as_iter().map(|doc| doc.document.clone()).collect() + fn documents(&self) -> impl Iterator + '_ { + self.as_iter().map(|doc| doc.document.clone()) } - pub fn ids(&self) -> Vec { - self.as_iter().map(|doc| doc.id.clone()).collect() + pub fn ids(&self) -> impl Iterator + '_ { + self.as_iter().map(|doc| doc.id.clone()) } pub fn as_iter(&self) -> impl Iterator { @@ -97,8 +97,11 @@ impl TryFrom for DocumentRecords { type Error = ArrowError; fn try_from(record_batch: RecordBatch) -> Result { - let ids = record_batch.to_str(0)?; - let documents = record_batch.to_str(1)?; + let binding_0 = record_batch.column(0); + let ids = binding_0.to_str::()?; + + let binding_1 = record_batch.column(1); + let documents = binding_1.to_str::()?; Ok(DocumentRecords( ids.into_iter() diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs index c73d4e53..7f74dd12 100644 --- a/rig-lancedb/src/table_schemas/embedding.rs +++ b/rig-lancedb/src/table_schemas/embedding.rs @@ -2,13 +2,13 @@ use std::{collections::HashMap, sync::Arc}; use arrow_array::{ builder::{FixedSizeListBuilder, Float64Builder}, - types::{Float32Type, Float64Type}, + types::{Float32Type, Float64Type, Utf8Type}, ArrayRef, RecordBatch, StringArray, }; use lancedb::arrow::arrow_schema::ArrowError; use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; -use crate::utils::{DeserializeArrow, DeserializePrimitiveArray}; +use crate::utils::{DeserializeByteArray, DeserializeListArray, DeserializePrimitiveArray}; /// Data format in the LanceDB table `embeddings` #[derive(Clone, Debug, PartialEq)] @@ -158,10 +158,16 @@ impl TryFrom for EmbeddingRecords { type Error = ArrowError; fn try_from(record_batch: RecordBatch) -> Result { - let ids = record_batch.to_str(0)?; - let document_ids = record_batch.to_str(1)?; - let contents = record_batch.to_str(2)?; - let embeddings = record_batch.to_float_list::(3)?; + let binding_0 = record_batch.column(0); + let ids = binding_0.to_str::()?; + + let binding_1 = record_batch.column(1); + let document_ids = binding_1.to_str::()?; + + let binding_2 = record_batch.column(2); + let contents = binding_2.to_str::()?; + + let embeddings = record_batch.column(3).to_float_list::()?; // There is a `_distance` field in the response if the executed query was a VectorQuery // Otherwise, for normal queries, the `_distance` field is not present in the response. diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index a8ef758d..bf8874e2 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use arrow_array::{ - Array, ArrowPrimitiveType, FixedSizeListArray, PrimitiveArray, RecordBatch, - RecordBatchIterator, StringArray, + types::ByteArrayType, Array, ArrowPrimitiveType, FixedSizeListArray, GenericByteArray, + PrimitiveArray, RecordBatch, RecordBatchIterator, }; use futures::TryStreamExt; use lancedb::{ @@ -13,6 +13,7 @@ use rig::vector_store::VectorStoreError; use crate::lancedb_to_rig_error; +/// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. pub trait DeserializePrimitiveArray { fn to_float( &self, @@ -32,44 +33,39 @@ impl DeserializePrimitiveArray for &Arc { } } -/// Trait used to "deserialize" a column of a RecordBatch object into a list o primitive types -pub trait DeserializeArrow { - /// Define the column number that contains strings, i. - /// For each item in the column, convert it to a string and collect the result in a vector of strings. - fn to_str(&self, i: usize) -> Result, ArrowError>; - /// Define the column number that contains the list of floats, i. - /// For each item in the column, convert it to a list and for each item in the list, convert it to a float. - /// Collect the result as a vector of vectors of floats. - fn to_float_list( - &self, - i: usize, - ) -> Result::Native>>, ArrowError>; +/// Trait used to "deserialize" an arrow_array::Array as as list of byte objects. +pub trait DeserializeByteArray { + fn to_str(&self) -> Result::Native>, ArrowError>; } -impl DeserializeArrow for RecordBatch { - fn to_str(&self, i: usize) -> Result, ArrowError> { - let column = self.column(i); - match column.as_any().downcast_ref::() { - Some(str_array) => Ok((0..str_array.len()) - .map(|j| str_array.value(j)) - .collect::>()), +impl DeserializeByteArray for &Arc { + fn to_str(&self) -> Result::Native>, ArrowError> { + match self.as_any().downcast_ref::>() { + Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), None => Err(ArrowError::CastError(format!( - "Can't cast column {i} to string array" + "Can't cast array: {self:?} to float array" ))), } } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of lists of primitive objects. +pub trait DeserializeListArray { + fn to_float_list( + &self, + ) -> Result::Native>>, ArrowError>; +} +impl DeserializeListArray for &Arc { fn to_float_list( &self, - i: usize, ) -> Result::Native>>, ArrowError> { - let column = self.column(i); - match column.as_any().downcast_ref::() { + match self.as_any().downcast_ref::() { Some(list_array) => (0..list_array.len()) .map(|j| (&list_array.value(j)).to_float::()) .collect::, _>>(), None => Err(ArrowError::CastError(format!( - "Can't cast column {i} to fixed size list array" + "Can't cast column {self:?} to fixed size list array" ))), } } diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index d51d6d48..894499e7 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -49,11 +49,11 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "context_vector_index"); + let index = vector_store.index(model, "context_vector_index", SearchParams::new()); // Query the index let results = index - .top_n_from_query("What is a linglingdong?", 1, SearchParams::new()) + .top_n_from_query("What is a linglingdong?", 1) .await? .into_iter() .map(|(score, doc)| (score, doc.id, doc.document)) diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 67a04664..41bf8500 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -87,8 +87,13 @@ impl MongoDbVectorStore { /// /// The index (of type "vector") must already exist for the MongoDB collection. /// See the MongoDB [documentation](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/) for more information on creating indexes. - pub fn index(&self, model: M, index_name: &str) -> MongoDbVectorIndex { - MongoDbVectorIndex::new(self.collection.clone(), model, index_name) + pub fn index( + &self, + model: M, + index_name: &str, + search_params: SearchParams, + ) -> MongoDbVectorIndex { + MongoDbVectorIndex::new(self.collection.clone(), model, index_name, search_params) } } @@ -97,6 +102,7 @@ pub struct MongoDbVectorIndex { collection: mongodb::Collection, model: M, index_name: String, + search_params: SearchParams, } impl MongoDbVectorIndex { @@ -104,11 +110,13 @@ impl MongoDbVectorIndex { collection: mongodb::Collection, model: M, index_name: &str, + search_params: SearchParams, ) -> Self { Self { collection, model, index_name: index_name.to_string(), + search_params, } } } @@ -134,6 +142,21 @@ impl SearchParams { num_candidates: None, } } + + pub fn filter(mut self, filter: mongodb::bson::Document) -> Self { + self.filter = filter; + self + } + + pub fn exact(mut self, exact: bool) -> Self { + self.exact = Some(exact); + self + } + + pub fn num_candidates(mut self, num_candidates: u32) -> Self { + self.num_candidates = Some(num_candidates); + self + } } impl Default for SearchParams { @@ -147,19 +170,22 @@ impl VectorStoreIndex for MongoDbV &self, query: &str, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n, search_params) - .await + self.top_n_from_embedding(&prompt_embedding, n).await } async fn top_n_from_embedding( &self, prompt_embedding: &Embedding, n: usize, - search_params: Self::SearchParams, ) -> Result, VectorStoreError> { + let SearchParams { + filter, + exact, + num_candidates, + } = &self.search_params; + let mut cursor = self .collection .aggregate( @@ -168,11 +194,11 @@ impl VectorStoreIndex for MongoDbV "$vectorSearch": { "queryVector": &prompt_embedding.vec, "index": &self.index_name, - "exact": search_params.exact.unwrap_or(false), + "exact": exact.unwrap_or(false), "path": "embeddings.vec", - "numCandidates": search_params.num_candidates.unwrap_or((n * 10) as u32), + "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, - "filter": &search_params.filter, + "filter": filter, } }, doc! { @@ -206,6 +232,4 @@ impl VectorStoreIndex for MongoDbV Ok(results) } - - type SearchParams = SearchParams; } From 0bcf5a36b3c12f006d2027f73a05766ff079eef9 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 24 Sep 2024 13:08:43 -0400 Subject: [PATCH 18/39] refactor: use constants instead of enum for model names --- rig-core/examples/calculator_chatbot.rs | 4 +- rig-core/examples/rag.rs | 4 +- rig-core/examples/rag_dynamic_tools.rs | 4 +- rig-core/examples/vector_search.rs | 4 +- rig-core/examples/vector_search_cohere.rs | 8 +- rig-core/src/providers/cohere.rs | 97 +++++++------------ rig-core/src/providers/openai.rs | 78 +++++++-------- .../examples/vector_search_local_ann.rs | 4 +- .../examples/vector_search_local_enn.rs | 4 +- rig-lancedb/examples/vector_search_s3_ann.rs | 4 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- 11 files changed, 89 insertions(+), 124 deletions(-) diff --git a/rig-core/examples/calculator_chatbot.rs b/rig-core/examples/calculator_chatbot.rs index c096857d..04d26dc3 100644 --- a/rig-core/examples/calculator_chatbot.rs +++ b/rig-core/examples/calculator_chatbot.rs @@ -3,7 +3,7 @@ use rig::{ cli_chatbot::cli_chatbot, completion::ToolDefinition, embeddings::EmbeddingsBuilder, - providers::openai::{Client, OpenAIEmbeddingModel}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; @@ -245,7 +245,7 @@ async fn main() -> Result<(), anyhow::Error> { .dynamic_tool(Divide) .build(); - let embedding_model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(embedding_model.clone()) .tools(&toolset)? .build() diff --git a/rig-core/examples/rag.rs b/rig-core/examples/rag.rs index fae0b91d..3abd8ee9 100644 --- a/rig-core/examples/rag.rs +++ b/rig-core/examples/rag.rs @@ -3,7 +3,7 @@ use std::env; use rig::{ completion::Prompt, embeddings::EmbeddingsBuilder, - providers::openai::{Client, OpenAIEmbeddingModel}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; @@ -13,7 +13,7 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - let embedding_model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // Create vector store, compute embeddings and load them in the store let mut vector_store = InMemoryVectorStore::default(); diff --git a/rig-core/examples/rag_dynamic_tools.rs b/rig-core/examples/rag_dynamic_tools.rs index cb7a955d..6e45730b 100644 --- a/rig-core/examples/rag_dynamic_tools.rs +++ b/rig-core/examples/rag_dynamic_tools.rs @@ -2,7 +2,7 @@ use anyhow::Result; use rig::{ completion::{Prompt, ToolDefinition}, embeddings::EmbeddingsBuilder, - providers::openai::{Client, OpenAIEmbeddingModel}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, tool::{Tool, ToolEmbedding, ToolSet}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore}, }; @@ -148,7 +148,7 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - let embedding_model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let embedding_model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // Create vector store, compute tool embeddings and load them in the store let mut vector_store = InMemoryVectorStore::default(); diff --git a/rig-core/examples/vector_search.rs b/rig-core/examples/vector_search.rs index 04ba2ab3..5975e43c 100644 --- a/rig-core/examples/vector_search.rs +++ b/rig-core/examples/vector_search.rs @@ -2,7 +2,7 @@ use std::env; use rig::{ embeddings::EmbeddingsBuilder, - providers::openai::{Client, OpenAIEmbeddingModel}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{in_memory_store::InMemoryVectorIndex, VectorStoreIndex}, }; @@ -12,7 +12,7 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") diff --git a/rig-core/examples/vector_search_cohere.rs b/rig-core/examples/vector_search_cohere.rs index a8c2d163..06d2cb1e 100644 --- a/rig-core/examples/vector_search_cohere.rs +++ b/rig-core/examples/vector_search_cohere.rs @@ -2,7 +2,7 @@ use std::env; use rig::{ embeddings::EmbeddingsBuilder, - providers::cohere::{Client, CohereEmbeddingModel}, + providers::cohere::{Client, EMBED_ENGLISH_V3}, vector_store::{in_memory_store::InMemoryVectorStore, VectorStore, VectorStoreIndex}, }; @@ -12,10 +12,8 @@ async fn main() -> Result<(), anyhow::Error> { let cohere_api_key = env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set"); let cohere_client = Client::new(&cohere_api_key); - let document_model = - cohere_client.embedding_model(&CohereEmbeddingModel::EmbedEnglishV3, "search_document"); - let search_model = - cohere_client.embedding_model(&CohereEmbeddingModel::EmbedEnglishV3, "search_query"); + let document_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_document"); + let search_model = cohere_client.embedding_model(EMBED_ENGLISH_V3, "search_query"); let mut vector_store = InMemoryVectorStore::default(); diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 7030d87c..42cef908 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -62,19 +62,27 @@ impl Client { self.http_client.post(url) } - pub fn embedding_model( + pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel { + let ndims = match model { + EMBED_ENGLISH_V3 | EMBED_MULTILINGUAL_V3 | EMBED_ENGLISH_LIGHT_V2 => 1024, + EMBED_ENGLISH_LIGHT_V3 | EMBED_MULTILINGUAL_LIGHT_V3 => 384, + EMBED_ENGLISH_V2 => 4096, + EMBED_MULTILINGUAL_V2 => 768, + _ => 0, + }; + EmbeddingModel::new(self.clone(), model, input_type, ndims) + } + + pub fn embedding_model_with_ndims( &self, - model: &CohereEmbeddingModel, + model: &str, input_type: &str, + ndims: usize, ) -> EmbeddingModel { - EmbeddingModel::new(self.clone(), model, input_type) + EmbeddingModel::new(self.clone(), model, input_type, ndims) } - pub fn embeddings( - &self, - model: &CohereEmbeddingModel, - input_type: &str, - ) -> EmbeddingsBuilder { + pub fn embeddings(&self, model: &str, input_type: &str) -> EmbeddingsBuilder { EmbeddingsBuilder::new(self.embedding_model(model, input_type)) } @@ -141,47 +149,20 @@ enum ApiResponse { // ================================================================ // Cohere Embedding API // ================================================================ -#[derive(Debug, Clone)] -pub enum CohereEmbeddingModel { - EmbedEnglishV3, - EmbedEnglishLightV3, - EmbedMultilingualV3, - EmbedMultilingualLightV3, - EmbedEnglishV2, - EmbedEnglishLightV2, - EmbedMultilingualV2, -} - -impl std::str::FromStr for CohereEmbeddingModel { - type Err = EmbeddingError; - - fn from_str(s: &str) -> std::result::Result { - match s { - "embed-english-v3.0" => Ok(Self::EmbedEnglishV3), - "embed-english-light-v3.0" => Ok(Self::EmbedEnglishLightV3), - "embed-multilingual-v3.0" => Ok(Self::EmbedMultilingualV3), - "embed-multilingual-light-v3.0" => Ok(Self::EmbedMultilingualLightV3), - "embed-english-v2.0" => Ok(Self::EmbedEnglishV2), - "embed-english-light-v2.0" => Ok(Self::EmbedEnglishLightV2), - "embed-multilingual-v2.0" => Ok(Self::EmbedMultilingualV2), - _ => Err(EmbeddingError::BadModel(s.to_string())), - } - } -} - -impl std::fmt::Display for CohereEmbeddingModel { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::EmbedEnglishLightV3 => write!(f, "embed-english-light-v3.0"), - Self::EmbedEnglishV3 => write!(f, "embed-english-v3.0"), - Self::EmbedMultilingualLightV3 => write!(f, "embed-multilingual-light-v3.0"), - Self::EmbedMultilingualV3 => write!(f, "embed-multilingual-v3.0"), - Self::EmbedEnglishV2 => write!(f, "embed-english-v2.0"), - Self::EmbedEnglishLightV2 => write!(f, "embed-english-light-v2.0"), - Self::EmbedMultilingualV2 => write!(f, "embed-multilingual-v2.0"), - } - } -} +/// `embed-english-v3.0` embedding model +pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0"; +/// `embed-english-light-v3.0` embedding model +pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0"; +/// `embed-multilingual-v3.0` embedding model +pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0"; +/// `embed-multilingual-light-v3.0` embedding model +pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0"; +/// `embed-english-v2.0` embedding model +pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0"; +/// `embed-english-light-v2.0` embedding model +pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0"; +/// `embed-multilingual-v2.0` embedding model +pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0"; #[derive(Deserialize)] pub struct EmbeddingResponse { @@ -226,23 +207,16 @@ pub struct BilledUnits { #[derive(Clone)] pub struct EmbeddingModel { client: Client, - pub model: CohereEmbeddingModel, + pub model: String, pub input_type: String, + ndims: usize, } impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 96; fn ndims(&self) -> usize { - match self.model { - CohereEmbeddingModel::EmbedEnglishV3 => 1024, - CohereEmbeddingModel::EmbedEnglishLightV3 => 384, - CohereEmbeddingModel::EmbedMultilingualV3 => 1024, - CohereEmbeddingModel::EmbedMultilingualLightV3 => 384, - CohereEmbeddingModel::EmbedEnglishV2 => 4096, - CohereEmbeddingModel::EmbedEnglishLightV2 => 1024, - CohereEmbeddingModel::EmbedMultilingualV2 => 768, - } + self.ndims } async fn embed_documents( @@ -289,11 +263,12 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } impl EmbeddingModel { - pub fn new(client: Client, model: &CohereEmbeddingModel, input_type: &str) -> Self { + pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self { Self { client, - model: model.clone(), + model: model.to_string(), input_type: input_type.to_string(), + ndims, } } } diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index a2e63727..38be8e89 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -79,8 +79,28 @@ impl Client { /// /// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE); /// ``` - pub fn embedding_model(&self, model: &OpenAIEmbeddingModel) -> EmbeddingModel { - EmbeddingModel::new(self.clone(), model) + pub fn embedding_model(&self, model: &str) -> EmbeddingModel { + let ndims = match model { + TEXT_EMBEDDING_3_LARGE => 3072, + TEXT_EMBEDDING_3_SMALL | TEXT_EMBEDDING_ADA_002 => 1536, + _ => 0, + }; + EmbeddingModel::new(self.clone(), model, ndims) + } + + /// Create an embedding model with the given name. + /// + /// # Example + /// ``` + /// use rig::providers::openai::{Client, self}; + /// + /// // Initialize the OpenAI client + /// let openai = Client::new("your-open-ai-api-key"); + /// + /// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE, 3072); + /// ``` + pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { + EmbeddingModel::new(self.clone(), model, ndims) } /// Create an embedding builder with the given embedding model. @@ -99,10 +119,7 @@ impl Client { /// .await /// .expect("Failed to embed documents"); /// ``` - pub fn embeddings( - &self, - model: &OpenAIEmbeddingModel, - ) -> embeddings::EmbeddingsBuilder { + pub fn embeddings(&self, model: &str) -> embeddings::EmbeddingsBuilder { embeddings::EmbeddingsBuilder::new(self.embedding_model(model)) } @@ -208,35 +225,12 @@ enum ApiResponse { // ================================================================ // OpenAI Embedding API // ================================================================ -#[derive(Debug, Clone)] -pub enum OpenAIEmbeddingModel { - TextEmbedding3Large, - TextEmbedding3Small, - TextEmbeddingAda002, -} - -impl std::str::FromStr for OpenAIEmbeddingModel { - type Err = EmbeddingError; - - fn from_str(s: &str) -> std::result::Result { - match s { - "text-embedding-3-large" => Ok(Self::TextEmbedding3Large), - "text-embedding-3-small" => Ok(Self::TextEmbedding3Small), - "text-embedding-ada-002" => Ok(Self::TextEmbeddingAda002), - _ => Err(EmbeddingError::BadModel(s.to_string())), - } - } -} - -impl std::fmt::Display for OpenAIEmbeddingModel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - OpenAIEmbeddingModel::TextEmbedding3Large => write!(f, "text-embedding-3-large"), - OpenAIEmbeddingModel::TextEmbedding3Small => write!(f, "text-embedding-3-small"), - OpenAIEmbeddingModel::TextEmbeddingAda002 => write!(f, "text-embedding-ada-002"), - } - } -} +/// `text-embedding-3-large` embedding model +pub const TEXT_EMBEDDING_3_LARGE: &str = "text-embedding-3-large"; +/// `text-embedding-3-small` embedding model +pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small"; +/// `text-embedding-ada-002` embedding model +pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002"; #[derive(Debug, Deserialize)] pub struct EmbeddingRequest { @@ -302,18 +296,15 @@ pub struct Usage { #[derive(Clone)] pub struct EmbeddingModel { client: Client, - pub model: OpenAIEmbeddingModel, + pub model: String, + ndims: usize, } impl embeddings::EmbeddingModel for EmbeddingModel { const MAX_DOCUMENTS: usize = 1024; fn ndims(&self) -> usize { - match self.model { - OpenAIEmbeddingModel::TextEmbedding3Large => 3072, - OpenAIEmbeddingModel::TextEmbedding3Small => 1536, - OpenAIEmbeddingModel::TextEmbeddingAda002 => 1536, - } + self.ndims } async fn embed_documents( @@ -357,10 +348,11 @@ impl embeddings::EmbeddingModel for EmbeddingModel { } impl EmbeddingModel { - pub fn new(client: Client, model: &OpenAIEmbeddingModel) -> Self { + pub fn new(client: Client, model: &str, ndims: usize) -> Self { Self { client, - model: model.clone(), + model: model.to_string(), + ndims, } } } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 4a6a518a..fc17e1a4 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -4,7 +4,7 @@ use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ completion::Prompt, embeddings::EmbeddingsBuilder, - providers::openai::{Client, OpenAIEmbeddingModel}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{VectorStore, VectorStoreIndexDyn}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; @@ -16,7 +16,7 @@ async fn main() -> Result<(), anyhow::Error> { let openai_client = Client::new(&openai_api_key); // Select the embedding model and generate our embeddings - let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let search_params = SearchParams::default().distance_type(DistanceType::Cosine); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 94d6aaf8..4b1cb010 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -2,7 +2,7 @@ use std::env; use rig::{ embeddings::EmbeddingsBuilder, - providers::openai::{Client, OpenAIEmbeddingModel}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{VectorStore, VectorStoreIndexDyn}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; @@ -14,7 +14,7 @@ async fn main() -> Result<(), anyhow::Error> { let openai_client = Client::new(&openai_api_key); // Select the embedding model and generate our embeddings - let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); // Initialize LanceDB locally. let db = lancedb::connect("data/lancedb-store").execute().await?; diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 1ec4f94a..2eefec2d 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,7 +4,7 @@ use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ completion::Prompt, embeddings::EmbeddingsBuilder, - providers::openai::{Client, OpenAIEmbeddingModel}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{VectorStore, VectorStoreIndexDyn}, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; @@ -19,7 +19,7 @@ async fn main() -> Result<(), anyhow::Error> { let openai_client = Client::new(&openai_api_key); // Select the embedding model and generate our embeddings - let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let search_params = SearchParams::default().distance_type(DistanceType::Cosine); diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 894499e7..5f633039 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -32,7 +32,7 @@ async fn main() -> Result<(), anyhow::Error> { let mut vector_store = MongoDbVectorStore::new(collection); // Select the embedding model and generate our embeddings - let model = openai_client.embedding_model(&OpenAIEmbeddingModel::TextEmbeddingAda002); + let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of a *flurbo*: A flurbo is a green alien that lives on cold planets") From 2f5844df2542989aa1a3372ca356a0d133365e40 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 24 Sep 2024 14:00:42 -0400 Subject: [PATCH 19/39] fix: make PR requested changes --- rig-core/src/embeddings.rs | 4 ---- rig-core/src/providers/cohere.rs | 3 +++ rig-core/src/providers/openai.rs | 4 +++- rig-core/src/vector_store/mod.rs | 4 ++-- rig-lancedb/examples/vector_search_local_ann.rs | 2 +- rig-lancedb/src/lib.rs | 17 ++++++++--------- rig-mongodb/examples/vector_search_mongodb.rs | 4 ++-- rig-mongodb/src/lib.rs | 1 - 8 files changed, 19 insertions(+), 20 deletions(-) diff --git a/rig-core/src/embeddings.rs b/rig-core/src/embeddings.rs index 321b05a2..e805d6d3 100644 --- a/rig-core/src/embeddings.rs +++ b/rig-core/src/embeddings.rs @@ -66,10 +66,6 @@ pub enum EmbeddingError { /// Error returned by the embedding model provider #[error("ProviderError: {0}")] ProviderError(String), - - /// Http error (e.g.: connection error, timeout, etc.) - #[error("BadModel: {0}")] - BadModel(String), } /// Trait for embedding models that can generate embeddings for documents. diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index 42cef908..cb01a70c 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -62,6 +62,8 @@ impl Client { self.http_client.post(url) } + /// Note: default embedding dimension of 0 will be used if model cannot be matched. + /// If this is the case, it's better to use function `embedding_model_with_ndims` pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel { let ndims = match model { EMBED_ENGLISH_V3 | EMBED_MULTILINGUAL_V3 | EMBED_ENGLISH_LIGHT_V2 => 1024, @@ -73,6 +75,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, input_type, ndims) } + /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model. pub fn embedding_model_with_ndims( &self, model: &str, diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 38be8e89..c3e33227 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -69,6 +69,8 @@ impl Client { } /// Create an embedding model with the given name. + /// Note: default embedding dimension of 0 will be used if model cannot be matched. + /// If this is the case, it's better to use function `embedding_model_with_ndims` /// /// # Example /// ``` @@ -88,7 +90,7 @@ impl Client { EmbeddingModel::new(self.clone(), model, ndims) } - /// Create an embedding model with the given name. + /// Create an embedding model with the given name and the number of dimensions in the embedding generated by the model. /// /// # Example /// ``` diff --git a/rig-core/src/vector_store/mod.rs b/rig-core/src/vector_store/mod.rs index c042f0c3..2e6a652f 100644 --- a/rig-core/src/vector_store/mod.rs +++ b/rig-core/src/vector_store/mod.rs @@ -183,7 +183,7 @@ impl VectorStoreIndexDyn for I { query: &'a str, n: usize, ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { self.top_n_from_query(query, n).await }) + Box::pin(self.top_n_from_query(query, n)) } fn top_n_from_embedding<'a>( @@ -191,7 +191,7 @@ impl VectorStoreIndexDyn for I { prompt_embedding: &'a Embedding, n: usize, ) -> BoxFuture<'a, Result, VectorStoreError>> { - Box::pin(async move { self.top_n_from_embedding(prompt_embedding, n).await }) + Box::pin(self.top_n_from_embedding(prompt_embedding, n)) } } diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index fc17e1a4..be9c7038 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -30,7 +30,7 @@ async fn main() -> Result<(), anyhow::Error> { .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") .build(); let response = agent - .prompt("Invent at least 100 words and their definitions") + .prompt("Invent 100 words and their definitions") .await?; let mut definitions: Vec = serde_json::from_str(&response)?; diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index a8463d2b..ac6f7876 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -10,7 +10,6 @@ use rig::{ embeddings::EmbeddingModel, vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, }; -use serde::{Deserialize, Serialize}; use table_schemas::{document::DocumentRecords, embedding::EmbeddingRecordsBatch, merge}; use utils::{Insert, Query}; @@ -38,7 +37,7 @@ pub struct LanceDbVectorStore { } /// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. -#[derive(Deserialize, Serialize, Debug, Clone)] +#[derive(Debug, Clone)] pub enum SearchType { // Flat search, also called ENN or kNN. Flat, @@ -46,7 +45,7 @@ pub enum SearchType { Approximate, } -#[derive(Deserialize, Serialize, Debug, Clone, Default)] +#[derive(Debug, Clone, Default)] pub struct SearchParams { /// Always set the distance_type to match the value used to train the index /// By default, set to L2 @@ -283,7 +282,7 @@ impl VectorStoreIndex for LanceDbV prompt_embedding: &rig::embeddings::Embedding, n: usize, ) -> Result, VectorStoreError> { - let query = self + let mut query = self .embedding_table .vector_search(prompt_embedding.vec.clone()) .map_err(lancedb_to_rig_error)? @@ -298,24 +297,24 @@ impl VectorStoreIndex for LanceDbV } = self.search_params.clone(); if let Some(distance_type) = distance_type { - query.clone().distance_type(distance_type); + query = query.distance_type(distance_type); } if let Some(SearchType::Flat) = search_type { - query.clone().bypass_vector_index(); + query = query.bypass_vector_index(); } if let Some(SearchType::Approximate) = search_type { if let Some(nprobes) = nprobes { - query.clone().nprobes(nprobes); + query = query.nprobes(nprobes); } if let Some(refine_factor) = refine_factor { - query.clone().refine_factor(refine_factor); + query = query.refine_factor(refine_factor); } } if let Some(true) = post_filter { - query.clone().postfilter(); + query = query.postfilter(); } let embeddings: EmbeddingRecordsBatch = query.execute_query().await?; diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 5f633039..204a0d9b 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -1,9 +1,9 @@ -use mongodb::{bson::doc, options::ClientOptions, Client as MongoClient, Collection}; +use mongodb::{options::ClientOptions, Client as MongoClient, Collection}; use std::env; use rig::{ embeddings::{DocumentEmbeddings, EmbeddingsBuilder}, - providers::openai::{Client, OpenAIEmbeddingModel}, + providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::{VectorStore, VectorStoreIndex}, }; use rig_mongodb::{MongoDbVectorStore, SearchParams}; diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 41bf8500..8f2afc04 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -123,7 +123,6 @@ impl MongoDbVectorIndex { /// See [MongoDB Vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information /// on each of the fields -#[derive(Deserialize)] pub struct SearchParams { /// Pre-filter filter: mongodb::bson::Document, From dfe32e2a1620c169524636e309ff4d165ba4fbe0 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 24 Sep 2024 14:14:32 -0400 Subject: [PATCH 20/39] fix: make PR requested changes --- rig-core/src/providers/cohere.rs | 4 ++-- rig-core/src/providers/openai.rs | 29 ++--------------------------- 2 files changed, 4 insertions(+), 29 deletions(-) diff --git a/rig-core/src/providers/cohere.rs b/rig-core/src/providers/cohere.rs index cb01a70c..1800fb40 100644 --- a/rig-core/src/providers/cohere.rs +++ b/rig-core/src/providers/cohere.rs @@ -62,7 +62,7 @@ impl Client { self.http_client.post(url) } - /// Note: default embedding dimension of 0 will be used if model cannot be matched. + /// Note: default embedding dimension of 0 will be used if model is not known. /// If this is the case, it's better to use function `embedding_model_with_ndims` pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel { let ndims = match model { @@ -230,7 +230,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .client .post("/v1/embed") .json(&json!({ - "model": self.model.to_string(), + "model": self.model, "texts": documents, "input_type": self.input_type, })) diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index c3e33227..0c0e6ef6 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -69,7 +69,7 @@ impl Client { } /// Create an embedding model with the given name. - /// Note: default embedding dimension of 0 will be used if model cannot be matched. + /// Note: default embedding dimension of 0 will be used if model is not known. /// If this is the case, it's better to use function `embedding_model_with_ndims` /// /// # Example @@ -234,31 +234,6 @@ pub const TEXT_EMBEDDING_3_SMALL: &str = "text-embedding-3-small"; /// `text-embedding-ada-002` embedding model pub const TEXT_EMBEDDING_ADA_002: &str = "text-embedding-ada-002"; -#[derive(Debug, Deserialize)] -pub struct EmbeddingRequest { - pub model: String, - pub input: Vec, - pub encoding_format: String, -} - -impl EmbeddingRequest { - pub fn new(model: &str, input: Vec) -> Self { - Self { - model: model.to_string(), - input, - encoding_format: "float".to_string(), - } - } -} - -#[derive(Debug, Deserialize)] -pub struct EmbeddingResponseError { - pub code: String, - pub message: String, - pub param: Option, - pub type_: String, -} - #[derive(Debug, Deserialize)] pub struct EmbeddingResponse { pub object: String, @@ -317,7 +292,7 @@ impl embeddings::EmbeddingModel for EmbeddingModel { .client .post("/v1/embeddings") .json(&json!({ - "model": self.model.to_string(), + "model": self.model, "input": documents, })) .send() From 5644a1d03f79738dda6f7ef91ac1fdc999fc7ee4 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 24 Sep 2024 14:17:19 -0400 Subject: [PATCH 21/39] style: cargo clippy --- rig-mongodb/src/lib.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 8f2afc04..7175fbcc 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -5,8 +5,6 @@ use rig::{ embeddings::{DocumentEmbeddings, Embedding, EmbeddingModel}, vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, }; -use serde::Deserialize; - /// A MongoDB vector store. pub struct MongoDbVectorStore { collection: mongodb::Collection, From 6fede36516600378e8adbb4551648df627f3df78 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 25 Sep 2024 17:58:42 -0400 Subject: [PATCH 22/39] feat: implement deserialization for any recordbatch returned from lanceDB --- rig-lancedb/src/utils/deserializer.rs | 508 ++++++++++++++++++++++++++ rig-lancedb/src/utils/mod.rs | 1 + 2 files changed, 509 insertions(+) create mode 100644 rig-lancedb/src/utils/deserializer.rs diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs new file mode 100644 index 00000000..bbe915b5 --- /dev/null +++ b/rig-lancedb/src/utils/deserializer.rs @@ -0,0 +1,508 @@ +use std::sync::Arc; + +use arrow_array::{ + types::{ + BinaryType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + LargeBinaryType, LargeUtf8Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, Utf8Type, + }, + Array, ArrowPrimitiveType, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, + GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, +}; +use lancedb::arrow::arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; +use rig::vector_store::VectorStoreError; +use serde::Serialize; +use serde_json::{json, Value}; + +use crate::serde_to_rig_error; + +fn arrow_to_rig_error(e: ArrowError) -> VectorStoreError { + VectorStoreError::DatastoreError(Box::new(e)) +} + +trait Test { + fn deserialize(&self) -> Result; +} + +impl Test for RecordBatch { + fn deserialize(&self) -> Result { + fn type_matcher(column: &Arc) -> Result, VectorStoreError> { + match column.data_type() { + DataType::Null => Ok(vec![serde_json::Value::Null]), + // f16 does not implement serde_json::Deserialize. Need to cast to f32. + DataType::Float16 => column + .to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|float_16| serde_json::to_value(float_16.to_f32())) + .collect::, _>>() + .map_err(serde_to_rig_error), + DataType::Float32 => column.to_primitive_value::(), + DataType::Float64 => column.to_primitive_value::(), + DataType::Int8 => column.to_primitive_value::(), + DataType::Int16 => column.to_primitive_value::(), + DataType::Int32 => column.to_primitive_value::(), + DataType::Int64 => column.to_primitive_value::(), + DataType::UInt8 => column.to_primitive_value::(), + DataType::UInt16 => column.to_primitive_value::(), + DataType::UInt32 => column.to_primitive_value::(), + DataType::UInt64 => column.to_primitive_value::(), + DataType::Date32 => column.to_primitive_value::(), + DataType::Date64 => column.to_primitive_value::(), + DataType::Decimal128(..) => column.to_primitive_value::(), + // i256 does not implement serde_json::Deserialize. Need to cast to i128. + DataType::Decimal256(..) => column + .to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|dec_256| serde_json::to_value(dec_256.as_i128())) + .collect::, _>>() + .map_err(serde_to_rig_error), + DataType::Time32(TimeUnit::Second) => { + column.to_primitive_value::() + } + DataType::Time32(TimeUnit::Millisecond) => { + column.to_primitive_value::() + } + DataType::Time64(TimeUnit::Microsecond) => { + column.to_primitive_value::() + } + DataType::Time64(TimeUnit::Nanosecond) => { + column.to_primitive_value::() + } + DataType::Timestamp(TimeUnit::Microsecond, ..) => { + column.to_primitive_value::() + } + DataType::Timestamp(TimeUnit::Millisecond, ..) => { + column.to_primitive_value::() + } + DataType::Timestamp(TimeUnit::Second, ..) => { + column.to_primitive_value::() + } + DataType::Timestamp(TimeUnit::Nanosecond, ..) => { + column.to_primitive_value::() + } + DataType::Duration(TimeUnit::Microsecond) => { + column.to_primitive_value::() + } + DataType::Duration(TimeUnit::Millisecond) => { + column.to_primitive_value::() + } + DataType::Duration(TimeUnit::Nanosecond) => { + column.to_primitive_value::() + } + DataType::Duration(TimeUnit::Second) => { + column.to_primitive_value::() + } + DataType::Interval(IntervalUnit::DayTime) => Ok(column + .to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|interval| { + json!({ + "days": interval.days, + "milliseconds": interval.milliseconds, + }) + }) + .collect()), + DataType::Interval(IntervalUnit::MonthDayNano) => Ok(column + .to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|interval| { + json!({ + "months": interval.months, + "days": interval.days, + "nanoseconds": interval.nanoseconds, + }) + }) + .collect()), + DataType::Interval(IntervalUnit::YearMonth) => { + column.to_primitive_value::() + } + DataType::Utf8 | DataType::Utf8View => column.to_str_value::(), + DataType::LargeUtf8 => column.to_str_value::(), + DataType::Binary => column.to_str_value::(), + DataType::LargeBinary => column.to_str_value::(), + DataType::FixedSizeBinary(n) => { + match column.as_any().downcast_ref::() { + Some(list_array) => (0..*n) + .map(|j| serde_json::to_value(list_array.value(j as usize))) + .collect::, _>>() + .map_err(serde_to_rig_error), + None => Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "Can't cast column {column:?} to fixed size list array" + )), + ))), + } + } + DataType::FixedSizeList(..) => column + .fixed_nested_lists() + .map_err(arrow_to_rig_error)? + .iter() + .map(|nested_list| type_matcher(nested_list)) + .map_ok(), + DataType::List(..) | DataType::ListView(..) => column + .nested_lists::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|nested_list| type_matcher(nested_list)) + .map_ok(), + DataType::LargeList(..) | DataType::LargeListView(..) => column + .nested_lists::() + .map_err(arrow_to_rig_error)? + .iter() + .map(|nested_list| type_matcher(nested_list)) + .map_ok(), + DataType::Struct(..) => match column.as_any().downcast_ref::() { + Some(struct_array) => struct_array + .nested_lists() + .iter() + .map(|nested_list| type_matcher(nested_list)) + .map_ok(), + None => Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "Can't cast array: {column:?} to struct array" + )), + ))), + }, + // DataType::Map(..) => { + // let item = match column.as_any().downcast_ref::() { + // Some(map_array) => map_array + // .entries() + // .nested_lists() + // .iter() + // .map(|nested_list| type_matcher(nested_list, nested_list.data_type())) + // .collect::, _>>(), + // None => Err(VectorStoreError::DatastoreError(Box::new( + // ArrowError::CastError(format!( + // "Can't cast array: {column:?} to map array" + // )), + // ))), + // }?; + // } + // DataType::Dictionary(key_data_type, value_data_type) => { + // let item = match column.as_any().downcast_ref::() { + // Some(map_array) => { + // let keys = &Arc::new(map_array.keys()); + // type_matcher(keys, keys.data_type()) + // } + // None => Err(ArrowError::CastError(format!( + // "Can't cast array: {column:?} to map array" + // ))), + // }?; + // }, + _ => { + println!("Unsupported data type"); + Ok(vec![serde_json::Value::Null]) + } + } + } + + let columns = self + .columns() + .iter() + .map(type_matcher) + .collect::, _>>()?; + + println!("{:?}", serde_json::to_string(&columns).unwrap()); + + Ok(json!({})) + } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. +pub trait DeserializePrimitiveArray { + fn to_primitive( + &self, + ) -> Result::Native>, ArrowError>; + + fn to_primitive_value(&self) -> Result, VectorStoreError> + where + ::Native: Serialize; +} + +impl DeserializePrimitiveArray for &Arc { + fn to_primitive( + &self, + ) -> Result::Native>, ArrowError> { + match self.as_any().downcast_ref::>() { + Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast array: {self:?} to float array" + ))), + } + } + + fn to_primitive_value(&self) -> Result, VectorStoreError> + where + ::Native: Serialize, + { + self.to_primitive::() + .map_err(arrow_to_rig_error)? + .iter() + .map(serde_json::to_value) + .collect::, _>>() + .map_err(serde_to_rig_error) + } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of byte objects. +pub trait DeserializeByteArray { + fn to_str(&self) -> Result::Native>, ArrowError>; + + fn to_str_value(&self) -> Result, VectorStoreError> + where + ::Native: Serialize; +} + +impl DeserializeByteArray for &Arc { + fn to_str(&self) -> Result::Native>, ArrowError> { + match self.as_any().downcast_ref::>() { + Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast array: {self:?} to float array" + ))), + } + } + + fn to_str_value(&self) -> Result, VectorStoreError> + where + ::Native: Serialize, + { + self.to_str::() + .map_err(arrow_to_rig_error)? + .iter() + .map(serde_json::to_value) + .collect::, _>>() + .map_err(serde_to_rig_error) + } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of list objects. +trait DeserializeListArray { + fn nested_lists( + &self, + ) -> Result>, ArrowError>; +} + +impl DeserializeListArray for &Arc { + fn nested_lists( + &self, + ) -> Result>, ArrowError> { + match self.as_any().downcast_ref::>() { + Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), + None => Err(ArrowError::CastError(format!( + "Can't cast array: {self:?} to float array" + ))), + } + } +} + +/// Trait used to "deserialize" an arrow_array::Array as as list of list objects. +trait DeserializeArray { + fn fixed_nested_lists(&self) -> Result>, ArrowError>; +} + +impl DeserializeArray for &Arc { + fn fixed_nested_lists(&self) -> Result>, ArrowError> { + match self.as_any().downcast_ref::() { + Some(list_array) => Ok((0..list_array.len()) + .map(|j| list_array.value(j as usize)) + .collect::>()), + None => { + return Err(ArrowError::CastError(format!( + "Can't cast column {self:?} to fixed size list array" + ))); + } + } + } +} + +trait DeserializeStructArray { + fn nested_lists(&self) -> Vec>; +} + +impl DeserializeStructArray for StructArray { + fn nested_lists(&self) -> Vec> { + (0..self.len()) + .map(|j| self.column(j).clone()) + .collect::>() + } +} + +trait MapOk { + fn map_ok(self) -> Result, VectorStoreError>; +} + +impl MapOk for I +where + I: Iterator, VectorStoreError>>, +{ + fn map_ok(self) -> Result, VectorStoreError> { + self.map(|maybe_list| match maybe_list { + Ok(list) => serde_json::to_value(list).map_err(serde_to_rig_error), + Err(e) => Err(e), + }) + .collect::, _>>() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{ + builder::{FixedSizeListBuilder, ListBuilder, StringBuilder, StructBuilder}, ArrayRef, BinaryArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, ListArray, RecordBatch, StringArray, StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array + }; + use lancedb::arrow::arrow_schema::{DataType, Field}; + + use crate::utils::deserializer::Test; + + #[tokio::test] + async fn test_primitive_deserialization() { + let string = Arc::new(StringArray::from_iter_values(vec!["Marty", "Tony"])) as ArrayRef; + let large_string = + Arc::new(LargeStringArray::from_iter_values(vec!["Jerry", "Freddy"])) as ArrayRef; + let binary = Arc::new(BinaryArray::from_iter_values(vec![b"hello", b"world"])) as ArrayRef; + let large_binary = Arc::new(LargeBinaryArray::from_iter_values(vec![ + b"The bright sun sets behind the mountains, casting gold", + b"A gentle breeze rustles through the trees at twilight.", + ])) as ArrayRef; + let float_32 = Arc::new(Float32Array::from_iter_values(vec![0.0, 1.0])) as ArrayRef; + let float_64 = Arc::new(Float64Array::from_iter_values(vec![0.0, 1.0])) as ArrayRef; + let int_8 = Arc::new(Int8Array::from_iter_values(vec![0, -1])) as ArrayRef; + let int_16 = Arc::new(Int16Array::from_iter_values(vec![-0, 1])) as ArrayRef; + let int_32 = Arc::new(Int32Array::from_iter_values(vec![0, -1])) as ArrayRef; + let int_64 = Arc::new(Int64Array::from_iter_values(vec![-0, 1])) as ArrayRef; + let uint_8 = Arc::new(UInt8Array::from_iter_values(vec![0, 1])) as ArrayRef; + let uint_16 = Arc::new(UInt16Array::from_iter_values(vec![0, 1])) as ArrayRef; + let uint_32 = Arc::new(UInt32Array::from_iter_values(vec![0, 1])) as ArrayRef; + let uint_64 = Arc::new(UInt64Array::from_iter_values(vec![0, 1])) as ArrayRef; + + let record_batch = RecordBatch::try_from_iter(vec![ + ("float_32", float_32), + ("float_64", float_64), + ("int_8", int_8), + ("int_16", int_16), + ("int_32", int_32), + ("int_64", int_64), + ("uint_8", uint_8), + ("uint_16", uint_16), + ("uint_32", uint_32), + ("uint_64", uint_64), + ("string", string), + ("large_string", large_string), + ("large_binary", large_binary), + ("binary", binary), + ]) + .unwrap(); + + let _t = record_batch.deserialize().unwrap(); + + assert!(false) + } + + #[tokio::test] + async fn test_list_recursion() { + let mut builder = FixedSizeListBuilder::new(StringBuilder::new(), 3); + builder.values().append_value("Hi"); + builder.values().append_value("Hey"); + builder.values().append_value("What's up"); + builder.append(true); + builder.values().append_value("Bye"); + builder.values().append_value("Seeya"); + builder.values().append_value("Later"); + builder.append(true); + + let record_batch = RecordBatch::try_from_iter(vec![( + "salutations", + Arc::new(builder.finish()) as ArrayRef, + )]) + .unwrap(); + + let _t = record_batch.deserialize().unwrap(); + + assert!(false) + } + + #[tokio::test] + async fn test_list_recursion_2() { + let mut builder = ListBuilder::new(ListBuilder::new(StringBuilder::new())); + builder + .values() + .append_value(vec![Some("Dog"), Some("Cat")]); + builder + .values() + .append_value(vec![Some("Mouse"), Some("Bird")]); + builder.append(true); + builder + .values() + .append_value(vec![Some("Giraffe"), Some("Mammoth")]); + builder + .values() + .append_value(vec![Some("Cow"), Some("Pig")]); + + let record_batch = + RecordBatch::try_from_iter(vec![("animals", Arc::new(builder.finish()) as ArrayRef)]) + .unwrap(); + + let _t = record_batch.deserialize().unwrap(); + + assert!(false) + } + + #[tokio::test] + async fn test_struct() { + let id_values = StringArray::from(vec!["id1", "id2", "id3"]); + + let age_values = Float32Array::from(vec![25.0, 30.5, 22.1]); + + let mut names_builder = ListBuilder::new(StringBuilder::new()); + names_builder.values().append_value("Alice"); + names_builder.values().append_value("Bob"); + names_builder.append(true); + names_builder.values().append_value("Charlie"); + names_builder.append(true); + names_builder.values().append_value("David"); + names_builder.values().append_value("Eve"); + names_builder.values().append_value("Frank"); + names_builder.append(true); + + let names_array = names_builder.finish(); + + // Step 4: Combine into a StructArray + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Utf8, false)), + Arc::new(id_values) as ArrayRef, + ), + ( + Arc::new(Field::new("age", DataType::Float32, false)), + Arc::new(age_values) as ArrayRef, + ), + ( + Arc::new(Field::new( + "names", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + false, + )), + Arc::new(names_array) as ArrayRef, + ), + ]); + + let record_batch = + RecordBatch::try_from_iter(vec![("employees", Arc::new(struct_array) as ArrayRef)]) + .unwrap(); + + let _t = record_batch.deserialize().unwrap(); + + assert!(false) + } +} diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index bf8874e2..bb9d2599 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod deserializer; use std::sync::Arc; use arrow_array::{ From 4a22b1537e6e1f7f0ae0475f9df27fede392943c Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 26 Sep 2024 17:23:10 -0400 Subject: [PATCH 23/39] feat: finish implementing deserialiser for record batch --- rig-lancedb/src/utils/deserializer.rs | 1054 ++++++++++++++++++------- 1 file changed, 774 insertions(+), 280 deletions(-) diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index bbe915b5..8eace4f4 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -1,18 +1,18 @@ use std::sync::Arc; use arrow_array::{ + cast::AsArray, types::{ - BinaryType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + ArrowDictionaryKeyType, BinaryType, ByteArrayType, Date32Type, Date64Type, Decimal128Type, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, - DurationSecondType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, - Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, - LargeBinaryType, LargeUtf8Type, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, - UInt32Type, UInt64Type, UInt8Type, Utf8Type, + DurationSecondType, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + IntervalDayTime, IntervalDayTimeType, IntervalMonthDayNano, IntervalMonthDayNanoType, + IntervalYearMonthType, LargeBinaryType, LargeUtf8Type, RunEndIndexType, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Type, }, - Array, ArrowPrimitiveType, FixedSizeBinaryArray, FixedSizeListArray, GenericByteArray, - GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, + Array, ArrowPrimitiveType, OffsetSizeTrait, RecordBatch, RunArray, StructArray, UnionArray, }; use lancedb::arrow::arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; use rig::vector_store::VectorStoreError; @@ -25,179 +25,291 @@ fn arrow_to_rig_error(e: ArrowError) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } -trait Test { +pub trait RecordBatchDeserializer { fn deserialize(&self) -> Result; } -impl Test for RecordBatch { +impl RecordBatchDeserializer for RecordBatch { fn deserialize(&self) -> Result { fn type_matcher(column: &Arc) -> Result, VectorStoreError> { match column.data_type() { DataType::Null => Ok(vec![serde_json::Value::Null]), - // f16 does not implement serde_json::Deserialize. Need to cast to f32. - DataType::Float16 => column - .to_primitive::() - .map_err(arrow_to_rig_error)? - .iter() - .map(|float_16| serde_json::to_value(float_16.to_f32())) - .collect::, _>>() + DataType::Float32 => column + .to_primitive_value::() .map_err(serde_to_rig_error), - DataType::Float32 => column.to_primitive_value::(), - DataType::Float64 => column.to_primitive_value::(), - DataType::Int8 => column.to_primitive_value::(), - DataType::Int16 => column.to_primitive_value::(), - DataType::Int32 => column.to_primitive_value::(), - DataType::Int64 => column.to_primitive_value::(), - DataType::UInt8 => column.to_primitive_value::(), - DataType::UInt16 => column.to_primitive_value::(), - DataType::UInt32 => column.to_primitive_value::(), - DataType::UInt64 => column.to_primitive_value::(), - DataType::Date32 => column.to_primitive_value::(), - DataType::Date64 => column.to_primitive_value::(), - DataType::Decimal128(..) => column.to_primitive_value::(), - // i256 does not implement serde_json::Deserialize. Need to cast to i128. - DataType::Decimal256(..) => column - .to_primitive::() - .map_err(arrow_to_rig_error)? - .iter() - .map(|dec_256| serde_json::to_value(dec_256.as_i128())) - .collect::, _>>() + DataType::Float64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int8 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int16 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt8 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt16 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Date32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Date64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Decimal128(..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time32(TimeUnit::Second) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time32(TimeUnit::Millisecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time64(TimeUnit::Microsecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time64(TimeUnit::Nanosecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Microsecond, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Millisecond, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Second, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Nanosecond, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Microsecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Millisecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Nanosecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Second) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Interval(IntervalUnit::YearMonth) => column + .to_primitive_value::() .map_err(serde_to_rig_error), - DataType::Time32(TimeUnit::Second) => { - column.to_primitive_value::() - } - DataType::Time32(TimeUnit::Millisecond) => { - column.to_primitive_value::() - } - DataType::Time64(TimeUnit::Microsecond) => { - column.to_primitive_value::() - } - DataType::Time64(TimeUnit::Nanosecond) => { - column.to_primitive_value::() - } - DataType::Timestamp(TimeUnit::Microsecond, ..) => { - column.to_primitive_value::() - } - DataType::Timestamp(TimeUnit::Millisecond, ..) => { - column.to_primitive_value::() - } - DataType::Timestamp(TimeUnit::Second, ..) => { - column.to_primitive_value::() - } - DataType::Timestamp(TimeUnit::Nanosecond, ..) => { - column.to_primitive_value::() - } - DataType::Duration(TimeUnit::Microsecond) => { - column.to_primitive_value::() - } - DataType::Duration(TimeUnit::Millisecond) => { - column.to_primitive_value::() - } - DataType::Duration(TimeUnit::Nanosecond) => { - column.to_primitive_value::() - } - DataType::Duration(TimeUnit::Second) => { - column.to_primitive_value::() - } DataType::Interval(IntervalUnit::DayTime) => Ok(column .to_primitive::() - .map_err(arrow_to_rig_error)? .iter() - .map(|interval| { + .map(|IntervalDayTime { days, milliseconds }| { json!({ - "days": interval.days, - "milliseconds": interval.milliseconds, + "days": days, + "milliseconds": milliseconds, }) }) .collect()), DataType::Interval(IntervalUnit::MonthDayNano) => Ok(column .to_primitive::() - .map_err(arrow_to_rig_error)? .iter() - .map(|interval| { - json!({ - "months": interval.months, - "days": interval.days, - "nanoseconds": interval.nanoseconds, - }) - }) + .map( + |IntervalMonthDayNano { + months, + days, + nanoseconds, + }| { + json!({ + "months": months, + "days": days, + "nanoseconds": nanoseconds, + }) + }, + ) .collect()), - DataType::Interval(IntervalUnit::YearMonth) => { - column.to_primitive_value::() + DataType::Utf8 => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::LargeUtf8 => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::Binary => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::LargeBinary => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::FixedSizeBinary(n) => (0..*n) + .map(|i| serde_json::to_value(column.as_fixed_size_binary().value(i as usize))) + .collect::, _>>() + .map_err(serde_to_rig_error), + DataType::Boolean => { + let bool_array = column.as_boolean(); + (0..bool_array.len()) + .map(|i| bool_array.value(i)) + .map(serde_json::to_value) + .collect::, _>>() + .map_err(serde_to_rig_error) } - DataType::Utf8 | DataType::Utf8View => column.to_str_value::(), - DataType::LargeUtf8 => column.to_str_value::(), - DataType::Binary => column.to_str_value::(), - DataType::LargeBinary => column.to_str_value::(), - DataType::FixedSizeBinary(n) => { - match column.as_any().downcast_ref::() { - Some(list_array) => (0..*n) - .map(|j| serde_json::to_value(list_array.value(j as usize))) - .collect::, _>>() - .map_err(serde_to_rig_error), - None => Err(VectorStoreError::DatastoreError(Box::new( - ArrowError::CastError(format!( - "Can't cast column {column:?} to fixed size list array" - )), - ))), - } + DataType::FixedSizeList(..) => { + column.to_fixed_lists().iter().map(type_matcher).map_ok() } - DataType::FixedSizeList(..) => column - .fixed_nested_lists() - .map_err(arrow_to_rig_error)? - .iter() - .map(|nested_list| type_matcher(nested_list)) - .map_ok(), - DataType::List(..) | DataType::ListView(..) => column - .nested_lists::() - .map_err(arrow_to_rig_error)? - .iter() - .map(|nested_list| type_matcher(nested_list)) - .map_ok(), - DataType::LargeList(..) | DataType::LargeListView(..) => column - .nested_lists::() - .map_err(arrow_to_rig_error)? - .iter() - .map(|nested_list| type_matcher(nested_list)) - .map_ok(), - DataType::Struct(..) => match column.as_any().downcast_ref::() { - Some(struct_array) => struct_array - .nested_lists() - .iter() - .map(|nested_list| type_matcher(nested_list)) - .map_ok(), + DataType::List(..) => column.to_list::().iter().map(type_matcher).map_ok(), + DataType::LargeList(..) => { + column.to_list::().iter().map(type_matcher).map_ok() + } + DataType::Struct(..) => { + let struct_array = column.as_struct(); + let struct_columns = struct_array + .inner_lists() + .iter() + .map(type_matcher) + .collect::, _>>()?; + + Ok(struct_columns + .build_struct(struct_array.num_rows(), struct_array.column_names())) + } + DataType::Map(..) => { + let map_columns = column + .as_map() + .entries() + .inner_lists() + .iter() + .map(type_matcher) + .collect::, _>>()?; + + Ok(map_columns.build_map()) + } + DataType::Dictionary(keys_type, ..) => { + let (keys, v) = match **keys_type { + DataType::Int8 => column.to_dict_values::()?, + DataType::Int16 => column.to_dict_values::()?, + DataType::Int32 => column.to_dict_values::()?, + DataType::Int64 => column.to_dict_values::()?, + DataType::UInt8 => column.to_dict_values::()?, + DataType::UInt16 => column.to_dict_values::()?, + DataType::UInt32 => column.to_dict_values::()?, + DataType::UInt64 => column.to_dict_values::()?, + _ => { + return Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "Dictionary keys type is not accepted: {keys_type:?}" + )), + ))) + } + }; + + let values = type_matcher(v)?; + + Ok(keys + .iter() + .zip(values) + .map(|(k, v)| { + let mut map = serde_json::Map::new(); + map.insert(k.to_string(), v); + map + }) + .map(Value::Object) + .collect()) + } + DataType::Union(..) => match column.as_any().downcast_ref::() { + Some(union_array) => (0..union_array.len()) + .map(|i| union_array.value(i).clone()) + .collect::>() + .iter() + .map(type_matcher) + .map_ok(), None => Err(VectorStoreError::DatastoreError(Box::new( ArrowError::CastError(format!( - "Can't cast array: {column:?} to struct array" + "Can't cast column {column:?} to union array" )), ))), }, - // DataType::Map(..) => { - // let item = match column.as_any().downcast_ref::() { - // Some(map_array) => map_array - // .entries() - // .nested_lists() - // .iter() - // .map(|nested_list| type_matcher(nested_list, nested_list.data_type())) - // .collect::, _>>(), - // None => Err(VectorStoreError::DatastoreError(Box::new( - // ArrowError::CastError(format!( - // "Can't cast array: {column:?} to map array" - // )), - // ))), - // }?; - // } - // DataType::Dictionary(key_data_type, value_data_type) => { - // let item = match column.as_any().downcast_ref::() { - // Some(map_array) => { - // let keys = &Arc::new(map_array.keys()); - // type_matcher(keys, keys.data_type()) - // } - // None => Err(ArrowError::CastError(format!( - // "Can't cast array: {column:?} to map array" - // ))), - // }?; - // }, + DataType::RunEndEncoded(counter_type, ..) => { + let items: Vec> = match counter_type.data_type() { + DataType::Int16 => { + let (counter, v) = column + .to_run_end::() + .map_err(arrow_to_rig_error)?; + + counter + .into_iter() + .zip(type_matcher(&v)?) + .map(|(n, value)| vec![value; n as usize]) + .collect() + } + DataType::Int32 => { + let (counter, v) = column + .to_run_end::() + .map_err(arrow_to_rig_error)?; + + counter + .into_iter() + .zip(type_matcher(&v)?) + .map(|(n, value)| vec![value; n as usize]) + .collect() + } + DataType::Int64 => { + let (counter, v) = column + .to_run_end::() + .map_err(arrow_to_rig_error)?; + + counter + .into_iter() + .zip(type_matcher(&v)?) + .map(|(n, value)| vec![value; n as usize]) + .collect() + } + _ => { + return Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "RunEndEncoded index type is not accepted: {counter_type:?}" + )), + ))) + } + }; + + items + .iter() + .map(|item| serde_json::to_value(item).map_err(serde_to_rig_error)) + .collect() + } + // Not yet fully supported + DataType::BinaryView => { + todo!() + } + // Not yet fully supported + DataType::Utf8View => { + todo!() + } + // Not yet fully supported + DataType::ListView(..) => { + todo!() + } + // Not yet fully supported + DataType::LargeListView(..) => { + todo!() + } + // f16 currently unstable + DataType::Float16 => { + todo!() + } + // i256 currently unstable + DataType::Decimal256(..) => { + todo!() + } _ => { println!("Unsupported data type"); Ok(vec![serde_json::Value::Null]) @@ -213,128 +325,194 @@ impl Test for RecordBatch { println!("{:?}", serde_json::to_string(&columns).unwrap()); - Ok(json!({})) + serde_json::to_value(&columns).map_err(serde_to_rig_error) } } /// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. pub trait DeserializePrimitiveArray { - fn to_primitive( - &self, - ) -> Result::Native>, ArrowError>; + /// Downcast arrow Array into a `PrimitiveArray` with items that implement trait `ArrowPrimitiveType`. + /// Return the primitive array values. + fn to_primitive(&self) -> Vec<::Native>; - fn to_primitive_value(&self) -> Result, VectorStoreError> + /// Same as above but convert the resulting array values into serde_json::Value. + fn to_primitive_value(&self) -> Result, serde_json::Error> where ::Native: Serialize; } impl DeserializePrimitiveArray for &Arc { - fn to_primitive( - &self, - ) -> Result::Native>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } + fn to_primitive(&self) -> Vec<::Native> { + let primitive_array = self.as_primitive::(); + + (0..primitive_array.len()) + .map(|i| primitive_array.value(i)) + .collect() } - fn to_primitive_value(&self) -> Result, VectorStoreError> + fn to_primitive_value(&self) -> Result, serde_json::Error> where ::Native: Serialize, { self.to_primitive::() - .map_err(arrow_to_rig_error)? .iter() .map(serde_json::to_value) - .collect::, _>>() - .map_err(serde_to_rig_error) + .collect() } } -/// Trait used to "deserialize" an arrow_array::Array as as list of byte objects. +/// Trait used to "deserialize" an arrow_array::Array as as list of str objects. pub trait DeserializeByteArray { - fn to_str(&self) -> Result::Native>, ArrowError>; + /// Downcast arrow Array into a `GenericByteArray` with items that implement trait `ByteArrayType`. + /// Return the generic byte array values. + fn to_str(&self) -> Vec<&::Native>; - fn to_str_value(&self) -> Result, VectorStoreError> + /// Same as above but convert the resulting array values into serde_json::Value. + fn to_str_value(&self) -> Result, serde_json::Error> where ::Native: Serialize; } impl DeserializeByteArray for &Arc { - fn to_str(&self) -> Result::Native>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } + fn to_str(&self) -> Vec<&::Native> { + let byte_array = self.as_bytes::(); + (0..byte_array.len()).map(|j| byte_array.value(j)).collect() } - fn to_str_value(&self) -> Result, VectorStoreError> + fn to_str_value(&self) -> Result, serde_json::Error> where ::Native: Serialize, { self.to_str::() - .map_err(arrow_to_rig_error)? .iter() .map(serde_json::to_value) - .collect::, _>>() - .map_err(serde_to_rig_error) + .collect() } } -/// Trait used to "deserialize" an arrow_array::Array as as list of list objects. +/// Trait used to "deserialize" an arrow_array::Array as a list of list objects. trait DeserializeListArray { - fn nested_lists( - &self, - ) -> Result>, ArrowError>; + /// Downcast arrow Array into a `GenericListArray` with items that implement trait `OffsetSizeTrait`. + /// Return the generic list array values. + fn to_list(&self) -> Vec>; } impl DeserializeListArray for &Arc { - fn nested_lists( + fn to_list(&self) -> Vec> { + (0..self.as_list::().len()) + .map(|j| self.as_list::().value(j)) + .collect() + } +} + +/// Trait used to "deserialize" an arrow_array::Array as a list of dict objects. +trait DeserializeDictArray { + /// Downcast arrow Array into a `DictionaryArray` with items that implement trait `ArrowDictionaryKeyType`. + /// Return the dictionary keys and values as a tuple. + fn to_dict( &self, - ) -> Result>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } + ) -> ( + Vec<::Native>, + &Arc, + ); + + fn to_dict_values( + &self, + ) -> Result<(Vec, &Arc), serde_json::Error> + where + ::Native: Serialize; +} + +impl DeserializeDictArray for &Arc { + fn to_dict( + &self, + ) -> ( + Vec<::Native>, + &Arc, + ) { + let dict_array = self.as_dictionary::(); + ( + (0..dict_array.keys().len()) + .map(|i| dict_array.keys().value(i)) + .collect(), + dict_array.values(), + ) + } + + fn to_dict_values( + &self, + ) -> Result<(Vec, &Arc), serde_json::Error> + where + ::Native: Serialize, + { + let (k, v) = self.to_dict::(); + + Ok(( + k.iter() + .map(serde_json::to_string) + .collect::, _>>()?, + v, + )) } } -/// Trait used to "deserialize" an arrow_array::Array as as list of list objects. +/// Trait used to "deserialize" an arrow_array::Array as as list of fixed size list objects. trait DeserializeArray { - fn fixed_nested_lists(&self) -> Result>, ArrowError>; + /// Downcast arrow Array into a `FixedSizeListArray`. + /// Return the fixed size list array values. + fn to_fixed_lists(&self) -> Vec>; } impl DeserializeArray for &Arc { - fn fixed_nested_lists(&self) -> Result>, ArrowError> { - match self.as_any().downcast_ref::() { - Some(list_array) => Ok((0..list_array.len()) - .map(|j| list_array.value(j as usize)) - .collect::>()), - None => { - return Err(ArrowError::CastError(format!( - "Can't cast column {self:?} to fixed size list array" - ))); - } + fn to_fixed_lists(&self) -> Vec> { + let list_array = self.as_fixed_size_list(); + + (0..list_array.len()).map(|i| list_array.value(i)).collect() + } +} + +type RunArrayParts = ( + Vec<::Native>, + Arc, +); + +/// Trait used to "deserialize" an arrow_array::Array as a list of list objects. +trait DeserializeRunArray { + /// Downcast arrow Array into a `GenericListArray` with items that implement trait `RunEndIndexType`. + /// Return the generic list array values. + fn to_run_end(&self) -> Result, ArrowError>; +} + +impl DeserializeRunArray for &Arc { + fn to_run_end(&self) -> Result, ArrowError> { + if let Some(run_array) = self.as_any().downcast_ref::>() { + return Ok(( + run_array.run_ends().values().to_vec(), + run_array.values().clone(), + )); } + Err(ArrowError::CastError(format!( + "Can't cast array: {self:?} to list array" + ))) } } trait DeserializeStructArray { - fn nested_lists(&self) -> Vec>; + fn inner_lists(&self) -> Vec>; + + fn num_rows(&self) -> usize; } impl DeserializeStructArray for StructArray { - fn nested_lists(&self) -> Vec> { - (0..self.len()) + fn inner_lists(&self) -> Vec> { + (0..self.num_columns()) .map(|j| self.column(j).clone()) .collect::>() } + + fn num_rows(&self) -> usize { + self.column(0).into_data().len() + } } trait MapOk { @@ -354,16 +532,143 @@ where } } +trait RebuildObject { + fn build_struct(&self, num_rows: usize, col_names: Vec<&str>) -> Vec; + + fn build_map(&self) -> Vec; +} + +impl RebuildObject for Vec> { + fn build_struct(&self, num_rows: usize, col_names: Vec<&str>) -> Vec { + (0..num_rows) + .map(|row_i| { + self.iter() + .enumerate() + .fold(serde_json::Map::new(), |mut acc, (col_i, col)| { + acc.insert(col_names[col_i].to_string(), col[row_i].clone()); + acc + }) + }) + .map(Value::Object) + .collect() + } + + fn build_map(&self) -> Vec { + let keys = &self[0]; + let values = &self[1]; + + keys.iter() + .zip(values) + .map(|(k, v)| { + let mut map = serde_json::Map::new(); + map.insert( + match k { + serde_json::Value::String(s) => s.clone(), + _ => k.to_string(), + }, + v.clone(), + ); + map + }) + .map(Value::Object) + .collect() + } +} + #[cfg(test)] mod tests { use std::sync::Arc; use arrow_array::{ - builder::{FixedSizeListBuilder, ListBuilder, StringBuilder, StructBuilder}, ArrayRef, BinaryArray, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, ListArray, RecordBatch, StringArray, StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array + builder::{ + FixedSizeListBuilder, ListBuilder, StringBuilder, StringDictionaryBuilder, + StringRunBuilder, UnionBuilder, + }, + types::{Float64Type, Int16Type, Int32Type, Int8Type}, + ArrayRef, BinaryArray, FixedSizeListArray, Float32Array, Float64Array, GenericListArray, + Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, + MapArray, RecordBatch, StringArray, StructArray, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, }; - use lancedb::arrow::arrow_schema::{DataType, Field}; + use lancedb::arrow::arrow_schema::{DataType, Field, Fields}; + use serde_json::json; + + use crate::utils::deserializer::RecordBatchDeserializer; + + fn fixed_list_actors() -> FixedSizeListArray { + let mut builder = FixedSizeListBuilder::new(StringBuilder::new(), 2); + builder.values().append_value("Johnny Depp"); + builder.values().append_value("Cate Blanchet"); + builder.append(true); + builder.values().append_value("Meryl Streep"); + builder.values().append_value("Scarlett Johansson"); + builder.append(true); + builder.values().append_value("Brad Pitt"); + builder.values().append_value("Natalie Portman"); + builder.append(true); + + builder.finish() + } + + fn name_list() -> GenericListArray { + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.values().append_value("Alice"); + builder.values().append_value("Bob"); + builder.append(true); + builder.values().append_value("Charlie"); + builder.append(true); + builder.values().append_value("David"); + builder.values().append_value("Eve"); + builder.values().append_value("Frank"); + builder.append(true); + builder.finish() + } + + fn nested_list_of_animals() -> GenericListArray { + // [ [ [ "Dog", "Cat" ], ["Mouse"] ], [ [ "Giraffe" ], ["Cow", "Pig"] ], [ [ "Sloth" ], ["Ant", "Monkey"] ] ] + let mut builder = ListBuilder::new(ListBuilder::new(StringBuilder::new())); + builder + .values() + .append_value(vec![Some("Dog"), Some("Cat")]); + builder.values().append_value(vec![Some("Mouse")]); + builder.append(true); + builder.values().append_value(vec![Some("Giraffe")]); + builder + .values() + .append_value(vec![Some("Cow"), Some("Pig")]); + builder.append(true); + builder.values().append_value(vec![Some("Sloth")]); + builder + .values() + .append_value(vec![Some("Ant"), Some("Monkey")]); + builder.append(true); + builder.finish() + } - use crate::utils::deserializer::Test; + fn movie_struct() -> StructArray { + StructArray::from(vec![ + ( + Arc::new(Field::new("name", DataType::Utf8, false)), + Arc::new(StringArray::from(vec![ + "Pulp Fiction", + "The Shawshank Redemption", + "La La Land", + ])) as ArrayRef, + ), + ( + Arc::new(Field::new("year", DataType::UInt32, false)), + Arc::new(UInt32Array::from(vec![1999, 2026, 1745])) as ArrayRef, + ), + ( + Arc::new(Field::new( + "actors", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Utf8, true)), 2), + false, + )), + Arc::new(fixed_list_actors()) as ArrayRef, + ), + ]) + } #[tokio::test] async fn test_primitive_deserialization() { @@ -371,10 +676,8 @@ mod tests { let large_string = Arc::new(LargeStringArray::from_iter_values(vec!["Jerry", "Freddy"])) as ArrayRef; let binary = Arc::new(BinaryArray::from_iter_values(vec![b"hello", b"world"])) as ArrayRef; - let large_binary = Arc::new(LargeBinaryArray::from_iter_values(vec![ - b"The bright sun sets behind the mountains, casting gold", - b"A gentle breeze rustles through the trees at twilight.", - ])) as ArrayRef; + let large_binary = + Arc::new(LargeBinaryArray::from_iter_values(vec![b"abc", b"def"])) as ArrayRef; let float_32 = Arc::new(Float32Array::from_iter_values(vec![0.0, 1.0])) as ArrayRef; let float_64 = Arc::new(Float64Array::from_iter_values(vec![0.0, 1.0])) as ArrayRef; let int_8 = Arc::new(Int8Array::from_iter_values(vec![0, -1])) as ArrayRef; @@ -404,80 +707,161 @@ mod tests { ]) .unwrap(); - let _t = record_batch.deserialize().unwrap(); - - assert!(false) + assert_eq!( + record_batch.deserialize().unwrap(), + json!([ + [0.0, 1.0], + [0.0, 1.0], + [0, -1], + [0, 1], + [0, -1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [0, 1], + ["Marty", "Tony"], + ["Jerry", "Freddy"], + [[97, 98, 99], [100, 101, 102]], + [[104, 101, 108, 108, 111], [119, 111, 114, 108, 100]] + ]) + ) } #[tokio::test] - async fn test_list_recursion() { - let mut builder = FixedSizeListBuilder::new(StringBuilder::new(), 3); - builder.values().append_value("Hi"); - builder.values().append_value("Hey"); - builder.values().append_value("What's up"); - builder.append(true); - builder.values().append_value("Bye"); - builder.values().append_value("Seeya"); - builder.values().append_value("Later"); - builder.append(true); + async fn test_dictionary_deserialization() { + let dictionary_values = StringArray::from(vec![None, Some("abc"), Some("def")]); - let record_batch = RecordBatch::try_from_iter(vec![( - "salutations", - Arc::new(builder.finish()) as ArrayRef, - )]) - .unwrap(); + let mut builder = + StringDictionaryBuilder::::new_with_dictionary(3, &dictionary_values) + .unwrap(); + builder.append("def").unwrap(); + builder.append_null(); + builder.append("abc").unwrap(); - let _t = record_batch.deserialize().unwrap(); + let dictionary_array = builder.finish(); + + let record_batch = + RecordBatch::try_from_iter(vec![("some_dict", Arc::new(dictionary_array) as ArrayRef)]) + .unwrap(); - assert!(false) + assert_eq!( + record_batch.deserialize().unwrap(), + json!([ + [ + { + "2": "" + }, + { + "0": "abc" + }, + { + "1": "def" + } + ] + ]) + ) } #[tokio::test] - async fn test_list_recursion_2() { - let mut builder = ListBuilder::new(ListBuilder::new(StringBuilder::new())); - builder - .values() - .append_value(vec![Some("Dog"), Some("Cat")]); - builder - .values() - .append_value(vec![Some("Mouse"), Some("Bird")]); - builder.append(true); - builder - .values() - .append_value(vec![Some("Giraffe"), Some("Mammoth")]); - builder - .values() - .append_value(vec![Some("Cow"), Some("Pig")]); + async fn test_union_deserialization() { + let mut builder = UnionBuilder::new_dense(); + builder.append::("type_a", 1).unwrap(); + builder.append::("type_b", 3.0).unwrap(); + builder.append::("type_a", 4).unwrap(); + let union = builder.build().unwrap(); let record_batch = - RecordBatch::try_from_iter(vec![("animals", Arc::new(builder.finish()) as ArrayRef)]) - .unwrap(); - - let _t = record_batch.deserialize().unwrap(); + RecordBatch::try_from_iter(vec![("some_dict", Arc::new(union) as ArrayRef)]).unwrap(); - assert!(false) + assert_eq!( + record_batch.deserialize().unwrap(), + json!([[[1], [3.0], [4]]]) + ) } #[tokio::test] - async fn test_struct() { - let id_values = StringArray::from(vec!["id1", "id2", "id3"]); + async fn test_run_end_deserialization() { + let mut builder = StringRunBuilder::::new(); - let age_values = Float32Array::from(vec![25.0, 30.5, 22.1]); + // The builder builds the dictionary value by value + builder.append_value("abc"); + builder.append_null(); + builder.extend([Some("def"), Some("def"), Some("abc")]); + let array = builder.finish(); - let mut names_builder = ListBuilder::new(StringBuilder::new()); - names_builder.values().append_value("Alice"); - names_builder.values().append_value("Bob"); - names_builder.append(true); - names_builder.values().append_value("Charlie"); - names_builder.append(true); - names_builder.values().append_value("David"); - names_builder.values().append_value("Eve"); - names_builder.values().append_value("Frank"); - names_builder.append(true); + let record_batch = + RecordBatch::try_from_iter(vec![("some_dict", Arc::new(array) as ArrayRef)]).unwrap(); + + assert_eq!( + record_batch.deserialize().unwrap(), + json!([[ + ["abc"], + ["", ""], + ["def", "def", "def", "def"], + ["abc", "abc", "abc", "abc", "abc"] + ]]) + ) + } + + #[tokio::test] + async fn test_map_deserialization() { + let record_batch = RecordBatch::try_from_iter(vec![( + "map_col", + Arc::new( + MapArray::new_from_strings( + vec!["tarentino", "darabont", "chazelle"].into_iter(), + &movie_struct(), + &[0, 1, 2], + ) + .unwrap(), + ) as ArrayRef, + )]) + .unwrap(); - let names_array = names_builder.finish(); + assert_eq!( + record_batch.deserialize().unwrap(), + json!([ + [ + { + "tarentino": { + "name": "Pulp Fiction", + "year": 1999, + "actors": [ + "Johnny Depp", + "Cate Blanchet" + ] + } + }, + { + "darabont": { + "name": "The Shawshank Redemption", + "year": 2026, + "actors": [ + "Meryl Streep", + "Scarlett Johansson" + ] + } + }, + { + "chazelle": { + "name": "La La Land", + "year": 1745, + "actors": [ + "Brad Pitt", + "Natalie Portman" + ] + } + } + ] + ]) + ) + } - // Step 4: Combine into a StructArray + #[tokio::test] + async fn test_recursion() { + let id_values = StringArray::from(vec!["id1", "id2", "id3"]); + let age_values = Float32Array::from(vec![25.0, 30.5, 22.1]); let struct_array = StructArray::from(vec![ ( Arc::new(Field::new("id", DataType::Utf8, false)), @@ -493,7 +877,38 @@ mod tests { DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), false, )), - Arc::new(names_array) as ArrayRef, + Arc::new(name_list()) as ArrayRef, + ), + ( + Arc::new(Field::new( + "favorite_animals", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + ))), + false, + )), + Arc::new(nested_list_of_animals()) as ArrayRef, + ), + ( + Arc::new(Field::new( + "favorite_movie", + DataType::Struct(Fields::from_iter(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("year", DataType::UInt32, false), + Field::new( + "actors", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Utf8, true)), + 2, + ), + false, + ), + ])), + false, + )), + Arc::new(movie_struct()) as ArrayRef, ), ]); @@ -501,8 +916,87 @@ mod tests { RecordBatch::try_from_iter(vec![("employees", Arc::new(struct_array) as ArrayRef)]) .unwrap(); - let _t = record_batch.deserialize().unwrap(); - - assert!(false) + assert_eq!( + record_batch.deserialize().unwrap(), + json!([ + [ + { + "id": "id1", + "age": 25.0, + "names": [ + "Alice", + "Bob" + ], + "favorite_animals": [ + [ + "Dog", + "Cat" + ], + [ + "Mouse" + ] + ], + "favorite_movie": { + "name": "Pulp Fiction", + "year": 1999, + "actors": [ + "Johnny Depp", + "Cate Blanchet" + ] + } + }, + { + "id": "id2", + "age": 30.5, + "names": [ + "Charlie" + ], + "favorite_animals": [ + [ + "Giraffe" + ], + [ + "Cow", + "Pig" + ] + ], + "favorite_movie": { + "name": "The Shawshank Redemption", + "year": 2026, + "actors": [ + "Meryl Streep", + "Scarlett Johansson" + ] + } + }, + { + "id": "id3", + "age": 22.100000381469727, + "names": [ + "David", + "Eve", + "Frank" + ], + "favorite_animals": [ + [ + "Sloth" + ], + [ + "Ant", + "Monkey" + ] + ], + "favorite_movie": { + "name": "La La Land", + "year": 1745, + "actors": [ + "Brad Pitt", + "Natalie Portman" + ] + } + } + ] + ]) + ) } } From ff85fa552d12480d490799c12a93307a15499c1c Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 26 Sep 2024 17:30:02 -0400 Subject: [PATCH 24/39] refactor: remove print statement --- rig-lancedb/src/utils/deserializer.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index 8eace4f4..96c7ce7b 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -323,8 +323,6 @@ impl RecordBatchDeserializer for RecordBatch { .map(type_matcher) .collect::, _>>()?; - println!("{:?}", serde_json::to_string(&columns).unwrap()); - serde_json::to_value(&columns).map_err(serde_to_rig_error) } } From 70802b35f7b0c6ee520a5d1c592af6beab7bfe5f Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 1 Oct 2024 16:15:28 -0400 Subject: [PATCH 25/39] refactor: update rig core version on lancedb crate, remove implementation of VectorStore trait --- rig-lancedb/Cargo.toml | 2 +- .../examples/vector_search_local_ann.rs | 2 +- .../examples/vector_search_local_enn.rs | 2 +- rig-lancedb/examples/vector_search_s3_ann.rs | 2 +- rig-lancedb/src/lib.rs | 93 ------------------- 5 files changed, 4 insertions(+), 97 deletions(-) diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index 3ec378f7..ed7426b0 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] lancedb = "0.10.0" -rig-core = { path = "../rig-core", version = "0.1.0" } +rig-core = { path = "../rig-core", version = "0.2.1" } arrow-array = "52.2.0" serde_json = "1.0.128" serde = "1.0.210" diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index be9c7038..358315bd 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -48,7 +48,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; // Add embeddings to vector store - vector_store.add_documents(embeddings).await?; + // vector_store.add_documents(embeddings).await?; // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information vector_store diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 4b1cb010..fe9a6be7 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -28,7 +28,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; // Add embeddings to vector store - vector_store.add_documents(embeddings).await?; + // vector_store.add_documents(embeddings).await?; // Query the index let results = vector_store diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 2eefec2d..1b4e666d 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -55,7 +55,7 @@ async fn main() -> Result<(), anyhow::Error> { .await?; // Add embeddings to vector store - vector_store.add_documents(embeddings).await?; + // vector_store.add_documents(embeddings).await?; // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information vector_store diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index ac6f7876..fa8b54c3 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -174,99 +174,6 @@ impl LanceDbVectorStore { } } -impl VectorStore for LanceDbVectorStore { - type Q = lancedb::query::Query; - - async fn add_documents( - &mut self, - documents: Vec, - ) -> Result<(), VectorStoreError> { - let document_records = - DocumentRecords::try_from(documents.clone()).map_err(serde_to_rig_error)?; - - self.document_table - .insert(document_records, Self::document_schema()) - .await - .map_err(lancedb_to_rig_error)?; - - let embedding_records = EmbeddingRecordsBatch::from(documents); - - self.embedding_table - .insert( - embedding_records, - Self::embedding_schema(self.model.ndims() as i32), - ) - .await - .map_err(lancedb_to_rig_error)?; - - Ok(()) - } - - async fn get_document_embeddings( - &self, - id: &str, - ) -> Result, VectorStoreError> { - let documents: DocumentRecords = self - .document_table - .query() - .only_if(format!("id = '{id}'")) - .execute_query() - .await?; - - let embeddings: EmbeddingRecordsBatch = self - .embedding_table - .query() - .only_if(format!("document_id = '{id}'")) - .execute_query() - .await?; - - Ok(merge(&documents, &embeddings)?.into_iter().next()) - } - - async fn get_document serde::Deserialize<'a>>( - &self, - id: &str, - ) -> Result, VectorStoreError> { - let documents: DocumentRecords = self - .document_table - .query() - .only_if(format!("id = '{id}'")) - .execute_query() - .await?; - - let document = documents - .as_iter() - .next() - .map(|document| serde_json::from_str(&document.document).map_err(serde_to_rig_error)) - .transpose(); - - document - } - - async fn get_document_by_query( - &self, - query: Self::Q, - ) -> Result, VectorStoreError> { - let documents: DocumentRecords = query.execute_query().await?; - - let embeddings: EmbeddingRecordsBatch = self - .embedding_table - .query() - .only_if(format!( - "document_id IN ({})", - documents - .ids() - .map(|id| format!("'{id}'")) - .collect::>() - .join(",") - )) - .execute_query() - .await?; - - Ok(merge(&documents, &embeddings)?.into_iter().next()) - } -} - impl VectorStoreIndex for LanceDbVectorStore { async fn top_n_from_query( &self, From 205f0c706befa39992153f6124ae31cb621117b3 Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 1 Oct 2024 17:16:54 -0400 Subject: [PATCH 26/39] feat: merge all arrow columns into JSON document in deserializer --- rig-lancedb/src/utils/deserializer.rs | 140 +++++++++++++++++++------- 1 file changed, 104 insertions(+), 36 deletions(-) diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index 96c7ce7b..94d28c40 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -287,27 +287,14 @@ impl RecordBatchDeserializer for RecordBatch { .collect() } // Not yet fully supported - DataType::BinaryView => { + DataType::BinaryView + | DataType::Utf8View + | DataType::ListView(..) + | DataType::LargeListView(..) => { todo!() } - // Not yet fully supported - DataType::Utf8View => { - todo!() - } - // Not yet fully supported - DataType::ListView(..) => { - todo!() - } - // Not yet fully supported - DataType::LargeListView(..) => { - todo!() - } - // f16 currently unstable - DataType::Float16 => { - todo!() - } - // i256 currently unstable - DataType::Decimal256(..) => { + // Currently unstable + DataType::Float16 | DataType::Decimal256(..) => { todo!() } _ => { @@ -317,13 +304,32 @@ impl RecordBatchDeserializer for RecordBatch { } } + let binding = self.schema(); + let column_names = binding + .fields() + .iter() + .map(|field| field.name()) + .collect::>(); + let columns = self .columns() .iter() .map(type_matcher) .collect::, _>>()?; - serde_json::to_value(&columns).map_err(serde_to_rig_error) + Ok(Value::Object((0..self.num_rows()).fold( + serde_json::Map::new(), + |mut acc, row_i| { + columns.iter().enumerate().for_each(|(col_i, col)| { + acc.entry(column_names[col_i].to_string()).and_modify(|v| { + if let Value::Array(v_arr) = v { + v_arr.push(col[row_i].clone()) + } + }).or_insert(Value::Array(vec![col[row_i].clone()])); + }); + acc + }, + ))) } } @@ -707,22 +713,84 @@ mod tests { assert_eq!( record_batch.deserialize().unwrap(), - json!([ - [0.0, 1.0], - [0.0, 1.0], - [0, -1], - [0, 1], - [0, -1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - [0, 1], - ["Marty", "Tony"], - ["Jerry", "Freddy"], - [[97, 98, 99], [100, 101, 102]], - [[104, 101, 108, 108, 111], [119, 111, 114, 108, 100]] - ]) + json!({ + "binary": [ + [ + 104, + 101, + 108, + 108, + 111 + ], + [ + 119, + 111, + 114, + 108, + 100 + ] + ], + "float_32": [ + 0.0, + 1.0 + ], + "float_64": [ + 0.0, + 1.0 + ], + "int_16": [ + 0, + 1 + ], + "int_32": [ + 0, + -1 + ], + "int_64": [ + 0, + 1 + ], + "int_8": [ + 0, + -1 + ], + "large_binary": [ + [ + 97, + 98, + 99 + ], + [ + 100, + 101, + 102 + ] + ], + "large_string": [ + "Jerry", + "Freddy" + ], + "string": [ + "Marty", + "Tony" + ], + "uint_16": [ + 0, + 1 + ], + "uint_32": [ + 0, + 1 + ], + "uint_64": [ + 0, + 1 + ], + "uint_8": [ + 0, + 1 + ] + }) ) } From 921b313b9182c8d5371c6442adbc29e48ad6fd5d Mon Sep 17 00:00:00 2001 From: Garance Date: Tue, 1 Oct 2024 21:29:11 -0400 Subject: [PATCH 27/39] feat: replace document embeddings with serde json value --- .../examples/vector_search_local_ann.rs | 4 +- .../examples/vector_search_local_enn.rs | 4 +- rig-lancedb/examples/vector_search_s3_ann.rs | 4 +- rig-lancedb/src/lib.rs | 160 +++------ rig-lancedb/src/table_schemas/mod.rs | 39 --- rig-lancedb/src/utils/deserializer.rs | 309 +++++++++--------- rig-lancedb/src/utils/mod.rs | 34 +- 7 files changed, 210 insertions(+), 344 deletions(-) diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 358315bd..ce1dd612 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -62,10 +62,10 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n_from_query("My boss says I zindle too much, what does that mean?", 1) + .top_n("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() - .map(|(score, doc)| (score, doc.id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index fe9a6be7..7c099807 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -32,10 +32,10 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n_from_query("My boss says I zindle too much, what does that mean?", 1) + .top_n("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() - .map(|(score, doc)| (score, doc.id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 1b4e666d..00ba96ea 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -69,10 +69,10 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n_from_query("My boss says I zindle too much, what does that mean?", 1) + .top_n("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() - .map(|(score, doc)| (score, doc.id, doc.document)) + .map(|(score, id, doc)| (score, id, doc)) .collect::>(); println!("Results: {:?}", results); diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index fa8b54c3..a141d8dc 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -8,10 +8,11 @@ use lancedb::{ }; use rig::{ embeddings::EmbeddingModel, - vector_store::{VectorStore, VectorStoreError, VectorStoreIndex}, + vector_store::{VectorStoreError, VectorStoreIndex}, }; -use table_schemas::{document::DocumentRecords, embedding::EmbeddingRecordsBatch, merge}; -use utils::{Insert, Query}; +use serde::Deserialize; +use serde_json::Value; +use utils::Query; mod table_schemas; mod utils; @@ -27,11 +28,8 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { pub struct LanceDbVectorStore { /// Defines which model is used to generate embeddings for the vector store model: M, - /// Table containing documents only - document_table: lancedb::Table, - /// Table containing embeddings only. - /// Foreign key references the document in document table. - embedding_table: lancedb::Table, + table: lancedb::Table, + id_field: String, /// Vector search params that are used during vector search operations. search_params: SearchParams, } @@ -93,104 +91,40 @@ impl SearchParams { } impl LanceDbVectorStore { - /// Note: Tables are created inside the new function rather than created outside and passed as reference to new function. - /// This is because a specific schema needs to be enforced on the tables and this is done at creation time. pub async fn new( - db: &lancedb::Connection, - model: &M, - search_params: &SearchParams, + table: lancedb::Table, + model: M, + id_field: String, + search_params: SearchParams, ) -> Result { - let document_table = db - .create_empty_table("documents", Arc::new(Self::document_schema())) - .execute() - .await?; - - let embedding_table = db - .create_empty_table( - "embeddings", - Arc::new(Self::embedding_schema(model.ndims() as i32)), - ) - .execute() - .await?; - Ok(Self { - document_table, - embedding_table, - model: model.clone(), - search_params: search_params.clone(), + table, + model, + id_field, + search_params, }) } - /// Schema of records in document table. - fn document_schema() -> Schema { - Schema::new(Fields::from(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("document", DataType::Utf8, false), - ])) - } - - /// Schema of records in embeddings table. - /// Every embedding vector in the table must have the same size. - fn embedding_schema(dimension: i32) -> Schema { - Schema::new(Fields::from(vec![ - Field::new("id", DataType::Utf8, false), - Field::new("document_id", DataType::Utf8, false), - Field::new("content", DataType::Utf8, false), - Field::new( - "embedding", - DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Float64, true)), - dimension, - ), - false, - ), - ])) - } - /// Define index on document table `id` field for search optimization. - pub async fn create_document_index(&self, index: Index) -> Result<(), lancedb::Error> { - self.document_table - .create_index(&["id"], index) - .execute() - .await - } - - /// Define index on embedding table `id` and `document_id` fields for search optimization. - pub async fn create_embedding_index(&self, index: Index) -> Result<(), lancedb::Error> { - self.embedding_table - .create_index(&["id", "document_id"], index) - .execute() - .await - } - - /// Define index on embedding table `embedding` fields for vector search optimization. - pub async fn create_index(&self, index: Index) -> Result<(), lancedb::Error> { - self.embedding_table - .create_index(&["embedding"], index) - .execute() - .await?; - - Ok(()) + pub async fn create_document_index( + &self, + index: Index, + field_names: &[impl AsRef], + ) -> Result<(), lancedb::Error> { + self.table.create_index(field_names, index).execute().await } } impl VectorStoreIndex for LanceDbVectorStore { - async fn top_n_from_query( + async fn top_n Deserialize<'a> + std::marker::Send>( &self, query: &str, n: usize, - ) -> Result, VectorStoreError> { + ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - self.top_n_from_embedding(&prompt_embedding, n).await - } - async fn top_n_from_embedding( - &self, - prompt_embedding: &rig::embeddings::Embedding, - n: usize, - ) -> Result, VectorStoreError> { let mut query = self - .embedding_table + .table .vector_search(prompt_embedding.vec.clone()) .map_err(lancedb_to_rig_error)? .limit(n); @@ -224,33 +158,31 @@ impl VectorStoreIndex for LanceDbV query = query.postfilter(); } - let embeddings: EmbeddingRecordsBatch = query.execute_query().await?; - - let documents: DocumentRecords = self - .document_table - .query() - .only_if(format!("id IN ({})", embeddings.document_ids())) + query .execute_query() - .await?; - - let document_embeddings = merge(&documents, &embeddings)?; - - Ok(document_embeddings + .await? .into_iter() - .map(|doc| { - let distance = embeddings - .get_by_id(&doc.id) - .map(|records| { - records - .as_iter() - .next() - .map(|record| record.distance.unwrap_or(0.0)) - .unwrap_or(0.0) - }) - .unwrap_or(0.0); - - (distance as f64, doc) + .map(|value| { + Ok(( + match value.get("distance") { + Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(), + _ => 0.0, + }, + match value.get(self.id_field.clone()) { + Some(Value::String(id)) => id.to_string(), + _ => "".to_string(), + }, + serde_json::from_value(value).map_err(serde_to_rig_error)?, + )) }) - .collect()) + .collect() + } + + async fn top_n_ids( + &self, + query: &str, + n: usize, + ) -> Result, VectorStoreError> { + todo!() } } diff --git a/rig-lancedb/src/table_schemas/mod.rs b/rig-lancedb/src/table_schemas/mod.rs index 5c175fd4..bd24dd65 100644 --- a/rig-lancedb/src/table_schemas/mod.rs +++ b/rig-lancedb/src/table_schemas/mod.rs @@ -1,41 +1,2 @@ -use document::{DocumentRecord, DocumentRecords}; -use embedding::{EmbeddingRecord, EmbeddingRecordsBatch}; -use rig::embeddings::{DocumentEmbeddings, Embedding}; - pub mod document; pub mod embedding; - -/// Merge an `DocumentRecords` object with an `EmbeddingRecordsBatch` object. -/// These objects contain document and embedding data, respectively, read from LanceDB. -/// For each document in `DocumentRecords` find the embeddings from `EmbeddingRecordsBatch` that correspond to that document, -/// using the document_id as reference. -pub fn merge( - documents: &DocumentRecords, - embeddings: &EmbeddingRecordsBatch, -) -> Result, serde_json::Error> { - documents - .as_iter() - .map(|DocumentRecord { id, document }| { - let emebedding_records = embeddings.get_by_id(id); - - Ok(DocumentEmbeddings { - id: id.to_string(), - document: serde_json::from_str(document)?, - embeddings: match emebedding_records { - Some(records) => records - .as_iter() - .map( - |EmbeddingRecord { - content, embedding, .. - }| Embedding { - document: content.to_string(), - vec: embedding.to_vec(), - }, - ) - .collect::>(), - None => vec![], - }, - }) - }) - .collect::, _>>() -} diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index 94d28c40..3a686d28 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -26,11 +26,23 @@ fn arrow_to_rig_error(e: ArrowError) -> VectorStoreError { } pub trait RecordBatchDeserializer { - fn deserialize(&self) -> Result; + fn deserialize(&self) -> Result, VectorStoreError>; +} + +impl RecordBatchDeserializer for Vec { + fn deserialize(&self) -> Result, VectorStoreError> { + Ok(self + .iter() + .map(|record_batch| record_batch.deserialize()) + .collect::, _>>()? + .into_iter() + .flatten() + .collect()) + } } impl RecordBatchDeserializer for RecordBatch { - fn deserialize(&self) -> Result { + fn deserialize(&self) -> Result, VectorStoreError> { fn type_matcher(column: &Arc) -> Result, VectorStoreError> { match column.data_type() { DataType::Null => Ok(vec![serde_json::Value::Null]), @@ -317,19 +329,18 @@ impl RecordBatchDeserializer for RecordBatch { .map(type_matcher) .collect::, _>>()?; - Ok(Value::Object((0..self.num_rows()).fold( - serde_json::Map::new(), - |mut acc, row_i| { - columns.iter().enumerate().for_each(|(col_i, col)| { - acc.entry(column_names[col_i].to_string()).and_modify(|v| { - if let Value::Array(v_arr) = v { - v_arr.push(col[row_i].clone()) - } - }).or_insert(Value::Array(vec![col[row_i].clone()])); - }); - acc - }, - ))) + Ok((0..self.num_rows()) + .map(|row_i| { + columns + .iter() + .enumerate() + .fold(serde_json::Map::new(), |mut acc, (col_i, col)| { + acc.insert(column_names[col_i].to_string(), col[row_i].clone()); + acc + }) + }) + .map(Value::Object) + .collect()) } } @@ -713,84 +724,60 @@ mod tests { assert_eq!( record_batch.deserialize().unwrap(), - json!({ - "binary": [ - [ + vec![ + json!({ + "binary": [ 104, 101, 108, 108, 111 ], - [ + "float_32": 0.0, + "float_64": 0.0, + "int_16": 0, + "int_32": 0, + "int_64": 0, + "int_8": 0, + "large_binary": [ + 97, + 98, + 99 + ], + "large_string": "Jerry", + "string": "Marty", + "uint_16": 0, + "uint_32": 0, + "uint_64": 0, + "uint_8": 0 + }), + json!({ + "binary": [ 119, 111, 114, 108, 100 - ] - ], - "float_32": [ - 0.0, - 1.0 - ], - "float_64": [ - 0.0, - 1.0 - ], - "int_16": [ - 0, - 1 - ], - "int_32": [ - 0, - -1 - ], - "int_64": [ - 0, - 1 - ], - "int_8": [ - 0, - -1 - ], - "large_binary": [ - [ - 97, - 98, - 99 ], - [ + "float_32": 1.0, + "float_64": 1.0, + "int_16": 1, + "int_32": -1, + "int_64": 1, + "int_8": -1, + "large_binary": [ 100, 101, 102 - ] - ], - "large_string": [ - "Jerry", - "Freddy" - ], - "string": [ - "Marty", - "Tony" - ], - "uint_16": [ - 0, - 1 - ], - "uint_32": [ - 0, - 1 - ], - "uint_64": [ - 0, - 1 - ], - "uint_8": [ - 0, - 1 - ] - }) + ], + "large_string": "Freddy", + "string": "Tony", + "uint_16": 1, + "uint_32": 1, + "uint_64": 1, + "uint_8": 1 + }) + ] ) } @@ -813,19 +800,23 @@ mod tests { assert_eq!( record_batch.deserialize().unwrap(), - json!([ - [ - { + vec![ + json!({ + "some_dict": { "2": "" - }, - { + } + }), + json!({ + "some_dict": { "0": "abc" - }, - { + } + }), + json!({ + "some_dict": { "1": "def" } - ] - ]) + }) + ] ) } @@ -838,11 +829,27 @@ mod tests { let union = builder.build().unwrap(); let record_batch = - RecordBatch::try_from_iter(vec![("some_dict", Arc::new(union) as ArrayRef)]).unwrap(); + RecordBatch::try_from_iter(vec![("some_union", Arc::new(union) as ArrayRef)]).unwrap(); assert_eq!( record_batch.deserialize().unwrap(), - json!([[[1], [3.0], [4]]]) + vec![ + json!({ + "some_union": [ + 1 + ] + }), + json!({ + "some_union": [ + 3.0 + ] + }), + json!({ + "some_union": [ + 4 + ] + }) + ] ) } @@ -859,15 +866,7 @@ mod tests { let record_batch = RecordBatch::try_from_iter(vec![("some_dict", Arc::new(array) as ArrayRef)]).unwrap(); - assert_eq!( - record_batch.deserialize().unwrap(), - json!([[ - ["abc"], - ["", ""], - ["def", "def", "def", "def"], - ["abc", "abc", "abc", "abc", "abc"] - ]]) - ) + assert_eq!(record_batch.deserialize().unwrap(), vec![json!({})]) } #[tokio::test] @@ -887,40 +886,32 @@ mod tests { assert_eq!( record_batch.deserialize().unwrap(), - json!([ - [ - { + vec![ + json!({ + "map_col": { "tarentino": { - "name": "Pulp Fiction", - "year": 1999, "actors": [ "Johnny Depp", "Cate Blanchet" - ] + ], + "name": "Pulp Fiction", + "year": 1999 } - }, - { + } + }), + json!({ + "map_col": { "darabont": { - "name": "The Shawshank Redemption", - "year": 2026, "actors": [ "Meryl Streep", "Scarlett Johansson" - ] - } - }, - { - "chazelle": { - "name": "La La Land", - "year": 1745, - "actors": [ - "Brad Pitt", - "Natalie Portman" - ] + ], + "name": "The Shawshank Redemption", + "year": 2026 } } - ] - ]) + }) + ] ) } @@ -984,15 +975,10 @@ mod tests { assert_eq!( record_batch.deserialize().unwrap(), - json!([ - [ - { - "id": "id1", + vec![ + json!({ + "employees": { "age": 25.0, - "names": [ - "Alice", - "Bob" - ], "favorite_animals": [ [ "Dog", @@ -1003,20 +989,23 @@ mod tests { ] ], "favorite_movie": { - "name": "Pulp Fiction", - "year": 1999, "actors": [ "Johnny Depp", "Cate Blanchet" - ] - } - }, - { - "id": "id2", - "age": 30.5, + ], + "name": "Pulp Fiction", + "year": 1999 + }, + "id": "id1", "names": [ - "Charlie" - ], + "Alice", + "Bob" + ] + } + }), + json!({ + "employees": { + "age": 30.5, "favorite_animals": [ [ "Giraffe" @@ -1027,22 +1016,22 @@ mod tests { ] ], "favorite_movie": { - "name": "The Shawshank Redemption", - "year": 2026, "actors": [ "Meryl Streep", "Scarlett Johansson" - ] - } - }, - { - "id": "id3", - "age": 22.100000381469727, + ], + "name": "The Shawshank Redemption", + "year": 2026 + }, + "id": "id2", "names": [ - "David", - "Eve", - "Frank" - ], + "Charlie" + ] + } + }), + json!({ + "employees": { + "age": 22.100000381469727, "favorite_animals": [ [ "Sloth" @@ -1053,16 +1042,22 @@ mod tests { ] ], "favorite_movie": { - "name": "La La Land", - "year": 1745, "actors": [ "Brad Pitt", "Natalie Portman" - ] - } + ], + "name": "La La Land", + "year": 1745 + }, + "id": "id3", + "names": [ + "David", + "Eve", + "Frank" + ] } - ] - ]) + }) + ] ) } } diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index bb9d2599..a35d7e1b 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -5,6 +5,7 @@ use arrow_array::{ types::ByteArrayType, Array, ArrowPrimitiveType, FixedSizeListArray, GenericByteArray, PrimitiveArray, RecordBatch, RecordBatchIterator, }; +use deserializer::RecordBatchDeserializer; use futures::TryStreamExt; use lancedb::{ arrow::arrow_schema::{ArrowError, Schema}, @@ -76,37 +77,14 @@ impl DeserializeListArray for &Arc { /// Used whenever a lanceDb table is queried. /// First, execute the query and get the result as a list of RecordBatches (columnar data). /// Then, convert the record batches to the desired type using the try_from trait. -pub trait Query -where - T: TryFrom, Error = VectorStoreError>, -{ - async fn execute_query(&self) -> Result; -} - -impl Query for lancedb::query::Query -where - T: TryFrom, Error = VectorStoreError>, -{ - async fn execute_query(&self) -> Result { - let record_batches = self - .execute() - .await - .map_err(lancedb_to_rig_error)? - .try_collect::>() - .await - .map_err(lancedb_to_rig_error)?; - - T::try_from(record_batches) - } +pub trait Query { + async fn execute_query(&self) -> Result, VectorStoreError>; } /// Same as the above trait but for the VectorQuery type. /// Used whenever a lanceDb table vector search is executed. -impl Query for lancedb::query::VectorQuery -where - T: TryFrom, Error = VectorStoreError>, -{ - async fn execute_query(&self) -> Result { +impl Query for lancedb::query::VectorQuery { + async fn execute_query(&self) -> Result, VectorStoreError> { let record_batches = self .execute() .await @@ -115,7 +93,7 @@ where .await .map_err(lancedb_to_rig_error)?; - T::try_from(record_batches) + record_batches.deserialize() } } From 005092522daef6532cc94b2e8afef6fc254a376c Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 2 Oct 2024 12:01:40 -0400 Subject: [PATCH 28/39] feat: update examples to use new version of VectorStoreIndex trait --- rig-lancedb/examples/fixtures/lib.rs | 67 ++++ .../examples/vector_search_local_ann.rs | 71 +++-- .../examples/vector_search_local_enn.rs | 42 ++- rig-lancedb/examples/vector_search_s3_ann.rs | 83 +++-- rig-lancedb/src/lib.rs | 110 ++++--- rig-lancedb/src/table_schemas/document.rs | 181 ----------- rig-lancedb/src/table_schemas/embedding.rs | 299 ------------------ rig-lancedb/src/table_schemas/mod.rs | 2 - rig-lancedb/src/utils/mod.rs | 83 +---- 9 files changed, 272 insertions(+), 666 deletions(-) create mode 100644 rig-lancedb/examples/fixtures/lib.rs delete mode 100644 rig-lancedb/src/table_schemas/document.rs delete mode 100644 rig-lancedb/src/table_schemas/embedding.rs delete mode 100644 rig-lancedb/src/table_schemas/mod.rs diff --git a/rig-lancedb/examples/fixtures/lib.rs b/rig-lancedb/examples/fixtures/lib.rs new file mode 100644 index 00000000..9a91432e --- /dev/null +++ b/rig-lancedb/examples/fixtures/lib.rs @@ -0,0 +1,67 @@ +use std::sync::Arc; + +use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray}; +use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema}; +use rig::embeddings::DocumentEmbeddings; + +// Schema of table in LanceDB. +pub fn schema(dims: usize) -> Schema { + Schema::new(Fields::from(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("content", DataType::Utf8, false), + Field::new( + "embedding", + DataType::FixedSizeList( + Arc::new(Field::new("item", DataType::Float64, true)), + dims as i32, + ), + false, + ), + ])) +} + +// Convert DocumentEmbeddings objects to a RecordBatch. +pub fn as_record_batch( + records: Vec, + dims: usize, +) -> Result { + let id = StringArray::from_iter_values( + records + .iter() + .flat_map(|record| (0..record.embeddings.len()).map(|i| format!("{}-{i}", record.id))) + .collect::>(), + ); + + let content = StringArray::from_iter_values( + records + .iter() + .flat_map(|record| { + record + .embeddings + .iter() + .map(|embedding| embedding.document.clone()) + }) + .collect::>(), + ); + + let embedding = FixedSizeListArray::from_iter_primitive::( + records + .into_iter() + .flat_map(|record| { + record + .embeddings + .into_iter() + .map(|embedding| embedding.vec.into_iter().map(Some).collect::>()) + .map(Some) + .collect::>() + }) + .collect::>(), + dims as i32, + ); + + RecordBatch::try_from_iter(vec![ + ("id", Arc::new(id) as ArrayRef), + ("content", Arc::new(content) as ArrayRef), + ("embedding", Arc::new(embedding) as ArrayRef), + ]) +} diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index ce1dd612..358ead03 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -1,13 +1,25 @@ -use std::env; +use std::{env, sync::Arc}; +use arrow_array::RecordBatchIterator; +use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ completion::Prompt, - embeddings::EmbeddingsBuilder, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndexDyn}, + vector_store::VectorStoreIndexDyn, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; + +#[path = "./fixtures/lib.rs"] +mod fixture; + +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub content: String, +} #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -15,15 +27,9 @@ async fn main() -> Result<(), anyhow::Error> { let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); let openai_client = Client::new(&openai_api_key); - // Select the embedding model and generate our embeddings + // Select an embedding model. let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - - // Initialize LanceDB locally. - let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?; - // Generate test data for RAG demo let agent = openai_client .agent("gpt-4o") @@ -39,6 +45,7 @@ async fn main() -> Result<(), anyhow::Error> { definitions.extend(definitions.clone()); definitions.extend(definitions.clone()); + // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") @@ -47,17 +54,35 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Add embeddings to vector store - // vector_store.add_documents(embeddings).await?; + // Define search_params params that will be used by the vector store to perform the vector search. + let search_params = SearchParams::default().distance_type(DistanceType::Cosine); + + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store").execute().await?; + + // Create table with embeddings. + let record_batch = as_record_batch(embeddings, model.ndims()); + let table = db + .create_table( + "definitions", + RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + ) + .execute() + .await?; + + let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information vector_store - .create_index(lancedb::index::Index::IvfPq( - IvfPqIndexBuilder::default() - // This overrides the default distance type of L2. - // Needs to be the same distance type as the one used in search params. - .distance_type(DistanceType::Cosine), - )) + .create_index( + lancedb::index::Index::IvfPq( + IvfPqIndexBuilder::default() + // This overrides the default distance type of L2. + // Needs to be the same distance type as the one used in search params. + .distance_type(DistanceType::Cosine), + ), + &["embedding"], + ) .await?; // Query the index @@ -65,8 +90,14 @@ async fn main() -> Result<(), anyhow::Error> { .top_n("My boss says I zindle too much, what does that mean?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) - .collect::>(); + .map(|(score, id, doc)| { + anyhow::Ok(( + score, + id, + serde_json::from_value::(doc)?, + )) + }) + .collect::, _>>()?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 7c099807..1ca2971d 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -1,11 +1,17 @@ -use std::env; +use std::{env, sync::Arc}; +use arrow_array::RecordBatchIterator; +use fixture::{as_record_batch, schema}; use rig::{ - embeddings::EmbeddingsBuilder, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndexDyn}, + vector_store::VectorStoreIndexDyn, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; + +#[path = "./fixtures/lib.rs"] +mod fixture; #[tokio::main] async fn main() -> Result<(), anyhow::Error> { @@ -16,10 +22,6 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Initialize LanceDB locally. - let db = lancedb::connect("data/lancedb-store").execute().await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model, &SearchParams::default()).await?; - let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") @@ -27,16 +29,28 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Add embeddings to vector store - // vector_store.add_documents(embeddings).await?; + // Define search_params params that will be used by the vector store to perform the vector search. + let search_params = SearchParams::default(); + + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store").execute().await?; + + // Create table with embeddings. + let record_batch = as_record_batch(embeddings, model.ndims()); + let table = db + .create_table( + "definitions", + RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + ) + .execute() + .await?; + + let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; // Query the index let results = vector_store - .top_n("My boss says I zindle too much, what does that mean?", 1) - .await? - .into_iter() - .map(|(score, id, doc)| (score, id, doc)) - .collect::>(); + .top_n_ids("My boss says I zindle too much, what does that mean?", 1) + .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 00ba96ea..b56d9156 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -1,17 +1,28 @@ -use std::env; +use std::{env, sync::Arc}; +use arrow_array::RecordBatchIterator; +use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ completion::Prompt, - embeddings::EmbeddingsBuilder, + embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::{VectorStore, VectorStoreIndexDyn}, + vector_store::VectorStoreIndexDyn, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; +use serde::Deserialize; + +#[path = "./fixtures/lib.rs"] +mod fixture; + +#[derive(Deserialize, Debug)] +pub struct VectorSearchResult { + pub id: String, + pub content: String, +} // Note: see docs to deploy LanceDB on other cloud providers such as google and azure. // https://lancedb.github.io/lancedb/guides/storage/ - #[tokio::main] async fn main() -> Result<(), anyhow::Error> { // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). @@ -21,23 +32,13 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - - // Initialize LanceDB on S3. - // Note: see below docs for more options and IAM permission required to read/write to S3. - // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 - let db = lancedb::connect("s3://lancedb-test-829666124233") - .execute() - .await?; - let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?; - // Generate test data for RAG demo let agent = openai_client .agent("gpt-4o") .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") .build(); let response = agent - .prompt("Invent at least 100 words and their definitions") + .prompt("Invent 100 words and their definitions") .await?; let mut definitions: Vec = serde_json::from_str(&response)?; @@ -46,7 +47,8 @@ async fn main() -> Result<(), anyhow::Error> { definitions.extend(definitions.clone()); definitions.extend(definitions.clone()); - let embeddings: Vec = EmbeddingsBuilder::new(model.clone()) + // Generate embeddings for the test data. + let embeddings = EmbeddingsBuilder::new(model.clone()) .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") @@ -54,26 +56,53 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Add embeddings to vector store - // vector_store.add_documents(embeddings).await?; + // Define search_params params that will be used by the vector store to perform the vector search. + let search_params = SearchParams::default().distance_type(DistanceType::Cosine); + + // Initialize LanceDB on S3. + // Note: see below docs for more options and IAM permission required to read/write to S3. + // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 + let db = lancedb::connect("s3://lancedb-test-829666124233") + .execute() + .await?; + // Create table with embeddings. + let record_batch = as_record_batch(embeddings, model.ndims()); + let table = db + .create_table( + "definitions", + RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), + ) + .execute() + .await?; + + let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information vector_store - .create_index(lancedb::index::Index::IvfPq( - IvfPqIndexBuilder::default() - // This overrides the default distance type of L2. - // Needs to be the same distance type as the one used in search params. - .distance_type(DistanceType::Cosine), - )) + .create_index( + lancedb::index::Index::IvfPq( + IvfPqIndexBuilder::default() + // This overrides the default distance type of L2. + // Needs to be the same distance type as the one used in search params. + .distance_type(DistanceType::Cosine), + ), + &["embedding"], + ) .await?; // Query the index let results = vector_store - .top_n("My boss says I zindle too much, what does that mean?", 1) + .top_n("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) .await? .into_iter() - .map(|(score, id, doc)| (score, id, doc)) - .collect::>(); + .map(|(score, id, doc)| { + anyhow::Ok(( + score, + id, + serde_json::from_value::(doc)?, + )) + }) + .collect::, _>>()?; println!("Results: {:?}", results); diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index a141d8dc..2b4e596d 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -1,9 +1,6 @@ -use std::sync::Arc; - use lancedb::{ - arrow::arrow_schema::{DataType, Field, Fields, Schema}, index::Index, - query::QueryBase, + query::{QueryBase, VectorQuery}, DistanceType, }; use rig::{ @@ -14,7 +11,6 @@ use serde::Deserialize; use serde_json::Value; use utils::Query; -mod table_schemas; mod utils; fn lancedb_to_rig_error(e: lancedb::Error) -> VectorStoreError { @@ -34,6 +30,41 @@ pub struct LanceDbVectorStore { search_params: SearchParams, } +impl LanceDbVectorStore { + fn build_query(&self, mut query: VectorQuery) -> VectorQuery { + let SearchParams { + distance_type, + search_type, + nprobes, + refine_factor, + post_filter, + } = self.search_params.clone(); + + if let Some(distance_type) = distance_type { + query = query.distance_type(distance_type); + } + + if let Some(SearchType::Flat) = search_type { + query = query.bypass_vector_index(); + } + + if let Some(SearchType::Approximate) = search_type { + if let Some(nprobes) = nprobes { + query = query.nprobes(nprobes); + } + if let Some(refine_factor) = refine_factor { + query = query.refine_factor(refine_factor); + } + } + + if let Some(true) = post_filter { + query = query.postfilter(); + } + + query + } +} + /// See [LanceDB vector search](https://lancedb.github.io/lancedb/search/) for more information. #[derive(Debug, Clone)] pub enum SearchType { @@ -94,19 +125,19 @@ impl LanceDbVectorStore { pub async fn new( table: lancedb::Table, model: M, - id_field: String, + id_field: &str, search_params: SearchParams, ) -> Result { Ok(Self { table, model, - id_field, + id_field: id_field.to_string(), search_params, }) } /// Define index on document table `id` field for search optimization. - pub async fn create_document_index( + pub async fn create_index( &self, index: Index, field_names: &[impl AsRef], @@ -123,48 +154,19 @@ impl VectorStoreIndex for LanceDbV ) -> Result, VectorStoreError> { let prompt_embedding = self.model.embed_document(query).await?; - let mut query = self + let query = self .table .vector_search(prompt_embedding.vec.clone()) .map_err(lancedb_to_rig_error)? .limit(n); - let SearchParams { - distance_type, - search_type, - nprobes, - refine_factor, - post_filter, - } = self.search_params.clone(); - - if let Some(distance_type) = distance_type { - query = query.distance_type(distance_type); - } - - if let Some(SearchType::Flat) = search_type { - query = query.bypass_vector_index(); - } - - if let Some(SearchType::Approximate) = search_type { - if let Some(nprobes) = nprobes { - query = query.nprobes(nprobes); - } - if let Some(refine_factor) = refine_factor { - query = query.refine_factor(refine_factor); - } - } - - if let Some(true) = post_filter { - query = query.postfilter(); - } - - query + self.build_query(query) .execute_query() .await? .into_iter() .map(|value| { Ok(( - match value.get("distance") { + match value.get("_distance") { Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(), _ => 0.0, }, @@ -183,6 +185,32 @@ impl VectorStoreIndex for LanceDbV query: &str, n: usize, ) -> Result, VectorStoreError> { - todo!() + let prompt_embedding = self.model.embed_document(query).await?; + + let query = self + .table + .query() + .select(lancedb::query::Select::Columns(vec![self.id_field.clone()])) + .nearest_to(prompt_embedding.vec.clone()) + .map_err(lancedb_to_rig_error)? + .limit(n); + + self.build_query(query) + .execute_query() + .await? + .into_iter() + .map(|value| { + Ok(( + match value.get("distance") { + Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(), + _ => 0.0, + }, + match value.get(self.id_field.clone()) { + Some(Value::String(id)) => id.to_string(), + _ => "".to_string(), + }, + )) + }) + .collect() } } diff --git a/rig-lancedb/src/table_schemas/document.rs b/rig-lancedb/src/table_schemas/document.rs deleted file mode 100644 index 384eb4bf..00000000 --- a/rig-lancedb/src/table_schemas/document.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::sync::Arc; - -use arrow_array::{types::Utf8Type, ArrayRef, RecordBatch, StringArray}; -use lancedb::arrow::arrow_schema::ArrowError; -use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; - -use crate::utils::DeserializeByteArray; - -/// Schema of `documents` table in LanceDB defined as a struct. -#[derive(Clone, Debug)] -pub struct DocumentRecord { - pub id: String, - pub document: String, -} - -/// Wrapper around `Vec` -#[derive(Debug)] -pub struct DocumentRecords(Vec); - -impl DocumentRecords { - fn new() -> Self { - Self(Vec::new()) - } - - fn records(&self) -> Vec { - self.0.clone() - } - - fn add_records(&mut self, records: Vec) { - self.0.extend(records); - } - - fn documents(&self) -> impl Iterator + '_ { - self.as_iter().map(|doc| doc.document.clone()) - } - - pub fn ids(&self) -> impl Iterator + '_ { - self.as_iter().map(|doc| doc.id.clone()) - } - - pub fn as_iter(&self) -> impl Iterator { - self.0.iter() - } -} - -/// Converts a `DocumentEmbeddings` object to a `DocumentRecord` object. -/// The `DocumentRecord` contains the correct schema required by the `documents` table. -impl TryFrom for DocumentRecord { - type Error = serde_json::Error; - - fn try_from(document: DocumentEmbeddings) -> Result { - Ok(DocumentRecord { - id: document.id, - document: serde_json::to_string(&document.document)?, - }) - } -} - -/// Converts a list of `DocumentEmbeddings` objects to a list of `DocumentRecord` objects. -/// This is useful when we need to write many `DocumentEmbeddings` items to the `documents` table at once. -impl TryFrom> for DocumentRecords { - type Error = serde_json::Error; - - fn try_from(documents: Vec) -> Result { - Ok(Self( - documents - .into_iter() - .map(DocumentRecord::try_from) - .collect::, _>>()?, - )) - } -} - -/// Convert a list of documents (`DocumentRecords`) to a `RecordBatch`, the data structure that needs ot be written to LanceDB. -/// All documents will be written to the database as part of the same batch. -impl TryFrom for RecordBatch { - type Error = ArrowError; - - fn try_from(document_records: DocumentRecords) -> Result { - let id = Arc::new(StringArray::from_iter_values(document_records.ids())) as ArrayRef; - let document = - Arc::new(StringArray::from_iter_values(document_records.documents())) as ArrayRef; - - RecordBatch::try_from_iter(vec![("id", id), ("document", document)]) - } -} - -impl From for Vec> { - fn from(documents: DocumentRecords) -> Self { - vec![RecordBatch::try_from(documents)] - } -} - -/// Convert a `RecordBatch` object, read from a lanceDb table, to a list of `DocumentRecord` objects. -/// This allows us to convert the query result to our data format. -impl TryFrom for DocumentRecords { - type Error = ArrowError; - - fn try_from(record_batch: RecordBatch) -> Result { - let binding_0 = record_batch.column(0); - let ids = binding_0.to_str::()?; - - let binding_1 = record_batch.column(1); - let documents = binding_1.to_str::()?; - - Ok(DocumentRecords( - ids.into_iter() - .zip(documents) - .map(|(id, document)| DocumentRecord { - id: id.to_string(), - document: document.to_string(), - }) - .collect(), - )) - } -} - -/// Convert a list of `RecordBatch` objects, read from a lanceDb table, to a list of `DocumentRecord` objects. -impl TryFrom> for DocumentRecords { - type Error = VectorStoreError; - - fn try_from(record_batches: Vec) -> Result { - let documents = record_batches - .into_iter() - .map(DocumentRecords::try_from) - .collect::, _>>() - .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - - Ok(documents - .into_iter() - .fold(DocumentRecords::new(), |mut acc, document| { - acc.add_records(document.records()); - acc - })) - } -} - -#[cfg(test)] -mod tests { - use arrow_array::RecordBatch; - - use crate::table_schemas::document::{DocumentRecord, DocumentRecords}; - - #[tokio::test] - async fn test_record_batch_conversion() { - let document_records = DocumentRecords(vec![ - DocumentRecord { - id: "ABC".to_string(), - document: serde_json::json!({ - "title": "Hello world", - "body": "Greetings", - }) - .to_string(), - }, - DocumentRecord { - id: "DEF".to_string(), - document: serde_json::json!({ - "title": "Sup dog", - "body": "Greetings", - }) - .to_string(), - }, - ]); - - let record_batch = RecordBatch::try_from(document_records).unwrap(); - - let deserialized_record_batch = DocumentRecords::try_from(record_batch).unwrap(); - - assert_eq!(deserialized_record_batch.0.len(), 2); - - assert_eq!(deserialized_record_batch.0[0].id, "ABC"); - assert_eq!( - deserialized_record_batch.0[0].document, - serde_json::json!({ - "title": "Hello world", - "body": "Greetings", - }) - .to_string() - ); - } -} diff --git a/rig-lancedb/src/table_schemas/embedding.rs b/rig-lancedb/src/table_schemas/embedding.rs deleted file mode 100644 index 7f74dd12..00000000 --- a/rig-lancedb/src/table_schemas/embedding.rs +++ /dev/null @@ -1,299 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use arrow_array::{ - builder::{FixedSizeListBuilder, Float64Builder}, - types::{Float32Type, Float64Type, Utf8Type}, - ArrayRef, RecordBatch, StringArray, -}; -use lancedb::arrow::arrow_schema::ArrowError; -use rig::{embeddings::DocumentEmbeddings, vector_store::VectorStoreError}; - -use crate::utils::{DeserializeByteArray, DeserializeListArray, DeserializePrimitiveArray}; - -/// Data format in the LanceDB table `embeddings` -#[derive(Clone, Debug, PartialEq)] -pub struct EmbeddingRecord { - pub id: String, - pub document_id: String, - pub content: String, - pub embedding: Vec, - /// Distance from prompt. - /// This value is only present after vector search executes and determines the distance - pub distance: Option, -} - -/// Group of EmbeddingRecord objects. This represents the list of embedding objects in a `DocumentEmbeddings` object. -#[derive(Clone, Debug)] -pub struct EmbeddingRecords { - records: Vec, - dimension: i32, -} - -impl EmbeddingRecords { - fn new(records: Vec, dimension: i32) -> Self { - EmbeddingRecords { records, dimension } - } - - fn add_record(&mut self, record: EmbeddingRecord) { - self.records.push(record); - } - - pub fn as_iter(&self) -> impl Iterator { - self.records.iter() - } -} - -/// HashMap where the key is the `DocumentEmbeddings` id -/// and the value is the`EmbeddingRecords` object that corresponds to the document. -#[derive(Debug)] -pub struct EmbeddingRecordsBatch(HashMap); - -impl EmbeddingRecordsBatch { - fn as_iter(&self) -> impl Iterator { - self.0.clone().into_values().collect::>().into_iter() - } - - pub fn get_by_id(&self, id: &str) -> Option { - self.0.get(id).cloned() - } - - pub fn document_ids(&self) -> String { - self.0 - .clone() - .into_keys() - .map(|id| format!("'{id}'")) - .collect::>() - .join(",") - } -} - -/// Convert from a `DocumentEmbeddings` to an `EmbeddingRecords` object (a list of `EmbeddingRecord` objects) -impl From for EmbeddingRecords { - fn from(document: DocumentEmbeddings) -> Self { - EmbeddingRecords::new( - document - .embeddings - .clone() - .into_iter() - .enumerate() - .map(move |(i, embedding)| EmbeddingRecord { - id: format!("{}-{i}", document.id), - document_id: document.id.clone(), - content: embedding.document, - embedding: embedding.vec, - distance: None, - }) - .collect(), - document - .embeddings - .first() - .map(|embedding| embedding.vec.len() as i32) - .unwrap_or(0), - ) - } -} - -/// Convert from a list of `DocumentEmbeddings` to an `EmbeddingRecordsBatch` object -/// For each `DocumentEmbeddings`, we create an `EmbeddingRecords` and add it to the -/// hashmap with its corresponding `DocumentEmbeddings` id. -impl From> for EmbeddingRecordsBatch { - fn from(documents: Vec) -> Self { - EmbeddingRecordsBatch( - documents - .into_iter() - .fold(HashMap::new(), |mut acc, document| { - acc.insert(document.id.clone(), EmbeddingRecords::from(document)); - acc - }), - ) - } -} - -/// Convert a list of embeddings (`EmbeddingRecords`) to a `RecordBatch`, the data structure that needs ot be written to LanceDB. -/// All embeddings related to a document will be written to the database as part of the same batch. -impl TryFrom for RecordBatch { - fn try_from(embedding_records: EmbeddingRecords) -> Result { - let id = StringArray::from_iter_values( - embedding_records.as_iter().map(|record| record.id.clone()), - ); - let document_id = StringArray::from_iter_values( - embedding_records - .as_iter() - .map(|record| record.document_id.clone()), - ); - let content = StringArray::from_iter_values( - embedding_records - .as_iter() - .map(|record| record.content.clone()), - ); - - let mut builder = - FixedSizeListBuilder::new(Float64Builder::new(), embedding_records.dimension); - embedding_records.as_iter().for_each(|record| { - record - .embedding - .iter() - .for_each(|value| builder.values().append_value(*value)); - builder.append(true); - }); - - RecordBatch::try_from_iter(vec![ - ("id", Arc::new(id) as ArrayRef), - ("document_id", Arc::new(document_id) as ArrayRef), - ("content", Arc::new(content) as ArrayRef), - ("embedding", Arc::new(builder.finish()) as ArrayRef), - ]) - } - - type Error = ArrowError; -} - -impl From for Vec> { - fn from(embeddings: EmbeddingRecordsBatch) -> Self { - embeddings.as_iter().map(RecordBatch::try_from).collect() - } -} - -impl TryFrom for EmbeddingRecords { - type Error = ArrowError; - - fn try_from(record_batch: RecordBatch) -> Result { - let binding_0 = record_batch.column(0); - let ids = binding_0.to_str::()?; - - let binding_1 = record_batch.column(1); - let document_ids = binding_1.to_str::()?; - - let binding_2 = record_batch.column(2); - let contents = binding_2.to_str::()?; - - let embeddings = record_batch.column(3).to_float_list::()?; - - // There is a `_distance` field in the response if the executed query was a VectorQuery - // Otherwise, for normal queries, the `_distance` field is not present in the response. - let distances = if record_batch.num_columns() == 5 { - record_batch - .column(4) - .to_float::()? - .into_iter() - .map(Some) - .collect() - } else { - vec![None; record_batch.num_rows()] - }; - - Ok(EmbeddingRecords::new( - ids.into_iter() - .zip(document_ids) - .zip(contents) - .zip(embeddings.clone()) - .zip(distances) - .map( - |((((id, document_id), content), embedding), distance)| EmbeddingRecord { - id: id.to_string(), - document_id: document_id.to_string(), - content: content.to_string(), - embedding, - distance, - }, - ) - .collect(), - embeddings - .iter() - .map(|embedding| embedding.len() as i32) - .next() - .unwrap_or(0), - )) - } -} - -impl TryFrom> for EmbeddingRecordsBatch { - type Error = VectorStoreError; - - fn try_from(record_batches: Vec) -> Result { - let embedding_records = record_batches - .into_iter() - .map(EmbeddingRecords::try_from) - .collect::, _>>() - .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?; - - let grouped_records = - embedding_records - .into_iter() - .fold(HashMap::new(), |mut acc, records| { - records.as_iter().for_each(|record| { - acc.entry(record.document_id.clone()) - .and_modify(|item: &mut EmbeddingRecords| { - item.add_record(record.clone()) - }) - .or_insert(EmbeddingRecords::new( - vec![record.clone()], - record.embedding.len() as i32, - )); - }); - acc - }); - - Ok(EmbeddingRecordsBatch(grouped_records)) - } -} - -#[cfg(test)] -mod tests { - use arrow_array::RecordBatch; - - use crate::table_schemas::embedding::{EmbeddingRecord, EmbeddingRecords}; - - #[tokio::test] - async fn test_record_batch_conversion() { - let embedding_records = EmbeddingRecords::new( - vec![ - EmbeddingRecord { - id: "some_id".to_string(), - document_id: "ABC".to_string(), - content: serde_json::json!({ - "title": "Hello world", - "body": "Greetings", - }) - .to_string(), - embedding: vec![1.0, 2.0, 3.0], - distance: None, - }, - EmbeddingRecord { - id: "another_id".to_string(), - document_id: "DEF".to_string(), - content: serde_json::json!({ - "title": "Sup dog", - "body": "Greetings", - }) - .to_string(), - embedding: vec![4.0, 5.0, 6.0], - distance: None, - }, - ], - 3, - ); - - let record_batch = RecordBatch::try_from(embedding_records).unwrap(); - - let deserialized_record_batch = EmbeddingRecords::try_from(record_batch).unwrap(); - - assert_eq!(deserialized_record_batch.as_iter().count(), 2); - assert_eq!( - deserialized_record_batch.as_iter().nth(0).unwrap().clone(), - EmbeddingRecord { - id: "some_id".to_string(), - document_id: "ABC".to_string(), - content: serde_json::json!({ - "title": "Hello world", - "body": "Greetings", - }) - .to_string(), - embedding: vec![1.0, 2.0, 3.0], - distance: None - } - ); - - assert!(false) - } -} diff --git a/rig-lancedb/src/table_schemas/mod.rs b/rig-lancedb/src/table_schemas/mod.rs deleted file mode 100644 index bd24dd65..00000000 --- a/rig-lancedb/src/table_schemas/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod document; -pub mod embedding; diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index a35d7e1b..e8db4559 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -1,78 +1,12 @@ pub mod deserializer; -use std::sync::Arc; -use arrow_array::{ - types::ByteArrayType, Array, ArrowPrimitiveType, FixedSizeListArray, GenericByteArray, - PrimitiveArray, RecordBatch, RecordBatchIterator, -}; use deserializer::RecordBatchDeserializer; use futures::TryStreamExt; -use lancedb::{ - arrow::arrow_schema::{ArrowError, Schema}, - query::ExecutableQuery, -}; +use lancedb::query::ExecutableQuery; use rig::vector_store::VectorStoreError; use crate::lancedb_to_rig_error; -/// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. -pub trait DeserializePrimitiveArray { - fn to_float( - &self, - ) -> Result::Native>, ArrowError>; -} - -impl DeserializePrimitiveArray for &Arc { - fn to_float( - &self, - ) -> Result::Native>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } - } -} - -/// Trait used to "deserialize" an arrow_array::Array as as list of byte objects. -pub trait DeserializeByteArray { - fn to_str(&self) -> Result::Native>, ArrowError>; -} - -impl DeserializeByteArray for &Arc { - fn to_str(&self) -> Result::Native>, ArrowError> { - match self.as_any().downcast_ref::>() { - Some(array) => Ok((0..array.len()).map(|j| array.value(j)).collect::>()), - None => Err(ArrowError::CastError(format!( - "Can't cast array: {self:?} to float array" - ))), - } - } -} - -/// Trait used to "deserialize" an arrow_array::Array as as list of lists of primitive objects. -pub trait DeserializeListArray { - fn to_float_list( - &self, - ) -> Result::Native>>, ArrowError>; -} - -impl DeserializeListArray for &Arc { - fn to_float_list( - &self, - ) -> Result::Native>>, ArrowError> { - match self.as_any().downcast_ref::() { - Some(list_array) => (0..list_array.len()) - .map(|j| (&list_array.value(j)).to_float::()) - .collect::, _>>(), - None => Err(ArrowError::CastError(format!( - "Can't cast column {self:?} to fixed size list array" - ))), - } - } -} - /// Trait that facilitates the conversion of columnar data returned by a lanceDb query to the desired struct. /// Used whenever a lanceDb table is queried. /// First, execute the query and get the result as a list of RecordBatches (columnar data). @@ -96,18 +30,3 @@ impl Query for lancedb::query::VectorQuery { record_batches.deserialize() } } - -/// Trait that facilitate inserting data defined as Rust structs into lanceDB table which contains columnar data. -pub trait Insert { - async fn insert(&self, data: T, schema: Schema) -> Result<(), lancedb::Error>; -} - -impl>>> Insert for lancedb::Table { - async fn insert(&self, data: T, schema: Schema) -> Result<(), lancedb::Error> { - self.add(RecordBatchIterator::new(data.into(), Arc::new(schema))) - .execute() - .await?; - - Ok(()) - } -} From 4a6a87d10f4578f17dbe3f786d06677ef746fe39 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 2 Oct 2024 12:23:56 -0400 Subject: [PATCH 29/39] docs: add doc strings --- rig-lancedb/src/lib.rs | 9 +++++++-- rig-lancedb/src/utils/deserializer.rs | 3 +++ rig-lancedb/src/utils/mod.rs | 6 +----- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 2b4e596d..5da12e38 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -22,15 +22,19 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { } pub struct LanceDbVectorStore { - /// Defines which model is used to generate embeddings for the vector store + /// Defines which model is used to generate embeddings for the vector store. model: M, + /// LanceDB table containing embeddings. table: lancedb::Table, + /// Column name in `table` that contains the id of a record. id_field: String, /// Vector search params that are used during vector search operations. search_params: SearchParams, } impl LanceDbVectorStore { + /// Apply the search_params to the vector query. + /// This is a helper function used by the methods `top_n` and `top_n_ids` of the `VectorStoreIndex` trait. fn build_query(&self, mut query: VectorQuery) -> VectorQuery { let SearchParams { distance_type, @@ -136,7 +140,8 @@ impl LanceDbVectorStore { }) } - /// Define index on document table `id` field for search optimization. + /// Define an index on the specified fields of the lanceDB table for search optimization. + /// Note: it is required to add an index on the column containing the embeddings when performing an ANN type vector search. pub async fn create_index( &self, index: Index, diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index 3a686d28..88f0e8de 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -25,6 +25,8 @@ fn arrow_to_rig_error(e: ArrowError) -> VectorStoreError { VectorStoreError::DatastoreError(Box::new(e)) } +/// Trait used to deserialize data returned from LanceDB queries into a serde_json::Value vector. +/// Data returned by LanceDB is a vector of `RecordBatch` items. pub trait RecordBatchDeserializer { fn deserialize(&self) -> Result, VectorStoreError>; } @@ -43,6 +45,7 @@ impl RecordBatchDeserializer for Vec { impl RecordBatchDeserializer for RecordBatch { fn deserialize(&self) -> Result, VectorStoreError> { + /// Recursive function that matches all possible data types store in LanceDB and converts them to serde_json::Value. fn type_matcher(column: &Arc) -> Result, VectorStoreError> { match column.data_type() { DataType::Null => Ok(vec![serde_json::Value::Null]), diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index e8db4559..46aeab31 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -7,16 +7,12 @@ use rig::vector_store::VectorStoreError; use crate::lancedb_to_rig_error; -/// Trait that facilitates the conversion of columnar data returned by a lanceDb query to the desired struct. +/// Trait that facilitates the conversion of columnar data returned by a lanceDb query to serde_json::Value. /// Used whenever a lanceDb table is queried. -/// First, execute the query and get the result as a list of RecordBatches (columnar data). -/// Then, convert the record batches to the desired type using the try_from trait. pub trait Query { async fn execute_query(&self) -> Result, VectorStoreError>; } -/// Same as the above trait but for the VectorQuery type. -/// Used whenever a lanceDb table vector search is executed. impl Query for lancedb::query::VectorQuery { async fn execute_query(&self) -> Result, VectorStoreError> { let record_batches = self From ec44d4ad7bb527ae0de01ccdc71744f820359475 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 2 Oct 2024 13:29:09 -0400 Subject: [PATCH 30/39] fix: fix bug in deserializing type run end --- rig-lancedb/src/utils/deserializer.rs | 72 +++++++++++++++++++-------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index 88f0e8de..0b746de4 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -252,45 +252,57 @@ impl RecordBatchDeserializer for RecordBatch { )), ))), }, - DataType::RunEndEncoded(counter_type, ..) => { - let items: Vec> = match counter_type.data_type() { + DataType::RunEndEncoded(index_type, ..) => { + let items = match index_type.data_type() { DataType::Int16 => { - let (counter, v) = column + let (indexes, v) = column .to_run_end::() .map_err(arrow_to_rig_error)?; - counter - .into_iter() + let mut prev = vec![0]; + prev.extend(indexes.clone()); + + prev.iter() + .zip(indexes) + .map(|(prev, cur)| cur - prev) .zip(type_matcher(&v)?) - .map(|(n, value)| vec![value; n as usize]) - .collect() + .flat_map(|(n, value)| vec![value; n as usize]) + .collect::>() } DataType::Int32 => { - let (counter, v) = column + let (indexes, v) = column .to_run_end::() .map_err(arrow_to_rig_error)?; - counter - .into_iter() + let mut prev = vec![0]; + prev.extend(indexes.clone()); + + prev.iter() + .zip(indexes) + .map(|(prev, cur)| cur - prev) .zip(type_matcher(&v)?) - .map(|(n, value)| vec![value; n as usize]) - .collect() + .flat_map(|(n, value)| vec![value; n as usize]) + .collect::>() } DataType::Int64 => { - let (counter, v) = column + let (indexes, v) = column .to_run_end::() .map_err(arrow_to_rig_error)?; - counter - .into_iter() + let mut prev = vec![0]; + prev.extend(indexes.clone()); + + prev.iter() + .zip(indexes) + .map(|(prev, cur)| cur - prev) .zip(type_matcher(&v)?) - .map(|(n, value)| vec![value; n as usize]) - .collect() + .flat_map(|(n, value)| vec![value; n as usize]) + .collect::>() } _ => { return Err(VectorStoreError::DatastoreError(Box::new( ArrowError::CastError(format!( - "RunEndEncoded index type is not accepted: {counter_type:?}" + "RunEndEncoded index type is not accepted: {index_type:?}" )), ))) } @@ -867,9 +879,29 @@ mod tests { let array = builder.finish(); let record_batch = - RecordBatch::try_from_iter(vec![("some_dict", Arc::new(array) as ArrayRef)]).unwrap(); + RecordBatch::try_from_iter(vec![("some_run_end", Arc::new(array) as ArrayRef)]) + .unwrap(); - assert_eq!(record_batch.deserialize().unwrap(), vec![json!({})]) + assert_eq!( + record_batch.deserialize().unwrap(), + vec![ + json!({ + "some_run_end": "abc" + }), + json!({ + "some_run_end": "" + }), + json!({ + "some_run_end": "def" + }), + json!({ + "some_run_end": "def" + }), + json!({ + "some_run_end": "abc" + }) + ] + ) } #[tokio::test] From edae69436467b4bb459c1caa08d86f1e07240789 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 2 Oct 2024 14:58:20 -0400 Subject: [PATCH 31/39] docs: add example docstring --- .../examples/vector_search_local_enn.rs | 1 - rig-lancedb/src/lib.rs | 79 +++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/rig-lancedb/examples/vector_search_local_enn.rs b/rig-lancedb/examples/vector_search_local_enn.rs index 1ca2971d..5932dcd0 100644 --- a/rig-lancedb/examples/vector_search_local_enn.rs +++ b/rig-lancedb/examples/vector_search_local_enn.rs @@ -8,7 +8,6 @@ use rig::{ vector_store::VectorStoreIndexDyn, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; -use serde::Deserialize; #[path = "./fixtures/lib.rs"] mod fixture; diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 5da12e38..1eec5e3a 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -21,6 +21,74 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { VectorStoreError::JsonError(e) } +/// # Example +/// ``` +/// use std::{env, sync::Arc}; + +/// use arrow_array::RecordBatchIterator; +/// use fixture::{as_record_batch, schema}; +/// use rig::{ +/// embeddings::{EmbeddingModel, EmbeddingsBuilder}, +/// providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, +/// vector_store::VectorStoreIndexDyn, +/// }; +/// use rig_lancedb::{LanceDbVectorStore, SearchParams}; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize, Debug)] +/// pub struct VectorSearchResult { +/// pub id: String, +/// pub content: String, +/// } +/// +/// // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). +/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); +/// let openai_client = Client::new(&openai_api_key); + +/// // Select the embedding model and generate our embeddings +/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + +/// let embeddings = EmbeddingsBuilder::new(model.clone()) +/// .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") +/// .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") +/// .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") +/// .build() +/// .await?; + +/// // Define search_params params that will be used by the vector store to perform the vector search. +/// let search_params = SearchParams::default(); + +/// // Initialize LanceDB locally. +/// let db = lancedb::connect("data/lancedb-store").execute().await?; + +/// // Create table with embeddings. +/// let record_batch = as_record_batch(embeddings, model.ndims()); +/// let table = db +/// .create_table( +/// "definitions", +/// RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))), +/// ) +/// .execute() +/// .await?; + +/// let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + +/// // Query the index +/// let results = vector_store +/// .top_n("My boss says I zindle too much, what does that mean?", 1) +/// .await? +/// .into_iter() +/// .map(|(score, id, doc)| { +/// anyhow::Ok(( +/// score, +/// id, +/// serde_json::from_value::(doc)?, +/// )) +/// }) +/// .collect::, _>>()?; + +/// println!("Results: {:?}", results); +/// ``` pub struct LanceDbVectorStore { /// Defines which model is used to generate embeddings for the vector store. model: M, @@ -42,6 +110,7 @@ impl LanceDbVectorStore { nprobes, refine_factor, post_filter, + column, } = self.search_params.clone(); if let Some(distance_type) = distance_type { @@ -65,6 +134,10 @@ impl LanceDbVectorStore { query = query.postfilter(); } + if let Some(column) = column { + query = query.column(column.as_str()) + } + query } } @@ -96,6 +169,7 @@ pub struct SearchParams { /// If set to true, filtering will happen after the vector search instead of before /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information post_filter: Option, + column: Option, } impl SearchParams { @@ -123,6 +197,11 @@ impl SearchParams { self.post_filter = Some(post_filter); self } + + pub fn column(mut self, column: &str) -> Self { + self.column = Some(column.to_string()); + self + } } impl LanceDbVectorStore { From 9c3eb0e5c4863c9c62db7ff8e5dae1ce65570789 Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 2 Oct 2024 15:05:02 -0400 Subject: [PATCH 32/39] fix: mongodb vector search - use num_candidates from search params --- rig-core/src/providers/openai.rs | 2 +- rig-lancedb/src/lib.rs | 2 +- rig-mongodb/src/lib.rs | 12 +++++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/rig-core/src/providers/openai.rs b/rig-core/src/providers/openai.rs index 6be27484..8262e6ce 100644 --- a/rig-core/src/providers/openai.rs +++ b/rig-core/src/providers/openai.rs @@ -99,7 +99,7 @@ impl Client { /// // Initialize the OpenAI client /// let openai = Client::new("your-open-ai-api-key"); /// - /// let embedding_model = openai.embedding_model(openai::TEXT_EMBEDDING_3_LARGE, 3072); + /// let embedding_model = openai.embedding_model("model-unknown-to-rig", 3072); /// ``` pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel { EmbeddingModel::new(self.clone(), model, ndims) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 1eec5e3a..af2a7071 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -34,7 +34,7 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// }; /// use rig_lancedb::{LanceDbVectorStore, SearchParams}; /// use serde::Deserialize; -/// +/// /// #[derive(Deserialize, Debug)] /// pub struct VectorSearchResult { /// pub id: String, diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index 93e2f756..c0be24e7 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -109,15 +109,21 @@ impl MongoDbVectorIndex { /// Vector search stage of aggregation pipeline of mongoDB collection. /// To be used by implementations of top_n and top_n_ids methods on VectorStoreIndex trait for MongoDbVectorIndex. fn pipeline_search_stage(&self, prompt_embedding: &Embedding, n: usize) -> bson::Document { + let SearchParams { + filter, + exact, + num_candidates, + } = &self.search_params; + doc! { "$vectorSearch": { "index": &self.index_name, "path": "embeddings.vec", "queryVector": &prompt_embedding.vec, - "numCandidates": (n * 10) as u32, + "numCandidates": num_candidates.unwrap_or((n * 10) as u32), "limit": n as u32, - "filter": &self.search_params.filter, - "exact": self.search_params.exact.unwrap_or(false) + "filter": filter, + "exact": exact.unwrap_or(false) } } } From 3eef745b30753a07c2b412b70fbb49a56522b2be Mon Sep 17 00:00:00 2001 From: Garance Date: Wed, 2 Oct 2024 15:12:46 -0400 Subject: [PATCH 33/39] fix(lancedb): replace VectorStoreIndexDyn with VectorStoreIndex in examples --- rig-lancedb/examples/vector_search_local_ann.rs | 15 +++------------ rig-lancedb/examples/vector_search_s3_ann.rs | 15 +++------------ 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 358ead03..ef72b37d 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -3,11 +3,11 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; +use rig::vector_store::VectorStoreIndex; use rig::{ completion::Prompt, embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::VectorStoreIndexDyn, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; @@ -87,17 +87,8 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n("My boss says I zindle too much, what does that mean?", 1) - .await? - .into_iter() - .map(|(score, id, doc)| { - anyhow::Ok(( - score, - id, - serde_json::from_value::(doc)?, - )) - }) - .collect::, _>>()?; + .top_n::("My boss says I zindle too much, what does that mean?", 1) + .await?; println!("Results: {:?}", results); diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index b56d9156..38bf4dd9 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -7,7 +7,7 @@ use rig::{ completion::Prompt, embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, - vector_store::VectorStoreIndexDyn, + vector_store::VectorStoreIndex, }; use rig_lancedb::{LanceDbVectorStore, SearchParams}; use serde::Deserialize; @@ -92,17 +92,8 @@ async fn main() -> Result<(), anyhow::Error> { // Query the index let results = vector_store - .top_n("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) - .await? - .into_iter() - .map(|(score, id, doc)| { - anyhow::Ok(( - score, - id, - serde_json::from_value::(doc)?, - )) - }) - .collect::, _>>()?; + .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) + .await?; println!("Results: {:?}", results); From 27435e432d2da65d3cf6b2af9311298a830a43a3 Mon Sep 17 00:00:00 2001 From: Garance Date: Thu, 3 Oct 2024 16:53:08 -0400 Subject: [PATCH 34/39] fix: make PR changes pt I --- rig-lancedb/Cargo.toml | 2 +- .../examples/vector_search_local_ann.rs | 5 +- rig-lancedb/src/lib.rs | 64 +++++++++---------- rig-lancedb/src/utils/deserializer.rs | 27 ++++---- rig-lancedb/src/utils/mod.rs | 4 +- rig-mongodb/examples/vector_search_mongodb.rs | 2 +- rig-mongodb/src/lib.rs | 15 +++-- 7 files changed, 65 insertions(+), 54 deletions(-) diff --git a/rig-lancedb/Cargo.toml b/rig-lancedb/Cargo.toml index ed7426b0..031df2d3 100644 --- a/rig-lancedb/Cargo.toml +++ b/rig-lancedb/Cargo.toml @@ -13,4 +13,4 @@ futures = "0.3.30" [dev-dependencies] tokio = "1.40.0" -anyhow = "1.0.89" \ No newline at end of file +anyhow = "1.0.89" diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index ef72b37d..d3889e06 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -73,16 +73,17 @@ async fn main() -> Result<(), anyhow::Error> { let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information - vector_store + table .create_index( + &["embedding"], lancedb::index::Index::IvfPq( IvfPqIndexBuilder::default() // This overrides the default distance type of L2. // Needs to be the same distance type as the one used in search params. .distance_type(DistanceType::Cosine), ), - &["embedding"], ) + .execute() .await?; // Query the index diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index af2a7071..76807030 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -1,5 +1,4 @@ use lancedb::{ - index::Index, query::{QueryBase, VectorQuery}, DistanceType, }; @@ -9,7 +8,7 @@ use rig::{ }; use serde::Deserialize; use serde_json::Value; -use utils::Query; +use utils::QueryToJson; mod utils; @@ -24,7 +23,7 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// # Example /// ``` /// use std::{env, sync::Arc}; - +/// /// use arrow_array::RecordBatchIterator; /// use fixture::{as_record_batch, schema}; /// use rig::{ @@ -44,23 +43,23 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo). /// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set"); /// let openai_client = Client::new(&openai_api_key); - +/// /// // Select the embedding model and generate our embeddings /// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - +/// /// let embeddings = EmbeddingsBuilder::new(model.clone()) /// .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.") /// .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive") /// .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.") /// .build() /// .await?; - +/// /// // Define search_params params that will be used by the vector store to perform the vector search. /// let search_params = SearchParams::default(); - +/// /// // Initialize LanceDB locally. /// let db = lancedb::connect("data/lancedb-store").execute().await?; - +/// /// // Create table with embeddings. /// let record_batch = as_record_batch(embeddings, model.ndims()); /// let table = db @@ -70,9 +69,9 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// ) /// .execute() /// .await?; - +/// /// let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; - +/// /// // Query the index /// let results = vector_store /// .top_n("My boss says I zindle too much, what does that mean?", 1) @@ -86,7 +85,7 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError { /// )) /// }) /// .collect::, _>>()?; - +/// /// println!("Results: {:?}", results); /// ``` pub struct LanceDbVectorStore { @@ -151,53 +150,61 @@ pub enum SearchType { Approximate, } +/// Parameters used to perform a vector search on a LanceDb table. #[derive(Debug, Clone, Default)] pub struct SearchParams { - /// Always set the distance_type to match the value used to train the index - /// By default, set to L2 distance_type: Option, - /// By default, ANN will be used if there is an index on the table. - /// By default, kNN will be used if there is NO index on the table. - /// To use defaults, set to None. search_type: Option, - /// Set this value only when search type is ANN. - /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information nprobes: Option, - /// Set this value only when search type is ANN. - /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information refine_factor: Option, - /// If set to true, filtering will happen after the vector search instead of before - /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information post_filter: Option, column: Option, } impl SearchParams { + /// Sets the distance type of the search params. + /// Always set the distance_type to match the value used to train the index. + /// The default is DistanceType::L2. pub fn distance_type(mut self, distance_type: DistanceType) -> Self { self.distance_type = Some(distance_type); self } + /// Sets the search type of the search params. + /// By default, ANN will be used if there is an index on the table and kNN will be used if there is NO index on the table. + /// To use the mentioned defaults, do not set the search type. pub fn search_type(mut self, search_type: SearchType) -> Self { self.search_type = Some(search_type); self } + /// Sets the nprobes of the search params. + /// Only set this value only when the search type is ANN. + /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information. pub fn nprobes(mut self, nprobes: usize) -> Self { self.nprobes = Some(nprobes); self } + /// Sets the refine factor of the search params. + /// Only set this value only when search type is ANN. + /// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information. pub fn refine_factor(mut self, refine_factor: u32) -> Self { self.refine_factor = Some(refine_factor); self } + /// Sets the post filter of the search params. + /// If set to true, filtering will happen after the vector search instead of before. + /// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information. pub fn post_filter(mut self, post_filter: bool) -> Self { self.post_filter = Some(post_filter); self } + /// Sets the column of the search params. + /// Only set this value if there is more than one column that contains lists of floats. + /// If there is only one column of list of floats, this column will be chosen for the vector search automatically. pub fn column(mut self, column: &str) -> Self { self.column = Some(column.to_string()); self @@ -205,6 +212,9 @@ impl SearchParams { } impl LanceDbVectorStore { + /// Create an instance of `LanceDbVectorStore` with an existing table and model. + /// Define the id field name of the table. + /// Define search parameters that will be used to perform vector searches on the table. pub async fn new( table: lancedb::Table, model: M, @@ -218,16 +228,6 @@ impl LanceDbVectorStore { search_params, }) } - - /// Define an index on the specified fields of the lanceDB table for search optimization. - /// Note: it is required to add an index on the column containing the embeddings when performing an ANN type vector search. - pub async fn create_index( - &self, - index: Index, - field_names: &[impl AsRef], - ) -> Result<(), lancedb::Error> { - self.table.create_index(field_names, index).execute().await - } } impl VectorStoreIndex for LanceDbVectorStore { diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index 0b746de4..fe703a8d 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -313,21 +313,24 @@ impl RecordBatchDeserializer for RecordBatch { .map(|item| serde_json::to_value(item).map_err(serde_to_rig_error)) .collect() } - // Not yet fully supported DataType::BinaryView | DataType::Utf8View | DataType::ListView(..) - | DataType::LargeListView(..) => { - todo!() - } - // Currently unstable - DataType::Float16 | DataType::Decimal256(..) => { - todo!() - } - _ => { - println!("Unsupported data type"); - Ok(vec![serde_json::Value::Null]) - } + | DataType::LargeListView(..) => Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "Data type: {} not yet fully supported", + column.data_type() + )), + ))), + DataType::Float16 | DataType::Decimal256(..) => Err( + VectorStoreError::DatastoreError(Box::new(ArrowError::CastError(format!( + "Data type: {} currently unstable", + column.data_type() + )))), + ), + _ => Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!("Unsupported data type: {}", column.data_type())), + ))), } } diff --git a/rig-lancedb/src/utils/mod.rs b/rig-lancedb/src/utils/mod.rs index 46aeab31..80e317dd 100644 --- a/rig-lancedb/src/utils/mod.rs +++ b/rig-lancedb/src/utils/mod.rs @@ -9,11 +9,11 @@ use crate::lancedb_to_rig_error; /// Trait that facilitates the conversion of columnar data returned by a lanceDb query to serde_json::Value. /// Used whenever a lanceDb table is queried. -pub trait Query { +pub trait QueryToJson { async fn execute_query(&self) -> Result, VectorStoreError>; } -impl Query for lancedb::query::VectorQuery { +impl QueryToJson for lancedb::query::VectorQuery { async fn execute_query(&self) -> Result, VectorStoreError> { let record_batches = self .execute() diff --git a/rig-mongodb/examples/vector_search_mongodb.rs b/rig-mongodb/examples/vector_search_mongodb.rs index 7d569c4c..3d062de3 100644 --- a/rig-mongodb/examples/vector_search_mongodb.rs +++ b/rig-mongodb/examples/vector_search_mongodb.rs @@ -49,7 +49,7 @@ async fn main() -> Result<(), anyhow::Error> { // Create a vector index on our vector store // IMPORTANT: Reuse the same model that was used to generate the embeddings - let index = vector_store.index(model, "vector_index", SearchParams::new()); + let index = vector_store.index(model, "vector_index", SearchParams::default()); // Query the index let results = index diff --git a/rig-mongodb/src/lib.rs b/rig-mongodb/src/lib.rs index c0be24e7..43869989 100644 --- a/rig-mongodb/src/lib.rs +++ b/rig-mongodb/src/lib.rs @@ -158,16 +158,13 @@ impl MongoDbVectorIndex { /// See [MongoDB Vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information /// on each of the fields pub struct SearchParams { - /// Pre-filter filter: mongodb::bson::Document, - /// Whether to use ANN or ENN search exact: Option, - /// Only set this field if exact is set to false - /// Number of nearest neighbors to use during the search num_candidates: Option, } impl SearchParams { + /// Initializes a new `SearchParams` with default values. pub fn new() -> Self { Self { filter: doc! {}, @@ -176,16 +173,26 @@ impl SearchParams { } } + /// Sets the pre-filter field of the search params. + /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information. pub fn filter(mut self, filter: mongodb::bson::Document) -> Self { self.filter = filter; self } + /// Sets the exact field of the search params. + /// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed. + /// By default, exact is false. + /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information. pub fn exact(mut self, exact: bool) -> Self { self.exact = Some(exact); self } + /// Sets the num_candidates field of the search params. + /// Only set this field if exact is set to false. + /// Number of nearest neighbors to use during the search. + /// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information. pub fn num_candidates(mut self, num_candidates: u32) -> Self { self.num_candidates = Some(num_candidates); self From b55e86e25aeb244978bf6497731149901a2aea0b Mon Sep 17 00:00:00 2001 From: Garance Date: Fri, 4 Oct 2024 10:02:43 -0400 Subject: [PATCH 35/39] fix: make PR changes Pt II --- .../examples/vector_search_local_ann.rs | 4 +- rig-lancedb/examples/vector_search_s3_ann.rs | 9 +- rig-lancedb/src/lib.rs | 5 +- rig-lancedb/src/utils/deserializer.rs | 580 +++++++++--------- 4 files changed, 299 insertions(+), 299 deletions(-) diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index d3889e06..ceabed1e 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -70,8 +70,6 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; - // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information table .create_index( @@ -86,6 +84,8 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; + let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + // Query the index let results = vector_store .top_n::("My boss says I zindle too much, what does that mean?", 1) diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 38bf4dd9..5f2906b6 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -75,21 +75,22 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; - let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; - // See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information - vector_store + table .create_index( + &["embedding"], lancedb::index::Index::IvfPq( IvfPqIndexBuilder::default() // This overrides the default distance type of L2. // Needs to be the same distance type as the one used in search params. .distance_type(DistanceType::Cosine), ), - &["embedding"], ) + .execute() .await?; + let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; + // Query the index let results = vector_store .top_n::("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1) diff --git a/rig-lancedb/src/lib.rs b/rig-lancedb/src/lib.rs index 76807030..1e8b344a 100644 --- a/rig-lancedb/src/lib.rs +++ b/rig-lancedb/src/lib.rs @@ -248,7 +248,8 @@ impl VectorStoreIndex for LanceDbV .execute_query() .await? .into_iter() - .map(|value| { + .enumerate() + .map(|(i, value)| { Ok(( match value.get("_distance") { Some(Value::Number(distance)) => distance.as_f64().unwrap_or_default(), @@ -256,7 +257,7 @@ impl VectorStoreIndex for LanceDbV }, match value.get(self.id_field.clone()) { Some(Value::String(id)) => id.to_string(), - _ => "".to_string(), + _ => format!("unknown{i}"), }, serde_json::from_value(value).map_err(serde_to_rig_error)?, )) diff --git a/rig-lancedb/src/utils/deserializer.rs b/rig-lancedb/src/utils/deserializer.rs index fe703a8d..104d1bf8 100644 --- a/rig-lancedb/src/utils/deserializer.rs +++ b/rig-lancedb/src/utils/deserializer.rs @@ -43,297 +43,9 @@ impl RecordBatchDeserializer for Vec { } } +/// Trait used to deserialize data returned from LanceDB queries into a serde_json::Value vector. impl RecordBatchDeserializer for RecordBatch { fn deserialize(&self) -> Result, VectorStoreError> { - /// Recursive function that matches all possible data types store in LanceDB and converts them to serde_json::Value. - fn type_matcher(column: &Arc) -> Result, VectorStoreError> { - match column.data_type() { - DataType::Null => Ok(vec![serde_json::Value::Null]), - DataType::Float32 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Float64 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Int8 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Int16 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Int32 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Int64 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::UInt8 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::UInt16 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::UInt32 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::UInt64 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Date32 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Date64 => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Decimal128(..) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Time32(TimeUnit::Second) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Time32(TimeUnit::Millisecond) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Time64(TimeUnit::Microsecond) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Time64(TimeUnit::Nanosecond) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Timestamp(TimeUnit::Microsecond, ..) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Timestamp(TimeUnit::Millisecond, ..) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Timestamp(TimeUnit::Second, ..) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Timestamp(TimeUnit::Nanosecond, ..) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Duration(TimeUnit::Microsecond) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Duration(TimeUnit::Millisecond) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Duration(TimeUnit::Nanosecond) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Duration(TimeUnit::Second) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Interval(IntervalUnit::YearMonth) => column - .to_primitive_value::() - .map_err(serde_to_rig_error), - DataType::Interval(IntervalUnit::DayTime) => Ok(column - .to_primitive::() - .iter() - .map(|IntervalDayTime { days, milliseconds }| { - json!({ - "days": days, - "milliseconds": milliseconds, - }) - }) - .collect()), - DataType::Interval(IntervalUnit::MonthDayNano) => Ok(column - .to_primitive::() - .iter() - .map( - |IntervalMonthDayNano { - months, - days, - nanoseconds, - }| { - json!({ - "months": months, - "days": days, - "nanoseconds": nanoseconds, - }) - }, - ) - .collect()), - DataType::Utf8 => column - .to_str_value::() - .map_err(serde_to_rig_error), - DataType::LargeUtf8 => column - .to_str_value::() - .map_err(serde_to_rig_error), - DataType::Binary => column - .to_str_value::() - .map_err(serde_to_rig_error), - DataType::LargeBinary => column - .to_str_value::() - .map_err(serde_to_rig_error), - DataType::FixedSizeBinary(n) => (0..*n) - .map(|i| serde_json::to_value(column.as_fixed_size_binary().value(i as usize))) - .collect::, _>>() - .map_err(serde_to_rig_error), - DataType::Boolean => { - let bool_array = column.as_boolean(); - (0..bool_array.len()) - .map(|i| bool_array.value(i)) - .map(serde_json::to_value) - .collect::, _>>() - .map_err(serde_to_rig_error) - } - DataType::FixedSizeList(..) => { - column.to_fixed_lists().iter().map(type_matcher).map_ok() - } - DataType::List(..) => column.to_list::().iter().map(type_matcher).map_ok(), - DataType::LargeList(..) => { - column.to_list::().iter().map(type_matcher).map_ok() - } - DataType::Struct(..) => { - let struct_array = column.as_struct(); - let struct_columns = struct_array - .inner_lists() - .iter() - .map(type_matcher) - .collect::, _>>()?; - - Ok(struct_columns - .build_struct(struct_array.num_rows(), struct_array.column_names())) - } - DataType::Map(..) => { - let map_columns = column - .as_map() - .entries() - .inner_lists() - .iter() - .map(type_matcher) - .collect::, _>>()?; - - Ok(map_columns.build_map()) - } - DataType::Dictionary(keys_type, ..) => { - let (keys, v) = match **keys_type { - DataType::Int8 => column.to_dict_values::()?, - DataType::Int16 => column.to_dict_values::()?, - DataType::Int32 => column.to_dict_values::()?, - DataType::Int64 => column.to_dict_values::()?, - DataType::UInt8 => column.to_dict_values::()?, - DataType::UInt16 => column.to_dict_values::()?, - DataType::UInt32 => column.to_dict_values::()?, - DataType::UInt64 => column.to_dict_values::()?, - _ => { - return Err(VectorStoreError::DatastoreError(Box::new( - ArrowError::CastError(format!( - "Dictionary keys type is not accepted: {keys_type:?}" - )), - ))) - } - }; - - let values = type_matcher(v)?; - - Ok(keys - .iter() - .zip(values) - .map(|(k, v)| { - let mut map = serde_json::Map::new(); - map.insert(k.to_string(), v); - map - }) - .map(Value::Object) - .collect()) - } - DataType::Union(..) => match column.as_any().downcast_ref::() { - Some(union_array) => (0..union_array.len()) - .map(|i| union_array.value(i).clone()) - .collect::>() - .iter() - .map(type_matcher) - .map_ok(), - None => Err(VectorStoreError::DatastoreError(Box::new( - ArrowError::CastError(format!( - "Can't cast column {column:?} to union array" - )), - ))), - }, - DataType::RunEndEncoded(index_type, ..) => { - let items = match index_type.data_type() { - DataType::Int16 => { - let (indexes, v) = column - .to_run_end::() - .map_err(arrow_to_rig_error)?; - - let mut prev = vec![0]; - prev.extend(indexes.clone()); - - prev.iter() - .zip(indexes) - .map(|(prev, cur)| cur - prev) - .zip(type_matcher(&v)?) - .flat_map(|(n, value)| vec![value; n as usize]) - .collect::>() - } - DataType::Int32 => { - let (indexes, v) = column - .to_run_end::() - .map_err(arrow_to_rig_error)?; - - let mut prev = vec![0]; - prev.extend(indexes.clone()); - - prev.iter() - .zip(indexes) - .map(|(prev, cur)| cur - prev) - .zip(type_matcher(&v)?) - .flat_map(|(n, value)| vec![value; n as usize]) - .collect::>() - } - DataType::Int64 => { - let (indexes, v) = column - .to_run_end::() - .map_err(arrow_to_rig_error)?; - - let mut prev = vec![0]; - prev.extend(indexes.clone()); - - prev.iter() - .zip(indexes) - .map(|(prev, cur)| cur - prev) - .zip(type_matcher(&v)?) - .flat_map(|(n, value)| vec![value; n as usize]) - .collect::>() - } - _ => { - return Err(VectorStoreError::DatastoreError(Box::new( - ArrowError::CastError(format!( - "RunEndEncoded index type is not accepted: {index_type:?}" - )), - ))) - } - }; - - items - .iter() - .map(|item| serde_json::to_value(item).map_err(serde_to_rig_error)) - .collect() - } - DataType::BinaryView - | DataType::Utf8View - | DataType::ListView(..) - | DataType::LargeListView(..) => Err(VectorStoreError::DatastoreError(Box::new( - ArrowError::CastError(format!( - "Data type: {} not yet fully supported", - column.data_type() - )), - ))), - DataType::Float16 | DataType::Decimal256(..) => Err( - VectorStoreError::DatastoreError(Box::new(ArrowError::CastError(format!( - "Data type: {} currently unstable", - column.data_type() - )))), - ), - _ => Err(VectorStoreError::DatastoreError(Box::new( - ArrowError::CastError(format!("Unsupported data type: {}", column.data_type())), - ))), - } - } - let binding = self.schema(); let column_names = binding .fields() @@ -362,8 +74,294 @@ impl RecordBatchDeserializer for RecordBatch { } } +/// Recursive function that matches all possible data types store in LanceDB and converts them to serde_json::Value vector. +fn type_matcher(column: &Arc) -> Result, VectorStoreError> { + match column.data_type() { + DataType::Null => Ok(vec![serde_json::Value::Null]), + DataType::Float32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Float64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int8 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int16 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Int64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt8 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt16 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::UInt64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Date32 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Date64 => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Decimal128(..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time32(TimeUnit::Second) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time32(TimeUnit::Millisecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time64(TimeUnit::Microsecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Time64(TimeUnit::Nanosecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Microsecond, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Millisecond, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Second, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Timestamp(TimeUnit::Nanosecond, ..) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Microsecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Millisecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Nanosecond) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Duration(TimeUnit::Second) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Interval(IntervalUnit::YearMonth) => column + .to_primitive_value::() + .map_err(serde_to_rig_error), + DataType::Interval(IntervalUnit::DayTime) => Ok(column + .to_primitive::() + .iter() + .map(|IntervalDayTime { days, milliseconds }| { + json!({ + "days": days, + "milliseconds": milliseconds, + }) + }) + .collect()), + DataType::Interval(IntervalUnit::MonthDayNano) => Ok(column + .to_primitive::() + .iter() + .map( + |IntervalMonthDayNano { + months, + days, + nanoseconds, + }| { + json!({ + "months": months, + "days": days, + "nanoseconds": nanoseconds, + }) + }, + ) + .collect()), + DataType::Utf8 => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::LargeUtf8 => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::Binary => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::LargeBinary => column + .to_str_value::() + .map_err(serde_to_rig_error), + DataType::FixedSizeBinary(n) => (0..*n) + .map(|i| serde_json::to_value(column.as_fixed_size_binary().value(i as usize))) + .collect::, _>>() + .map_err(serde_to_rig_error), + DataType::Boolean => { + let bool_array = column.as_boolean(); + (0..bool_array.len()) + .map(|i| bool_array.value(i)) + .map(serde_json::to_value) + .collect::, _>>() + .map_err(serde_to_rig_error) + } + DataType::FixedSizeList(..) => column.to_fixed_lists().iter().map(type_matcher).map_ok(), + DataType::List(..) => column.to_list::().iter().map(type_matcher).map_ok(), + DataType::LargeList(..) => column.to_list::().iter().map(type_matcher).map_ok(), + DataType::Struct(..) => { + let struct_array = column.as_struct(); + let struct_columns = struct_array + .inner_lists() + .iter() + .map(type_matcher) + .collect::, _>>()?; + + Ok(struct_columns.build_struct(struct_array.num_rows(), struct_array.column_names())) + } + DataType::Map(..) => { + let map_columns = column + .as_map() + .entries() + .inner_lists() + .iter() + .map(type_matcher) + .collect::, _>>()?; + + Ok(map_columns.build_map()) + } + DataType::Dictionary(keys_type, ..) => { + let (keys, v) = match **keys_type { + DataType::Int8 => column.to_dict_values::()?, + DataType::Int16 => column.to_dict_values::()?, + DataType::Int32 => column.to_dict_values::()?, + DataType::Int64 => column.to_dict_values::()?, + DataType::UInt8 => column.to_dict_values::()?, + DataType::UInt16 => column.to_dict_values::()?, + DataType::UInt32 => column.to_dict_values::()?, + DataType::UInt64 => column.to_dict_values::()?, + _ => { + return Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "Dictionary keys type is not accepted: {keys_type:?}" + )), + ))) + } + }; + + let values = type_matcher(v)?; + + Ok(keys + .iter() + .zip(values) + .map(|(k, v)| { + let mut map = serde_json::Map::new(); + map.insert(k.to_string(), v); + map + }) + .map(Value::Object) + .collect()) + } + DataType::Union(..) => match column.as_any().downcast_ref::() { + Some(union_array) => (0..union_array.len()) + .map(|i| union_array.value(i).clone()) + .collect::>() + .iter() + .map(type_matcher) + .map_ok(), + None => Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!("Can't cast column {column:?} to union array")), + ))), + }, + DataType::RunEndEncoded(index_type, ..) => { + let items = match index_type.data_type() { + DataType::Int16 => { + let (indexes, v) = column + .to_run_end::() + .map_err(arrow_to_rig_error)?; + + let mut prev = vec![0]; + prev.extend(indexes.clone()); + + prev.iter() + .zip(indexes) + .map(|(prev, cur)| cur - prev) + .zip(type_matcher(&v)?) + .flat_map(|(n, value)| vec![value; n as usize]) + .collect::>() + } + DataType::Int32 => { + let (indexes, v) = column + .to_run_end::() + .map_err(arrow_to_rig_error)?; + + let mut prev = vec![0]; + prev.extend(indexes.clone()); + + prev.iter() + .zip(indexes) + .map(|(prev, cur)| cur - prev) + .zip(type_matcher(&v)?) + .flat_map(|(n, value)| vec![value; n as usize]) + .collect::>() + } + DataType::Int64 => { + let (indexes, v) = column + .to_run_end::() + .map_err(arrow_to_rig_error)?; + + let mut prev = vec![0]; + prev.extend(indexes.clone()); + + prev.iter() + .zip(indexes) + .map(|(prev, cur)| cur - prev) + .zip(type_matcher(&v)?) + .flat_map(|(n, value)| vec![value; n as usize]) + .collect::>() + } + _ => { + return Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "RunEndEncoded index type is not accepted: {index_type:?}" + )), + ))) + } + }; + + items + .iter() + .map(|item| serde_json::to_value(item).map_err(serde_to_rig_error)) + .collect() + } + DataType::BinaryView + | DataType::Utf8View + | DataType::ListView(..) + | DataType::LargeListView(..) => Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!( + "Data type: {} not yet fully supported", + column.data_type() + )), + ))), + DataType::Float16 | DataType::Decimal256(..) => Err(VectorStoreError::DatastoreError( + Box::new(ArrowError::CastError(format!( + "Data type: {} currently unstable", + column.data_type() + ))), + )), + _ => Err(VectorStoreError::DatastoreError(Box::new( + ArrowError::CastError(format!("Unsupported data type: {}", column.data_type())), + ))), + } +} + +/////////////////////////////////////////////////////////////////////////////////// +/// Everything below includes helpers for the recursive function `type_matcher`./// +/////////////////////////////////////////////////////////////////////////////////// + /// Trait used to "deserialize" an arrow_array::Array as as list of primitive objects. -pub trait DeserializePrimitiveArray { +trait DeserializePrimitiveArray { /// Downcast arrow Array into a `PrimitiveArray` with items that implement trait `ArrowPrimitiveType`. /// Return the primitive array values. fn to_primitive(&self) -> Vec<::Native>; @@ -395,7 +393,7 @@ impl DeserializePrimitiveArray for &Arc { } /// Trait used to "deserialize" an arrow_array::Array as as list of str objects. -pub trait DeserializeByteArray { +trait DeserializeByteArray { /// Downcast arrow Array into a `GenericByteArray` with items that implement trait `ByteArrayType`. /// Return the generic byte array values. fn to_str(&self) -> Vec<&::Native>; From cc5a328093d4c89c1c693dc4cb446f67fac7a866 Mon Sep 17 00:00:00 2001 From: Garance Date: Fri, 4 Oct 2024 10:06:02 -0400 Subject: [PATCH 36/39] fix(ci): install protobuf-compiler in test job --- .github/workflows/ci.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f6e22f50..2aee004b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -70,6 +70,9 @@ jobs: with: tool: nextest + - name: Install Protoc + uses: arduino/setup-protoc@v3 + - name: Test with latest nextest release uses: actions-rs/cargo@v1 with: From 9a310cb1abd1f2421757270e8bc93c7ac929e3c4 Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 7 Oct 2024 09:43:41 -0400 Subject: [PATCH 37/39] fix: update lancedb examples test data --- .../examples/vector_search_local_ann.rs | 18 ++++-------------- rig-lancedb/examples/vector_search_s3_ann.rs | 18 ++++-------------- 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index ceabed1e..cab3682d 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -5,7 +5,6 @@ use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::vector_store::VectorStoreIndex; use rig::{ - completion::Prompt, embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, }; @@ -30,20 +29,11 @@ async fn main() -> Result<(), anyhow::Error> { // Select an embedding model. let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Generate test data for RAG demo - let agent = openai_client - .agent("gpt-4o") - .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") - .build(); - let response = agent - .prompt("Invent 100 words and their definitions") - .await?; - let mut definitions: Vec = serde_json::from_str(&response)?; + // Set up test data for RAG demo + let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - // Note: need at least 256 rows in order to create an index on a table but OpenAI limits the output size - // so we triplicate the vector for testing purposes. - definitions.extend(definitions.clone()); - definitions.extend(definitions.clone()); + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + let definitions = vec![definition; 256]; // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index 5f2906b6..da76868a 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -4,7 +4,6 @@ use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; use rig::{ - completion::Prompt, embeddings::{EmbeddingModel, EmbeddingsBuilder}, providers::openai::{Client, TEXT_EMBEDDING_ADA_002}, vector_store::VectorStoreIndex, @@ -32,20 +31,11 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); - // Generate test data for RAG demo - let agent = openai_client - .agent("gpt-4o") - .preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.") - .build(); - let response = agent - .prompt("Invent 100 words and their definitions") - .await?; - let mut definitions: Vec = serde_json::from_str(&response)?; + // Set up test data for RAG demo + let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); - // Note: need at least 256 rows in order to create an index on a table but OpenAI limits the output size - // so we triplicate the vector for testing purposes. - definitions.extend(definitions.clone()); - definitions.extend(definitions.clone()); + // Note: need at least 256 rows in order to create an index so copy the definition 256 times for testing purposes. + let definitions = vec![definition; 256]; // Generate embeddings for the test data. let embeddings = EmbeddingsBuilder::new(model.clone()) From 9b09639fa547f8bb090d60673a07b4c14f02088d Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 7 Oct 2024 09:48:23 -0400 Subject: [PATCH 38/39] refactor: lance db examples --- .../examples/vector_search_local_ann.rs | 18 ++++++------------ rig-lancedb/examples/vector_search_s3_ann.rs | 19 ++++++++++--------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index cab3682d..9e4db9d2 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -29,6 +29,9 @@ async fn main() -> Result<(), anyhow::Error> { // Select an embedding model. let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + // Initialize LanceDB locally. + let db = lancedb::connect("data/lancedb-store").execute().await?; + // Set up test data for RAG demo let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); @@ -44,12 +47,6 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Define search_params params that will be used by the vector store to perform the vector search. - let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - - // Initialize LanceDB locally. - let db = lancedb::connect("data/lancedb-store").execute().await?; - // Create table with embeddings. let record_batch = as_record_batch(embeddings, model.ndims()); let table = db @@ -64,16 +61,13 @@ async fn main() -> Result<(), anyhow::Error> { table .create_index( &["embedding"], - lancedb::index::Index::IvfPq( - IvfPqIndexBuilder::default() - // This overrides the default distance type of L2. - // Needs to be the same distance type as the one used in search params. - .distance_type(DistanceType::Cosine), - ), + lancedb::index::Index::IvfPq(IvfPqIndexBuilder::default()), ) .execute() .await?; + // Define search_params params that will be used by the vector store to perform the vector search. + let search_params = SearchParams::default(); let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; // Query the index diff --git a/rig-lancedb/examples/vector_search_s3_ann.rs b/rig-lancedb/examples/vector_search_s3_ann.rs index da76868a..70f0c8c5 100644 --- a/rig-lancedb/examples/vector_search_s3_ann.rs +++ b/rig-lancedb/examples/vector_search_s3_ann.rs @@ -31,6 +31,13 @@ async fn main() -> Result<(), anyhow::Error> { // Select the embedding model and generate our embeddings let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002); + // Initialize LanceDB on S3. + // Note: see below docs for more options and IAM permission required to read/write to S3. + // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 + let db = lancedb::connect("s3://lancedb-test-829666124233") + .execute() + .await?; + // Set up test data for RAG demo let definition = "Definition of *flumbuzzle (verb)*: to bewilder or confuse someone completely, often by using nonsensical or overly complex explanations or instructions.".to_string(); @@ -46,15 +53,6 @@ async fn main() -> Result<(), anyhow::Error> { .build() .await?; - // Define search_params params that will be used by the vector store to perform the vector search. - let search_params = SearchParams::default().distance_type(DistanceType::Cosine); - - // Initialize LanceDB on S3. - // Note: see below docs for more options and IAM permission required to read/write to S3. - // https://lancedb.github.io/lancedb/guides/storage/#aws-s3 - let db = lancedb::connect("s3://lancedb-test-829666124233") - .execute() - .await?; // Create table with embeddings. let record_batch = as_record_batch(embeddings, model.ndims()); let table = db @@ -79,6 +77,9 @@ async fn main() -> Result<(), anyhow::Error> { .execute() .await?; + // Define search_params params that will be used by the vector store to perform the vector search. + let search_params = SearchParams::default().distance_type(DistanceType::Cosine); + let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?; // Query the index From d5dc56ad52d469173845df3c714877ed11912146 Mon Sep 17 00:00:00 2001 From: Garance Date: Mon, 7 Oct 2024 09:51:35 -0400 Subject: [PATCH 39/39] style: cargo fmt --- rig-lancedb/examples/vector_search_local_ann.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rig-lancedb/examples/vector_search_local_ann.rs b/rig-lancedb/examples/vector_search_local_ann.rs index 9e4db9d2..3ecd6b23 100644 --- a/rig-lancedb/examples/vector_search_local_ann.rs +++ b/rig-lancedb/examples/vector_search_local_ann.rs @@ -2,7 +2,7 @@ use std::{env, sync::Arc}; use arrow_array::RecordBatchIterator; use fixture::{as_record_batch, schema}; -use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType}; +use lancedb::index::vector::IvfPqIndexBuilder; use rig::vector_store::VectorStoreIndex; use rig::{ embeddings::{EmbeddingModel, EmbeddingsBuilder},