diff --git a/Cargo.lock b/Cargo.lock index 9464489..b54d901 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,6 +57,21 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.20" @@ -127,6 +142,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" @@ -353,6 +374,30 @@ dependencies = [ "regex-syntax 0.7.5", ] +[[package]] +name = "assert_approx_eq" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c07dab4369547dbe5114677b33fbbf724971019f3818172d59a97a61c774ffd" + +[[package]] +name = "async-compression" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" +dependencies = [ + "bzip2", + "flate2", + "futures-core", + "futures-io", + "memchr", + "pin-project-lite", + "tokio", + "xz2", + "zstd 0.13.2", + "zstd-safe 7.2.1", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -416,6 +461,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -469,6 +520,28 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake3" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -501,6 +574,27 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "brotli" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -541,12 +635,35 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cc" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -691,6 +808,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -721,6 +844,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam" version = "0.8.4" @@ -849,6 +981,67 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "datafusion" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7014432223f4d721cb9786cd88bb89e7464e0ba984d4a7f49db7787f5f268674" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-schema 47.0.0", + "async-compression", + "async-trait", + "bytes", + "bzip2", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-optimizer", + "datafusion-physical-expr", + "datafusion-physical-plan", + "datafusion-sql", + "flate2", + "futures", + "glob", + "half", + "hashbrown 0.14.5", + "indexmap 2.6.0", + "itertools 0.11.0", + "log", + "num_cpus", + "object_store", + "parking_lot", + "parquet", + "percent-encoding", + "pin-project-lite", + "rand", + "sqlparser", + "tempfile", + "tokio", + "tokio-util", + "url", + "uuid", + "xz2", + "zstd 0.12.4", +] + [[package]] name = "datafusion-common" version = "32.0.0" @@ -863,9 +1056,32 @@ dependencies = [ "chrono", "half", "num_cpus", + "object_store", + "parquet", "sqlparser", ] +[[package]] +name = "datafusion-execution" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780b73b2407050e53f51a9781868593f694102c59e622de9a8aafc0343c4f237" +dependencies = [ + "arrow", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-expr", + "futures", + "hashbrown 0.14.5", + "log", + "object_store", + "parking_lot", + "rand", + "tempfile", + "url", +] + [[package]] name = "datafusion-expr" version = "32.0.0" @@ -881,6 +1097,103 @@ dependencies = [ "strum_macros 0.25.3", ] +[[package]] +name = "datafusion-optimizer" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f2904a432f795484fd45e29ded4537152adb60f636c05691db34fcd94c92c96" +dependencies = [ + "arrow", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-expr", + "hashbrown 0.14.5", + "itertools 0.11.0", + "log", + "regex-syntax 0.7.5", +] + +[[package]] +name = "datafusion-physical-expr" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57b4968e9a998dc0476c4db7a82f280e2026b25f464e4aa0c3bb9807ee63ddfd" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema 47.0.0", + "base64 0.21.7", + "blake2", + "blake3", + "chrono", + "datafusion-common", + "datafusion-expr", + "half", + "hashbrown 0.14.5", + "hex", + "indexmap 2.6.0", + "itertools 0.11.0", + "libc", + "log", + "md-5", + "paste", + "petgraph", + "rand", + "regex", + "sha2", + "unicode-segmentation", + "uuid", +] + +[[package]] +name = "datafusion-physical-plan" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efd0d1fe54e37a47a2d58a1232c22786f2c28ad35805fdcd08f0253a8b0aaa90" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema 47.0.0", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr", + "futures", + "half", + "hashbrown 0.14.5", + "indexmap 2.6.0", + "itertools 0.11.0", + "log", + "once_cell", + "parking_lot", + "pin-project-lite", + "rand", + "tokio", + "uuid", +] + +[[package]] +name = "datafusion-sql" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b568d44c87ead99604d704f942e257c8a236ee1bbf890ee3e034ad659dcb2c21" +dependencies = [ + "arrow", + "arrow-schema 47.0.0", + "datafusion-common", + "datafusion-expr", + "log", + "sqlparser", +] + [[package]] name = "der" version = "0.7.9" @@ -925,6 +1238,12 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "dotenvy" version = "0.15.7" @@ -984,6 +1303,12 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flatbuffers" version = "23.5.26" @@ -994,6 +1319,16 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "flume" version = "0.11.1" @@ -1034,6 +1369,7 @@ checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -1084,6 +1420,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -1105,6 +1452,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1242,6 +1590,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -1443,12 +1797,27 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "integer-encoding" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.12.1" @@ -1473,6 +1842,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.72" @@ -1606,6 +1984,36 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "lz4" +version = "1.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "lzma-sys" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1775,6 +2183,27 @@ dependencies = [ "libc", ] +[[package]] +name = "num_enum" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "object" version = "0.36.5" @@ -1784,6 +2213,27 @@ dependencies = [ "memchr", ] +[[package]] +name = "object_store" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f930c88a43b1c3f6e776dfe495b4afab89882dbc81530c632db2ed65451ebcb4" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures", + "humantime", + "itertools 0.11.0", + "parking_lot", + "percent-encoding", + "snafu", + "tokio", + "tracing", + "url", + "walkdir", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -1795,8 +2245,10 @@ name = "optd-cost-model" version = "0.1.0" dependencies = [ "arrow-schema 53.2.0", + "assert_approx_eq", "chrono", "crossbeam", + "datafusion", "datafusion-expr", "itertools 0.13.0", "optd-persistent", @@ -1805,6 +2257,9 @@ dependencies = [ "serde", "serde_json", "serde_with", + "test-case", + "tokio", + "trait-variant", ] [[package]] @@ -1813,6 +2268,7 @@ version = "0.1.0" dependencies = [ "async-stream", "async-trait", + "num_enum", "sea-orm", "sea-orm-migration", "serde_json", @@ -1821,6 +2277,15 @@ dependencies = [ "trait-variant", ] +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + [[package]] name = "ordered-float" version = "3.9.2" @@ -1893,6 +2358,40 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "parquet" +version = "47.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0463cc3b256d5f50408c49a4be3a16674f4c8ceef60941709620a062b1f6bf4d" +dependencies = [ + "ahash 0.8.11", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-schema 47.0.0", + "arrow-select", + "base64 0.21.7", + "brotli", + "bytes", + "chrono", + "flate2", + "futures", + "hashbrown 0.14.5", + "lz4", + "num", + "num-bigint", + "object_store", + "paste", + "seq-macro", + "snap", + "thrift", + "tokio", + "twox-hash", + "zstd 0.12.4", +] + [[package]] name = "parse-zoneinfo" version = "0.3.1" @@ -1923,6 +2422,16 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.6.0", +] + [[package]] name = "phf" version = "0.11.2" @@ -2361,6 +2870,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2538,6 +3056,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + [[package]] name = "serde" version = "1.0.215" @@ -2588,7 +3112,7 @@ version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -2689,6 +3213,34 @@ dependencies = [ "serde", ] +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "socket2" version = "0.5.7" @@ -2855,7 +3407,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -2902,7 +3454,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -3093,6 +3645,39 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "test-case" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb2550dd13afcd286853192af8601920d959b14c401fcece38071d53bf0768a8" +dependencies = [ + "test-case-macros", +] + +[[package]] +name = "test-case-core" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adcb7fd841cd518e279be3d5a3eb0636409487998a4aff22f3de87b81e88384f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "test-case-macros" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", + "test-case-core", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -3123,6 +3708,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "thrift" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" +dependencies = [ + "byteorder", + "integer-encoding", + "ordered-float 2.10.1", +] + [[package]] name = "time" version = "0.3.36" @@ -3198,6 +3794,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "socket2", "tokio-macros", @@ -3226,6 +3823,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.6.8" @@ -3301,6 +3911,16 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + [[package]] name = "typenum" version = "1.17.0" @@ -3334,6 +3954,12 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + [[package]] name = "unicode-width" version = "0.2.0" @@ -3387,6 +4013,7 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ + "getrandom", "serde", ] @@ -3402,6 +4029,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -3488,6 +4125,15 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -3675,6 +4321,15 @@ dependencies = [ "tap", ] +[[package]] +name = "xz2" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" +dependencies = [ + "lzma-sys", +] + [[package]] name = "yansi" version = "1.0.1" @@ -3774,3 +4429,50 @@ dependencies = [ "quote", "syn 2.0.87", ] + +[[package]] +name = "zstd" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" +dependencies = [ + "zstd-safe 6.0.6", +] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe 7.2.1", +] + +[[package]] +name = "zstd-safe" +version = "6.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/optd-cost-model/Cargo.lock b/optd-cost-model/Cargo.lock index a38097d..bf0b367 100644 --- a/optd-cost-model/Cargo.lock +++ b/optd-cost-model/Cargo.lock @@ -57,6 +57,21 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + [[package]] name = "allocator-api2" version = "0.2.20" @@ -127,6 +142,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" @@ -353,6 +374,24 @@ dependencies = [ "regex-syntax 0.7.5", ] +[[package]] +name = "async-compression" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" +dependencies = [ + "bzip2", + "flate2", + "futures-core", + "futures-io", + "memchr", + "pin-project-lite", + "tokio", + "xz2", + "zstd 0.13.2", + "zstd-safe 7.2.1", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -416,6 +455,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -469,6 +514,28 @@ dependencies = [ "wyz", ] +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + +[[package]] +name = "blake3" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -480,9 +547,9 @@ dependencies = [ [[package]] name = "borsh" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5327f6c99920069d1fe374aa743be1af0031dea9f250852cdf1ae6a0861ee24" +checksum = "2506947f73ad44e344215ccd6403ac2ae18cd8e046e581a441bf8d199f257f03" dependencies = [ "borsh-derive", "cfg_aliases", @@ -490,9 +557,9 @@ dependencies = [ [[package]] name = "borsh-derive" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10aedd8f1a81a8aafbfde924b0e3061cd6fedd6f6bbcfc6a76e6fd426d7bfe26" +checksum = "c2593a3b8b938bd68373196c9832f516be11fa487ef4ae745eb282e6a56a7244" dependencies = [ "once_cell", "proc-macro-crate", @@ -501,6 +568,27 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "brotli" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -541,12 +629,35 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cc" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aeb932158bd710538c73702db6945cb68a8fb08c519e6e12706b94263b36db8" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -601,9 +712,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive", @@ -611,9 +722,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstream", "anstyle", @@ -635,9 +746,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "colorchoice" @@ -647,9 +758,9 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "comfy-table" -version = "7.1.1" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" +checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" dependencies = [ "strum 0.26.3", "strum_macros 0.26.4", @@ -691,6 +802,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -721,6 +838,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam" version = "0.8.4" @@ -849,6 +975,67 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "datafusion" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7014432223f4d721cb9786cd88bb89e7464e0ba984d4a7f49db7787f5f268674" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-schema 47.0.0", + "async-compression", + "async-trait", + "bytes", + "bzip2", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-optimizer", + "datafusion-physical-expr", + "datafusion-physical-plan", + "datafusion-sql", + "flate2", + "futures", + "glob", + "half", + "hashbrown 0.14.5", + "indexmap 2.6.0", + "itertools 0.11.0", + "log", + "num_cpus", + "object_store", + "parking_lot", + "parquet", + "percent-encoding", + "pin-project-lite", + "rand", + "sqlparser", + "tempfile", + "tokio", + "tokio-util", + "url", + "uuid", + "xz2", + "zstd 0.12.4", +] + [[package]] name = "datafusion-common" version = "32.0.0" @@ -863,9 +1050,32 @@ dependencies = [ "chrono", "half", "num_cpus", + "object_store", + "parquet", "sqlparser", ] +[[package]] +name = "datafusion-execution" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "780b73b2407050e53f51a9781868593f694102c59e622de9a8aafc0343c4f237" +dependencies = [ + "arrow", + "chrono", + "dashmap", + "datafusion-common", + "datafusion-expr", + "futures", + "hashbrown 0.14.5", + "log", + "object_store", + "parking_lot", + "rand", + "tempfile", + "url", +] + [[package]] name = "datafusion-expr" version = "32.0.0" @@ -881,6 +1091,103 @@ dependencies = [ "strum_macros 0.25.3", ] +[[package]] +name = "datafusion-optimizer" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f2904a432f795484fd45e29ded4537152adb60f636c05691db34fcd94c92c96" +dependencies = [ + "arrow", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-expr", + "datafusion-physical-expr", + "hashbrown 0.14.5", + "itertools 0.11.0", + "log", + "regex-syntax 0.7.5", +] + +[[package]] +name = "datafusion-physical-expr" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57b4968e9a998dc0476c4db7a82f280e2026b25f464e4aa0c3bb9807ee63ddfd" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema 47.0.0", + "base64 0.21.7", + "blake2", + "blake3", + "chrono", + "datafusion-common", + "datafusion-expr", + "half", + "hashbrown 0.14.5", + "hex", + "indexmap 2.6.0", + "itertools 0.11.0", + "libc", + "log", + "md-5", + "paste", + "petgraph", + "rand", + "regex", + "sha2", + "unicode-segmentation", + "uuid", +] + +[[package]] +name = "datafusion-physical-plan" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efd0d1fe54e37a47a2d58a1232c22786f2c28ad35805fdcd08f0253a8b0aaa90" +dependencies = [ + "ahash 0.8.11", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-schema 47.0.0", + "async-trait", + "chrono", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-expr", + "futures", + "half", + "hashbrown 0.14.5", + "indexmap 2.6.0", + "itertools 0.11.0", + "log", + "once_cell", + "parking_lot", + "pin-project-lite", + "rand", + "tokio", + "uuid", +] + +[[package]] +name = "datafusion-sql" +version = "32.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b568d44c87ead99604d704f942e257c8a236ee1bbf890ee3e034ad659dcb2c21" +dependencies = [ + "arrow", + "arrow-schema 47.0.0", + "datafusion-common", + "datafusion-expr", + "log", + "sqlparser", +] + [[package]] name = "der" version = "0.7.9" @@ -925,6 +1232,12 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "dotenvy" version = "0.15.7" @@ -984,6 +1297,12 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flatbuffers" version = "23.5.26" @@ -994,6 +1313,16 @@ dependencies = [ "rustc_version", ] +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "flume" version = "0.11.1" @@ -1034,6 +1363,7 @@ checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", + "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -1084,6 +1414,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -1105,6 +1446,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1242,6 +1584,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iana-time-zone" version = "0.1.61" @@ -1443,12 +1791,27 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "integer-encoding" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" + [[package]] name = "is_terminal_polyfill" version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.12.1" @@ -1473,6 +1836,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.72" @@ -1606,6 +1978,36 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "lz4" +version = "1.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "lzma-sys" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1784,6 +2186,27 @@ dependencies = [ "memchr", ] +[[package]] +name = "object_store" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f930c88a43b1c3f6e776dfe495b4afab89882dbc81530c632db2ed65451ebcb4" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures", + "humantime", + "itertools 0.11.0", + "parking_lot", + "percent-encoding", + "snafu", + "tokio", + "tracing", + "url", + "walkdir", +] + [[package]] name = "once_cell" version = "1.20.2" @@ -1797,6 +2220,7 @@ dependencies = [ "arrow-schema 53.2.0", "chrono", "crossbeam", + "datafusion", "datafusion-expr", "itertools 0.13.0", "optd-persistent", @@ -1821,6 +2245,15 @@ dependencies = [ "trait-variant", ] +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + [[package]] name = "ordered-float" version = "3.9.2" @@ -1893,6 +2326,40 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "parquet" +version = "47.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0463cc3b256d5f50408c49a4be3a16674f4c8ceef60941709620a062b1f6bf4d" +dependencies = [ + "ahash 0.8.11", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ipc", + "arrow-schema 47.0.0", + "arrow-select", + "base64 0.21.7", + "brotli", + "bytes", + "chrono", + "flate2", + "futures", + "hashbrown 0.14.5", + "lz4", + "num", + "num-bigint", + "object_store", + "paste", + "seq-macro", + "snap", + "thrift", + "tokio", + "twox-hash", + "zstd 0.12.4", +] + [[package]] name = "parse-zoneinfo" version = "0.3.1" @@ -1923,6 +2390,16 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" +dependencies = [ + "fixedbitset", + "indexmap 2.6.0", +] + [[package]] name = "phf" version = "0.11.2" @@ -2361,6 +2838,15 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2538,6 +3024,12 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + [[package]] name = "serde" version = "1.0.215" @@ -2588,7 +3080,7 @@ version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -2689,6 +3181,34 @@ dependencies = [ "serde", ] +[[package]] +name = "snafu" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +dependencies = [ + "doc-comment", + "snafu-derive", +] + +[[package]] +name = "snafu-derive" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "socket2" version = "0.5.7" @@ -2855,7 +3375,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -2902,7 +3422,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bigdecimal", "bitflags 2.6.0", "byteorder", @@ -3123,6 +3643,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "thrift" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e54bc85fc7faa8bc175c4bab5b92ba8d9a3ce893d0e9f42cc455c8ab16a9e09" +dependencies = [ + "byteorder", + "integer-encoding", + "ordered-float 2.10.1", +] + [[package]] name = "time" version = "0.3.36" @@ -3198,6 +3729,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "socket2", "tokio-macros", @@ -3226,6 +3758,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.6.8" @@ -3301,6 +3846,16 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + [[package]] name = "typenum" version = "1.17.0" @@ -3334,11 +3889,17 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e70f2a8b45122e719eb623c01822704c4e0907e7e426a05927e1a1cfff5b75d0" +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + [[package]] name = "unicode-width" -version = "0.1.14" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" [[package]] name = "unicode_categories" @@ -3387,6 +3948,7 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ + "getrandom", "serde", ] @@ -3402,6 +3964,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -3488,6 +4060,15 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -3675,6 +4256,15 @@ dependencies = [ "tap", ] +[[package]] +name = "xz2" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" +dependencies = [ + "lzma-sys", +] + [[package]] name = "yansi" version = "1.0.1" @@ -3774,3 +4364,50 @@ dependencies = [ "quote", "syn 2.0.87", ] + +[[package]] +name = "zstd" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" +dependencies = [ + "zstd-safe 6.0.6", +] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe 7.2.1", +] + +[[package]] +name = "zstd-safe" +version = "6.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee98ffd0b48ee95e6c5168188e44a54550b1564d9d530ee21d5f0eaed1069581" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.13+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/optd-cost-model/Cargo.toml b/optd-cost-model/Cargo.toml index 1d41af7..e8b22aa 100644 --- a/optd-cost-model/Cargo.toml +++ b/optd-cost-model/Cargo.toml @@ -2,6 +2,7 @@ name = "optd-cost-model" version = "0.1.0" edition = "2021" +authors = ["Yuanxin Cao", "Lan Lou", "Kunle Li"] [dependencies] optd-persistent = { path = "../optd-persistent", version = "0.1" } @@ -10,10 +11,15 @@ serde_json = "1.0" serde_with = { version = "3.7.0", features = ["json"] } arrow-schema = "53.2.0" datafusion-expr = "32.0.0" +datafusion = "32.0.0" ordered-float = "4.0" chrono = "0.4" itertools = "0.13" +assert_approx_eq = "1.1.0" +trait-variant = "0.1.2" +tokio = { version = "1.0.1", features = ["macros", "rt-multi-thread"] } [dev-dependencies] crossbeam = "0.8" rand = "0.8" +test-case = "3.3" diff --git a/optd-cost-model/src/common/nodes.rs b/optd-cost-model/src/common/nodes.rs index 38e2500..79a47f7 100644 --- a/optd-cost-model/src/common/nodes.rs +++ b/optd-cost-model/src/common/nodes.rs @@ -1,4 +1,5 @@ -use std::sync::Arc; +use core::fmt; +use std::{fmt::Display, sync::Arc}; use arrow_schema::DataType; @@ -24,6 +25,12 @@ pub enum JoinType { RightAnti, } +impl Display for JoinType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + /// TODO: documentation #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PhysicalNodeType { @@ -49,8 +56,7 @@ impl std::fmt::Display for PhysicalNodeType { pub enum PredicateType { List, Constant(ConstantType), - AttributeRef, - ExternAttributeRef, + AttrIndex, UnOp(UnOpType), BinOp(BinOpType), LogOp(LogOpType), @@ -77,7 +83,7 @@ pub struct PredicateNode { /// A generic predicate node type pub typ: PredicateType, /// Child predicate nodes, always materialized - pub children: Vec, + pub children: Vec, /// Data associated with the predicate, if any pub data: Option, } @@ -94,3 +100,28 @@ impl std::fmt::Display for PredicateNode { write!(f, ")") } } + +impl PredicateNode { + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.children[idx].clone() + } + + pub fn unwrap_data(&self) -> Value { + self.data.clone().unwrap() + } +} +pub trait ReprPredicateNode: 'static + Clone { + fn into_pred_node(self) -> ArcPredicateNode; + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option; +} + +impl ReprPredicateNode for ArcPredicateNode { + fn into_pred_node(self) -> ArcPredicateNode { + self + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + Some(pred_node) + } +} diff --git a/optd-cost-model/src/common/predicates/attr_index_pred.rs b/optd-cost-model/src/common/predicates/attr_index_pred.rs new file mode 100644 index 0000000..412c7a3 --- /dev/null +++ b/optd-cost-model/src/common/predicates/attr_index_pred.rs @@ -0,0 +1,42 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +/// [`AttributeIndexPred`] represents the position of an attribute in a schema or +/// [`GroupAttrRefs`]. +/// +/// The `data` field holds the index of the attribute in the schema or [`GroupAttrRefs`]. +#[derive(Clone, Debug)] +pub struct AttrIndexPred(pub ArcPredicateNode); + +impl AttrIndexPred { + pub fn new(attr_idx: u64) -> AttrIndexPred { + AttrIndexPred( + PredicateNode { + typ: PredicateType::AttrIndex, + children: vec![], + data: Some(Value::UInt64(attr_idx)), + } + .into(), + ) + } + + /// Gets the attribute index. + pub fn attr_index(&self) -> u64 { + self.0.data.as_ref().unwrap().as_u64() + } +} + +impl ReprPredicateNode for AttrIndexPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if pred_node.typ != PredicateType::AttrIndex { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/bin_op_pred.rs b/optd-cost-model/src/common/predicates/bin_op_pred.rs index 196d987..5c48688 100644 --- a/optd-cost-model/src/common/predicates/bin_op_pred.rs +++ b/optd-cost-model/src/common/predicates/bin_op_pred.rs @@ -1,3 +1,5 @@ +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum BinOpType { @@ -38,3 +40,48 @@ impl BinOpType { ) } } + +#[derive(Clone, Debug)] +pub struct BinOpPred(pub ArcPredicateNode); + +impl BinOpPred { + pub fn new(left: ArcPredicateNode, right: ArcPredicateNode, op_type: BinOpType) -> Self { + BinOpPred( + PredicateNode { + typ: PredicateType::BinOp(op_type), + children: vec![left, right], + data: None, + } + .into(), + ) + } + + pub fn left_child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn right_child(&self) -> ArcPredicateNode { + self.0.child(1) + } + + pub fn op_type(&self) -> BinOpType { + if let PredicateType::BinOp(op_type) = self.0.typ { + op_type + } else { + panic!("not a bin op") + } + } +} + +impl ReprPredicateNode for BinOpPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::BinOp(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/cast_pred.rs b/optd-cost-model/src/common/predicates/cast_pred.rs new file mode 100644 index 0000000..2e1ef54 --- /dev/null +++ b/optd-cost-model/src/common/predicates/cast_pred.rs @@ -0,0 +1,49 @@ +use arrow_schema::DataType; + +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +use super::data_type_pred::DataTypePred; + +/// [`CastPred`] casts a column from one data type to another. +/// +/// A [`CastPred`] has two children: +/// 1. The original data to cast +/// 2. The target data type to cast to +#[derive(Clone, Debug)] +pub struct CastPred(pub ArcPredicateNode); + +impl CastPred { + pub fn new(child: ArcPredicateNode, cast_to: DataType) -> Self { + CastPred( + PredicateNode { + typ: PredicateType::Cast, + children: vec![child, DataTypePred::new(cast_to).into_pred_node()], + data: None, + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn cast_to(&self) -> DataType { + DataTypePred::from_pred_node(self.0.child(1)) + .unwrap() + .data_type() + } +} + +impl ReprPredicateNode for CastPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::Cast) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/constant_pred.rs b/optd-cost-model/src/common/predicates/constant_pred.rs index 7923ae4..61285f7 100644 --- a/optd-cost-model/src/common/predicates/constant_pred.rs +++ b/optd-cost-model/src/common/predicates/constant_pred.rs @@ -1,5 +1,14 @@ +use std::sync::Arc; + +use arrow_schema::{DataType, IntervalUnit}; +use optd_persistent::cost_model::interface::AttrType; use serde::{Deserialize, Serialize}; +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::{SerializableOrderedF64, Value}, +}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)] pub enum ConstantType { @@ -19,3 +28,193 @@ pub enum ConstantType { Decimal, Binary, } + +impl ConstantType { + pub fn get_data_type_from_value(value: &Value) -> Self { + match value { + Value::Bool(_) => ConstantType::Bool, + Value::String(_) => ConstantType::Utf8String, + Value::UInt8(_) => ConstantType::UInt8, + Value::UInt16(_) => ConstantType::UInt16, + Value::UInt32(_) => ConstantType::UInt32, + Value::UInt64(_) => ConstantType::UInt64, + Value::Int8(_) => ConstantType::Int8, + Value::Int16(_) => ConstantType::Int16, + Value::Int32(_) => ConstantType::Int32, + Value::Int64(_) => ConstantType::Int64, + Value::Float(_) => ConstantType::Float64, + Value::Date32(_) => ConstantType::Date, + _ => unimplemented!("get_data_type_from_value() not implemented for value {value}"), + } + } + + // TODO: current DataType and ConstantType are not 1 to 1 mapping + // optd schema stores constantType from data type in catalog.get + // for decimal128, the precision is lost + pub fn from_data_type(data_type: DataType) -> Self { + match data_type { + DataType::Binary => ConstantType::Binary, + DataType::Boolean => ConstantType::Bool, + DataType::UInt8 => ConstantType::UInt8, + DataType::UInt16 => ConstantType::UInt16, + DataType::UInt32 => ConstantType::UInt32, + DataType::UInt64 => ConstantType::UInt64, + DataType::Int8 => ConstantType::Int8, + DataType::Int16 => ConstantType::Int16, + DataType::Int32 => ConstantType::Int32, + DataType::Int64 => ConstantType::Int64, + DataType::Float64 => ConstantType::Float64, + DataType::Date32 => ConstantType::Date, + DataType::Interval(IntervalUnit::MonthDayNano) => ConstantType::IntervalMonthDateNano, + DataType::Utf8 => ConstantType::Utf8String, + DataType::Decimal128(_, _) => ConstantType::Decimal, + _ => unimplemented!("no conversion to ConstantType for DataType {data_type}"), + } + } + + pub fn into_data_type(&self) -> DataType { + match self { + ConstantType::Binary => DataType::Binary, + ConstantType::Bool => DataType::Boolean, + ConstantType::UInt8 => DataType::UInt8, + ConstantType::UInt16 => DataType::UInt16, + ConstantType::UInt32 => DataType::UInt32, + ConstantType::UInt64 => DataType::UInt64, + ConstantType::Int8 => DataType::Int8, + ConstantType::Int16 => DataType::Int16, + ConstantType::Int32 => DataType::Int32, + ConstantType::Int64 => DataType::Int64, + ConstantType::Float64 => DataType::Float64, + ConstantType::Date => DataType::Date32, + ConstantType::IntervalMonthDateNano => DataType::Interval(IntervalUnit::MonthDayNano), + ConstantType::Decimal => DataType::Float64, + ConstantType::Utf8String => DataType::Utf8, + } + } + + pub fn from_persistent_attr_type(attr_type: AttrType) -> Self { + match attr_type { + AttrType::Integer => ConstantType::Int32, + AttrType::Float => ConstantType::Float64, + AttrType::Varchar => ConstantType::Utf8String, + AttrType::Boolean => ConstantType::Bool, + } + } +} + +#[derive(Clone, Debug)] +pub struct ConstantPred(pub ArcPredicateNode); + +impl ConstantPred { + pub fn new(value: Value) -> Self { + let typ = ConstantType::get_data_type_from_value(&value); + Self::new_with_type(value, typ) + } + + pub fn new_with_type(value: Value, typ: ConstantType) -> Self { + ConstantPred( + PredicateNode { + typ: PredicateType::Constant(typ), + children: vec![], + data: Some(value), + } + .into(), + ) + } + + pub fn bool(value: bool) -> Self { + Self::new_with_type(Value::Bool(value), ConstantType::Bool) + } + + pub fn string(value: impl AsRef) -> Self { + Self::new_with_type( + Value::String(value.as_ref().into()), + ConstantType::Utf8String, + ) + } + + pub fn uint8(value: u8) -> Self { + Self::new_with_type(Value::UInt8(value), ConstantType::UInt8) + } + + pub fn uint16(value: u16) -> Self { + Self::new_with_type(Value::UInt16(value), ConstantType::UInt16) + } + + pub fn uint32(value: u32) -> Self { + Self::new_with_type(Value::UInt32(value), ConstantType::UInt32) + } + + pub fn uint64(value: u64) -> Self { + Self::new_with_type(Value::UInt64(value), ConstantType::UInt64) + } + + pub fn int8(value: i8) -> Self { + Self::new_with_type(Value::Int8(value), ConstantType::Int8) + } + + pub fn int16(value: i16) -> Self { + Self::new_with_type(Value::Int16(value), ConstantType::Int16) + } + + pub fn int32(value: i32) -> Self { + Self::new_with_type(Value::Int32(value), ConstantType::Int32) + } + + pub fn int64(value: i64) -> Self { + Self::new_with_type(Value::Int64(value), ConstantType::Int64) + } + + pub fn interval_month_day_nano(value: i128) -> Self { + Self::new_with_type(Value::Int128(value), ConstantType::IntervalMonthDateNano) + } + + pub fn float64(value: f64) -> Self { + Self::new_with_type( + Value::Float(SerializableOrderedF64(value.into())), + ConstantType::Float64, + ) + } + + pub fn date(value: i64) -> Self { + Self::new_with_type(Value::Int64(value), ConstantType::Date) + } + + pub fn decimal(value: f64) -> Self { + Self::new_with_type( + Value::Float(SerializableOrderedF64(value.into())), + ConstantType::Decimal, + ) + } + + pub fn serialized(value: Arc<[u8]>) -> Self { + Self::new_with_type(Value::Serialized(value), ConstantType::Binary) + } + + /// Gets the constant value. + pub fn value(&self) -> Value { + self.0.data.clone().unwrap() + } + + pub fn constant_type(&self) -> ConstantType { + if let PredicateType::Constant(typ) = self.0.typ { + typ + } else { + panic!("not a constant") + } + } +} + +impl ReprPredicateNode for ConstantPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(rel_node: ArcPredicateNode) -> Option { + if let PredicateType::Constant(_) = rel_node.typ { + Some(Self(rel_node)) + } else { + None + } + } +} diff --git a/optd-cost-model/src/common/predicates/data_type_pred.rs b/optd-cost-model/src/common/predicates/data_type_pred.rs new file mode 100644 index 0000000..fe29336 --- /dev/null +++ b/optd-cost-model/src/common/predicates/data_type_pred.rs @@ -0,0 +1,40 @@ +use arrow_schema::DataType; + +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +#[derive(Clone, Debug)] +pub struct DataTypePred(pub ArcPredicateNode); + +impl DataTypePred { + pub fn new(typ: DataType) -> Self { + DataTypePred( + PredicateNode { + typ: PredicateType::DataType(typ), + children: vec![], + data: None, + } + .into(), + ) + } + + pub fn data_type(&self) -> DataType { + if let PredicateType::DataType(ref data_type) = self.0.typ { + data_type.clone() + } else { + panic!("not a data type") + } + } +} + +impl ReprPredicateNode for DataTypePred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::DataType(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/in_list_pred.rs b/optd-cost-model/src/common/predicates/in_list_pred.rs new file mode 100644 index 0000000..8d3b511 --- /dev/null +++ b/optd-cost-model/src/common/predicates/in_list_pred.rs @@ -0,0 +1,48 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +use super::list_pred::ListPred; + +#[derive(Clone, Debug)] +pub struct InListPred(pub ArcPredicateNode); + +impl InListPred { + pub fn new(child: ArcPredicateNode, list: ListPred, negated: bool) -> Self { + InListPred( + PredicateNode { + typ: PredicateType::InList, + children: vec![child, list.into_pred_node()], + data: Some(Value::Bool(negated)), + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn list(&self) -> ListPred { + ListPred::from_pred_node(self.0.child(1)).unwrap() + } + + /// `true` for `NOT IN`. + pub fn negated(&self) -> bool { + self.0.data.as_ref().unwrap().as_bool() + } +} + +impl ReprPredicateNode for InListPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::InList) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/like_pred.rs b/optd-cost-model/src/common/predicates/like_pred.rs new file mode 100644 index 0000000..bf9fe31 --- /dev/null +++ b/optd-cost-model/src/common/predicates/like_pred.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use crate::common::{ + nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}, + values::Value, +}; + +#[derive(Clone, Debug)] +pub struct LikePred(pub ArcPredicateNode); + +impl LikePred { + pub fn new( + negated: bool, + case_insensitive: bool, + child: ArcPredicateNode, + pattern: ArcPredicateNode, + ) -> Self { + // TODO: support multiple values in data. + let negated = if negated { 1 } else { 0 }; + let case_insensitive = if case_insensitive { 1 } else { 0 }; + LikePred( + PredicateNode { + typ: PredicateType::Like, + children: vec![child.into_pred_node(), pattern.into_pred_node()], + data: Some(Value::Serialized(Arc::new([negated, case_insensitive]))), + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn pattern(&self) -> ArcPredicateNode { + self.0.child(1) + } + + /// `true` for `NOT LIKE`. + pub fn negated(&self) -> bool { + match self.0.data.as_ref().unwrap() { + Value::Serialized(data) => data[0] != 0, + _ => panic!("not a serialized value"), + } + } + + pub fn case_insensitive(&self) -> bool { + match self.0.data.as_ref().unwrap() { + Value::Serialized(data) => data[1] != 0, + _ => panic!("not a serialized value"), + } + } +} + +impl ReprPredicateNode for LikePred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::Like) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/list_pred.rs b/optd-cost-model/src/common/predicates/list_pred.rs new file mode 100644 index 0000000..972598d --- /dev/null +++ b/optd-cost-model/src/common/predicates/list_pred.rs @@ -0,0 +1,47 @@ +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +#[derive(Clone, Debug)] +pub struct ListPred(pub ArcPredicateNode); + +impl ListPred { + pub fn new(preds: Vec) -> Self { + ListPred( + PredicateNode { + typ: PredicateType::List, + children: preds, + data: None, + } + .into(), + ) + } + + /// Gets number of expressions in the list + pub fn len(&self) -> usize { + self.0.children.len() + } + + pub fn is_empty(&self) -> bool { + self.0.children.is_empty() + } + + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.0.child(idx) + } + + pub fn to_vec(&self) -> Vec { + self.0.children.clone() + } +} + +impl ReprPredicateNode for ListPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if pred_node.typ != PredicateType::List { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/log_op_pred.rs b/optd-cost-model/src/common/predicates/log_op_pred.rs index 88c5746..1899cb1 100644 --- a/optd-cost-model/src/common/predicates/log_op_pred.rs +++ b/optd-cost-model/src/common/predicates/log_op_pred.rs @@ -1,5 +1,9 @@ use std::fmt::Display; +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + +use super::list_pred::ListPred; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum LogOpType { @@ -12,3 +16,70 @@ impl Display for LogOpType { write!(f, "{:?}", self) } } + +#[derive(Clone, Debug)] +pub struct LogOpPred(pub ArcPredicateNode); + +impl LogOpPred { + pub fn new(op_type: LogOpType, preds: Vec) -> Self { + LogOpPred( + PredicateNode { + typ: PredicateType::LogOp(op_type), + children: preds, + data: None, + } + .into(), + ) + } + + /// flatten_nested_logical is a helper function to flatten nested logical operators with same op + /// type eg. (a AND (b AND c)) => ExprList([a, b, c]) + /// (a OR (b OR c)) => ExprList([a, b, c]) + /// It assume the children of the input expr_list are already flattened + /// and can only be used in bottom up manner + pub fn new_flattened_nested_logical(op: LogOpType, expr_list: ListPred) -> Self { + // Since we assume that we are building the children bottom up, + // there is no need to call flatten_nested_logical recursively + let mut new_expr_list = Vec::new(); + for child in expr_list.to_vec() { + if let PredicateType::LogOp(child_op) = child.typ { + if child_op == op { + let child_log_op_expr = LogOpPred::from_pred_node(child).unwrap(); + new_expr_list.extend(child_log_op_expr.children().to_vec()); + continue; + } + } + new_expr_list.push(child.clone()); + } + LogOpPred::new(op, new_expr_list) + } + + pub fn children(&self) -> Vec { + self.0.children.clone() + } + + pub fn child(&self, idx: usize) -> ArcPredicateNode { + self.0.child(idx) + } + + pub fn op_type(&self) -> LogOpType { + if let PredicateType::LogOp(op_type) = self.0.typ { + op_type + } else { + panic!("not a log op") + } + } +} + +impl ReprPredicateNode for LogOpPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::LogOp(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/predicates/mod.rs b/optd-cost-model/src/common/predicates/mod.rs index 87e6e94..40c64cf 100644 --- a/optd-cost-model/src/common/predicates/mod.rs +++ b/optd-cost-model/src/common/predicates/mod.rs @@ -1,6 +1,12 @@ +pub mod attr_index_pred; pub mod bin_op_pred; +pub mod cast_pred; pub mod constant_pred; +pub mod data_type_pred; pub mod func_pred; +pub mod in_list_pred; +pub mod like_pred; +pub mod list_pred; pub mod log_op_pred; pub mod sort_order_pred; pub mod un_op_pred; diff --git a/optd-cost-model/src/common/predicates/un_op_pred.rs b/optd-cost-model/src/common/predicates/un_op_pred.rs index d33158f..a3fc270 100644 --- a/optd-cost-model/src/common/predicates/un_op_pred.rs +++ b/optd-cost-model/src/common/predicates/un_op_pred.rs @@ -1,5 +1,7 @@ use std::fmt::Display; +use crate::common::nodes::{ArcPredicateNode, PredicateNode, PredicateType, ReprPredicateNode}; + /// TODO: documentation #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] pub enum UnOpType { @@ -12,3 +14,44 @@ impl Display for UnOpType { write!(f, "{:?}", self) } } + +#[derive(Clone, Debug)] +pub struct UnOpPred(pub ArcPredicateNode); + +impl UnOpPred { + pub fn new(child: ArcPredicateNode, op_type: UnOpType) -> Self { + UnOpPred( + PredicateNode { + typ: PredicateType::UnOp(op_type), + children: vec![child], + data: None, + } + .into(), + ) + } + + pub fn child(&self) -> ArcPredicateNode { + self.0.child(0) + } + + pub fn op_type(&self) -> UnOpType { + if let PredicateType::UnOp(op_type) = self.0.typ { + op_type + } else { + panic!("not a un op") + } + } +} + +impl ReprPredicateNode for UnOpPred { + fn into_pred_node(self) -> ArcPredicateNode { + self.0 + } + + fn from_pred_node(pred_node: ArcPredicateNode) -> Option { + if !matches!(pred_node.typ, PredicateType::UnOp(_)) { + return None; + } + Some(Self(pred_node)) + } +} diff --git a/optd-cost-model/src/common/properties/attr_ref.rs b/optd-cost-model/src/common/properties/attr_ref.rs index eb10fbb..d6105b6 100644 --- a/optd-cost-model/src/common/properties/attr_ref.rs +++ b/optd-cost-model/src/common/properties/attr_ref.rs @@ -23,6 +23,10 @@ pub enum AttrRef { } impl AttrRef { + pub fn new_base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self { + AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) + } + pub fn base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self { AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) } @@ -161,9 +165,9 @@ impl SemanticCorrelation { } /// [`GroupAttrRefs`] represents the attributes of a group in a query. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct GroupAttrRefs { - attribute_refs: AttrRefs, + attr_refs: AttrRefs, /// Correlation of the output attributes of the group. output_correlation: Option, } @@ -171,13 +175,13 @@ pub struct GroupAttrRefs { impl GroupAttrRefs { pub fn new(attribute_refs: AttrRefs, output_correlation: Option) -> Self { Self { - attribute_refs, + attr_refs: attribute_refs, output_correlation, } } - pub fn base_table_attribute_refs(&self) -> &AttrRefs { - &self.attribute_refs + pub fn attr_refs(&self) -> &AttrRefs { + &self.attr_refs } pub fn output_correlation(&self) -> Option<&SemanticCorrelation> { diff --git a/optd-cost-model/src/common/properties/mod.rs b/optd-cost-model/src/common/properties/mod.rs index c9acbd1..a90d634 100644 --- a/optd-cost-model/src/common/properties/mod.rs +++ b/optd-cost-model/src/common/properties/mod.rs @@ -21,3 +21,21 @@ impl std::fmt::Display for Attribute { } } } + +impl Attribute { + pub fn new(name: String, typ: ConstantType, nullable: bool) -> Self { + Self { + name, + typ, + nullable, + } + } + + pub fn new_non_null_int64(name: String) -> Self { + Self { + name, + typ: ConstantType::Int64, + nullable: false, + } + } +} diff --git a/optd-cost-model/src/common/properties/schema.rs b/optd-cost-model/src/common/properties/schema.rs index 4ee4fce..d25a23a 100644 --- a/optd-cost-model/src/common/properties/schema.rs +++ b/optd-cost-model/src/common/properties/schema.rs @@ -33,3 +33,9 @@ impl Schema { self.len() == 0 } } + +impl From> for Schema { + fn from(attributes: Vec) -> Self { + Self::new(attributes) + } +} diff --git a/optd-cost-model/src/common/types.rs b/optd-cost-model/src/common/types.rs index 1e92355..fecd143 100644 --- a/optd-cost-model/src/common/types.rs +++ b/optd-cost-model/src/common/types.rs @@ -1,24 +1,27 @@ use std::fmt::Display; +/// TODO: Implement from and to methods for the following types to enable conversion +/// to and from their persistent counterparts. + /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct GroupId(pub usize); +pub struct GroupId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct ExprId(pub usize); +pub struct ExprId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct TableId(pub usize); +pub struct TableId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct AttrId(pub usize); +pub struct AttrId(pub u64); /// TODO: documentation #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default, Hash)] -pub struct EpochId(pub usize); +pub struct EpochId(pub u64); impl Display for GroupId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/optd-cost-model/src/cost/agg.rs b/optd-cost-model/src/cost/agg.rs index 8b13789..f5edc7a 100644 --- a/optd-cost-model/src/cost/agg.rs +++ b/optd-cost-model/src/cost/agg.rs @@ -1 +1,203 @@ +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{attr_index_pred::AttrIndexPred, list_pred::ListPred}, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, + }, + cost_model::CostModelImpl, + stats::DEFAULT_NUM_DISTINCT, + storage::CostModelStorageManager, + CostModelError, CostModelResult, EstimatedStatistic, SemanticError, +}; +impl CostModelImpl { + pub async fn get_agg_row_cnt( + &self, + group_id: GroupId, + group_by: ArcPredicateNode, + ) -> CostModelResult { + let group_by = ListPred::from_pred_node(group_by).unwrap(); + if group_by.is_empty() { + Ok(EstimatedStatistic(1.0)) + } else { + // Multiply the n-distinct of all the group by columns. + // TODO: improve with multi-dimensional n-distinct + let mut row_cnt = 1; + + for node in &group_by.0.children { + match node.typ { + PredicateType::AttrIndex => { + let attr_ref = + AttrIndexPred::from_pred_node(node.clone()).ok_or_else(|| { + SemanticError::InvalidPredicate( + "Expected AttributeRef predicate".to_string(), + ) + })?; + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref.attr_index()) + { + // TODO: Only query ndistinct instead of all kinds of stats. + let stats_option = + self.get_attribute_comb_stats(table_id, &[attr_idx]).await?; + + let ndistinct = match stats_option { + Some(stats) => stats.ndistinct, + None => { + // The column type is not supported or stats are missing. + DEFAULT_NUM_DISTINCT + } + }; + row_cnt *= ndistinct; + } else { + // TOOD: Handle derived attributes. + row_cnt *= DEFAULT_NUM_DISTINCT; + } + } + _ => { + // TODO: Consider the case where `GROUP BY 1`. + panic!("GROUP BY must have attribute ref predicate"); + } + } + } + Ok(EstimatedStatistic(row_cnt as f64)) + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, ops::Deref}; + + use crate::{ + common::{ + predicates::constant_pred::ConstantType, + properties::Attribute, + types::{GroupId, TableId}, + values::Value, + }, + cost_model::tests::{ + attr_index, cnst, create_mock_cost_model, create_mock_cost_model_with_attr_types, + empty_list, empty_per_attr_stats, list, TestPerAttributeStats, TEST_ATTR1_BASE_INDEX, + TEST_ATTR2_BASE_INDEX, TEST_ATTR3_BASE_INDEX, TEST_GROUP1_ID, TEST_TABLE1_ID, + }, + stats::{utilities::simple_map::SimpleMap, MostCommonValues, DEFAULT_NUM_DISTINCT}, + EstimatedStatistic, + }; + + #[tokio::test] + async fn test_agg_no_stats() { + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![], + vec![HashMap::from([ + (TEST_ATTR1_BASE_INDEX, ConstantType::Int32), + (TEST_ATTR2_BASE_INDEX, ConstantType::Int32), + ])], + vec![None], + ); + + // Group by empty list should return 1. + let group_bys = empty_list(); + assert_eq!( + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), + EstimatedStatistic(1.0) + ); + + // Group by single column should return the default value since there are no stats. + let group_bys = list(vec![attr_index(0)]); + assert_eq!( + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), + EstimatedStatistic(DEFAULT_NUM_DISTINCT as f64) + ); + + // Group by two columns should return the default value squared since there are no stats. + let group_bys = list(vec![attr_index(0), attr_index(1)]); + assert_eq!( + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), + EstimatedStatistic((DEFAULT_NUM_DISTINCT * DEFAULT_NUM_DISTINCT) as f64) + ); + } + + #[tokio::test] + async fn test_agg_with_stats() { + let attr1_ndistinct = 12; + let attr2_ndistinct = 645; + let attr1_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + None, + attr1_ndistinct, + 0.0, + ); + let attr2_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + None, + attr2_ndistinct, + 0.0, + ); + + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![HashMap::from([ + (TEST_ATTR1_BASE_INDEX, attr1_stats), + (TEST_ATTR2_BASE_INDEX, attr2_stats), + ])], + vec![HashMap::from([ + (TEST_ATTR1_BASE_INDEX, ConstantType::Int32), + (TEST_ATTR2_BASE_INDEX, ConstantType::Int32), + (TEST_ATTR3_BASE_INDEX, ConstantType::Int32), + ])], + vec![None], + ); + + // Group by empty list should return 1. + let group_bys = empty_list(); + assert_eq!( + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), + EstimatedStatistic(1.0) + ); + + // Group by single column should return the n-distinct of the column. + let group_bys = list(vec![attr_index(0)]); + assert_eq!( + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), + EstimatedStatistic(attr1_ndistinct as f64) + ); + + // Group by two columns should return the product of the n-distinct of the columns. + let group_bys = list(vec![attr_index(0), attr_index(1)]); + assert_eq!( + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), + EstimatedStatistic((attr1_ndistinct * attr2_ndistinct) as f64) + ); + + // Group by multiple columns should return the product of the n-distinct of the columns. If one of the columns + // does not have stats, it should use the default value instead. + let group_bys = list(vec![attr_index(0), attr_index(1), attr_index(2)]); + assert_eq!( + cost_model + .get_agg_row_cnt(TEST_GROUP1_ID, group_bys) + .await + .unwrap(), + EstimatedStatistic((attr1_ndistinct * attr2_ndistinct * DEFAULT_NUM_DISTINCT) as f64) + ); + } +} diff --git a/optd-cost-model/src/cost/filter.rs b/optd-cost-model/src/cost/filter.rs deleted file mode 100644 index 8b13789..0000000 --- a/optd-cost-model/src/cost/filter.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/optd-cost-model/src/cost/filter/attribute.rs b/optd-cost-model/src/cost/filter/attribute.rs new file mode 100644 index 0000000..7a082b7 --- /dev/null +++ b/optd-cost-model/src/cost/filter/attribute.rs @@ -0,0 +1,183 @@ +use std::ops::Bound; + +use crate::{ + common::{types::TableId, values::Value}, + cost_model::CostModelImpl, + stats::{AttributeCombValue, AttributeCombValueStats, DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL}, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + /// Get the selectivity of an expression of the form "attribute equals value" (or "value equals + /// attribute") Will handle the case of statistics missing + /// Equality predicates are handled entirely differently from range predicates so this is its + /// own function + /// Also, get_attribute_equality_selectivity is a subroutine when computing range + /// selectivity, which is another reason for separating these into two functions + /// is_eq means whether it's == or != + /// + /// Currently, we only support calculating the equality selectivity for an existed attribute, + /// not a derived attribute. + /// TODO: Support derived attributes. + pub(crate) async fn get_attribute_equality_selectivity( + &self, + table_id: TableId, + attr_base_index: u64, + value: &Value, + is_eq: bool, + ) -> CostModelResult { + let ret_sel = { + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_base_index]) + .await? + { + let eq_freq = + if let Some(freq) = attribute_stats.mcvs.freq(&vec![Some(value.clone())]) { + freq + } else { + let non_mcv_freq = 1.0 - attribute_stats.mcvs.total_freq(); + // always safe because usize is at least as large as i32 + let ndistinct_as_usize = attribute_stats.ndistinct as usize; + let non_mcv_cnt = ndistinct_as_usize - attribute_stats.mcvs.cnt(); + if non_mcv_cnt == 0 { + return Ok(0.0); + } + // note that nulls are not included in ndistinct so we don't need to do non_mcv_cnt + // - 1 if null_frac > 0 + (non_mcv_freq - attribute_stats.null_frac) / (non_mcv_cnt as f64) + }; + if is_eq { + eq_freq + } else { + 1.0 - eq_freq - attribute_stats.null_frac + } + } else { + #[allow(clippy::collapsible_else_if)] + if is_eq { + DEFAULT_EQ_SEL + } else { + 1.0 - DEFAULT_EQ_SEL + } + } + }; + + assert!( + (0.0..=1.0).contains(&ret_sel), + "ret_sel ({}) should be in [0, 1]", + ret_sel + ); + Ok(ret_sel) + } + + /// Compute the frequency of values in a attribute less than or equal to the given value. + fn get_attribute_leq_value_freq( + per_attribute_stats: &AttributeCombValueStats, + value: &Value, + ) -> f64 { + // because distr does not include the values in MCVs, we need to compute the CDFs there as + // well because nulls return false in any comparison, they are never included when + // computing range selectivity + let distr_leq_freq = per_attribute_stats.distr.as_ref().unwrap().cdf(value); + let value = value.clone(); + let pred = Box::new(move |val: &AttributeCombValue| *val[0].as_ref().unwrap() <= value); + let mcvs_leq_freq = per_attribute_stats.mcvs.freq_over_pred(pred); + let ret_freq = distr_leq_freq + mcvs_leq_freq; + assert!( + (0.0..=1.0).contains(&ret_freq), + "ret_freq ({}) should be in [0, 1]", + ret_freq + ); + ret_freq + } + + /// Compute the frequency of values in a attribute less than the given value. + /// + /// Currently, we only support calculating the equality selectivity for an existed attribute, + /// not a derived attribute. + /// TODO: Support derived attributes. + async fn get_attribute_lt_value_freq( + &self, + attribute_stats: &AttributeCombValueStats, + table_id: TableId, + attr_base_index: u64, + value: &Value, + ) -> CostModelResult { + // depending on whether value is in mcvs or not, we use different logic to turn total_lt_cdf + // into total_leq_cdf this logic just so happens to be the exact same logic as + // get_attribute_equality_selectivity implements + let ret_freq = Self::get_attribute_leq_value_freq(attribute_stats, value) + - self + .get_attribute_equality_selectivity(table_id, attr_base_index, value, true) + .await?; + assert!( + (0.0..=1.0).contains(&ret_freq), + "ret_freq ({}) should be in [0, 1]", + ret_freq + ); + Ok(ret_freq) + } + + /// Get the selectivity of an expression of the form "attribute =/> value" (or "value + /// =/> attribute"). Computes selectivity based off of statistics. + /// Range predicates are handled entirely differently from equality predicates so this is its + /// own function. If it is unable to find the statistics, it returns DEFAULT_INEQ_SEL. + /// The selectivity is computed as quantile of the right bound minus quantile of the left bound. + /// + /// Currently, we only support calculating the equality selectivity for an existed attribute, + /// not a derived attribute. + /// TODO: Support derived attributes. + pub(crate) async fn get_attribute_range_selectivity( + &self, + table_id: TableId, + attr_base_index: u64, + start: Bound<&Value>, + end: Bound<&Value>, + ) -> CostModelResult { + // TODO: Consider attribute is a derived attribute + if let Some(attribute_stats) = self + .get_attribute_comb_stats(table_id, &[attr_base_index]) + .await? + { + let left_quantile = match start { + Bound::Unbounded => 0.0, + Bound::Included(value) => { + self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + ) + .await? + } + Bound::Excluded(value) => { + Self::get_attribute_leq_value_freq(&attribute_stats, value) + } + }; + let right_quantile = match end { + Bound::Unbounded => 1.0, + Bound::Included(value) => { + Self::get_attribute_leq_value_freq(&attribute_stats, value) + } + Bound::Excluded(value) => { + self.get_attribute_lt_value_freq( + &attribute_stats, + table_id, + attr_base_index, + value, + ) + .await? + } + }; + assert!( + left_quantile <= right_quantile, + "left_quantile ({}) should be <= right_quantile ({})", + left_quantile, + right_quantile + ); + Ok(right_quantile - left_quantile) + } else { + Ok(DEFAULT_INEQ_SEL) + } + } +} diff --git a/optd-cost-model/src/cost/filter/comp_op.rs b/optd-cost-model/src/cost/filter/comp_op.rs new file mode 100644 index 0000000..5270819 --- /dev/null +++ b/optd-cost-model/src/cost/filter/comp_op.rs @@ -0,0 +1,280 @@ +use std::ops::Bound; + +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{ + attr_index_pred::AttrIndexPred, bin_op_pred::BinOpType, cast_pred::CastPred, + constant_pred::ConstantPred, + }, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, + values::Value, + }, + cost_model::CostModelImpl, + stats::{DEFAULT_EQ_SEL, DEFAULT_INEQ_SEL, UNIMPLEMENTED_SEL}, + storage::CostModelStorageManager, + CostModelResult, SemanticError, +}; + +impl CostModelImpl { + /// Comparison operators are the base case for recursion in get_filter_selectivity() + pub(crate) async fn get_comp_op_selectivity( + &self, + group_id: GroupId, + comp_bin_op_typ: BinOpType, + left: ArcPredicateNode, + right: ArcPredicateNode, + ) -> CostModelResult { + assert!(comp_bin_op_typ.is_comparison()); + + // I intentionally performed moves on left and right. This way, we don't accidentally use + // them after this block + let semantic_res = self.get_semantic_nodes(group_id, left, right).await; + if semantic_res.is_err() { + return Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)); + } + let (attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref) = semantic_res.unwrap(); + + // Handle the different cases of semantic nodes. + if attr_ref_exprs.is_empty() { + Ok(UNIMPLEMENTED_SEL) + } else if attr_ref_exprs.len() == 1 { + let attr_ref_expr = attr_ref_exprs + .first() + .expect("we just checked that attr_ref_exprs.len() == 1"); + let attr_ref_idx = attr_ref_expr.attr_index(); + + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref_idx) + { + if values.len() == 1 { + let value = values + .first() + .expect("we just checked that values.len() == 1"); + match comp_bin_op_typ { + BinOpType::Eq => { + self.get_attribute_equality_selectivity(table_id, attr_idx, value, true) + .await + } + BinOpType::Neq => { + self.get_attribute_equality_selectivity( + table_id, + attr_ref_idx, + value, + false, + ) + .await + } + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => { + let start = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Unbounded, + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Unbounded, + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Excluded(value), + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Included(value), + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + let end = match (comp_bin_op_typ, is_left_attr_ref) { + (BinOpType::Lt, true) | (BinOpType::Geq, false) => Bound::Excluded(value), + (BinOpType::Leq, true) | (BinOpType::Gt, false) => Bound::Included(value), + (BinOpType::Gt, true) | (BinOpType::Leq, false) => Bound::Unbounded, + (BinOpType::Geq, true) | (BinOpType::Lt, false) => Bound::Unbounded, + _ => unreachable!("all comparison BinOpTypes were enumerated. this should be unreachable"), + }; + self.get_attribute_range_selectivity(table_id, attr_ref_idx, start, end) + .await + } + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), + } + } else { + let non_attr_ref_expr = non_attr_ref_exprs.first().expect( + "non_attr_ref_exprs should have a value since attr_ref_exprs.len() == 1", + ); + + match non_attr_ref_expr.as_ref().typ { + PredicateType::BinOp(_) => { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } + PredicateType::Cast => Ok(UNIMPLEMENTED_SEL), + PredicateType::Constant(_) => { + unreachable!( + "we should have handled this in the values.len() == 1 branch" + ) + } + _ => unimplemented!( + "unhandled case of comparing a attribute ref node to {}", + non_attr_ref_expr.as_ref().typ + ), + } + } + } else { + // TODO: attribute is derived + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } + } else if attr_ref_exprs.len() == 2 { + Ok(Self::get_default_comparison_op_selectivity(comp_bin_op_typ)) + } else { + unreachable!("we could have at most pushed left and right into attr_ref_exprs") + } + } + + /// Convert the left and right child nodes of some operation to what they semantically are. + /// This is convenient to avoid repeating the same logic just with "left" and "right" swapped. + /// The last return value is true when the input node (left) is a AttributeRefPred. + #[allow(clippy::type_complexity)] + async fn get_semantic_nodes( + &self, + group_id: GroupId, + left: ArcPredicateNode, + right: ArcPredicateNode, + ) -> CostModelResult<(Vec, Vec, Vec, bool)> { + let mut attr_ref_exprs = vec![]; + let mut values = vec![]; + let mut non_attr_ref_exprs = vec![]; + let is_left_attr_ref; + + // Recursively unwrap casts as much as we can. + let mut uncasted_left = left; + let mut uncasted_right = right; + loop { + // println!("loop {}, uncasted_left={:?}, uncasted_right={:?}", Local::now(), + // uncasted_left, uncasted_right); + if uncasted_left.as_ref().typ == PredicateType::Cast + && uncasted_right.as_ref().typ == PredicateType::Cast + { + let left_cast_expr = CastPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Cast"); + let right_cast_expr = CastPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Cast"); + assert!(left_cast_expr.cast_to() == right_cast_expr.cast_to()); + uncasted_left = left_cast_expr.child().into_pred_node(); + uncasted_right = right_cast_expr.child().into_pred_node(); + } else if uncasted_left.as_ref().typ == PredicateType::Cast + || uncasted_right.as_ref().typ == PredicateType::Cast + { + let is_left_cast = uncasted_left.as_ref().typ == PredicateType::Cast; + let (mut cast_node, mut non_cast_node) = if is_left_cast { + (uncasted_left, uncasted_right) + } else { + (uncasted_right, uncasted_left) + }; + + let cast_expr = CastPred::from_pred_node(cast_node) + .expect("we already checked that the type is Cast"); + let cast_expr_child = cast_expr.child().into_pred_node(); + let cast_expr_cast_to = cast_expr.cast_to(); + + let should_break = match cast_expr_child.typ { + PredicateType::Constant(_) => { + cast_node = ConstantPred::new( + ConstantPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is Constant") + .value() + .convert_to_type(cast_expr_cast_to), + ) + .into_pred_node(); + false + } + PredicateType::AttrIndex => { + let attr_ref_expr = AttrIndexPred::from_pred_node(cast_expr_child) + .expect("we already checked that the type is AttributeRef"); + let attr_ref_idx = attr_ref_expr.attr_index(); + cast_node = attr_ref_expr.into_pred_node(); + // The "invert" cast is to invert the cast so that we're casting the + // non_cast_node to the attribute's original type. + let attribute_info = self.memo.get_attribute_info(group_id, attr_ref_idx); + let invert_cast_data_type = &attribute_info.typ.into_data_type(); + + match non_cast_node.typ { + PredicateType::AttrIndex => { + // In general, there's no way to remove the Cast here. We can't move + // the Cast to the other AttributeRef + // because that would lead to an infinite loop. Thus, we just leave + // the cast where it is and break. + true + } + _ => { + non_cast_node = + CastPred::new(non_cast_node, invert_cast_data_type.clone()) + .into_pred_node(); + false + } + } + } + _ => todo!(), + }; + + (uncasted_left, uncasted_right) = if is_left_cast { + (cast_node, non_cast_node) + } else { + (non_cast_node, cast_node) + }; + + if should_break { + break; + } + } else { + break; + } + } + + // Sort nodes into attr_ref_exprs, values, and non_attr_ref_exprs + match uncasted_left.as_ref().typ { + PredicateType::AttrIndex => { + is_left_attr_ref = true; + attr_ref_exprs.push( + AttrIndexPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is AttributeRef"), + ); + } + PredicateType::Constant(_) => { + is_left_attr_ref = false; + values.push( + ConstantPred::from_pred_node(uncasted_left) + .expect("we already checked that the type is Constant") + .value(), + ) + } + _ => { + is_left_attr_ref = false; + non_attr_ref_exprs.push(uncasted_left); + } + } + match uncasted_right.as_ref().typ { + PredicateType::AttrIndex => { + attr_ref_exprs.push( + AttrIndexPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is AttributeRef"), + ); + } + PredicateType::Constant(_) => values.push( + ConstantPred::from_pred_node(uncasted_right) + .expect("we already checked that the type is Constant") + .value(), + ), + _ => { + non_attr_ref_exprs.push(uncasted_right); + } + } + + assert!(attr_ref_exprs.len() + values.len() + non_attr_ref_exprs.len() == 2); + Ok((attr_ref_exprs, values, non_attr_ref_exprs, is_left_attr_ref)) + } + + /// The default selectivity of a comparison expression + /// Used when one side of the comparison is a attribute while the other side is something too + /// complex/impossible to evaluate (subquery, UDF, another attribute, we have no stats, etc.) + fn get_default_comparison_op_selectivity(comp_bin_op_typ: BinOpType) -> f64 { + assert!(comp_bin_op_typ.is_comparison()); + match comp_bin_op_typ { + BinOpType::Eq => DEFAULT_EQ_SEL, + BinOpType::Neq => 1.0 - DEFAULT_EQ_SEL, + BinOpType::Lt | BinOpType::Leq | BinOpType::Gt | BinOpType::Geq => DEFAULT_INEQ_SEL, + _ => unreachable!( + "all comparison BinOpTypes were enumerated. this should be unreachable" + ), + } + } +} diff --git a/optd-cost-model/src/cost/filter/constant.rs b/optd-cost-model/src/cost/filter/constant.rs new file mode 100644 index 0000000..e131bde --- /dev/null +++ b/optd-cost-model/src/cost/filter/constant.rs @@ -0,0 +1,38 @@ +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType}, + predicates::constant_pred::ConstantType, + values::Value, + }, + cost_model::CostModelImpl, + storage::CostModelStorageManager, +}; + +impl CostModelImpl { + pub(crate) fn get_constant_selectivity(const_node: ArcPredicateNode) -> f64 { + if let PredicateType::Constant(const_typ) = const_node.typ { + if matches!(const_typ, ConstantType::Bool) { + let value = const_node + .as_ref() + .data + .as_ref() + .expect("constants should have data"); + if let Value::Bool(bool_value) = value { + if *bool_value { + 1.0 + } else { + 0.0 + } + } else { + unreachable!( + "if the typ is ConstantType::Bool, the value should be a Value::Bool" + ) + } + } else { + panic!("selectivity is not defined on constants which are not bools") + } + } else { + panic!("get_constant_selectivity must be called on a constant") + } + } +} diff --git a/optd-cost-model/src/cost/filter/core.rs b/optd-cost-model/src/cost/filter/core.rs new file mode 100644 index 0000000..05363e4 --- /dev/null +++ b/optd-cost-model/src/cost/filter/core.rs @@ -0,0 +1,877 @@ +use crate::{ + common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{in_list_pred::InListPred, like_pred::LikePred, un_op_pred::UnOpType}, + types::GroupId, + }, + cost_model::CostModelImpl, + stats::UNIMPLEMENTED_SEL, + storage::CostModelStorageManager, + CostModelResult, EstimatedStatistic, +}; + +impl CostModelImpl { + // TODO: is it a good design to pass table_id here? I think it needs to be refactored. + // Consider to remove table_id. + pub async fn get_filter_row_cnt( + &self, + child_row_cnt: EstimatedStatistic, + group_id: GroupId, + cond: ArcPredicateNode, + ) -> CostModelResult { + let selectivity = { self.get_filter_selectivity(group_id, cond).await? }; + Ok(EstimatedStatistic((child_row_cnt.0 * selectivity).max(1.0))) + } + + pub async fn get_filter_selectivity( + &self, + group_id: GroupId, + expr_tree: ArcPredicateNode, + ) -> CostModelResult { + Box::pin(async move { + match &expr_tree.typ { + PredicateType::Constant(_) => Ok(Self::get_constant_selectivity(expr_tree)), + PredicateType::AttrIndex => unimplemented!("check bool type or else panic"), + PredicateType::UnOp(un_op_typ) => { + assert!(expr_tree.children.len() == 1); + let child = expr_tree.child(0); + match un_op_typ { + // not doesn't care about nulls so there's no complex logic. it just reverses + // the selectivity for instance, != _will not_ include nulls + // but "NOT ==" _will_ include nulls + UnOpType::Not => Ok(1.0 - self.get_filter_selectivity(group_id, child).await?), + UnOpType::Neg => panic!( + "the selectivity of operations that return numerical values is undefined" + ), + } + } + PredicateType::BinOp(bin_op_typ) => { + assert!(expr_tree.children.len() == 2); + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + + if bin_op_typ.is_comparison() { + self.get_comp_op_selectivity(group_id, *bin_op_typ, left_child, right_child).await + } else if bin_op_typ.is_numerical() { + panic!( + "the selectivity of operations that return numerical values is undefined" + ) + } else { + unreachable!("all BinOpTypes should be true for at least one is_*() function") + } + } + PredicateType::LogOp(log_op_typ) => { + self.get_log_op_selectivity(group_id, *log_op_typ, &expr_tree.children).await + } + PredicateType::Func(_) => unimplemented!("check bool type or else panic"), + PredicateType::SortOrder(_) => { + panic!("the selectivity of sort order expressions is undefined") + } + PredicateType::Between => Ok(UNIMPLEMENTED_SEL), + PredicateType::Cast => unimplemented!("check bool type or else panic"), + PredicateType::Like => { + let like_expr = LikePred::from_pred_node(expr_tree).unwrap(); + self.get_like_selectivity(group_id, &like_expr).await + } + PredicateType::DataType(_) => { + panic!("the selectivity of a data type is not defined") + } + PredicateType::InList => { + let in_list_expr = InListPred::from_pred_node(expr_tree).unwrap(); + self.get_in_list_selectivity(group_id, &in_list_expr).await + } + _ => unreachable!( + "all expression DfPredType were enumerated. this should be unreachable" + ), + } + }).await + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{ + predicates::{ + bin_op_pred::BinOpType, constant_pred::ConstantType, log_op_pred::LogOpType, + un_op_pred::UnOpType, + }, + properties::Attribute, + types::TableId, + values::Value, + }, + cost_model::tests::*, + memo_ext::tests::MemoGroupInfo, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + Distribution, MostCommonValues, DEFAULT_EQ_SEL, + }, + }; + use arrow_schema::DataType; + + #[tokio::test] + async fn test_const() { + let cost_model = create_mock_cost_model( + vec![TableId(0)], + vec![HashMap::from([(0, empty_per_attr_stats())])], + vec![None], + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, cnst(Value::Bool(true))) + .await + .unwrap(), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, cnst(Value::Bool(false))) + .await + .unwrap(), + 0.0 + ); + } + + #[tokio::test] + async fn test_attr_ref_eq_constint_in_mcv() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.0, + ); + let table_id = TableId(0); + let cost_model = create_mock_cost_model( + vec![table_id], + vec![HashMap::from([(0, per_attribute_stats)])], + vec![None], + ); + + let expr_tree = bin_op( + BinOpType::Eq, + attr_index(0), // TODO: Fix this + cnst(Value::Int32(1)), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cnst(Value::Int32(1)), + attr_index(0), // TODO: Fix this + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.3 + ); + } + + #[tokio::test] + async fn test_attr_ref_eq_constint_not_in_mcv() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.2), + (vec![Some(Value::Int32(3))], 0.44), + ])), + None, + 5, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(2))); + let expr_tree_rev = bin_op(BinOpType::Eq, cnst(Value::Int32(2)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.12 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.12 + ); + } + + /// I only have one test for NEQ since I'll assume that it uses the same underlying logic as EQ + #[tokio::test] + async fn test_attr_ref_neq_constint_in_mcv() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Neq, attr_index(0), cnst(Value::Int32(1))); + let expr_tree_rev = bin_op(BinOpType::Neq, cnst(Value::Int32(1)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 1.0 - 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 1.0 - 0.3 + ); + } + + #[tokio::test] + async fn test_attr_ref_leq_constint_no_mcvs_in_range() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.7 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.7 + ); + } + + #[tokio::test] + async fn test_attr_ref_leq_constint_with_mcvs_in_range_not_at_border() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(17))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.85 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.85 + ); + } + + #[tokio::test] + async fn test_attr_ref_leq_constint_with_mcv_at_border() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(15))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Gt, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.93 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.93 + ); + } + + #[tokio::test] + async fn test_attr_ref_lt_constint_no_mcvs_in_range() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Lt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.6 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.6 + ); + } + + #[tokio::test] + async fn test_attr_ef_lt_constint_with_mcvs_in_range_not_at_border() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(17))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + * remaining value has freq 0.1 */ + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Lt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.75 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.75 + ); + } + + #[tokio::test] + async fn test_attr_ref_lt_constint_with_mcv_at_border() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(6))], 0.05), + (vec![Some(Value::Int32(10))], 0.1), + (vec![Some(Value::Int32(15))], 0.08), + (vec![Some(Value::Int32(25))], 0.07), + ])), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 11, /* there are 4 MCVs which together add up to 0.3. With 11 total ndistinct, each + * remaining value has freq 0.1 */ + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Lt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Geq, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.85 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.85 + ); + } + + /// I have fewer tests for GT since I'll assume that it uses the same underlying logic as LEQ + /// The only interesting thing to test is that if there are nulls, those aren't included in GT + #[tokio::test] + async fn test_attr_ref_gt_constint() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Gt, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Leq, cnst(Value::Int32(15)), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 1.0 - 0.7 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 1.0 - 0.7 + ); + } + + #[tokio::test] + async fn test_attr_ref_geq_constint() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + Some(Distribution::SimpleDistribution(SimpleMap::new(vec![( + Value::Int32(15), + 0.7, + )]))), + 10, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op(BinOpType::Geq, attr_index(0), cnst(Value::Int32(15))); + let expr_tree_rev = bin_op(BinOpType::Lt, cnst(Value::Int32(15)), attr_index(0)); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 1.0 - 0.6 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 1.0 - 0.6 + ); + } + + #[tokio::test] + async fn test_and() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.3), + (vec![Some(Value::Int32(5))], 0.5), + (vec![Some(Value::Int32(8))], 0.2), + ])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let eq1 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(1))); + let eq5 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(5))); + let eq8 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(8))); + let expr_tree = log_op(LogOpType::And, vec![eq1.clone(), eq5.clone(), eq8.clone()]); + let expr_tree_shift1 = log_op(LogOpType::And, vec![eq5.clone(), eq8.clone(), eq1.clone()]); + let expr_tree_shift2 = log_op(LogOpType::And, vec![eq8.clone(), eq1.clone(), eq5.clone()]); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.03 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift1) + .await + .unwrap(), + 0.03 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift2) + .await + .unwrap(), + 0.03 + ); + } + + #[tokio::test] + async fn test_or() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.3), + (vec![Some(Value::Int32(5))], 0.5), + (vec![Some(Value::Int32(8))], 0.2), + ])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let eq1 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(1))); + let eq5 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(5))); + let eq8 = bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(8))); + let expr_tree = log_op(LogOpType::Or, vec![eq1.clone(), eq5.clone(), eq8.clone()]); + let expr_tree_shift1 = log_op(LogOpType::Or, vec![eq5.clone(), eq8.clone(), eq1.clone()]); + let expr_tree_shift2 = log_op(LogOpType::Or, vec![eq8.clone(), eq1.clone(), eq5.clone()]); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.72 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift1) + .await + .unwrap(), + 0.72 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_shift2) + .await + .unwrap(), + 0.72 + ); + } + + #[tokio::test] + async fn test_not() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = un_op( + UnOpType::Not, + bin_op(BinOpType::Eq, attr_index(0), cnst(Value::Int32(1))), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.7 + ); + } + + // I didn't test any non-unique cases with filter. The non-unique tests without filter should + // cover that + + #[tokio::test] + async fn test_attr_ref_eq_cast_value() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + let expr_tree = bin_op( + BinOpType::Eq, + attr_index(0), + cast(cnst(Value::Int64(1)), DataType::Int32), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cast(cnst(Value::Int64(1)), DataType::Int32), + attr_index(0), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.3 + ); + } + + #[tokio::test] + async fn test_cast_attr_ref_eq_value() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![( + vec![Some(Value::Int32(1))], + 0.3, + )])), + None, + 0, + 0.1, + ); + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + ConstantType::Int32, + )])], + vec![None], + ); + + let expr_tree = bin_op( + BinOpType::Eq, + cast(attr_index(0), DataType::Int64), // TODO: Fix this + cnst(Value::Int64(1)), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + cnst(Value::Int64(1)), + cast(attr_index(0), DataType::Int64), // TODO: Fix this + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + 0.3 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + 0.3 + ); + } + + /// In this case, we should leave the Cast as is. + /// + /// Note that the test only checks the selectivity and thus doesn't explicitly test that the + /// Cast is indeed left as is. However, if get_filter_selectivity() doesn't crash, that's a + /// pretty good signal that the Cast was left as is. + #[tokio::test] + async fn test_cast_attr_ref_eq_attr_ref() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::default()), + None, + 0, + 0.0, + ); + let table_id = TableId(0); + let cost_model = create_mock_cost_model_with_attr_types( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![HashMap::from([ + (TEST_ATTR1_BASE_INDEX, ConstantType::Int32), + (TEST_ATTR2_BASE_INDEX, ConstantType::Int64), + ])], + vec![None], + ); + + let expr_tree = bin_op( + BinOpType::Eq, + cast(attr_index(0), DataType::Int64), + attr_index(1), + ); + let expr_tree_rev = bin_op( + BinOpType::Eq, + attr_index(1), + cast(attr_index(0), DataType::Int64), + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree) + .await + .unwrap(), + DEFAULT_EQ_SEL + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_filter_selectivity(TEST_GROUP1_ID, expr_tree_rev) + .await + .unwrap(), + DEFAULT_EQ_SEL + ); + } +} diff --git a/optd-cost-model/src/cost/filter/in_list.rs b/optd-cost-model/src/cost/filter/in_list.rs new file mode 100644 index 0000000..f056fb1 --- /dev/null +++ b/optd-cost-model/src/cost/filter/in_list.rs @@ -0,0 +1,167 @@ +use crate::{ + common::{ + nodes::{PredicateType, ReprPredicateNode}, + predicates::{ + attr_index_pred::AttrIndexPred, constant_pred::ConstantPred, in_list_pred::InListPred, + }, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, + }, + cost_model::CostModelImpl, + stats::UNIMPLEMENTED_SEL, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + /// Only support attrA in (val1, val2, val3) where attrA is a attribute ref and + /// val1, val2, val3 are constants. + pub(crate) async fn get_in_list_selectivity( + &self, + group_id: GroupId, + expr: &InListPred, + ) -> CostModelResult { + let child = expr.child(); + + // Check child is a attribute ref. + if !matches!(child.typ, PredicateType::AttrIndex) { + return Ok(UNIMPLEMENTED_SEL); + } + + // Check all expressions in the list are constants. + let list_exprs = expr.list().to_vec(); + if list_exprs + .iter() + .any(|expr| !matches!(expr.typ, PredicateType::Constant(_))) + { + return Ok(UNIMPLEMENTED_SEL); + } + + // Convert child and const expressions to concrete types. + let attr_ref_pred = AttrIndexPred::from_pred_node(child).unwrap(); + let attr_ref_idx = attr_ref_pred.attr_index(); + + let list_exprs = list_exprs + .into_iter() + .map(|expr| { + ConstantPred::from_pred_node(expr) + .expect("we already checked all list elements are constants") + }) + .collect::>(); + let negated = expr.negated(); + + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref_idx) + { + let mut in_sel = 0.0; + for expr in &list_exprs { + let selectivity = self + .get_attribute_equality_selectivity( + table_id, + attr_idx, + &expr.value(), + /* is_equality */ true, + ) + .await?; + in_sel += selectivity; + } + in_sel = in_sel.min(1.0); + if negated { + Ok(1.0 - in_sel) + } else { + Ok(in_sel) + } + } else { + // TODO: Child is a derived attribute. + Ok(UNIMPLEMENTED_SEL) + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{ + types::{GroupId, TableId}, + values::Value, + }, + cost_model::tests::*, + memo_ext::tests::MemoGroupInfo, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + MostCommonValues, + }, + }; + + #[tokio::test] + async fn test_in_list() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::Int32(1))], 0.8), + (vec![Some(Value::Int32(2))], 0.2), + ])), + None, + 2, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(1)], false)) + .await + .unwrap(), + 0.8 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity( + TEST_GROUP1_ID, + &in_list(0, vec![Value::Int32(1), Value::Int32(2)], false) + ) + .await + .unwrap(), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(3)], false)) + .await + .unwrap(), + 0.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(1)], true)) + .await + .unwrap(), + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity( + TEST_GROUP1_ID, + &in_list(0, vec![Value::Int32(1), Value::Int32(2)], true) + ) + .await + .unwrap(), + 0.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_in_list_selectivity(TEST_GROUP1_ID, &in_list(0, vec![Value::Int32(3)], true)) // TODO: Fix this + .await + .unwrap(), + 1.0 + ); + } +} diff --git a/optd-cost-model/src/cost/filter/like.rs b/optd-cost-model/src/cost/filter/like.rs new file mode 100644 index 0000000..32800e4 --- /dev/null +++ b/optd-cost-model/src/cost/filter/like.rs @@ -0,0 +1,210 @@ +use datafusion::arrow::{array::StringArray, compute::like}; + +use crate::{ + common::{ + nodes::{PredicateType, ReprPredicateNode}, + predicates::{ + attr_index_pred::AttrIndexPred, constant_pred::ConstantPred, like_pred::LikePred, + }, + properties::attr_ref::{AttrRef, BaseTableAttrRef}, + types::GroupId, + }, + cost_model::CostModelImpl, + stats::{ + AttributeCombValue, FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, UNIMPLEMENTED_SEL, + }, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + /// Compute the selectivity of a (NOT) LIKE expression. + /// + /// The logic is somewhat similar to Postgres but different. Postgres first estimates the + /// histogram part of the population and then add up data for any MCV values. If the + /// histogram is large enough, it just uses the number of matches in the histogram, + /// otherwise it estimates the fixed prefix and remainder of pattern separately and + /// combine them. + /// + /// Our approach is simpler and less selective. Firstly, we don't use histogram. The selectivity + /// is composed of MCV frequency and non-MCV selectivity. MCV frequency is computed by + /// adding up frequencies of MCVs that match the pattern. Non-MCV selectivity is computed + /// in the same way that Postgres computes selectivity for the wildcard part of the pattern. + pub(crate) async fn get_like_selectivity( + &self, + group_id: GroupId, + like_expr: &LikePred, + ) -> CostModelResult { + let child = like_expr.child(); + + // Check child is a attribute ref. + if !matches!(child.typ, PredicateType::AttrIndex) { + return Ok(UNIMPLEMENTED_SEL); + } + + // Check pattern is a constant. + let pattern = like_expr.pattern(); + if !matches!(pattern.typ, PredicateType::Constant(_)) { + return Ok(UNIMPLEMENTED_SEL); + } + + let attr_ref_pred = AttrIndexPred::from_pred_node(child).unwrap(); + let attr_ref_idx = attr_ref_pred.attr_index(); + + if let AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) = + self.memo.get_attribute_ref(group_id, attr_ref_idx) + { + let pattern = ConstantPred::from_pred_node(pattern) + .expect("we already checked pattern is a constant") + .value() + .as_str(); + + // Compute the selectivity exculuding MCVs. + // See Postgres `like_selectivity`. + let non_mcv_sel = pattern + .chars() + .fold(1.0, |acc, c| { + if c == '%' { + acc * FULL_WILDCARD_SEL_FACTOR + } else { + acc * FIXED_CHAR_SEL_FACTOR + } + }) + .min(1.0); + + // Compute the selectivity in MCVs. + // TODO: Handle the case where `attribute_stats` is None. + let (mut mcv_freq, mut null_frac) = (0.0, 0.0); + if let Some(attribute_stats) = + self.get_attribute_comb_stats(table_id, &[attr_idx]).await? + { + (mcv_freq, null_frac) = { + let pred = Box::new(move |val: &AttributeCombValue| { + let string = + StringArray::from(vec![val[0].as_ref().unwrap().as_str().as_ref()]); + let pattern = StringArray::from(vec![pattern.as_ref()]); + like(&string, &pattern).unwrap().value(0) + }); + ( + attribute_stats.mcvs.freq_over_pred(pred), + attribute_stats.null_frac, + ) + }; + } + let result = non_mcv_sel + mcv_freq; + + Ok(if like_expr.negated() { + 1.0 - result - null_frac + } else { + result + } + // Postgres clamps the result after histogram and before MCV. See Postgres + // `patternsel_common`. + .clamp(0.0001, 0.9999)) + } else { + // TOOD: derived attribute + Ok(UNIMPLEMENTED_SEL) + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{ + types::{GroupId, TableId}, + values::Value, + }, + cost_model::tests::*, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + MostCommonValues, FIXED_CHAR_SEL_FACTOR, FULL_WILDCARD_SEL_FACTOR, + }, + }; + + #[tokio::test] + async fn test_like_no_nulls() { + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![ + (vec![Some(Value::String("abcd".into()))], 0.1), + (vec![Some(Value::String("abc".into()))], 0.1), + ])), + None, + 2, + 0.0, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity( + TEST_GROUP1_ID, + &like(TEST_ATTR1_BASE_INDEX, "%abcd%", false) + ) // TODO: Fix this + .await + .unwrap(), + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(TEST_GROUP1_ID, &like(TEST_ATTR1_BASE_INDEX, "%abc%", false)) // TODO: Fix this + .await + .unwrap(), + 0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3) + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(TEST_GROUP1_ID, &like(TEST_ATTR1_BASE_INDEX, "%abc%", true)) // TODO: Fix this + .await + .unwrap(), + 1.0 - (0.1 + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(3)) + ); + } + + #[tokio::test] + async fn test_like_with_nulls() { + let null_frac = 0.5; + let mut mcvs_counts = HashMap::new(); + mcvs_counts.insert(vec![Some(Value::String("abcd".into()))], 1); + let mcvs_total_count = 10; + let per_attribute_stats = TestPerAttributeStats::new( + MostCommonValues::Counter(Counter::new_from_existing(mcvs_counts, mcvs_total_count)), + None, + 2, + null_frac, + ); + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([( + TEST_ATTR1_BASE_INDEX, + per_attribute_stats, + )])], + vec![None], + ); + + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(TEST_GROUP1_ID, &like(0, "%abcd%", false)) // TODO: Fix this + .await + .unwrap(), + 0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4) + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_like_selectivity(TEST_GROUP1_ID, &like(0, "%abcd%", true)) // TODO: Fix this + .await + .unwrap(), + 1.0 - (0.1 + FULL_WILDCARD_SEL_FACTOR.powi(2) * FIXED_CHAR_SEL_FACTOR.powi(4)) + - null_frac + ); + } +} diff --git a/optd-cost-model/src/cost/filter/log_op.rs b/optd-cost-model/src/cost/filter/log_op.rs new file mode 100644 index 0000000..61862a2 --- /dev/null +++ b/optd-cost-model/src/cost/filter/log_op.rs @@ -0,0 +1,34 @@ +use crate::{ + common::{nodes::ArcPredicateNode, predicates::log_op_pred::LogOpType, types::GroupId}, + cost_model::CostModelImpl, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + pub(crate) async fn get_log_op_selectivity( + &self, + group_id: GroupId, + log_op_typ: LogOpType, + children: &[ArcPredicateNode], + ) -> CostModelResult { + match log_op_typ { + LogOpType::And => { + let mut and_sel = 1.0; + for child in children { + let selectivity = self.get_filter_selectivity(group_id, child.clone()).await?; + and_sel *= selectivity; + } + Ok(and_sel) + } + LogOpType::Or => { + let mut or_sel_neg = 1.0; + for child in children { + let selectivity = self.get_filter_selectivity(group_id, child.clone()).await?; + or_sel_neg *= (1.0 - selectivity); + } + Ok(1.0 - or_sel_neg) + } + } + } +} diff --git a/optd-cost-model/src/cost/filter/mod.rs b/optd-cost-model/src/cost/filter/mod.rs new file mode 100644 index 0000000..00ea653 --- /dev/null +++ b/optd-cost-model/src/cost/filter/mod.rs @@ -0,0 +1,7 @@ +pub mod attribute; +pub mod comp_op; +pub mod constant; +pub mod core; +pub mod in_list; +pub mod like; +pub mod log_op; diff --git a/optd-cost-model/src/cost/join.rs b/optd-cost-model/src/cost/join.rs deleted file mode 100644 index 8b13789..0000000 --- a/optd-cost-model/src/cost/join.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/optd-cost-model/src/cost/join/core.rs b/optd-cost-model/src/cost/join/core.rs new file mode 100644 index 0000000..c68c1db --- /dev/null +++ b/optd-cost-model/src/cost/join/core.rs @@ -0,0 +1,1275 @@ +use std::collections::HashSet; + +use itertools::Itertools; + +use crate::{ + common::{ + nodes::{ArcPredicateNode, JoinType, PredicateType, ReprPredicateNode}, + predicates::{ + attr_index_pred::AttrIndexPred, + bin_op_pred::BinOpType, + list_pred::ListPred, + log_op_pred::{LogOpPred, LogOpType}, + }, + properties::attr_ref::{ + self, AttrRef, AttrRefs, BaseTableAttrRef, EqPredicate, GroupAttrRefs, + SemanticCorrelation, + }, + types::GroupId, + }, + cost::join::get_on_attr_ref_pair, + cost_model::CostModelImpl, + stats::DEFAULT_NUM_DISTINCT, + storage::CostModelStorageManager, + CostModelResult, +}; + +impl CostModelImpl { + /// The expr_tree input must be a "mixed expression tree", just like with + /// `get_filter_selectivity`. + /// + /// This is a "wrapper" to separate the equality conditions from the filter conditions before + /// calling the "main" `get_join_selectivity_core` function. + #[allow(clippy::too_many_arguments)] + pub(crate) async fn get_join_selectivity_from_expr_tree( + &self, + join_typ: JoinType, + group_id: GroupId, + expr_tree: ArcPredicateNode, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + ) -> CostModelResult { + if expr_tree.typ == PredicateType::LogOp(LogOpType::And) { + let mut on_attr_ref_pairs = vec![]; + let mut filter_expr_trees = vec![]; + for child_expr_tree in &expr_tree.children { + if let Some(on_attr_ref_pair) = + get_on_attr_ref_pair(child_expr_tree.clone(), attr_refs) + { + on_attr_ref_pairs.push(on_attr_ref_pair) + } else { + let child_expr = child_expr_tree.clone(); + filter_expr_trees.push(child_expr); + } + } + assert!(on_attr_ref_pairs.len() + filter_expr_trees.len() == expr_tree.children.len()); + let filter_expr_tree = if filter_expr_trees.is_empty() { + None + } else { + Some(LogOpPred::new(LogOpType::And, filter_expr_trees).into_pred_node()) + }; + self.get_join_selectivity_core( + join_typ, + group_id, + on_attr_ref_pairs, + filter_expr_tree, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } else { + #[allow(clippy::collapsible_else_if)] + if let Some(on_attr_ref_pair) = get_on_attr_ref_pair(expr_tree.clone(), attr_refs) { + self.get_join_selectivity_core( + join_typ, + group_id, + vec![on_attr_ref_pair], + None, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } else { + self.get_join_selectivity_core( + join_typ, + group_id, + vec![], + Some(expr_tree), + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + 0, + ) + .await + } + } + } + + /// A wrapper to convert the join keys to the format expected by get_join_selectivity_core() + #[allow(clippy::too_many_arguments)] + pub(crate) async fn get_join_selectivity_from_keys( + &self, + join_typ: JoinType, + group_id: GroupId, + left_keys: ListPred, + right_keys: ListPred, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + left_attr_cnt: usize, + ) -> CostModelResult { + assert!(left_keys.len() == right_keys.len()); + // I assume that the keys are already in the right order + // s.t. the ith key of left_keys corresponds with the ith key of right_keys + let on_attr_ref_pairs = left_keys + .to_vec() + .into_iter() + .zip(right_keys.to_vec()) + .map(|(left_key, right_key)| { + ( + AttrIndexPred::from_pred_node(left_key).expect("keys should be AttrRefPreds"), + AttrIndexPred::from_pred_node(right_key).expect("keys should be AttrRefPreds"), + ) + }) + .collect_vec(); + self.get_join_selectivity_core( + join_typ, + group_id, + on_attr_ref_pairs, + None, + attr_refs, + input_correlation, + left_row_cnt, + right_row_cnt, + left_attr_cnt, + ) + .await + } + + /// The core logic of join selectivity which assumes we've already separated the expression + /// into the on conditions and the filters. + /// + /// Hash join and NLJ reference right table attributes differently, hence the + /// `right_attr_ref_offset` parameter. + /// + /// For hash join, the right table attributes indices are with respect to the right table, + /// which means #0 is the first attribute of the right table. + /// + /// For NLJ, the right table attributes indices are with respect to the output of the join. + /// For example, if the left table has 3 attributes, the first attribute of the right table + /// is #3 instead of #0. + #[allow(clippy::too_many_arguments)] + async fn get_join_selectivity_core( + &self, + join_typ: JoinType, + group_id: GroupId, + on_attr_ref_pairs: Vec<(AttrIndexPred, AttrIndexPred)>, + filter_expr_tree: Option, + attr_refs: &AttrRefs, + input_correlation: Option, + left_row_cnt: f64, + right_row_cnt: f64, + right_attr_ref_offset: usize, + ) -> CostModelResult { + let join_on_selectivity = self + .get_join_on_selectivity( + &on_attr_ref_pairs, + attr_refs, + input_correlation, + right_attr_ref_offset, + ) + .await?; + // Currently, there is no difference in how we handle a join filter and a select filter, + // so we use the same function. + // + // One difference (that we *don't* care about right now) is that join filters can contain + // expressions from multiple different tables. Currently, this doesn't affect the + // get_filter_selectivity() function, but this may change in the future. + let join_filter_selectivity = match filter_expr_tree { + Some(filter_expr_tree) => { + self.get_filter_selectivity(group_id, filter_expr_tree) + .await? + } + None => 1.0, + }; + let inner_join_selectivity = join_on_selectivity * join_filter_selectivity; + + Ok(match join_typ { + JoinType::Inner => inner_join_selectivity, + JoinType::LeftOuter => f64::max(inner_join_selectivity, 1.0 / right_row_cnt), + JoinType::RightOuter => f64::max(inner_join_selectivity, 1.0 / left_row_cnt), + JoinType::Cross => { + assert!( + on_attr_ref_pairs.is_empty(), + "Cross joins should not have on attributes" + ); + join_filter_selectivity + } + _ => unimplemented!("join_typ={} is not implemented", join_typ), + }) + } + + /// Get the selectivity of one attribute eq predicate, e.g. attrA = attrB. + async fn get_join_selectivity_from_on_attr_ref_pair( + &self, + left: &AttrRef, + right: &AttrRef, + ) -> CostModelResult { + // the formula for each pair is min(1 / ndistinct1, 1 / ndistinct2) + // (see https://postgrespro.com/blog/pgsql/5969618) + let mut ndistincts = vec![]; + for attr_ref in [left, right] { + let ndistinct = match attr_ref { + AttrRef::BaseTableAttrRef(base_attr_ref) => { + match self + .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) + .await? + { + Some(per_attr_stats) => per_attr_stats.ndistinct, + None => DEFAULT_NUM_DISTINCT, + } + } + AttrRef::Derived => DEFAULT_NUM_DISTINCT, + }; + ndistincts.push(ndistinct); + } + + // using reduce(f64::min) is the idiomatic workaround to min() because + // f64 does not implement Ord due to NaN + let selectivity = ndistincts.into_iter().map(|ndistinct| 1.0 / ndistinct as f64).reduce(f64::min).expect("reduce() only returns None if the iterator is empty, which is impossible since attr_ref_exprs.len() == 2"); + assert!( + !selectivity.is_nan(), + "it should be impossible for selectivity to be NaN since n-distinct is never 0" + ); + Ok(selectivity) + } + + /// Given a set of N attributes involved in a multi-equality, find the total selectivity + /// of the multi-equality. + /// + /// This is a generalization of get_join_selectivity_from_on_attr_ref_pair(). + async fn get_join_selectivity_from_most_selective_attrs( + &self, + base_attr_refs: HashSet, + ) -> CostModelResult { + assert!(base_attr_refs.len() > 1); + let num_base_attr_refs = base_attr_refs.len(); + + let mut ndistincts = vec![]; + for base_attr_ref in base_attr_refs.iter() { + let ndistinct = match self + .get_attribute_comb_stats(base_attr_ref.table_id, &[base_attr_ref.attr_idx]) + .await? + { + Some(per_attr_stats) => per_attr_stats.ndistinct, + None => DEFAULT_NUM_DISTINCT, + }; + ndistincts.push(ndistinct); + } + + Ok(ndistincts + .into_iter() + .map(|ndistinct| 1.0 / ndistinct as f64) + .sorted_by(|a, b| { + a.partial_cmp(b) + .expect("No floats should be NaN since n-distinct is never 0") + }) + .take(num_base_attr_refs - 1) + .product()) + } + + /// A predicate set defines a "multi-equality graph", which is an unweighted undirected graph. + /// The nodes are attributes while edges are predicates. The old graph is defined by + /// `past_eq_attrs` while the `predicate` is the new addition to this graph. This + /// unweighted undirected graph consists of a number of connected components, where each + /// connected component represents attributes that are set to be equal to each other. Single + /// nodes not connected to anything are considered standalone connected components. + /// + /// The selectivity of each connected component of N nodes is equal to the product of + /// 1/ndistinct of the N-1 nodes with the highest ndistinct values. You can see this if you + /// imagine that all attributes being joined are unique attributes and that they follow the + /// inclusion principle (every element of the smaller tables is present in the larger + /// tables). When these assumptions are not true, the selectivity may not be completely + /// accurate. However, it is still fairly accurate. + /// + /// However, we cannot simply add `predicate` to the multi-equality graph and compute the + /// selectivity of the entire connected component, because this would be "double counting" a + /// lot of nodes. The join(s) before this join would already have a selectivity value. Thus, + /// we compute the selectivity of the join(s) before this join (the first block of the + /// function) and then the selectivity of the connected component after this join. The + /// quotient is the "adjustment" factor. + /// + /// NOTE: This function modifies `past_eq_attrs` by adding `predicate` to it. + async fn get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( + &self, + predicate: &EqPredicate, + past_eq_attrs: &mut SemanticCorrelation, + ) -> CostModelResult { + if predicate.left == predicate.right { + // self-join, TODO: is this correct? + return Ok(1.0); + } + // To find the adjustment, we need to know the selectivity of the graph before `predicate` + // is added. + // + // There are two cases: (1) adding `predicate` does not change the # of connected + // components, and (2) adding `predicate` reduces the # of connected by 1. Note that + // attributes not involved in any predicates are considered a part of the graph and are + // a connected component on their own. + let children_pred_sel = { + if past_eq_attrs.is_eq(&predicate.left, &predicate.right) { + self.get_join_selectivity_from_most_selective_attrs( + past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left), + ) + .await? + } else { + let left_sel = if past_eq_attrs.contains(&predicate.left) { + self.get_join_selectivity_from_most_selective_attrs( + past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left), + ) + .await? + } else { + 1.0 + }; + let right_sel = if past_eq_attrs.contains(&predicate.right) { + self.get_join_selectivity_from_most_selective_attrs( + past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.right), + ) + .await? + } else { + 1.0 + }; + left_sel * right_sel + } + }; + + // Add predicate to past_eq_attrs and compute the selectivity of the connected component + // it creates. + past_eq_attrs.add_predicate(predicate.clone()); + let new_pred_sel = { + let attrs = past_eq_attrs.find_attrs_for_eq_attribute_set(&predicate.left); + self.get_join_selectivity_from_most_selective_attrs(attrs) + } + .await?; + + // Compute the adjustment factor. + Ok(new_pred_sel / children_pred_sel) + } + + /// Get the selectivity of the on conditions. + /// + /// Note that the selectivity of the on conditions does not depend on join type. + /// Join type is accounted for separately in get_join_selectivity_core(). + /// + /// We also check if each predicate is correlated with any of the previous predicates. + /// + /// More specifically, we are checking if the predicate can be expressed with other existing + /// predicates. E.g. if we have a predicate like A = B and B = C is equivalent to A = C. + // + /// However, we don't just throw away A = C, because we want to pick the most selective + /// predicates. For details on how we do this, see + /// `get_join_selectivity_from_redundant_predicates`. + async fn get_join_on_selectivity( + &self, + on_attr_ref_pairs: &[(AttrIndexPred, AttrIndexPred)], + attr_refs: &AttrRefs, + input_correlation: Option, + right_attr_ref_offset: usize, + ) -> CostModelResult { + let mut past_eq_attrs = input_correlation.unwrap_or_default(); + + // Multiply the selectivities of all individual conditions together + let mut selectivity = 1.0; + for on_attr_ref_pair in on_attr_ref_pairs { + let left_attr_ref = &attr_refs[on_attr_ref_pair.0.attr_index() as usize]; + let right_attr_ref = + &attr_refs[on_attr_ref_pair.1.attr_index() as usize + right_attr_ref_offset]; + + selectivity *= + if let (AttrRef::BaseTableAttrRef(left), AttrRef::BaseTableAttrRef(right)) = + (left_attr_ref, right_attr_ref) + { + let predicate = EqPredicate::new(left.clone(), right.clone()); + self.get_join_selectivity_adjustment_when_adding_to_multi_equality_graph( + &predicate, + &mut past_eq_attrs, + ) + .await? + } else { + self.get_join_selectivity_from_on_attr_ref_pair(left_attr_ref, right_attr_ref) + .await? + }; + } + + Ok(selectivity) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + common::{ + predicates::{attr_index_pred, constant_pred::ConstantType}, + properties::Attribute, + types::TableId, + values::Value, + }, + cost_model::tests::{ + attr_index, bin_op, cnst, create_four_table_mock_cost_model, create_mock_cost_model, + create_three_table_mock_cost_model, create_two_table_mock_cost_model, + create_two_table_mock_cost_model_custom_row_cnts, empty_per_attr_stats, log_op, + per_attr_stats_with_dist_and_ndistinct, per_attr_stats_with_ndistinct, + TestOptCostModelMock, TestPerAttributeStats, TEST_ATTR1_NAME, TEST_ATTR2_NAME, + TEST_TABLE1_ID, TEST_TABLE2_ID, TEST_TABLE3_ID, TEST_TABLE4_ID, + }, + memo_ext::tests::MemoGroupInfo, + stats::DEFAULT_EQ_SEL, + }; + + use super::*; + + const JOIN_GROUP_ID: GroupId = GroupId(10); + + /// A wrapper around get_join_selectivity_from_expr_tree that extracts the + /// table row counts from the cost model. + async fn test_get_join_selectivity( + cost_model: &TestOptCostModelMock, + reverse_tables: bool, + join_typ: JoinType, + expr_tree: ArcPredicateNode, + attr_refs: &AttrRefs, + input_correlation: Option, + ) -> f64 { + let table1_row_cnt = cost_model.get_row_count(TEST_TABLE1_ID) as f64; + let table2_row_cnt = cost_model.get_row_count(TEST_TABLE2_ID) as f64; + + if !reverse_tables { + cost_model + .get_join_selectivity_from_expr_tree( + join_typ, + JOIN_GROUP_ID, + expr_tree, + attr_refs, + input_correlation, + table1_row_cnt, + table2_row_cnt, + ) + .await + .unwrap() + } else { + cost_model + .get_join_selectivity_from_expr_tree( + join_typ, + JOIN_GROUP_ID, + expr_tree, + attr_refs, + input_correlation, + table2_row_cnt, + table1_row_cnt, + ) + .await + .unwrap() + } + } + + #[tokio::test] + async fn test_inner_const() { + let cost_model = create_mock_cost_model( + vec![TEST_TABLE1_ID], + vec![HashMap::from([(0, empty_per_attr_stats())])], + vec![None], + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_join_selectivity_from_expr_tree( + JoinType::Inner, + JOIN_GROUP_ID, + cnst(Value::Bool(true)), + &vec![], + None, + f64::NAN, + f64::NAN + ) + .await + .unwrap(), + 1.0 + ); + assert_approx_eq::assert_approx_eq!( + cost_model + .get_join_selectivity_from_expr_tree( + JoinType::Inner, + JOIN_GROUP_ID, + cnst(Value::Bool(false)), + &vec![], + None, + f64::NAN, + f64::NAN + ) + .await + .unwrap(), + 0.0 + ); + } + + #[tokio::test] + async fn test_inner_oncond() { + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + None, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None, + ) + .await, + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &attr_refs, + None, + ) + .await, + 0.2 + ); + } + + #[tokio::test] + async fn test_inner_and_of_onconds() { + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + None, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let eq0and1 = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let eq1and0 = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and0.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq1and0.clone(), eq0and1.clone()]); + + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None, + ) + .await, + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &attr_refs, + None + ) + .await, + 0.2 + ); + } + + #[tokio::test] + async fn test_inner_and_of_oncond_and_filter() { + let join_memo = HashMap::from([( + JOIN_GROUP_ID, + MemoGroupInfo::new( + vec![ + Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string()), + Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string()), + ] + .into(), + GroupAttrRefs::new( + vec![ + AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0), + ], + None, + ), + ), + )]); + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + Some(join_memo), + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let eq0and1 = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let eq100 = bin_op(BinOpType::Eq, attr_index(1), cnst(Value::Int32(100))); + let expr_tree = log_op(LogOpType::And, vec![eq0and1.clone(), eq100.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), eq0and1.clone()]); + + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None + ) + .await, + 0.05 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &attr_refs, + None + ) + .await, + 0.05 + ); + } + + #[tokio::test] + async fn test_inner_and_of_filters() { + let join_memo = HashMap::from([( + JOIN_GROUP_ID, + MemoGroupInfo::new( + vec![ + Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string()), + Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string()), + ] + .into(), + GroupAttrRefs::new( + vec![ + AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0), + ], + None, + ), + ), + )]); + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + Some(join_memo), + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let neq12 = bin_op(BinOpType::Neq, attr_index(0), cnst(Value::Int32(12))); + let eq100 = bin_op(BinOpType::Eq, attr_index(1), cnst(Value::Int32(100))); + let expr_tree = log_op(LogOpType::And, vec![neq12.clone(), eq100.clone()]); + let expr_tree_rev = log_op(LogOpType::And, vec![eq100.clone(), neq12.clone()]); + + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None, + ) + .await, + 0.2 + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev, + &attr_refs, + None + ) + .await, + 0.2 + ); + } + + #[tokio::test] + async fn test_inner_colref_eq_colref_same_table_is_not_oncond() { + let cost_model = create_two_table_mock_cost_model( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + None, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(0)); + + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree, + &attr_refs, + None + ) + .await, + DEFAULT_EQ_SEL + ); + } + + // We don't test joinsel or with oncond because if there is an oncond (on condition), the + // top-level operator must be an AND + + /// I made this helper function to avoid copying all eight lines over and over + async fn assert_outer_selectivities( + cost_model: &TestOptCostModelMock, + expr_tree: ArcPredicateNode, + expr_tree_rev: ArcPredicateNode, + attr_refs: &AttrRefs, + expected_table1_outer_sel: f64, + expected_table2_outer_sel: f64, + ) { + // all table 1 outer combinations + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::LeftOuter, + expr_tree.clone(), + attr_refs, + None + ) + .await, + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::LeftOuter, + expr_tree_rev.clone(), + attr_refs, + None + ) + .await, + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::RightOuter, + expr_tree.clone(), + attr_refs, + None + ) + .await, + expected_table1_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::RightOuter, + expr_tree_rev.clone(), + attr_refs, + None + ) + .await, + expected_table1_outer_sel + ); + // all table 2 outer combinations + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::LeftOuter, + expr_tree.clone(), + attr_refs, + None + ) + .await, + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + true, + JoinType::LeftOuter, + expr_tree_rev.clone(), + attr_refs, + None + ) + .await, + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::RightOuter, + expr_tree.clone(), + attr_refs, + None + ) + .await, + expected_table2_outer_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + cost_model, + false, + JoinType::RightOuter, + expr_tree_rev.clone(), + attr_refs, + None + ) + .await, + expected_table2_outer_sel + ); + } + + /// Unique oncond means an oncondition on columns which are unique in both tables + /// There's only one case if both columns are unique and have different row counts: the inner + /// will be < 1 / row count of one table and = 1 / row count of another + #[tokio::test] + async fn test_outer_unique_oncond() { + let cost_model = create_two_table_mock_cost_model_custom_row_cnts( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + 5, + 4, + None, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + + // sanity check the expected inner sel + let expected_inner_sel = 0.2; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities(&cost_model, expr_tree, expr_tree_rev, &attr_refs, 0.25, 0.2); + } + + /// Non-unique oncond means the column is not unique in either table + /// Inner always >= row count means that the inner join result is >= 1 / the row count of both + /// tables + #[tokio::test] + async fn test_outer_nonunique_oncond_inner_always_geq_rowcnt() { + let cost_model = create_two_table_mock_cost_model_custom_row_cnts( + per_attr_stats_with_ndistinct(5), + per_attr_stats_with_ndistinct(4), + 10, + 8, + None, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + + // sanity check the expected inner sel + let expected_inner_sel = 0.2; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities(&cost_model, expr_tree, expr_tree_rev, &attr_refs, 0.2, 0.2) + .await; + } + + /// Non-unique oncond means the column is not unique in either table + /// Inner sometimes < row count means that the inner join result < 1 / the row count of exactly + /// one table. Note that without a join filter, it's impossible to be less than the row + /// count of both tables + #[tokio::test] + async fn test_outer_nonunique_oncond_inner_sometimes_lt_rowcnt() { + let cost_model = create_two_table_mock_cost_model_custom_row_cnts( + per_attr_stats_with_ndistinct(10), + per_attr_stats_with_ndistinct(2), + 20, + 4, + None, + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let expr_tree = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let expr_tree_rev = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + + // sanity check the expected inner sel + let expected_inner_sel = 0.1; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_rev.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities(&cost_model, expr_tree, expr_tree_rev, &attr_refs, 0.25, 0.1) + .await; + } + + /// Unique oncond means an oncondition on columns which are unique in both tables + /// Filter means we're adding a join filter + /// There's only one case if both columns are unique and there's a filter: + /// the inner will be < 1 / row count of both tables + #[tokio::test] + async fn test_outer_unique_oncond_filter() { + let join_memo = HashMap::from([( + JOIN_GROUP_ID, + MemoGroupInfo::new( + vec![ + Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string()), + Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string()), + ] + .into(), + GroupAttrRefs::new( + vec![ + AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0), + ], + None, + ), + ), + )]); + let cost_model = create_two_table_mock_cost_model_custom_row_cnts( + per_attr_stats_with_dist_and_ndistinct(vec![(Value::Int32(128), 0.4)], 50), + per_attr_stats_with_ndistinct(4), + 50, + 4, + Some(join_memo), + ); + + let attr_refs = vec![ + AttrRef::base_table_attr_ref(TEST_TABLE1_ID, 0), + AttrRef::base_table_attr_ref(TEST_TABLE2_ID, 0), + ]; + // the left/right of the join refers to the tables, not the order of columns in the + // predicate + let eq0and1 = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let eq1and0 = bin_op(BinOpType::Eq, attr_index(1), attr_index(0)); + let filter = bin_op(BinOpType::Leq, attr_index(0), cnst(Value::Int32(128))); + let expr_tree = log_op(LogOpType::And, vec![eq0and1, filter.clone()]); + // inner rev means its the inner expr (the eq op) whose children are being reversed, as + // opposed to the and op + let expr_tree_inner_rev = log_op(LogOpType::And, vec![eq1and0, filter.clone()]); + + // sanity check the expected inner sel + let expected_inner_sel = 0.008; + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + assert_approx_eq::assert_approx_eq!( + test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree_inner_rev.clone(), + &attr_refs, + None + ) + .await, + expected_inner_sel + ); + // check the outer sels + assert_outer_selectivities( + &cost_model, + expr_tree, + expr_tree_inner_rev, + &attr_refs, + 0.25, + 0.02, + ) + .await; + } + + /// Test all possible permutations of three-table joins. + /// A three-table join consists of at least two joins. `join1_on_cond` is the condition of the + /// first join. There can only be one condition because only two tables are involved at + /// the time of the first join. + #[tokio::test] + #[test_case::test_case(&[(0, 1)])] + #[test_case::test_case(&[(0, 2)])] + #[test_case::test_case(&[(1, 2)])] + #[test_case::test_case(&[(0, 1), (0, 2)])] + #[test_case::test_case(&[(0, 1), (1, 2)])] + #[test_case::test_case(&[(0, 2), (1, 2)])] + #[test_case::test_case(&[(0, 1), (0, 2), (1, 2)])] + async fn test_three_table_join_for_initial_join_on_conds( + initial_join_on_conds: &[(usize, usize)], + ) { + assert!( + !initial_join_on_conds.is_empty(), + "initial_join_on_conds should be non-empty" + ); + assert_eq!( + initial_join_on_conds.len(), + initial_join_on_conds.iter().collect::>().len(), + "initial_join_on_conds shouldn't contain duplicates" + ); + let cost_model = create_three_table_mock_cost_model( + per_attr_stats_with_ndistinct(2), + per_attr_stats_with_ndistinct(3), + per_attr_stats_with_ndistinct(4), + ); + + let base_attr_refs = vec![ + BaseTableAttrRef { + table_id: TEST_TABLE1_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE2_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE3_ID, + attr_idx: 0, + }, + ]; + let attr_refs = base_attr_refs + .clone() + .into_iter() + .map(AttrRef::BaseTableAttrRef) + .collect(); + + let mut eq_columns = SemanticCorrelation::new(); + for initial_join_on_cond in initial_join_on_conds { + eq_columns.add_predicate(EqPredicate::new( + base_attr_refs[initial_join_on_cond.0].clone(), + base_attr_refs[initial_join_on_cond.1].clone(), + )); + } + let initial_selectivity = { + if initial_join_on_conds.len() == 1 { + let initial_join_on_cond = initial_join_on_conds.first().unwrap(); + if initial_join_on_cond == &(0, 1) { + 1.0 / 3.0 + } else if initial_join_on_cond == &(0, 2) || initial_join_on_cond == &(1, 2) { + 1.0 / 4.0 + } else { + panic!(); + } + } else { + 1.0 / 12.0 + } + }; + + let input_correlation = Some(eq_columns); + + // Try all join conditions of the final join which would lead to all three tables being + // joined. + let eq0and1 = bin_op(BinOpType::Eq, attr_index(0), attr_index(1)); + let eq0and2 = bin_op(BinOpType::Eq, attr_index(0), attr_index(2)); + let eq1and2 = bin_op(BinOpType::Eq, attr_index(1), attr_index(2)); + let and_01_02 = log_op(LogOpType::And, vec![eq0and1.clone(), eq0and2.clone()]); + let and_01_12 = log_op(LogOpType::And, vec![eq0and1.clone(), eq1and2.clone()]); + let and_02_12 = log_op(LogOpType::And, vec![eq0and2.clone(), eq1and2.clone()]); + let and_01_02_12 = log_op( + LogOpType::And, + vec![eq0and1.clone(), eq0and2.clone(), eq1and2.clone()], + ); + let mut join2_expr_trees = vec![and_01_02, and_01_12, and_02_12, and_01_02_12]; + if initial_join_on_conds.len() == 1 { + let initial_join_on_cond = initial_join_on_conds.first().unwrap(); + if initial_join_on_cond == &(0, 1) { + join2_expr_trees.push(eq0and2); + join2_expr_trees.push(eq1and2); + } else if initial_join_on_cond == &(0, 2) { + join2_expr_trees.push(eq0and1); + join2_expr_trees.push(eq1and2); + } else if initial_join_on_cond == &(1, 2) { + join2_expr_trees.push(eq0and1); + join2_expr_trees.push(eq0and2); + } else { + panic!(); + } + } + for expr_tree in join2_expr_trees { + let overall_selectivity = initial_selectivity + * test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + expr_tree.clone(), + &attr_refs, + input_correlation.clone(), + ) + .await; + assert_approx_eq::assert_approx_eq!(overall_selectivity, 1.0 / 12.0); + } + } + + #[tokio::test] + async fn test_join_which_connects_two_components_together() { + let cost_model = create_four_table_mock_cost_model( + per_attr_stats_with_ndistinct(2), + per_attr_stats_with_ndistinct(3), + per_attr_stats_with_ndistinct(4), + per_attr_stats_with_ndistinct(5), + ); + let base_attr_refs = vec![ + BaseTableAttrRef { + table_id: TEST_TABLE1_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE2_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE3_ID, + attr_idx: 0, + }, + BaseTableAttrRef { + table_id: TEST_TABLE4_ID, + attr_idx: 0, + }, + ]; + let attr_refs = base_attr_refs + .clone() + .into_iter() + .map(AttrRef::BaseTableAttrRef) + .collect(); + + let mut eq_columns = SemanticCorrelation::new(); + eq_columns.add_predicate(EqPredicate::new( + base_attr_refs[0].clone(), + base_attr_refs[1].clone(), + )); + eq_columns.add_predicate(EqPredicate::new( + base_attr_refs[2].clone(), + base_attr_refs[3].clone(), + )); + let initial_selectivity = 1.0 / (3.0 * 5.0); + let input_correlation = Some(eq_columns); + + let eq1and2 = bin_op(BinOpType::Eq, attr_index(1), attr_index(2)); + let overall_selectivity = initial_selectivity + * test_get_join_selectivity( + &cost_model, + false, + JoinType::Inner, + eq1and2.clone(), + &attr_refs, + input_correlation, + ) + .await; + assert_approx_eq::assert_approx_eq!(overall_selectivity, 1.0 / (3.0 * 4.0 * 5.0)); + } +} diff --git a/optd-cost-model/src/cost/join/hash_join.rs b/optd-cost-model/src/cost/join/hash_join.rs new file mode 100644 index 0000000..47c9ebd --- /dev/null +++ b/optd-cost-model/src/cost/join/hash_join.rs @@ -0,0 +1,55 @@ +use itertools::Itertools; + +use crate::{ + common::{ + nodes::{JoinType, ReprPredicateNode}, + predicates::{attr_index_pred::AttrIndexPred, list_pred::ListPred}, + properties::attr_ref::{AttrRefs, SemanticCorrelation}, + types::GroupId, + }, + cost_model::CostModelImpl, + storage::CostModelStorageManager, + CostModelResult, EstimatedStatistic, +}; + +use super::get_input_correlation; + +impl CostModelImpl { + #[allow(clippy::too_many_arguments)] + pub async fn get_hash_join_row_cnt( + &self, + join_typ: JoinType, + group_id: GroupId, + left_row_cnt: f64, + right_row_cnt: f64, + left_group_id: GroupId, + right_group_id: GroupId, + left_keys: ListPred, + right_keys: ListPred, + ) -> CostModelResult { + let selectivity = { + let output_attr_refs = self.memo.get_attribute_refs(group_id); + let left_attr_refs = self.memo.get_attribute_refs(left_group_id); + let right_attr_refs = self.memo.get_attribute_refs(right_group_id); + let left_attr_cnt = left_attr_refs.attr_refs().len(); + // there may be more than one expression tree in a group. + // see comment in PredicateType::PhysicalFilter(_) for more information + let input_correlation = get_input_correlation(left_attr_refs, right_attr_refs); + self.get_join_selectivity_from_keys( + join_typ, + group_id, + left_keys, + right_keys, + output_attr_refs.attr_refs(), + input_correlation, + left_row_cnt, + right_row_cnt, + left_attr_cnt, + ) + .await? + }; + Ok(EstimatedStatistic( + (left_row_cnt * right_row_cnt * selectivity).max(1.0), + )) + } +} diff --git a/optd-cost-model/src/cost/join/mod.rs b/optd-cost-model/src/cost/join/mod.rs new file mode 100644 index 0000000..71b991b --- /dev/null +++ b/optd-cost-model/src/cost/join/mod.rs @@ -0,0 +1,72 @@ +use crate::common::{ + nodes::{ArcPredicateNode, PredicateType, ReprPredicateNode}, + predicates::{attr_index_pred::AttrIndexPred, bin_op_pred::BinOpType}, + properties::attr_ref::{ + AttrRef, AttrRefs, BaseTableAttrRef, GroupAttrRefs, SemanticCorrelation, + }, +}; + +pub mod core; +pub mod hash_join; +pub mod nested_loop_join; + +pub(crate) fn get_input_correlation( + left_prop: GroupAttrRefs, + right_prop: GroupAttrRefs, +) -> Option { + SemanticCorrelation::merge( + left_prop.output_correlation().cloned(), + right_prop.output_correlation().cloned(), + ) +} + +/// Check if an expr_tree is a join condition, returning the join on attr ref pair if it is. +/// The reason the check and the info are in the same function is because their code is almost +/// identical. It only picks out equality conditions between two attribute refs on different +/// tables +pub(crate) fn get_on_attr_ref_pair( + expr_tree: ArcPredicateNode, + attr_refs: &AttrRefs, +) -> Option<(AttrIndexPred, AttrIndexPred)> { + // 1. Check that it's equality + if expr_tree.typ == PredicateType::BinOp(BinOpType::Eq) { + let left_child = expr_tree.child(0); + let right_child = expr_tree.child(1); + // 2. Check that both sides are attribute refs + if left_child.typ == PredicateType::AttrIndex && right_child.typ == PredicateType::AttrIndex + { + // 3. Check that both sides don't belong to the same table (if we don't know, that + // means they don't belong) + let left_attr_ref_expr = AttrIndexPred::from_pred_node(left_child) + .expect("we already checked that the type is AttrRef"); + let right_attr_ref_expr = AttrIndexPred::from_pred_node(right_child) + .expect("we already checked that the type is AttrRef"); + let left_attr_ref = &attr_refs[left_attr_ref_expr.attr_index() as usize]; + let right_attr_ref = &attr_refs[right_attr_ref_expr.attr_index() as usize]; + let is_same_table = if let ( + AttrRef::BaseTableAttrRef(BaseTableAttrRef { + table_id: left_table_id, + .. + }), + AttrRef::BaseTableAttrRef(BaseTableAttrRef { + table_id: right_table_id, + .. + }), + ) = (left_attr_ref, right_attr_ref) + { + left_table_id == right_table_id + } else { + false + }; + if !is_same_table { + Some((left_attr_ref_expr, right_attr_ref_expr)) + } else { + None + } + } else { + None + } + } else { + None + } +} diff --git a/optd-cost-model/src/cost/join/nested_loop_join.rs b/optd-cost-model/src/cost/join/nested_loop_join.rs new file mode 100644 index 0000000..ebb70c9 --- /dev/null +++ b/optd-cost-model/src/cost/join/nested_loop_join.rs @@ -0,0 +1,48 @@ +use crate::{ + common::{ + nodes::{ArcPredicateNode, JoinType, PredicateType, ReprPredicateNode}, + predicates::log_op_pred::{LogOpPred, LogOpType}, + properties::attr_ref::{AttrRefs, SemanticCorrelation}, + types::GroupId, + }, + cost_model::CostModelImpl, + storage::CostModelStorageManager, + CostModelResult, EstimatedStatistic, +}; + +use super::get_input_correlation; + +impl CostModelImpl { + #[allow(clippy::too_many_arguments)] + pub async fn get_nlj_row_cnt( + &self, + join_typ: JoinType, + group_id: GroupId, + left_row_cnt: f64, + right_row_cnt: f64, + left_group_id: GroupId, + right_group_id: GroupId, + join_cond: ArcPredicateNode, + ) -> CostModelResult { + let selectivity = { + let output_attr_refs = self.memo.get_attribute_refs(group_id); + let left_attr_refs = self.memo.get_attribute_refs(left_group_id); + let right_attr_refs = self.memo.get_attribute_refs(right_group_id); + let input_correlation = get_input_correlation(left_attr_refs, right_attr_refs); + + self.get_join_selectivity_from_expr_tree( + join_typ, + group_id, + join_cond, + output_attr_refs.attr_refs(), + input_correlation, + left_row_cnt, + right_row_cnt, + ) + .await? + }; + Ok(EstimatedStatistic( + (left_row_cnt * right_row_cnt * selectivity).max(1.0), + )) + } +} diff --git a/optd-cost-model/src/cost/limit.rs b/optd-cost-model/src/cost/limit.rs new file mode 100644 index 0000000..c63c0e0 --- /dev/null +++ b/optd-cost-model/src/cost/limit.rs @@ -0,0 +1,28 @@ +use crate::{ + common::{ + nodes::{ArcPredicateNode, ReprPredicateNode}, + predicates::constant_pred::ConstantPred, + }, + cost_model::CostModelImpl, + storage::CostModelStorageManager, + CostModelResult, EstimatedStatistic, +}; + +impl CostModelImpl { + pub(crate) fn get_limit_row_cnt( + &self, + child_row_cnt: EstimatedStatistic, + fetch_expr: ArcPredicateNode, + ) -> CostModelResult { + let fetch = ConstantPred::from_pred_node(fetch_expr) + .unwrap() + .value() + .as_u64(); + // u64::MAX represents None + if fetch == u64::MAX { + Ok(child_row_cnt) + } else { + Ok(EstimatedStatistic(child_row_cnt.0.min(fetch as f64))) + } + } +} diff --git a/optd-cost-model/src/cost/mod.rs b/optd-cost-model/src/cost/mod.rs index 795ed3e..c98d7d7 100644 --- a/optd-cost-model/src/cost/mod.rs +++ b/optd-cost-model/src/cost/mod.rs @@ -1,3 +1,6 @@ +#![allow(unused)] + pub mod agg; pub mod filter; pub mod join; +pub mod limit; diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index e933add..4583484 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -13,33 +13,34 @@ use crate::{ types::{AttrId, EpochId, ExprId, TableId}, }, memo_ext::MemoExt, + stats::AttributeCombValueStats, storage::CostModelStorageManager, ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue, }; /// TODO: documentation -pub struct CostModelImpl { - storage_manager: CostModelStorageManager, - default_catalog_source: CatalogSource, - _memo: Arc, +pub struct CostModelImpl { + pub storage_manager: S, + pub default_catalog_source: CatalogSource, + pub memo: Arc, } -impl CostModelImpl { +impl CostModelImpl { /// TODO: documentation pub fn new( - storage_manager: CostModelStorageManager, + storage_manager: S, default_catalog_source: CatalogSource, memo: Arc, ) -> Self { Self { storage_manager, default_catalog_source, - _memo: memo, + memo, } } } -impl CostModel for CostModelImpl { +impl CostModel for CostModelImpl { fn compute_operation_cost( &self, node: &PhysicalNodeType, @@ -71,7 +72,6 @@ impl CostModel for CostModelImpl { fn get_table_statistic_for_analysis( &self, - // TODO: i32 should be changed to TableId. table_id: TableId, stat_type: StatType, epoch_id: Option, @@ -96,3 +96,534 @@ impl CostModel for CostModelImpl { todo!() } } + +impl CostModelImpl { + /// TODO: documentation + /// TODO: if we have memory cache, + /// we should add the reference. (&AttributeCombValueStats) + pub(crate) async fn get_attribute_comb_stats( + &self, + table_id: TableId, + attr_comb: &[u64], + ) -> CostModelResult> { + self.storage_manager + .get_attributes_comb_statistics(table_id, attr_comb) + .await + } +} + +/// I thought about using the system's own parser and planner to generate these expression trees, +/// but this is not currently feasible because it would create a cyclic dependency between +/// optd-datafusion-bridge and optd-datafusion-repr +#[cfg(test)] +pub mod tests { + use std::{collections::HashMap, hash::Hash}; + + use arrow_schema::DataType; + use itertools::Itertools; + use optd_persistent::cost_model::interface::CatalogSource; + use serde::{Deserialize, Serialize}; + + use crate::{ + common::{ + nodes::ReprPredicateNode, + predicates::{ + attr_index_pred::AttrIndexPred, + bin_op_pred::{BinOpPred, BinOpType}, + cast_pred::CastPred, + constant_pred::{ConstantPred, ConstantType}, + in_list_pred::InListPred, + like_pred::LikePred, + list_pred::ListPred, + log_op_pred::{LogOpPred, LogOpType}, + un_op_pred::{UnOpPred, UnOpType}, + }, + properties::{ + attr_ref::{AttrRef, GroupAttrRefs}, + schema::Schema, + Attribute, + }, + types::GroupId, + values::Value, + }, + memo_ext::tests::{MemoGroupInfo, MockMemoExtImpl}, + stats::{ + utilities::{counter::Counter, simple_map::SimpleMap}, + AttributeCombValueStats, Distribution, MostCommonValues, + }, + storage::mock::{CostModelStorageMockManagerImpl, TableStats}, + }; + + use super::*; + + pub const TEST_TABLE1_ID: TableId = TableId(0); + pub const TEST_TABLE2_ID: TableId = TableId(1); + pub const TEST_TABLE3_ID: TableId = TableId(2); + pub const TEST_TABLE4_ID: TableId = TableId(3); + + pub const TEST_GROUP1_ID: GroupId = GroupId(0); + pub const TEST_GROUP2_ID: GroupId = GroupId(1); + pub const TEST_GROUP3_ID: GroupId = GroupId(2); + pub const TEST_GROUP4_ID: GroupId = GroupId(3); + + // This is base index rather than ref index. + pub const TEST_ATTR1_BASE_INDEX: u64 = 0; + pub const TEST_ATTR2_BASE_INDEX: u64 = 1; + pub const TEST_ATTR3_BASE_INDEX: u64 = 2; + + pub const TEST_ATTR1_NAME: &str = "attr1"; + pub const TEST_ATTR2_NAME: &str = "attr2"; + pub const TEST_ATTR3_NAME: &str = "attr3"; + pub const TEST_ATTR4_NAME: &str = "attr4"; + + pub type TestPerAttributeStats = AttributeCombValueStats; + // TODO: add tests for non-mock storage manager + pub type TestOptCostModelMock = CostModelImpl; + + // Use this method, we only create one group `TEST_GROUP1_ID` in the memo. + // We put the first attribute in the first table as the ref index 0 in the group. + // And put the second attribute in the first table as the ref index 1 in the group. + // etc. + // The orders of attributes and tables are defined by the order of their ids (smaller first). + pub fn create_mock_cost_model( + table_id: Vec, + // u64 should be base attribute index. + per_attribute_stats: Vec>, + row_counts: Vec>, + ) -> TestOptCostModelMock { + let attr_ids: Vec<(TableId, u64, Option)> = per_attribute_stats + .iter() + .enumerate() + .map(|(idx, m)| (table_id[idx], m)) + .flat_map(|(table_id, m)| { + m.iter() + .map(|(attr_idx, _)| (table_id, *attr_idx, None)) + .collect_vec() + }) + .sorted_by_key(|(table_id, attr_idx, _)| (*table_id, *attr_idx)) + .collect(); + create_mock_cost_model_with_memo( + table_id.clone(), + per_attribute_stats, + row_counts, + create_one_group_all_base_attributes_mock_memo(attr_ids), + ) + } + + pub fn create_mock_cost_model_with_attr_types( + table_id: Vec, + // u64 should be base attribute index. + per_attribute_stats: Vec>, + attributes: Vec>, + row_counts: Vec>, + ) -> TestOptCostModelMock { + let attr_ids: Vec<(TableId, u64, Option)> = attributes + .iter() + .enumerate() + .map(|(idx, m)| (table_id[idx], m)) + .flat_map(|(table_id, m)| { + m.iter() + .map(|(attr_idx, typ)| (table_id, *attr_idx, Some(*typ))) + .collect_vec() + }) + .sorted_by_key(|(table_id, attr_idx, _)| (*table_id, *attr_idx)) + .collect(); + create_mock_cost_model_with_memo( + table_id.clone(), + per_attribute_stats, + row_counts, + create_one_group_all_base_attributes_mock_memo(attr_ids), + ) + } + + pub fn create_mock_cost_model_with_memo( + table_id: Vec, + per_attribute_stats: Vec>, + row_counts: Vec>, + memo: MockMemoExtImpl, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + table_id + .into_iter() + .zip(per_attribute_stats) + .zip(row_counts) + .map(|((table_id, per_attr_stats), row_count)| { + ( + table_id, + TableStats::new( + row_count.unwrap_or(100), + per_attr_stats + .into_iter() + .map(|(attr_idx, stats)| (vec![attr_idx], stats)) + .collect(), + ), + ) + }) + .collect(), + ); + CostModelImpl::new(storage_manager, CatalogSource::Mock, Arc::new(memo)) + } + + // attributes: Vec<(TableId, AttrBaseIndex)> + pub fn create_one_group_all_base_attributes_mock_memo( + attr_ids: Vec<(TableId, u64, Option)>, + ) -> MockMemoExtImpl { + let group_info = MemoGroupInfo::new( + Schema::new( + attr_ids + .clone() + .into_iter() + .map(|(_, _, typ)| Attribute { + name: "attr".to_string(), + typ: typ.unwrap_or(ConstantType::Int64), + nullable: false, + }) + .collect(), + ), + GroupAttrRefs::new( + attr_ids + .into_iter() + .map(|(table_id, attr_base_index, _)| { + AttrRef::new_base_table_attr_ref(table_id, attr_base_index) + }) + .collect(), + None, + ), + ); + MockMemoExtImpl::from(HashMap::from([(TEST_GROUP1_ID, group_info)])) + } + + /// Create a cost model two tables, each with one attribute. Each attribute has 100 values. + pub fn create_two_table_mock_cost_model( + tbl1_per_attr_stats: TestPerAttributeStats, + tbl2_per_attr_stats: TestPerAttributeStats, + additional_memo: Option>, + ) -> TestOptCostModelMock { + create_two_table_mock_cost_model_custom_row_cnts( + tbl1_per_attr_stats, + tbl2_per_attr_stats, + 100, + 100, + additional_memo, + ) + } + + /// Create a cost model three tables, each with one attribute. Each attribute has 100 values. + pub fn create_three_table_mock_cost_model( + tbl1_per_column_stats: TestPerAttributeStats, + tbl2_per_column_stats: TestPerAttributeStats, + tbl3_per_column_stats: TestPerAttributeStats, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + vec![ + ( + TEST_TABLE1_ID, + TableStats::new( + 100, + vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE2_ID, + TableStats::new( + 100, + vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE3_ID, + TableStats::new( + 100, + vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), + ), + ), + ] + .into_iter() + .collect(), + ); + let memo = HashMap::from([ + ( + TEST_GROUP1_ID, + MemoGroupInfo::new( + vec![Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string())].into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP2_ID, + MemoGroupInfo::new( + vec![Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string())].into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP3_ID, + MemoGroupInfo::new( + vec![Attribute::new_non_null_int64(TEST_ATTR3_NAME.to_string())].into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE3_ID, 0)], + None, + ), + ), + ), + ]); + CostModelImpl::new( + storage_manager, + CatalogSource::Mock, + Arc::new(MockMemoExtImpl::from(memo)), + ) + } + + /// Create a cost model four tables, each with one attribute. Each attribute has 100 values. + pub fn create_four_table_mock_cost_model( + tbl1_per_column_stats: TestPerAttributeStats, + tbl2_per_column_stats: TestPerAttributeStats, + tbl3_per_column_stats: TestPerAttributeStats, + tbl4_per_column_stats: TestPerAttributeStats, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + vec![ + ( + TEST_TABLE1_ID, + TableStats::new( + 100, + vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE2_ID, + TableStats::new( + 100, + vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE3_ID, + TableStats::new( + 100, + vec![(vec![0], tbl3_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE4_ID, + TableStats::new( + 100, + vec![(vec![0], tbl4_per_column_stats)].into_iter().collect(), + ), + ), + ] + .into_iter() + .collect(), + ); + let memo = HashMap::from([ + ( + TEST_GROUP1_ID, + MemoGroupInfo::new( + vec![Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string())].into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP2_ID, + MemoGroupInfo::new( + vec![Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string())].into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP3_ID, + MemoGroupInfo::new( + vec![Attribute::new_non_null_int64(TEST_ATTR3_NAME.to_string())].into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE3_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP4_ID, + MemoGroupInfo::new( + vec![Attribute::new_non_null_int64(TEST_ATTR4_NAME.to_string())].into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE4_ID, 0)], + None, + ), + ), + ), + ]); + CostModelImpl::new( + storage_manager, + CatalogSource::Mock, + Arc::new(MockMemoExtImpl::from(memo)), + ) + } + + /// We need custom row counts because some join algorithms rely on the row cnt + pub fn create_two_table_mock_cost_model_custom_row_cnts( + tbl1_per_column_stats: TestPerAttributeStats, + tbl2_per_column_stats: TestPerAttributeStats, + tbl1_row_cnt: u64, + tbl2_row_cnt: u64, + additional_memo: Option>, + ) -> TestOptCostModelMock { + let storage_manager = CostModelStorageMockManagerImpl::new( + vec![ + ( + TEST_TABLE1_ID, + TableStats::new( + tbl1_row_cnt, + vec![(vec![0], tbl1_per_column_stats)].into_iter().collect(), + ), + ), + ( + TEST_TABLE2_ID, + TableStats::new( + tbl2_row_cnt, + vec![(vec![0], tbl2_per_column_stats)].into_iter().collect(), + ), + ), + ] + .into_iter() + .collect(), + ); + let mut memo = HashMap::from([ + ( + TEST_GROUP1_ID, + MemoGroupInfo::new( + vec![Attribute::new_non_null_int64(TEST_ATTR1_NAME.to_string())].into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE1_ID, 0)], + None, + ), + ), + ), + ( + TEST_GROUP2_ID, + MemoGroupInfo::new( + vec![Attribute::new_non_null_int64(TEST_ATTR2_NAME.to_string())].into(), + GroupAttrRefs::new( + vec![AttrRef::new_base_table_attr_ref(TEST_TABLE2_ID, 0)], + None, + ), + ), + ), + ]); + if let Some(additional_memo) = additional_memo { + memo.extend(additional_memo); + } + CostModelImpl::new( + storage_manager, + CatalogSource::Mock, + Arc::new(MockMemoExtImpl::from(memo)), + ) + } + + impl TestOptCostModelMock { + pub fn get_row_count(&self, table_id: TableId) -> u64 { + self.storage_manager + .per_table_stats_map + .get(&table_id) + .map(|stats| stats.row_cnt) + .unwrap_or(0) + } + + pub fn get_attr_refs(&self, group_id: GroupId) -> GroupAttrRefs { + self.memo.get_attribute_refs(group_id) + } + } + + pub fn attr_index(attr_index: u64) -> ArcPredicateNode { + AttrIndexPred::new(attr_index).into_pred_node() + } + + pub fn cnst(value: Value) -> ArcPredicateNode { + ConstantPred::new(value).into_pred_node() + } + + pub fn cast(child: ArcPredicateNode, cast_type: DataType) -> ArcPredicateNode { + CastPred::new(child, cast_type).into_pred_node() + } + + pub fn bin_op( + op_type: BinOpType, + left: ArcPredicateNode, + right: ArcPredicateNode, + ) -> ArcPredicateNode { + BinOpPred::new(left, right, op_type).into_pred_node() + } + + pub fn log_op(op_type: LogOpType, children: Vec) -> ArcPredicateNode { + LogOpPred::new(op_type, children).into_pred_node() + } + + pub fn un_op(op_type: UnOpType, child: ArcPredicateNode) -> ArcPredicateNode { + UnOpPred::new(child, op_type).into_pred_node() + } + + pub fn empty_list() -> ArcPredicateNode { + ListPred::new(vec![]).into_pred_node() + } + + pub fn list(children: Vec) -> ArcPredicateNode { + ListPred::new(children).into_pred_node() + } + + pub fn in_list(attr_idx: u64, list: Vec, negated: bool) -> InListPred { + InListPred::new( + attr_index(attr_idx), + ListPred::new(list.into_iter().map(cnst).collect_vec()), + negated, + ) + } + + pub fn like(attr_idx: u64, pattern: &str, negated: bool) -> LikePred { + LikePred::new( + negated, + false, + attr_index(attr_idx), + cnst(Value::String(pattern.into())), + ) + } + + pub(crate) fn empty_per_attr_stats() -> TestPerAttributeStats { + TestPerAttributeStats::new( + MostCommonValues::empty(), + Some(Distribution::empty()), + 0, + 0.0, + ) + } + + pub(crate) fn per_attr_stats_with_ndistinct(ndistinct: u64) -> TestPerAttributeStats { + TestPerAttributeStats::new( + MostCommonValues::empty(), + Some(Distribution::empty()), + ndistinct, + 0.0, + ) + } + + pub(crate) fn per_attr_stats_with_dist_and_ndistinct( + dist: Vec<(Value, f64)>, + ndistinct: u64, + ) -> TestPerAttributeStats { + TestPerAttributeStats::new( + MostCommonValues::empty(), + Some(Distribution::SimpleDistribution(SimpleMap::new(dist))), + ndistinct, + 0.0, + ) + } +} diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index 5417f1c..13774b2 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -33,7 +33,8 @@ pub struct Cost(pub Vec); /// Estimated statistic calculated by the cost model. /// It is the estimated output row count of the targeted expression. -pub struct EstimatedStatistic(pub u64); +#[derive(PartialEq, PartialOrd, Debug)] +pub struct EstimatedStatistic(pub f64); pub type CostModelResult = Result; @@ -42,12 +43,13 @@ pub enum SemanticError { // TODO: Add more error types UnknownStatisticType, VersionedStatisticNotFound, - AttributeNotFound(TableId, i32), // (table_id, attribute_base_index) + AttributeNotFound(TableId, u64), // (table_id, attribute_base_index) + // FIXME: not sure if this should be put here + InvalidPredicate(String), } #[derive(Debug)] pub enum CostModelError { - // TODO: Add more error types ORMError(BackendError), SemanticError(SemanticError), } @@ -58,6 +60,12 @@ impl From for CostModelError { } } +impl From for CostModelError { + fn from(err: SemanticError) -> Self { + CostModelError::SemanticError(err) + } +} + pub trait CostModel: 'static + Send + Sync { /// TODO: documentation fn compute_operation_cost( diff --git a/optd-cost-model/src/memo_ext.rs b/optd-cost-model/src/memo_ext.rs index 16cddca..c7827c5 100644 --- a/optd-cost-model/src/memo_ext.rs +++ b/optd-cost-model/src/memo_ext.rs @@ -1,5 +1,9 @@ use crate::common::{ - properties::{attr_ref::GroupAttrRefs, schema::Schema, Attribute}, + properties::{ + attr_ref::{AttrRef, GroupAttrRefs}, + schema::Schema, + Attribute, + }, types::GroupId, }; @@ -13,10 +17,78 @@ use crate::common::{ pub trait MemoExt: Send + Sync + 'static { /// Get the schema of a group in the memo. fn get_schema(&self, group_id: GroupId) -> Schema; - /// Get the attribute reference of a group in the memo. - fn get_attribute_ref(&self, group_id: GroupId) -> GroupAttrRefs; - /// Get the attribute information of a given attribute in a group in the memo. + /// Get the attribute info of a given attribute in a group in the memo. fn get_attribute_info(&self, group_id: GroupId, attr_ref_idx: u64) -> Attribute; + /// Get the attribute reference of a group in the memo. + fn get_attribute_refs(&self, group_id: GroupId) -> GroupAttrRefs; + /// Get the attribute reference of a given attribute in a group in the memo. + fn get_attribute_ref(&self, group_id: GroupId, attr_ref_idx: u64) -> AttrRef; // TODO: Figure out what other information is needed to compute the cost... } + +#[cfg(test)] +pub mod tests { + use std::collections::HashMap; + + use crate::common::{ + properties::{ + attr_ref::{AttrRef, GroupAttrRefs}, + schema::Schema, + Attribute, + }, + types::GroupId, + }; + + pub struct MemoGroupInfo { + pub schema: Schema, + pub attr_refs: GroupAttrRefs, + } + + impl MemoGroupInfo { + pub fn new(schema: Schema, attr_refs: GroupAttrRefs) -> Self { + Self { schema, attr_refs } + } + } + + #[derive(Default)] + pub struct MockMemoExtImpl { + memo: HashMap, + } + + impl MockMemoExtImpl { + pub fn add_group_info( + &mut self, + group_id: GroupId, + schema: Schema, + attr_ref: GroupAttrRefs, + ) { + self.memo + .insert(group_id, MemoGroupInfo::new(schema, attr_ref)); + } + } + + impl super::MemoExt for MockMemoExtImpl { + fn get_schema(&self, group_id: GroupId) -> Schema { + self.memo.get(&group_id).unwrap().schema.clone() + } + + fn get_attribute_info(&self, group_id: GroupId, attr_ref_idx: u64) -> Attribute { + self.memo.get(&group_id).unwrap().schema.attributes[attr_ref_idx as usize].clone() + } + + fn get_attribute_refs(&self, group_id: GroupId) -> GroupAttrRefs { + self.memo.get(&group_id).unwrap().attr_refs.clone() + } + + fn get_attribute_ref(&self, group_id: GroupId, attr_ref_idx: u64) -> AttrRef { + self.memo.get(&group_id).unwrap().attr_refs.attr_refs()[attr_ref_idx as usize].clone() + } + } + + impl From> for MockMemoExtImpl { + fn from(memo: HashMap) -> Self { + Self { memo } + } + } +} diff --git a/optd-cost-model/src/stats/mod.rs b/optd-cost-model/src/stats/mod.rs index 0b1396a..7ec2510 100644 --- a/optd-cost-model/src/stats/mod.rs +++ b/optd-cost-model/src/stats/mod.rs @@ -1,12 +1,15 @@ #![allow(unused)] mod arith_encoder; -pub mod counter; -pub mod tdigest; +pub mod utilities; use crate::common::values::Value; -use counter::Counter; use serde::{Deserialize, Serialize}; +use utilities::counter::Counter; +use utilities::{ + simple_map::{self, SimpleMap}, + tdigest::TDigest, +}; // Default n-distinct estimate for derived columns or columns lacking statistics pub const DEFAULT_NUM_DISTINCT: u64 = 200; @@ -27,10 +30,12 @@ pub const FIXED_CHAR_SEL_FACTOR: f64 = 0.2; pub type AttributeCombValue = Vec>; -#[derive(Serialize, Deserialize, Debug)] +// TODO: remove the clone, see the comment in the [`AttributeCombValueStats`] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "type")] pub enum MostCommonValues { Counter(Counter), + SimpleFrequency(SimpleMap), // Add more types here... } @@ -43,12 +48,14 @@ impl MostCommonValues { pub fn freq(&self, value: &AttributeCombValue) -> Option { match self { MostCommonValues::Counter(counter) => counter.frequencies().get(value).copied(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.get(value).copied(), } } pub fn total_freq(&self) -> f64 { match self { MostCommonValues::Counter(counter) => counter.frequencies().values().sum(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.values().sum(), } } @@ -60,6 +67,12 @@ impl MostCommonValues { .filter(|(val, _)| pred(val)) .map(|(_, freq)| freq) .sum(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map + .m + .iter() + .filter(|(val, _)| pred(val)) + .map(|(_, freq)| freq) + .sum(), } } @@ -67,14 +80,21 @@ impl MostCommonValues { pub fn cnt(&self) -> usize { match self { MostCommonValues::Counter(counter) => counter.frequencies().len(), + MostCommonValues::SimpleFrequency(simple_map) => simple_map.m.len(), } } + + pub fn empty() -> Self { + MostCommonValues::SimpleFrequency(SimpleMap::new(vec![])) + } } -#[derive(Serialize, Deserialize, Debug)] +// TODO: remove the clone, see the comment in the [`AttributeCombValueStats`] +#[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "type")] pub enum Distribution { - TDigest(tdigest::TDigest), + TDigest(TDigest), + SimpleDistribution(SimpleMap), // Add more types here... } @@ -89,11 +109,25 @@ impl Distribution { tdigest.centroids.len() as f64 * tdigest.cdf(value) / nb_rows as f64 } } + Distribution::SimpleDistribution(simple_distribution) => { + *simple_distribution.m.get(value).unwrap_or(&0.0) + } } } + + pub fn empty() -> Self { + Distribution::SimpleDistribution(SimpleMap::new(vec![])) + } } -#[derive(Serialize, Deserialize, Debug)] +// TODO: Remove the clone. Now I have to add this because +// persistent.rs doesn't have a memory cache, so we have to +// return AttributeCombValueStats rather than &AttributeCombValueStats. +// But this poses a problem for mock.rs when testing, since mock storage +// only has memory hash map, so we need to return a clone of AttributeCombValueStats. +// Later, if memory cache is added, we should change this to return a reference. +// **and** remove the clone. +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct AttributeCombValueStats { pub mcvs: MostCommonValues, // Does NOT contain full nulls. pub distr: Option, // Does NOT contain mcvs; optional. @@ -104,9 +138,9 @@ pub struct AttributeCombValueStats { impl AttributeCombValueStats { pub fn new( mcvs: MostCommonValues, + distr: Option, ndistinct: u64, null_frac: f64, - distr: Option, ) -> Self { Self { mcvs, diff --git a/optd-cost-model/src/stats/counter.rs b/optd-cost-model/src/stats/utilities/counter.rs similarity index 95% rename from optd-cost-model/src/stats/counter.rs rename to optd-cost-model/src/stats/utilities/counter.rs index 65a2d63..368700c 100644 --- a/optd-cost-model/src/stats/counter.rs +++ b/optd-cost-model/src/stats/utilities/counter.rs @@ -5,8 +5,9 @@ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; /// The Counter structure to track exact frequencies of fixed elements. +/// TODO: remove the clone, see the comment in the [`AttributeCombValueStats`] #[serde_with::serde_as] -#[derive(Default, Serialize, Deserialize, Debug)] +#[derive(Default, Serialize, Deserialize, Debug, Clone)] pub struct Counter { #[serde_as(as = "HashMap")] counts: HashMap, // The exact counts of an element T. @@ -32,6 +33,13 @@ where } } + pub fn new_from_existing(counts: HashMap, total_count: i32) -> Self { + Counter:: { + counts, + total_count, + } + } + // Inserts an element in the Counter if it is being tracked. fn insert_element(&mut self, elem: T, occ: i32) { if let Some(frequency) = self.counts.get_mut(&elem) { diff --git a/optd-cost-model/src/stats/utilities/mod.rs b/optd-cost-model/src/stats/utilities/mod.rs new file mode 100644 index 0000000..0a7903b --- /dev/null +++ b/optd-cost-model/src/stats/utilities/mod.rs @@ -0,0 +1,3 @@ +pub mod counter; +pub mod simple_map; +pub mod tdigest; diff --git a/optd-cost-model/src/stats/utilities/simple_map.rs b/optd-cost-model/src/stats/utilities/simple_map.rs new file mode 100644 index 0000000..d04439e --- /dev/null +++ b/optd-cost-model/src/stats/utilities/simple_map.rs @@ -0,0 +1,21 @@ +use std::collections::HashMap; +use std::hash::Hash; + +use serde::{Deserialize, Serialize}; + +use crate::common::values::Value; + +/// TODO: documentation +/// Now it is mainly for testing purposes. +#[derive(Clone, Serialize, Deserialize, Debug, Default)] +pub struct SimpleMap { + pub(crate) m: HashMap, +} + +impl SimpleMap { + pub fn new(v: Vec<(K, f64)>) -> Self { + Self { + m: v.into_iter().collect(), + } + } +} diff --git a/optd-cost-model/src/stats/tdigest.rs b/optd-cost-model/src/stats/utilities/tdigest.rs similarity index 99% rename from optd-cost-model/src/stats/tdigest.rs rename to optd-cost-model/src/stats/utilities/tdigest.rs index 83dc9b5..96a2269 100644 --- a/optd-cost-model/src/stats/tdigest.rs +++ b/optd-cost-model/src/stats/utilities/tdigest.rs @@ -15,9 +15,7 @@ use std::marker::PhantomData; use itertools::Itertools; use serde::{Deserialize, Serialize}; -use crate::common::values::Value; - -use super::arith_encoder; +use crate::{common::values::Value, stats::arith_encoder}; pub const DEFAULT_COMPRESSION: f64 = 200.0; diff --git a/optd-cost-model/src/storage/mock.rs b/optd-cost-model/src/storage/mock.rs new file mode 100644 index 0000000..d878bcb --- /dev/null +++ b/optd-cost-model/src/storage/mock.rs @@ -0,0 +1,61 @@ +#![allow(unused_variables, dead_code)] +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult}; + +use super::CostModelStorageManager; + +pub type AttrIndices = Vec; + +#[serde_with::serde_as] +#[derive(Serialize, Deserialize, Debug)] +pub struct TableStats { + pub row_cnt: u64, + #[serde_as(as = "HashMap")] + pub column_comb_stats: HashMap, +} + +impl TableStats { + pub fn new( + row_cnt: u64, + column_comb_stats: HashMap, + ) -> Self { + Self { + row_cnt, + column_comb_stats, + } + } +} + +pub type BaseTableStats = HashMap; + +pub struct CostModelStorageMockManagerImpl { + pub(crate) per_table_stats_map: BaseTableStats, +} + +impl CostModelStorageMockManagerImpl { + pub fn new(per_table_stats_map: BaseTableStats) -> Self { + Self { + per_table_stats_map, + } + } +} + +impl CostModelStorageManager for CostModelStorageMockManagerImpl { + async fn get_attributes_comb_statistics( + &self, + table_id: TableId, + attr_base_indices: &[u64], + ) -> CostModelResult> { + let table_stats = self.per_table_stats_map.get(&table_id); + match table_stats { + None => Ok(None), + Some(table_stats) => match table_stats.column_comb_stats.get(attr_base_indices) { + None => Ok(None), + Some(stats) => Ok(Some(stats.clone())), + }, + } + } +} diff --git a/optd-cost-model/src/storage/mod.rs b/optd-cost-model/src/storage/mod.rs new file mode 100644 index 0000000..d3d26cd --- /dev/null +++ b/optd-cost-model/src/storage/mod.rs @@ -0,0 +1,13 @@ +use crate::{common::types::TableId, stats::AttributeCombValueStats, CostModelResult}; + +pub mod mock; +pub mod persistent; + +#[trait_variant::make(Send)] +pub trait CostModelStorageManager { + async fn get_attributes_comb_statistics( + &self, + table_id: TableId, + attr_base_indices: &[u64], + ) -> CostModelResult>; +} diff --git a/optd-cost-model/src/storage.rs b/optd-cost-model/src/storage/persistent.rs similarity index 73% rename from optd-cost-model/src/storage.rs rename to optd-cost-model/src/storage/persistent.rs index 5538618..dede7f3 100644 --- a/optd-cost-model/src/storage.rs +++ b/optd-cost-model/src/storage/persistent.rs @@ -1,43 +1,31 @@ #![allow(unused_variables)] use std::sync::Arc; -use optd_persistent::{ - cost_model::interface::{Attr, StatType}, - CostModelStorageLayer, -}; +use optd_persistent::{cost_model::interface::StatType, CostModelStorageLayer}; use crate::{ common::types::TableId, - stats::{counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, + stats::{utilities::counter::Counter, AttributeCombValueStats, Distribution, MostCommonValues}, CostModelResult, }; +use super::CostModelStorageManager; + /// TODO: documentation -pub struct CostModelStorageManager { +pub struct CostModelStorageManagerImpl { pub backend_manager: Arc, // TODO: in-memory cache } -impl CostModelStorageManager { +impl CostModelStorageManagerImpl { pub fn new(backend_manager: Arc) -> Self { Self { backend_manager } } +} - /// Gets the attribute information for a given table and attribute base index. - /// - /// TODO: if we have memory cache, - /// we should add the reference. (&Attr) - pub async fn get_attribute_info( - &self, - table_id: TableId, - attr_base_index: i32, - ) -> CostModelResult> { - Ok(self - .backend_manager - .get_attribute(table_id.into(), attr_base_index) - .await?) - } - +impl CostModelStorageManager + for CostModelStorageManagerImpl +{ /// Gets the latest statistics for a given table. /// /// TODO: Currently, in `AttributeCombValueStats`, only `Distribution` is optional. @@ -50,16 +38,19 @@ impl CostModelStorageManager { /// /// TODO: Shall we pass in an epoch here to make sure that the statistics are from the same /// epoch? - pub async fn get_attributes_comb_statistics( + /// + /// TODO(IMPORTANT): what if table is a derived (temporary) table? And what if + /// the attribute is a derived attribute? + async fn get_attributes_comb_statistics( &self, table_id: TableId, - attr_base_indices: &[i32], + attr_base_indices: &[u64], ) -> CostModelResult> { let dist: Option = self .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::Distribution, None, ) @@ -70,7 +61,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::MostCommonValues, None, ) @@ -82,7 +73,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::Cardinality, None, ) @@ -94,7 +85,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::TableRowCount, None, ) @@ -105,7 +96,7 @@ impl CostModelStorageManager { .backend_manager .get_stats_for_attr_indices_based( table_id.into(), - attr_base_indices.to_vec(), + attr_base_indices.iter().map(|&x| x as i32).collect(), StatType::NonNullCount, None, ) @@ -123,9 +114,9 @@ impl CostModelStorageManager { }; Ok(Some(AttributeCombValueStats::new( - mcvs, ndistinct, null_frac, dist, + mcvs, dist, ndistinct, null_frac, ))) } -} -// TODO: add some tests, especially cover the error cases. + // TODO: Support querying for a specific type of statistics. +} diff --git a/optd-persistent/Cargo.toml b/optd-persistent/Cargo.toml index c576100..50af728 100644 --- a/optd-persistent/Cargo.toml +++ b/optd-persistent/Cargo.toml @@ -21,3 +21,4 @@ trait-variant = "0.1.2" async-trait = "0.1.43" async-stream = "0.3.1" strum = "0.26.1" +num_enum = "0.7.3" diff --git a/optd-persistent/src/cost_model/interface.rs b/optd-persistent/src/cost_model/interface.rs index a03087f..ee767d7 100644 --- a/optd-persistent/src/cost_model/interface.rs +++ b/optd-persistent/src/cost_model/interface.rs @@ -4,6 +4,7 @@ use crate::entities::cascades_group; use crate::entities::logical_expression; use crate::entities::physical_expression; use crate::StorageResult; +use num_enum::{IntoPrimitive, TryFromPrimitive}; use sea_orm::prelude::Json; use sea_orm::*; use sea_orm_migration::prelude::*; @@ -16,6 +17,7 @@ pub type AttrId = i32; pub type ExprId = i32; pub type EpochId = i32; pub type StatId = i32; +pub type AttrIndex = i32; /// TODO: documentation pub enum CatalogSource { @@ -24,8 +26,10 @@ pub enum CatalogSource { } /// TODO: documentation +#[repr(i32)] +#[derive(Copy, Clone, Debug, PartialEq, IntoPrimitive, TryFromPrimitive)] pub enum AttrType { - Integer, + Integer = 1, Float, Varchar, Boolean, @@ -96,7 +100,7 @@ pub struct Attr { pub table_id: i32, pub name: String, pub compression_method: String, - pub attr_type: i32, + pub attr_type: AttrType, pub base_index: i32, pub nullable: bool, } @@ -149,7 +153,7 @@ pub trait CostModelStorageLayer { async fn get_stats_for_attr_indices_based( &self, table_id: TableId, - attr_base_indices: Vec, + attr_base_indices: Vec, stat_type: StatType, epoch_id: Option, ) -> StorageResult>; @@ -165,6 +169,6 @@ pub trait CostModelStorageLayer { async fn get_attribute( &self, table_id: TableId, - attribute_base_index: i32, + attribute_base_index: AttrIndex, ) -> StorageResult>; } diff --git a/optd-persistent/src/cost_model/orm.rs b/optd-persistent/src/cost_model/orm.rs index 5b56476..d5b7ad6 100644 --- a/optd-persistent/src/cost_model/orm.rs +++ b/optd-persistent/src/cost_model/orm.rs @@ -14,7 +14,8 @@ use serde_json::json; use super::catalog::mock_catalog::{self, MockCatalog}; use super::interface::{ - Attr, AttrId, CatalogSource, EpochId, EpochOption, ExprId, Stat, StatId, StatType, TableId, + Attr, AttrId, AttrIndex, AttrType, CatalogSource, EpochId, EpochOption, ExprId, Stat, StatId, + StatType, TableId, }; impl BackendManager { @@ -434,7 +435,7 @@ impl CostModelStorageLayer for BackendManager { async fn get_stats_for_attr_indices_based( &self, table_id: TableId, - attr_base_indices: Vec, + attr_base_indices: Vec, stat_type: StatType, epoch_id: Option, ) -> StorageResult> { @@ -549,21 +550,30 @@ impl CostModelStorageLayer for BackendManager { async fn get_attribute( &self, table_id: TableId, - attribute_base_index: i32, + attribute_base_index: AttrIndex, ) -> StorageResult> { - Ok(Attribute::find() + let attr_res = Attribute::find() .filter(attribute::Column::TableId.eq(table_id)) .filter(attribute::Column::BaseAttributeNumber.eq(attribute_base_index)) .one(&self.db) - .await? - .map(|attr| Attr { - table_id, - name: attr.name, - compression_method: attr.compression_method, - attr_type: attr.variant_tag, - base_index: attribute_base_index, - nullable: !attr.is_not_null, - })) + .await?; + match attr_res { + Some(attr) => match AttrType::try_from(attr.variant_tag) { + Ok(attr_type) => Ok(Some(Attr { + table_id: attr.table_id, + name: attr.name, + compression_method: attr.compression_method, + attr_type, + base_index: attr.base_attribute_number, + nullable: attr.is_not_null, + })), + Err(_) => Err(BackendError::BackendError(format!( + "Failed to convert variant tag {} to AttrType", + attr.variant_tag + ))), + }, + None => Ok(None), + } } }