diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 0f45d51835f4..22d2f2187dd0 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -28,18 +28,18 @@ runs: - name: Install Build Dependencies shell: bash run: | - RETRY="ci/scripts/retry" - "${RETRY}" apt-get update - "${RETRY}" apt-get install -y protobuf-compiler + RETRY=("ci/scripts/retry" timeout 120) + "${RETRY[@]}" apt-get update + "${RETRY[@]}" apt-get install -y protobuf-compiler - name: Setup Rust toolchain shell: bash # rustfmt is needed for the substrait build script run: | - RETRY="ci/scripts/retry" + RETRY=("ci/scripts/retry" timeout 120) echo "Installing ${{ inputs.rust-version }}" - "${RETRY}" rustup toolchain install ${{ inputs.rust-version }} - "${RETRY}" rustup default ${{ inputs.rust-version }} - "${RETRY}" rustup component add rustfmt + "${RETRY[@]}" rustup toolchain install ${{ inputs.rust-version }} + "${RETRY[@]}" rustup default ${{ inputs.rust-version }} + "${RETRY[@]}" rustup component add rustfmt - name: Configure rust runtime env uses: ./.github/actions/setup-rust-runtime - name: Fixup git permissions diff --git a/Cargo.toml b/Cargo.toml index 001153915632..e947afff8f4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,22 +74,22 @@ version = "43.0.0" ahash = { version = "0.8", default-features = false, features = [ "runtime-rng", ] } -arrow = { version = "53.2.0", features = [ +arrow = { version = "53.3.0", features = [ "prettyprint", ] } -arrow-array = { version = "53.2.0", default-features = false, features = [ +arrow-array = { version = "53.3.0", default-features = false, features = [ "chrono-tz", ] } -arrow-buffer = { version = "53.2.0", default-features = false } -arrow-flight = { version = "53.2.0", features = [ +arrow-buffer = { version = "53.3.0", default-features = false } +arrow-flight = { version = "53.3.0", features = [ "flight-sql-experimental", ] } -arrow-ipc = { version = "53.2.0", default-features = false, features = [ +arrow-ipc = { version = "53.3.0", default-features = false, features = [ "lz4", ] } -arrow-ord = { version = "53.2.0", default-features = false } -arrow-schema = { version = "53.2.0", default-features = false } -arrow-string = { version = "53.2.0", default-features = false } +arrow-ord = { version = "53.3.0", default-features = false } +arrow-schema = { version = "53.3.0", default-features = false } +arrow-string = { version = "53.3.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" @@ -131,7 +131,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.11.0", default-features = false } parking_lot = "0.12" -parquet = { version = "53.2.0", default-features = false, features = [ +parquet = { version = "53.3.0", default-features = false, features = [ "arrow", "async", "object_store", diff --git a/ci/scripts/retry b/ci/scripts/retry index 0569dea58c94..411dc532ca52 100755 --- a/ci/scripts/retry +++ b/ci/scripts/retry @@ -7,7 +7,7 @@ x() { "$@" } -max_retry_time_seconds=$(( 3 * 60 )) +max_retry_time_seconds=$(( 5 * 60 )) retry_delay_seconds=10 END=$(( $(date +%s) + ${max_retry_time_seconds} )) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index bfd0411798c9..8afb096df55f 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -173,9 +173,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4caf25cdc4a985f91df42ed9e9308e1adbcd341a31a72605c697033fcef163e3" +checksum = "c91839b07e474b3995035fd8ac33ee54f9c9ccbbb1ea33d9909c71bffdf1259d" dependencies = [ "arrow-arith", "arrow-array", @@ -194,9 +194,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91f2dfd1a7ec0aca967dfaa616096aec49779adc8eccec005e2f5e4111b1192a" +checksum = "855c57c4efd26722b044dcd3e348252560e3e0333087fb9f6479dc0bf744054f" dependencies = [ "arrow-array", "arrow-buffer", @@ -209,9 +209,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d39387ca628be747394890a6e47f138ceac1aa912eab64f02519fed24b637af8" +checksum = "bd03279cea46569acf9295f6224fbc370c5df184b4d2ecfe97ccb131d5615a7f" dependencies = [ "ahash", "arrow-buffer", @@ -220,15 +220,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.1", "num", ] [[package]] name = "arrow-buffer" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e51e05228852ffe3eb391ce7178a0f97d2cf80cc6ef91d3c4a6b3cb688049ec" +checksum = "9e4a9b9b1d6d7117f6138e13bc4dd5daa7f94e671b70e8c9c4dc37b4f5ecfc16" dependencies = [ "bytes", "half", @@ -237,9 +237,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d09aea56ec9fa267f3f3f6cdab67d8a9974cbba90b3aa38c8fe9d0bb071bd8c1" +checksum = "bc70e39916e60c5b7af7a8e2719e3ae589326039e1e863675a008bee5ffe90fd" dependencies = [ "arrow-array", "arrow-buffer", @@ -258,9 +258,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c07b5232be87d115fde73e32f2ca7f1b353bff1b44ac422d3c6fc6ae38f11f0d" +checksum = "789b2af43c1049b03a8d088ff6b2257cdcea1756cd76b174b1f2600356771b97" dependencies = [ "arrow-array", "arrow-buffer", @@ -277,9 +277,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98ae0af50890b494cebd7d6b04b35e896205c1d1df7b29a6272c5d0d0249ef5" +checksum = "e4e75edf21ffd53744a9b8e3ed11101f610e7ceb1a29860432824f1834a1f623" dependencies = [ "arrow-buffer", "arrow-schema", @@ -289,9 +289,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ed91bdeaff5a1c00d28d8f73466bcb64d32bbd7093b5a30156b4b9f4dba3eee" +checksum = "d186a909dece9160bf8312f5124d797884f608ef5435a36d9d608e0b2a9bcbf8" dependencies = [ "arrow-array", "arrow-buffer", @@ -304,9 +304,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0471f51260a5309307e5d409c9dc70aede1cd9cf1d4ff0f0a1e8e1a2dd0e0d3c" +checksum = "b66ff2fedc1222942d0bd2fd391cb14a85baa3857be95c9373179bd616753b85" dependencies = [ "arrow-array", "arrow-buffer", @@ -324,9 +324,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2883d7035e0b600fb4c30ce1e50e66e53d8656aa729f2bfa4b51d359cf3ded52" +checksum = "ece7b5bc1180e6d82d1a60e1688c199829e8842e38497563c3ab6ea813e527fd" dependencies = [ "arrow-array", "arrow-buffer", @@ -339,9 +339,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "552907e8e587a6fde4f8843fd7a27a576a260f65dab6c065741ea79f633fc5be" +checksum = "745c114c8f0e8ce211c83389270de6fbe96a9088a7b32c2a041258a443fe83ff" dependencies = [ "ahash", "arrow-array", @@ -353,15 +353,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "539ada65246b949bd99ffa0881a9a15a4a529448af1a07a9838dd78617dafab1" +checksum = "b95513080e728e4cec37f1ff5af4f12c9688d47795d17cda80b6ec2cf74d4678" [[package]] name = "arrow-select" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6259e566b752da6dceab91766ed8b2e67bf6270eb9ad8a6e07a33c1bede2b125" +checksum = "8e415279094ea70323c032c6e739c48ad8d80e78a09bef7117b8718ad5bf3722" dependencies = [ "ahash", "arrow-array", @@ -373,9 +373,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3179ccbd18ebf04277a095ba7321b93fd1f774f18816bd5f6b3ce2f594edb6c" +checksum = "11d956cae7002eb8d83a27dbd34daaea1cf5b75852f0b84deb4d93a276e92bbf" dependencies = [ "arrow-array", "arrow-buffer", @@ -567,9 +567,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.49.0" +version = "1.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53dcf5e7d9bd1517b8b998e170e650047cea8a2b85fe1835abe3210713e541b7" +checksum = "6ada54e5f26ac246dc79727def52f7f8ed38915cb47781e2a72213957dc3a7d5" dependencies = [ "aws-credential-types", "aws-runtime", @@ -857,9 +857,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" +checksum = "1a68f1f47cdf0ec8ee4b941b2eee2a80cb796db73118c0dd09ac63fbe405be22" dependencies = [ "memchr", "regex-automata", @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aeb932158bd710538c73702db6945cb68a8fb08c519e6e12706b94263b36db8" +checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" dependencies = [ "jobserver", "libc", @@ -980,9 +980,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" +checksum = "fb3b4b9e5a7c7514dfa52869339ee98b3156b0bfb4e8a77c4ff4babb64b1604f" dependencies = [ "clap_builder", "clap_derive", @@ -990,9 +990,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.20" +version = "4.5.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +checksum = "b17a95aa67cc7b5ebd32aa5370189aa0d79069ef1c64ce893bd30fb24bff20ec" dependencies = [ "anstream", "anstyle", @@ -1014,9 +1014,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.2" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" +checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7" [[package]] name = "clipboard-win" @@ -1035,9 +1035,9 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "comfy-table" -version = "7.1.2" +version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d05af1e006a2407bedef5af410552494ce5be9090444dbbcb57258c1af3d56" +checksum = "24f165e7b643266ea80cb858aed492ad9280e3e05ce24d4a99d7d7b889b6a4d9" dependencies = [ "strum 0.26.3", "strum_macros 0.26.4", @@ -1158,9 +1158,9 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" +checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" dependencies = [ "quote", "syn", @@ -1537,9 +1537,11 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr-common", + "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-plan", "itertools", + "log", "recursive", ] @@ -1749,9 +1751,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.34" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" dependencies = [ "crc32fast", "miniz_oxide", @@ -1932,9 +1934,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" dependencies = [ "atomic-waker", "bytes", @@ -2118,14 +2120,14 @@ dependencies = [ [[package]] name = "hyper" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" +checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -2160,9 +2162,9 @@ checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.5.0", + "hyper 1.5.1", "hyper-util", - "rustls 0.23.16", + "rustls 0.23.17", "rustls-native-certs 0.8.0", "rustls-pki-types", "tokio", @@ -2181,7 +2183,7 @@ dependencies = [ "futures-util", "http 1.1.0", "http-body 1.0.1", - "hyper 1.5.0", + "hyper 1.5.1", "pin-project-lite", "socket2", "tokio", @@ -2390,9 +2392,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2" [[package]] name = "jobserver" @@ -2484,9 +2486,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.162" +version = "0.2.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18d287de67fe55fd7e1581fe933d965a5a9477b38e949cfa9f8574ef01506398" +checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f" [[package]] name = "libflate" @@ -2776,7 +2778,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.5.0", + "hyper 1.5.1", "itertools", "md-5", "parking_lot", @@ -2853,9 +2855,9 @@ dependencies = [ [[package]] name = "parquet" -version = "53.2.0" +version = "53.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dea02606ba6f5e856561d8d507dba8bac060aefca2a6c0f1aa1d361fed91ff3e" +checksum = "2b449890367085eb65d7d3321540abc3d7babbd179ce31df0016e90719114191" dependencies = [ "ahash", "arrow-array", @@ -2872,7 +2874,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.14.5", + "hashbrown 0.15.1", "lz4_flex", "num", "num-bigint", @@ -3073,7 +3075,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.16", + "rustls 0.23.17", "socket2", "thiserror 2.0.3", "tokio", @@ -3091,7 +3093,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls 0.23.16", + "rustls 0.23.17", "rustls-pki-types", "slab", "thiserror 2.0.3", @@ -3254,11 +3256,11 @@ dependencies = [ "bytes", "futures-core", "futures-util", - "h2 0.4.6", + "h2 0.4.7", "http 1.1.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.0", + "hyper 1.5.1", "hyper-rustls 0.27.3", "hyper-util", "ipnet", @@ -3269,7 +3271,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.16", + "rustls 0.23.17", "rustls-native-certs 0.8.0", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -3363,9 +3365,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.40" +version = "0.38.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99e4ea3e1cdc4b559b8e5650f9c8e5998e3e5c1343b4eaf034565f32318d63c0" +checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" dependencies = [ "bitflags 2.6.0", "errno", @@ -3388,9 +3390,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "7f1a745511c54ba6d4465e8d5dfbd81b45791756de28d4981af70d6dca128f1e" dependencies = [ "once_cell", "ring", @@ -3518,9 +3520,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" dependencies = [ "windows-sys 0.59.0", ] @@ -3598,9 +3600,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.132" +version = "1.0.133" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377" dependencies = [ "itoa", "memchr", @@ -3822,9 +3824,9 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" dependencies = [ "futures-core", ] @@ -4019,7 +4021,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.16", + "rustls 0.23.17", "rustls-pki-types", "tokio", ] @@ -4135,9 +4137,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-ident" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] name = "unicode-segmentation" diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index ca3a2bef882e..d771930de25d 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -247,6 +247,9 @@ pub trait TableProvider: Debug + Sync + Send { } /// Get statistics for this table, if available + /// Although not presently used in mainline DataFusion, this allows implementation specific + /// behavior for downstream repositories, in conjunction with specialized optimizer rules to + /// perform operations such as re-ordering of joins. fn statistics(&self) -> Option { None } diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index d855198fa7c6..c47ed2815906 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -21,7 +21,7 @@ use arrow_schema::{Field, FieldRef}; use crate::error::_schema_err; use crate::utils::{parse_identifiers_normalized, quote_identifier}; -use crate::{DFSchema, DataFusionError, Result, SchemaError, TableReference}; +use crate::{DFSchema, Result, SchemaError, TableReference}; use std::collections::HashSet; use std::convert::Infallible; use std::fmt; diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 05988d6c6da4..4fac7298c455 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -598,9 +598,9 @@ macro_rules! arrow_err { #[macro_export] macro_rules! schema_datafusion_err { ($ERR:expr) => { - DataFusionError::SchemaError( + $crate::error::DataFusionError::SchemaError( $ERR, - Box::new(Some(DataFusionError::get_back_trace())), + Box::new(Some($crate::error::DataFusionError::get_back_trace())), ) }; } @@ -609,9 +609,9 @@ macro_rules! schema_datafusion_err { #[macro_export] macro_rules! schema_err { ($ERR:expr) => { - Err(DataFusionError::SchemaError( + Err($crate::error::DataFusionError::SchemaError( $ERR, - Box::new(Some(DataFusionError::get_back_trace())), + Box::new(Some($crate::error::DataFusionError::get_back_trace())), )) }; } diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 8bd646626e06..e18d70844d32 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -32,7 +32,7 @@ use arrow_buffer::IntervalMonthDayNano; use crate::cast::{ as_binary_view_array, as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, as_large_list_array, as_list_array, as_map_array, - as_primitive_array, as_string_array, as_string_view_array, as_struct_array, + as_string_array, as_string_view_array, as_struct_array, }; use crate::error::Result; #[cfg(not(feature = "force_hash_collisions"))] @@ -392,14 +392,6 @@ pub fn create_hashes<'a>( let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); hash_array(array, random_state, hashes_buffer, rehash) } - DataType::Decimal128(_, _) => { - let array = as_primitive_array::(array)?; - hash_array_primitive(array, random_state, hashes_buffer, rehash) - } - DataType::Decimal256(_, _) => { - let array = as_primitive_array::(array)?; - hash_array_primitive(array, random_state, hashes_buffer, rehash) - } DataType::Dictionary(_, _) => downcast_dictionary_array! { array => hash_dictionary(array, random_state, hashes_buffer, rehash)?, _ => unreachable!() diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c8ec7f18339a..0c153583e34b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -17,11 +17,12 @@ //! [`TreeNode`] for visiting and rewriting expression and plan trees +use crate::Result; use recursive::recursive; +use std::collections::HashMap; +use std::hash::Hash; use std::sync::Arc; -use crate::Result; - /// These macros are used to determine continuation during transforming traversals. macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ @@ -769,6 +770,297 @@ impl Transformed { } } +/// [`TreeNodeContainer`] contains elements that a function can be applied on or mapped. +/// The elements of the container are siblings so the continuation rules are similar to +/// [`TreeNodeRecursion::visit_sibling`] / [`Transformed::transform_sibling`]. +pub trait TreeNodeContainer<'a, T: 'a>: Sized { + /// Applies `f` to all elements of the container. + /// This method is usually called from [`TreeNode::apply_children`] implementations as + /// a node is actually a container of the node's children. + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result; + + /// Maps all elements of the container with `f`. + /// This method is usually called from [`TreeNode::map_children`] implementations as + /// a node is actually a container of the node's children. + fn map_elements Result>>( + self, + f: F, + ) -> Result>; +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Box { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.as_ref().apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + (*self).map_elements(f)?.map_data(|c| Ok(Self::new(c))) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T> + Clone> TreeNodeContainer<'a, T> for Arc { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.as_ref().apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + Arc::unwrap_or_clone(self) + .map_elements(f)? + .map_data(|c| Ok(Arc::new(c))) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Option { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + match self { + Some(t) => t.apply_elements(f), + None => Ok(TreeNodeRecursion::Continue), + } + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.map_or(Ok(Transformed::no(None)), |c| { + c.map_elements(f)?.map_data(|c| Ok(Some(c))) + }) + } +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> for Vec { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + self.into_iter() + .map(|c| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + c.map_elements(&mut f).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + }) + } + TreeNodeRecursion::Stop => Ok(c), + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) + } +} + +impl<'a, T: 'a, K: Eq + Hash, C: TreeNodeContainer<'a, T>> TreeNodeContainer<'a, T> + for HashMap +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self.values() { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + self.into_iter() + .map(|(k, c)| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + c.map_elements(&mut f).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + (k, result.data) + }) + } + TreeNodeRecursion::Stop => Ok((k, c)), + }) + .collect::>>() + .map(|data| Transformed::new(data, transformed, tnr)) + } +} + +impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> + TreeNodeContainer<'a, T> for (C0, C1) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1)))? + .transform_sibling(|(new_c0, c1)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1))) + }) + } +} + +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + > TreeNodeContainer<'a, T> for (C0, C1, C2) +{ + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f)) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + self.0 + .map_elements(&mut f)? + .map_data(|new_c0| Ok((new_c0, self.1, self.2)))? + .transform_sibling(|(new_c0, c1, c2)| { + c1.map_elements(&mut f)? + .map_data(|new_c1| Ok((new_c0, new_c1, c2))) + })? + .transform_sibling(|(new_c0, new_c1, c2)| { + c2.map_elements(&mut f)? + .map_data(|new_c2| Ok((new_c0, new_c1, new_c2))) + }) + } +} + +/// [`TreeNodeRefContainer`] contains references to elements that a function can be +/// applied on. The elements of the container are siblings so the continuation rules are +/// similar to [`TreeNodeRecursion::visit_sibling`]. +/// +/// This container is similar to [`TreeNodeContainer`], but the lifetime of the reference +/// elements (`T`) are not derived from the container's lifetime. +/// A typical usage of this container is in `Expr::apply_children` when we need to +/// construct a temporary container to be able to call `apply_ref_elements` on a +/// collection of tree node references. But in that case the container's temporary +/// lifetime is different to the lifetime of tree nodes that we put into it. +/// Please find an example usecase in `Expr::apply_children` with the `Expr::Case` case. +/// +/// Most of the cases we don't need to create a temporary container with +/// `TreeNodeRefContainer`, but we can just call `TreeNodeContainer::apply_elements`. +/// Please find an example usecase in `Expr::apply_children` with the `Expr::GroupingSet` +/// case. +pub trait TreeNodeRefContainer<'a, T: 'a>: Sized { + /// Applies `f` to all elements of the container. + /// This method is usually called from [`TreeNode::apply_children`] implementations as + /// a node is actually a container of the node's children. + fn apply_ref_elements Result>( + &self, + f: F, + ) -> Result; +} + +impl<'a, T: 'a, C: TreeNodeContainer<'a, T>> TreeNodeRefContainer<'a, T> for Vec<&'a C> { + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; + for c in self { + tnr = c.apply_elements(&mut f)?; + match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + } + Ok(tnr) + } +} + +impl<'a, T: 'a, C0: TreeNodeContainer<'a, T>, C1: TreeNodeContainer<'a, T>> + TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f)) + } +} + +impl< + 'a, + T: 'a, + C0: TreeNodeContainer<'a, T>, + C1: TreeNodeContainer<'a, T>, + C2: TreeNodeContainer<'a, T>, + > TreeNodeRefContainer<'a, T> for (&'a C0, &'a C1, &'a C2) +{ + fn apply_ref_elements Result>( + &self, + mut f: F, + ) -> Result { + self.0 + .apply_elements(&mut f)? + .visit_sibling(|| self.1.apply_elements(&mut f))? + .visit_sibling(|| self.2.apply_elements(&mut f)) + } +} + /// Transformation helper to process a sequence of iterable tree nodes that are siblings. pub trait TreeNodeIterator: Iterator { /// Apples `f` to each item in this iterator @@ -843,50 +1135,6 @@ impl TreeNodeIterator for I { } } -/// Transformation helper to process a heterogeneous sequence of tree node containing -/// expressions. -/// -/// This macro is very similar to [TreeNodeIterator::map_until_stop_and_collect] to -/// process nodes that are siblings, but it accepts an initial transformation (`F0`) and -/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its -/// transformation (`F`). -/// -/// The macro builds up a tuple that contains `Transformed.data` result of `F0` as the -/// first element and further elements from the sequence of pairs. An element from a pair -/// is either the value of `EXPR` or the `Transformed.data` result of `F`, depending on -/// the `Transformed.tnr` result of previous `F`s (`F0` initially). -/// -/// # Returns -/// Error if any of the transformations returns an error -/// -/// Ok(Transformed<(data0, ..., dataN)>) such that: -/// 1. `transformed` is true if any of the transformations had transformed true -/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from `F0` and -/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` of `F` -/// 3. `tnr` from `F0` or the last invocation of `F` -#[macro_export] -macro_rules! map_until_stop_and_collect { - ($F0:expr, $($EXPR:expr, $F:expr),*) => {{ - $F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| { - let all_datas = ( - data0, - $( - if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump { - $F.map(|result| { - tnr = result.tnr; - transformed |= result.transformed; - result.data - })? - } else { - $EXPR - }, - )* - ); - Ok(Transformed::new(all_datas, transformed, tnr)) - }) - }} -} - /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. /// /// # Example @@ -1021,7 +1269,7 @@ pub(crate) mod tests { use std::fmt::Display; use crate::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use crate::Result; @@ -1054,7 +1302,7 @@ pub(crate) mod tests { &'n self, f: F, ) -> Result { - self.children.iter().apply_until_stop(f) + self.children.apply_elements(f) } fn map_children Result>>( @@ -1063,8 +1311,7 @@ pub(crate) mod tests { ) -> Result> { Ok(self .children - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(|new_children| Self { children: new_children, ..self @@ -1072,6 +1319,22 @@ pub(crate) mod tests { } } + impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } + } + // J // | // I diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index df5ede5e8391..8df5ef82cd0c 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -46,7 +46,6 @@ use object_store::{GetOptions, GetRange, GetResultPayload, ObjectStore}; /// Execution plan for scanning Arrow data source #[derive(Debug, Clone)] -#[allow(dead_code)] pub struct ArrowExec { base_config: FileScanConfig, projected_statistics: Statistics, diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 2e83be212f8b..68d219ef0e5e 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -34,7 +34,6 @@ use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; /// Execution plan for scanning Avro data source #[derive(Debug, Clone)] -#[allow(dead_code)] pub struct AvroExec { base_config: FileScanConfig, projected_statistics: Statistics, diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index 9fc081dd5329..e99cf8222381 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -296,7 +296,9 @@ impl SessionState { .resolve(&catalog.default_catalog, &catalog.default_schema) } - pub(crate) fn schema_for_ref( + /// Retrieve the [`SchemaProvider`] for a specific [`TableReference`], if it + /// esists. + pub fn schema_for_ref( &self, table_ref: impl Into, ) -> datafusion_common::Result> { diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index b2df32a62e44..d049e774d7c6 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -382,14 +382,14 @@ //! //! Calling [`execute`] produces 1 or more partitions of data, //! as a [`SendableRecordBatchStream`], which implements a pull based execution -//! API. Calling `.next().await` will incrementally compute and return the next +//! API. Calling [`next()`]`.await` will incrementally compute and return the next //! [`RecordBatch`]. Balanced parallelism is achieved using [Volcano style] //! "Exchange" operations implemented by [`RepartitionExec`]. //! //! While some recent research such as [Morsel-Driven Parallelism] describes challenges //! with the pull style Volcano execution model on NUMA architectures, in practice DataFusion achieves -//! similar scalability as systems that use morsel driven approach such as DuckDB. -//! See the [DataFusion paper submitted to SIGMOD] for more details. +//! similar scalability as systems that use push driven schedulers [such as DuckDB]. +//! See the [DataFusion paper in SIGMOD 2024] for more details. //! //! [`execute`]: physical_plan::ExecutionPlan::execute //! [`SendableRecordBatchStream`]: crate::physical_plan::SendableRecordBatchStream @@ -403,24 +403,189 @@ //! [`RepartitionExec`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/repartition/struct.RepartitionExec.html //! [Volcano style]: https://w6113.github.io/files/papers/volcanoparallelism-89.pdf //! [Morsel-Driven Parallelism]: https://db.in.tum.de/~leis/papers/morsels.pdf -//! [DataFusion paper submitted SIGMOD]: https://github.com/apache/datafusion/files/13874720/DataFusion_Query_Engine___SIGMOD_2024.pdf +//! [DataFusion paper in SIGMOD 2024]: https://github.com/apache/datafusion/files/15149988/DataFusion_Query_Engine___SIGMOD_2024-FINAL-mk4.pdf +//! [such as DuckDB]: https://github.com/duckdb/duckdb/issues/1583 //! [implementors of `ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#implementors //! -//! ## Thread Scheduling +//! ## Streaming Execution //! -//! DataFusion incrementally computes output from a [`SendableRecordBatchStream`] -//! with `target_partitions` threads. Parallelism is implementing using multiple -//! [Tokio] [`task`]s, which are executed by threads managed by a tokio Runtime. -//! While tokio is most commonly used -//! for asynchronous network I/O, its combination of an efficient, work-stealing -//! scheduler, first class compiler support for automatic continuation generation, -//! and exceptional performance makes it a compelling choice for CPU intensive -//! applications as well. This is explained in more detail in [Using Rustlang’s Async Tokio -//! Runtime for CPU-Bound Tasks]. +//! DataFusion is a "streaming" query engine which means `ExecutionPlan`s incrementally +//! read from their input(s) and compute output one [`RecordBatch`] at a time +//! by continually polling [`SendableRecordBatchStream`]s. Output and +//! intermediate `RecordBatch`s each have approximately `batch_size` rows, +//! which amortizes per-batch overhead of execution. +//! +//! Note that certain operations, sometimes called "pipeline breakers", +//! (for example full sorts or hash aggregations) are fundamentally non streaming and +//! must read their input fully before producing **any** output. As much as possible, +//! other operators read a single [`RecordBatch`] from their input to produce a +//! single `RecordBatch` as output. +//! +//! For example, given this SQL query: +//! +//! ```sql +//! SELECT date_trunc('month', time) FROM data WHERE id IN (10,20,30); +//! ``` +//! +//! The diagram below shows the call sequence when a consumer calls [`next()`] to +//! get the next `RecordBatch` of output. While it is possible that some +//! steps run on different threads, typically tokio will use the same thread +//! that called `next()` to read from the input, apply the filter, and +//! return the results without interleaving any other operations. This results +//! in excellent cache locality as the same CPU core that produces the data often +//! consumes it immediately as well. +//! +//! ```text +//! +//! Step 3: FilterExec calls next() Step 2: ProjectionExec calls +//! on input Stream next() on input Stream +//! ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +//! │ Step 1: Consumer +//! ▼ ▼ │ calls next() +//! ┏━━━━━━━━━━━━━━┓ ┏━━━━━┻━━━━━━━━━━━━━┓ ┏━━━━━━━━━━━━━━━━━━━━━━━━┓ +//! ┃ ┃ ┃ ┃ ┃ ◀ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +//! ┃ DataSource ┃ ┃ ┃ ┃ ┃ +//! ┃ (e.g. ┃ ┃ FilterExec ┃ ┃ ProjectionExec ┃ +//! ┃ ParquetExec) ┃ ┃id IN (10, 20, 30) ┃ ┃date_bin('month', time) ┃ +//! ┃ ┃ ┃ ┃ ┃ ┣ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ▶ +//! ┃ ┃ ┃ ┃ ┃ ┃ +//! ┗━━━━━━━━━━━━━━┛ ┗━━━━━━━━━━━┳━━━━━━━┛ ┗━━━━━━━━━━━━━━━━━━━━━━━━┛ +//! │ ▲ ▲ Step 6: ProjectionExec +//! │ │ │ computes date_trunc into a +//! └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ new RecordBatch returned +//! ┌─────────────────────┐ ┌─────────────┐ from client +//! │ RecordBatch │ │ RecordBatch │ +//! └─────────────────────┘ └─────────────┘ +//! +//! Step 4: DataSource returns a Step 5: FilterExec returns a new +//! single RecordBatch RecordBatch with only matching rows +//! ``` +//! +//! [`next()`]: futures::StreamExt::next +//! +//! ## Thread Scheduling, CPU / IO Thread Pools, and [Tokio] [`Runtime`]s +//! +//! DataFusion automatically runs each plan with multiple CPU cores using +//! a [Tokio] [`Runtime`] as a thread pool. While tokio is most commonly used +//! for asynchronous network I/O, the combination of an efficient, work-stealing +//! scheduler and first class compiler support for automatic continuation +//! generation (`async`), also makes it a compelling choice for CPU intensive +//! applications as explained in the [Using Rustlang’s Async Tokio +//! Runtime for CPU-Bound Tasks] blog. +//! +//! The number of cores used is determined by the `target_partitions` +//! configuration setting, which defaults to the number of CPU cores. +//! While preparing for execution, DataFusion tries to create this many distinct +//! `async` [`Stream`]s for each `ExecutionPlan`. +//! The `Stream`s for certain `ExecutionPlans`, such as as [`RepartitionExec`] +//! and [`CoalescePartitionsExec`], spawn [Tokio] [`task`]s, that are run by +//! threads managed by the `Runtime`. +//! Many DataFusion `Stream`s perform CPU intensive processing. +//! +//! Using `async` for CPU intensive tasks makes it easy for [`TableProvider`]s +//! to perform network I/O using standard Rust `async` during execution. +//! However, this design also makes it very easy to mix CPU intensive and latency +//! sensitive I/O work on the same thread pool ([`Runtime`]). +//! Using the same (default) `Runtime` is convenient, and often works well for +//! initial development and processing local files, but it can lead to problems +//! under load and/or when reading from network sources such as AWS S3. +//! +//! If your system does not fully utilize either the CPU or network bandwidth +//! during execution, or you see significantly higher tail (e.g. p99) latencies +//! responding to network requests, **it is likely you need to use a different +//! `Runtime` for CPU intensive DataFusion plans**. This effect can be especially +//! pronounced when running several queries concurrently. +//! +//! As shown in the following figure, using the same `Runtime` for both CPU +//! intensive processing and network requests can introduce significant +//! delays in responding to those network requests. Delays in processing network +//! requests can and does lead network flow control to throttle the available +//! bandwidth in response. +//! +//! ```text +//! Legend +//! +//! ┏━━━━━━┓ +//! Processing network request ┃ ┃ CPU bound work +//! is delayed due to processing ┗━━━━━━┛ +//! CPU bound work ┌─┐ +//! │ │ Network request +//! ││ └─┘ processing +//! +//! ││ +//! ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +//! │ │ +//! +//! ▼ ▼ +//! ┌─────────────┐ ┌─┐┌─┐┏━━━━━━━━━━━━━━━━━━━┓┏━━━━━━━━━━━━━━━━━━━┓┌─┐ +//! │ │thread 1 │ ││ │┃ Decoding ┃┃ Filtering ┃│ │ +//! │ │ └─┘└─┘┗━━━━━━━━━━━━━━━━━━━┛┗━━━━━━━━━━━━━━━━━━━┛└─┘ +//! │ │ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ +//! │Tokio Runtime│thread 2 ┃ Decoding ┃ Filtering ┃ Decoding ┃ ... +//! │(thread pool)│ ┗━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ +//! │ │ ... ... +//! │ │ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓┌─┐ ┏━━━━━━━━━━━━━━┓ +//! │ │thread N ┃ Decoding ┃ Filtering ┃│ │ ┃ Decoding ┃ +//! └─────────────┘ ┗━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┛└─┘ ┗━━━━━━━━━━━━━━┛ +//! ─────────────────────────────────────────────────────────────▶ +//! time +//! ``` +//! +//! The bottleneck resulting from network throttling can be avoided +//! by using separate [`Runtime`]s for the different types of work, as shown +//! in the diagram below. +//! +//! ```text +//! A separate thread pool processes network Legend +//! requests, reducing the latency for +//! processing each request ┏━━━━━━┓ +//! ┃ ┃ CPU bound work +//! │ ┗━━━━━━┛ +//! │ ┌─┐ +//! ┌ ─ ─ ─ ─ ┘ │ │ Network request +//! ┌ ─ ─ ─ ┘ └─┘ processing +//! │ +//! ▼ ▼ +//! ┌─────────────┐ ┌─┐┌─┐┌─┐ +//! │ │thread 1 │ ││ ││ │ +//! │ │ └─┘└─┘└─┘ +//! │Tokio Runtime│ ... +//! │(thread pool)│thread 2 +//! │ │ +//! │"IO Runtime" │ ... +//! │ │ ┌─┐ +//! │ │thread N │ │ +//! └─────────────┘ └─┘ +//! ─────────────────────────────────────────────────────────────▶ +//! time +//! +//! ┌─────────────┐ ┏━━━━━━━━━━━━━━━━━━━┓┏━━━━━━━━━━━━━━━━━━━┓ +//! │ │thread 1 ┃ Decoding ┃┃ Filtering ┃ +//! │ │ ┗━━━━━━━━━━━━━━━━━━━┛┗━━━━━━━━━━━━━━━━━━━┛ +//! │Tokio Runtime│ ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ +//! │(thread pool)│thread 2 ┃ Decoding ┃ Filtering ┃ Decoding ┃ ... +//! │ │ ┗━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ +//! │ CPU Runtime │ ... ... +//! │ │ ┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┓ +//! │ │thread N ┃ Decoding ┃ Filtering ┃ Decoding ┃ +//! └─────────────┘ ┗━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━━━━━━┻━━━━━━━━━━━━━━┛ +//! ─────────────────────────────────────────────────────────────▶ +//! time +//!``` +//! +//! Note that DataFusion does not use [`tokio::task::spawn_blocking`] for +//! CPU-bounded work, because `spawn_blocking` is designed for blocking **IO**, +//! not designed CPU bound tasks. Among other challenges, spawned blocking +//! tasks can't yield waiting for input (can't call `await`) so they +//! can't be used to limit the number of concurrent CPU bound tasks or +//! keep the processing pipeline to the same core. //! //! [Tokio]: https://tokio.rs +//! [`Runtime`]: tokio::runtime::Runtime //! [`task`]: tokio::task //! [Using Rustlang’s Async Tokio Runtime for CPU-Bound Tasks]: https://thenewstack.io/using-rustlangs-async-tokio-runtime-for-cpu-bound-tasks/ +//! [`RepartitionExec`]: physical_plan::repartition::RepartitionExec +//! [`CoalescePartitionsExec`]: physical_plan::coalesce_partitions::CoalescePartitionsExec //! //! ## State Management and Configuration //! diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index a9f6f30dc175..000c27effdb6 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -27,7 +27,6 @@ pub mod enforce_sorting; pub mod join_selection; pub mod optimizer; pub mod projection_pushdown; -pub mod pruning; pub mod replace_with_order_preserving_variants; pub mod sanity_checker; #[cfg(test)] diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index c32b4951db44..bff74252df7b 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -480,11 +480,6 @@ fn type_union_resolution_coercion( let new_value_type = type_union_resolution_coercion(value_type, other_type); new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) } - (DataType::List(lhs), DataType::List(rhs)) => { - let new_item_type = - type_union_resolution_coercion(lhs.data_type(), rhs.data_type()); - new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true)))) - } (DataType::Struct(lhs), DataType::Struct(rhs)) => { if lhs.len() != rhs.len() { return None; @@ -529,6 +524,7 @@ fn type_union_resolution_coercion( // Numeric coercion is the same as comparison coercion, both find the narrowest type // that can accommodate both types binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| list_coercion(lhs_type, rhs_type)) .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) @@ -1138,27 +1134,46 @@ fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + let data_types = vec![lhs_field.data_type().clone(), rhs_field.data_type().clone()]; + Some(Arc::new( + (**lhs_field) + .clone() + .with_data_type(type_union_resolution(&data_types)?) + .with_nullable(lhs_field.is_nullable() || rhs_field.is_nullable()), + )) +} + /// Coercion rules for list types. fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (List(_), List(_)) => Some(lhs_type.clone()), - (LargeList(_), List(_)) => Some(lhs_type.clone()), - (List(_), LargeList(_)) => Some(rhs_type.clone()), - (LargeList(_), LargeList(_)) => Some(lhs_type.clone()), - (List(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), - (FixedSizeList(_, _), List(_)) => Some(rhs_type.clone()), // Coerce to the left side FixedSizeList type if the list lengths are the same, // otherwise coerce to list with the left type for dynamic length - (FixedSizeList(lf, ls), FixedSizeList(_, rs)) => { + (FixedSizeList(lhs_field, ls), FixedSizeList(rhs_field, rs)) => { if ls == rs { - Some(lhs_type.clone()) + Some(FixedSizeList( + coerce_list_children(lhs_field, rhs_field)?, + *rs, + )) } else { - Some(List(Arc::clone(lf))) + Some(List(coerce_list_children(lhs_field, rhs_field)?)) } } - (LargeList(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), - (FixedSizeList(_, _), LargeList(_)) => Some(rhs_type.clone()), + // LargeList on any side + ( + LargeList(lhs_field), + List(rhs_field) | LargeList(rhs_field) | FixedSizeList(rhs_field, _), + ) + | (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => { + Some(LargeList(coerce_list_children(lhs_field, rhs_field)?)) + } + // Lists on both sides + (List(lhs_field), List(rhs_field) | FixedSizeList(rhs_field, _)) + | (FixedSizeList(lhs_field, _), List(rhs_field)) => { + Some(List(coerce_list_children(lhs_field, rhs_field)?)) + } _ => None, } } @@ -2105,10 +2120,36 @@ mod tests { DataType::List(Arc::clone(&inner_field)) ); + // Negative test: inner_timestamp_field and inner_field are not compatible because their inner types are not compatible + let inner_timestamp_field = Arc::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )); + let result_type = get_input_types( + &DataType::List(Arc::clone(&inner_field)), + &Operator::Eq, + &DataType::List(Arc::clone(&inner_timestamp_field)), + ); + assert!(result_type.is_err()); + // TODO add other data type Ok(()) } + #[test] + fn test_list_coercion() { + let lhs_type = DataType::List(Arc::new(Field::new("lhs", DataType::Int8, false))); + + let rhs_type = DataType::List(Arc::new(Field::new("rhs", DataType::Int64, true))); + + let coerced_type = list_coercion(&lhs_type, &rhs_type).unwrap(); + assert_eq!( + coerced_type, + DataType::List(Arc::new(Field::new("lhs", DataType::Int64, true))) + ); // nullable because the RHS is nullable + } + #[test] fn test_type_coercion_logical_op() -> Result<()> { test_coercion_binary_rule!( diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 83d35c3d25b1..8490c08a70bb 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -32,7 +32,7 @@ use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::cse::HashNode; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, + Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; use datafusion_common::{ plan_err, Column, DFSchema, HashMap, Result, ScalarValue, TableReference, @@ -351,6 +351,22 @@ impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)> for Expr { } } +impl<'a> TreeNodeContainer<'a, Self> for Expr { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } +} + /// UNNEST expression. #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Unnest { @@ -653,6 +669,24 @@ impl Display for Sort { } } +impl<'a> TreeNodeContainer<'a, Expr> for Sort { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.expr.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.expr + .map_elements(f)? + .map_data(|expr| Ok(Self { expr, ..self })) + } +} + /// Aggregate function /// /// See also [`ExprFunctionExt`] to set these fields on `Expr` diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 93e8b5fd045e..8c64a017988e 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -26,7 +26,10 @@ use std::{ use crate::expr::Sort; use arrow::datatypes::DataType; -use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference}; +use datafusion_common::tree_node::{Transformed, TreeNodeContainer, TreeNodeRecursion}; +use datafusion_common::{ + Constraints, DFSchemaRef, Result, SchemaReference, TableReference, +}; use sqlparser::ast::Ident; /// Various types of DDL (CREATE / DROP) catalog manipulation @@ -487,6 +490,28 @@ pub struct OperateFunctionArg { pub data_type: DataType, pub default_expr: Option, } + +impl<'a> TreeNodeContainer<'a, Expr> for OperateFunctionArg { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.default_expr.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.default_expr.map_elements(f)?.map_data(|default_expr| { + Ok(Self { + default_expr, + ..self + }) + }) + } +} + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct CreateFunctionBody { /// LANGUAGE lang_name @@ -497,6 +522,29 @@ pub struct CreateFunctionBody { pub function_body: Option, } +impl<'a> TreeNodeContainer<'a, Expr> for CreateFunctionBody { + fn apply_elements Result>( + &'a self, + f: F, + ) -> Result { + self.function_body.apply_elements(f) + } + + fn map_elements Result>>( + self, + f: F, + ) -> Result> { + self.function_body + .map_elements(f)? + .map_data(|function_body| { + Ok(Self { + function_body, + ..self + }) + }) + } +} + #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct DropFunction { pub name: String, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6ee99b22c7f3..e9f4f1f80972 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -45,7 +45,9 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, +}; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, @@ -287,6 +289,22 @@ impl Default for LogicalPlan { } } +impl<'a> TreeNodeContainer<'a, Self> for LogicalPlan { + fn apply_elements Result>( + &'a self, + mut f: F, + ) -> Result { + f(self) + } + + fn map_elements Result>>( + self, + mut f: F, + ) -> Result> { + f(self) + } +} + impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index 05e2b1af14d3..26df379f5e4a 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -16,12 +16,10 @@ // under the License. use arrow::datatypes::DataType; -use datafusion_common::tree_node::{Transformed, TreeNodeIterator}; -use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use datafusion_common::{DFSchema, DFSchemaRef}; use std::fmt::{self, Display}; use std::sync::{Arc, OnceLock}; -use super::tree_node::rewrite_arc; use crate::{expr_vec_fmt, Expr, LogicalPlan}; /// Statements have a unchanging empty schema. @@ -80,53 +78,6 @@ impl Statement { } } - /// Rewrites input LogicalPlans in the current `Statement` using `f`. - pub(super) fn map_inputs< - F: FnMut(LogicalPlan) -> Result>, - >( - self, - f: F, - ) -> Result> { - match self { - Statement::Prepare(Prepare { - input, - name, - data_types, - }) => Ok(rewrite_arc(input, f)?.update_data(|input| { - Statement::Prepare(Prepare { - input, - name, - data_types, - }) - })), - _ => Ok(Transformed::no(self)), - } - } - - /// Returns a iterator over all expressions in the current `Statement`. - pub(super) fn expression_iter(&self) -> impl Iterator { - match self { - Statement::Execute(Execute { parameters, .. }) => parameters.iter(), - _ => [].iter(), - } - } - - /// Rewrites all expressions in the current `Statement` using `f`. - pub(super) fn map_expressions Result>>( - self, - f: F, - ) -> Result> { - match self { - Statement::Execute(Execute { name, parameters }) => Ok(parameters - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|parameters| { - Statement::Execute(Execute { parameters, name }) - })), - _ => Ok(Transformed::no(self)), - } - } - /// Return a `format`able structure with the a human readable /// description of this LogicalPlan node per node, not including /// children. diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index e7dfe8791924..6850c30f4f81 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -36,32 +36,30 @@ //! (Re)creation APIs (these require substantial cloning and thus are slow): //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions + use crate::{ dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, - Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, Join, Limit, - LogicalPlan, Partitioning, Projection, RecursiveQuery, Repartition, Sort, Subquery, - SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, Values, Window, + Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, + Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, + Sort, Statement, Subquery, SubqueryAlias, TableScan, Union, Unnest, + UserDefinedLogicalNode, Values, Window, }; +use datafusion_common::tree_node::TreeNodeRefContainer; use recursive::recursive; -use std::ops::Deref; -use std::sync::Arc; use crate::expr::{Exists, InSubquery}; -use crate::tree_node::{transform_sort_option_vec, transform_sort_vec}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, -}; -use datafusion_common::{ - internal_err, map_until_stop_and_collect, DataFusionError, Result, + Transformed, TreeNode, TreeNodeContainer, TreeNodeIterator, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; +use datafusion_common::{internal_err, Result}; impl TreeNode for LogicalPlan { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, f: F, ) -> Result { - self.inputs().into_iter().apply_until_stop(f) + self.inputs().apply_ref_elements(f) } /// Applies `f` to each child (input) of this plan node, rewriting them *in place.* @@ -74,14 +72,14 @@ impl TreeNode for LogicalPlan { /// [`Expr::Exists`]: crate::Expr::Exists fn map_children Result>>( self, - mut f: F, + f: F, ) -> Result> { Ok(match self { LogicalPlan::Projection(Projection { expr, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Projection(Projection { expr, input, @@ -92,7 +90,7 @@ impl TreeNode for LogicalPlan { predicate, input, having, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Filter(Filter { predicate, input, @@ -102,7 +100,7 @@ impl TreeNode for LogicalPlan { LogicalPlan::Repartition(Repartition { input, partitioning_scheme, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -112,7 +110,7 @@ impl TreeNode for LogicalPlan { input, window_expr, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Window(Window { input, window_expr, @@ -124,7 +122,7 @@ impl TreeNode for LogicalPlan { group_expr, aggr_expr, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Aggregate(Aggregate { input, group_expr, @@ -132,7 +130,8 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => rewrite_arc(input, f)? + LogicalPlan::Sort(Sort { expr, input, fetch }) => input + .map_elements(f)? .update_data(|input| LogicalPlan::Sort(Sort { expr, input, fetch })), LogicalPlan::Join(Join { left, @@ -143,12 +142,7 @@ impl TreeNode for LogicalPlan { join_constraint, schema, null_equals_null, - }) => map_until_stop_and_collect!( - rewrite_arc(left, &mut f), - right, - rewrite_arc(right, &mut f) - )? - .update_data(|(left, right)| { + }) => (left, right).map_elements(f)?.update_data(|(left, right)| { LogicalPlan::Join(Join { left, right, @@ -160,12 +154,13 @@ impl TreeNode for LogicalPlan { null_equals_null, }) }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => rewrite_arc(input, f)? + LogicalPlan::Limit(Limit { skip, fetch, input }) => input + .map_elements(f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, - }) => rewrite_arc(subquery, f)?.update_data(|subquery| { + }) => subquery.map_elements(f)?.update_data(|subquery| { LogicalPlan::Subquery(Subquery { subquery, outer_ref_columns, @@ -175,7 +170,7 @@ impl TreeNode for LogicalPlan { input, alias, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, @@ -184,17 +179,18 @@ impl TreeNode for LogicalPlan { }), LogicalPlan::Extension(extension) => rewrite_extension_inputs(extension, f)? .update_data(LogicalPlan::Extension), - LogicalPlan::Union(Union { inputs, schema }) => rewrite_arcs(inputs, f)? + LogicalPlan::Union(Union { inputs, schema }) => inputs + .map_elements(f)? .update_data(|inputs| LogicalPlan::Union(Union { inputs, schema })), LogicalPlan::Distinct(distinct) => match distinct { - Distinct::All(input) => rewrite_arc(input, f)?.update_data(Distinct::All), + Distinct::All(input) => input.map_elements(f)?.update_data(Distinct::All), Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { Distinct::On(DistinctOn { on_expr, select_expr, @@ -211,7 +207,7 @@ impl TreeNode for LogicalPlan { stringified_plans, schema, logical_optimization_succeeded, - }) => rewrite_arc(plan, f)?.update_data(|plan| { + }) => plan.map_elements(f)?.update_data(|plan| { LogicalPlan::Explain(Explain { verbose, plan, @@ -224,7 +220,7 @@ impl TreeNode for LogicalPlan { verbose, input, schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Analyze(Analyze { verbose, input, @@ -237,7 +233,7 @@ impl TreeNode for LogicalPlan { op, input, output_schema, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Dml(DmlStatement { table_name, table_schema, @@ -252,7 +248,7 @@ impl TreeNode for LogicalPlan { partition_by, file_type, options, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Copy(CopyTo { input, output_url, @@ -271,7 +267,7 @@ impl TreeNode for LogicalPlan { or_replace, column_defaults, temporary, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { DdlStatement::CreateMemoryTable(CreateMemoryTable { name, constraints, @@ -288,7 +284,7 @@ impl TreeNode for LogicalPlan { or_replace, definition, temporary, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { DdlStatement::CreateView(CreateView { name, input, @@ -318,7 +314,7 @@ impl TreeNode for LogicalPlan { dependency_indices, schema, options, - }) => rewrite_arc(input, f)?.update_data(|input| { + }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Unnest(Unnest { input, exec_columns: input_columns, @@ -334,22 +330,24 @@ impl TreeNode for LogicalPlan { static_term, recursive_term, is_distinct, - }) => map_until_stop_and_collect!( - rewrite_arc(static_term, &mut f), - recursive_term, - rewrite_arc(recursive_term, &mut f) - )? - .update_data(|(static_term, recursive_term)| { - LogicalPlan::RecursiveQuery(RecursiveQuery { - name, - static_term, - recursive_term, - is_distinct, - }) - }), - LogicalPlan::Statement(stmt) => { - stmt.map_inputs(f)?.update_data(LogicalPlan::Statement) + }) => (static_term, recursive_term).map_elements(f)?.update_data( + |(static_term, recursive_term)| { + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term, + recursive_term, + is_distinct, + }) + }, + ), + LogicalPlan::Statement(stmt) => match stmt { + Statement::Prepare(p) => p + .input + .map_elements(f)? + .update_data(|input| Statement::Prepare(Prepare { input, ..p })), + _ => Transformed::no(stmt), } + .update_data(LogicalPlan::Statement), // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::EmptyRelation { .. } @@ -359,24 +357,6 @@ impl TreeNode for LogicalPlan { } } -/// Applies `f` to rewrite a `Arc` without copying, if possible -pub(super) fn rewrite_arc Result>>( - plan: Arc, - mut f: F, -) -> Result>> { - f(Arc::unwrap_or_clone(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan))) -} - -/// rewrite a `Vec` of `Arc` without copying, if possible -fn rewrite_arcs Result>>( - input_plans: Vec>, - mut f: F, -) -> Result>>> { - input_plans - .into_iter() - .map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f)) -} - /// Rewrites all inputs for an Extension node "in place" /// (it currently has to copy values because there are no APIs for in place modification) /// @@ -423,54 +403,40 @@ impl LogicalPlan { mut f: F, ) -> Result { match self { - LogicalPlan::Projection(Projection { expr, .. }) => { - expr.iter().apply_until_stop(f) - } - LogicalPlan::Values(Values { values, .. }) => values - .iter() - .apply_until_stop(|value| value.iter().apply_until_stop(&mut f)), + LogicalPlan::Projection(Projection { expr, .. }) => expr.apply_elements(f), + LogicalPlan::Values(Values { values, .. }) => values.apply_elements(f), LogicalPlan::Filter(Filter { predicate, .. }) => f(predicate), LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { Partitioning::Hash(expr, _) | Partitioning::DistributeBy(expr) => { - expr.iter().apply_until_stop(f) + expr.apply_elements(f) } Partitioning::RoundRobinBatch(_) => Ok(TreeNodeRecursion::Continue), }, LogicalPlan::Window(Window { window_expr, .. }) => { - window_expr.iter().apply_until_stop(f) + window_expr.apply_elements(f) } LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. - }) => group_expr - .iter() - .chain(aggr_expr.iter()) - .apply_until_stop(f), + }) => (group_expr, aggr_expr).apply_ref_elements(f), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. // 2. the second part is non-equijoin(filter). LogicalPlan::Join(Join { on, filter, .. }) => { - on.iter() - // TODO: why we need to create an `Expr::eq`? Cloning `Expr` is costly... - // it not ideal to create an expr here to analyze them, but could cache it on the Join itself - .map(|(l, r)| Expr::eq(l.clone(), r.clone())) - .apply_until_stop(|e| f(&e))? - .visit_sibling(|| filter.iter().apply_until_stop(f)) - } - LogicalPlan::Sort(Sort { expr, .. }) => { - expr.iter().apply_until_stop(|sort| f(&sort.expr)) + (on, filter).apply_ref_elements(f) } + LogicalPlan::Sort(Sort { expr, .. }) => expr.apply_elements(f), LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - extension.node.expressions().iter().apply_until_stop(f) + extension.node.expressions().apply_elements(f) } LogicalPlan::TableScan(TableScan { filters, .. }) => { - filters.iter().apply_until_stop(f) + filters.apply_elements(f) } LogicalPlan::Unnest(unnest) => { let columns = unnest.exec_columns.clone(); @@ -479,24 +445,23 @@ impl LogicalPlan { .iter() .map(|c| Expr::Column(c.clone())) .collect::>(); - exprs.iter().apply_until_stop(f) + exprs.apply_elements(f) } LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, .. - })) => on_expr - .iter() - .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten().map(|sort| &sort.expr)) - .apply_until_stop(f), - LogicalPlan::Limit(Limit { skip, fetch, .. }) => skip - .iter() - .chain(fetch.iter()) - .map(|e| e.deref()) - .apply_until_stop(f), - LogicalPlan::Statement(stmt) => stmt.expression_iter().apply_until_stop(f), + })) => (on_expr, select_expr, sort_expr).apply_ref_elements(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + (skip, fetch).apply_ref_elements(f) + } + LogicalPlan::Statement(stmt) => match stmt { + Statement::Execute(Execute { parameters, .. }) => { + parameters.apply_elements(f) + } + _ => Ok(TreeNodeRecursion::Continue), + }, // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) @@ -529,21 +494,15 @@ impl LogicalPlan { expr, input, schema, - }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) - }), + }) => expr.map_elements(f)?.update_data(|expr| { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) + }), LogicalPlan::Values(Values { schema, values }) => values - .into_iter() - .map_until_stop_and_collect(|value| { - value.into_iter().map_until_stop_and_collect(&mut f) - })? + .map_elements(f)? .update_data(|values| LogicalPlan::Values(Values { schema, values })), LogicalPlan::Filter(Filter { predicate, @@ -561,12 +520,10 @@ impl LogicalPlan { partitioning_scheme, }) => match partitioning_scheme { Partitioning::Hash(expr, usize) => expr - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(|expr| Partitioning::Hash(expr, usize)), Partitioning::DistributeBy(expr) => expr - .into_iter() - .map_until_stop_and_collect(f)? + .map_elements(f)? .update_data(Partitioning::DistributeBy), Partitioning::RoundRobinBatch(_) => Transformed::no(partitioning_scheme), } @@ -580,34 +537,28 @@ impl LogicalPlan { input, window_expr, schema, - }) => window_expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|window_expr| { - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) - }), + }) => window_expr.map_elements(f)?.update_data(|window_expr| { + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) + }), LogicalPlan::Aggregate(Aggregate { input, group_expr, aggr_expr, schema, - }) => map_until_stop_and_collect!( - group_expr.into_iter().map_until_stop_and_collect(&mut f), - aggr_expr, - aggr_expr.into_iter().map_until_stop_and_collect(&mut f) - )? - .update_data(|(group_expr, aggr_expr)| { - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) - }), + }) => (group_expr, aggr_expr).map_elements(f)?.update_data( + |(group_expr, aggr_expr)| { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) + }, + ), // There are two part of expression for join, equijoin(on) and non-equijoin(filter). // 1. the first part is `on.len()` equijoin expressions, and the struct of each expr is `left-on = right-on`. @@ -621,16 +572,7 @@ impl LogicalPlan { join_constraint, schema, null_equals_null, - }) => map_until_stop_and_collect!( - on.into_iter().map_until_stop_and_collect( - |on| map_until_stop_and_collect!(f(on.0), on.1, f(on.1)) - ), - filter, - filter.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }) - )? - .update_data(|(on, filter)| { + }) => (on, filter).map_elements(f)?.update_data(|(on, filter)| { LogicalPlan::Join(Join { left, right, @@ -642,17 +584,13 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => { - transform_sort_vec(expr, &mut f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })) - } + LogicalPlan::Sort(Sort { expr, input, fetch }) => expr + .map_elements(f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), LogicalPlan::Extension(Extension { node }) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - let exprs = node - .expressions() - .into_iter() - .map_until_stop_and_collect(f)?; + let exprs = node.expressions().map_elements(f)?; let plan = LogicalPlan::Extension(Extension { node: UserDefinedLogicalNode::with_exprs_and_inputs( node.as_ref(), @@ -669,64 +607,47 @@ impl LogicalPlan { projected_schema, filters, fetch, - }) => filters - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|filters| { - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - projected_schema, - filters, - fetch, - }) - }), + }) => filters.map_elements(f)?.update_data(|filters| { + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) + }), LogicalPlan::Distinct(Distinct::On(DistinctOn { on_expr, select_expr, sort_expr, input, schema, - })) => map_until_stop_and_collect!( - on_expr.into_iter().map_until_stop_and_collect(&mut f), - select_expr, - select_expr.into_iter().map_until_stop_and_collect(&mut f), - sort_expr, - transform_sort_option_vec(sort_expr, &mut f) - )? - .update_data(|(on_expr, select_expr, sort_expr)| { - LogicalPlan::Distinct(Distinct::On(DistinctOn { - on_expr, - select_expr, - sort_expr, - input, - schema, - })) - }), - LogicalPlan::Limit(Limit { skip, fetch, input }) => { - let skip = skip.map(|e| *e); - let fetch = fetch.map(|e| *e); - map_until_stop_and_collect!( - skip.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }), - fetch, - fetch.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { - Ok(f(e)?.update_data(Some)) - }) - )? - .update_data(|(skip, fetch)| { - LogicalPlan::Limit(Limit { - skip: skip.map(Box::new), - fetch: fetch.map(Box::new), + })) => (on_expr, select_expr, sort_expr) + .map_elements(f)? + .update_data(|(on_expr, select_expr, sort_expr)| { + LogicalPlan::Distinct(Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, input, - }) + schema, + })) + }), + LogicalPlan::Limit(Limit { skip, fetch, input }) => { + (skip, fetch).map_elements(f)?.update_data(|(skip, fetch)| { + LogicalPlan::Limit(Limit { skip, fetch, input }) }) } - LogicalPlan::Statement(stmt) => { - stmt.map_expressions(f)?.update_data(LogicalPlan::Statement) + LogicalPlan::Statement(stmt) => match stmt { + Statement::Execute(e) => { + e.parameters.map_elements(f)?.update_data(|parameters| { + Statement::Execute(Execute { parameters, ..e }) + }) + } + _ => Transformed::no(stmt), } + .update_data(LogicalPlan::Statement), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Unnest(_) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index e964091aae66..eacace5ed046 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -19,14 +19,14 @@ use crate::expr::{ AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, + InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, }; -use datafusion_common::{map_until_stop_and_collect, Result}; +use datafusion_common::Result; /// Implementation of the [`TreeNode`] trait /// @@ -42,9 +42,9 @@ impl TreeNode for Expr { &'n self, f: F, ) -> Result { - let children = match self { - Expr::Alias(Alias{expr,..}) - | Expr::Unnest(Unnest{expr}) + match self { + Expr::Alias(Alias { expr, .. }) + | Expr::Unnest(Unnest { expr }) | Expr::Not(expr) | Expr::IsNotNull(expr) | Expr::IsTrue(expr) @@ -57,78 +57,50 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()], + | Expr::InSubquery(InSubquery { expr, .. }) => expr.apply_elements(f), Expr::GroupingSet(GroupingSet::Rollup(exprs)) - | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), - Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { - args.iter().collect() + | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.apply_elements(f), + Expr::ScalarFunction(ScalarFunction { args, .. }) => { + args.apply_elements(f) } Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { - lists_of_exprs.iter().flatten().collect() + lists_of_exprs.apply_elements(f) } Expr::Column(_) // Treat OuterReferenceColumn as a leaf expression | Expr::OuterReferenceColumn(_, _) | Expr::ScalarVariable(_, _) | Expr::Literal(_) - | Expr::Exists {..} + | Expr::Exists { .. } | Expr::ScalarSubquery(_) - | Expr::Wildcard {..} - | Expr::Placeholder (_) => vec![], + | Expr::Wildcard { .. } + | Expr::Placeholder(_) => Ok(TreeNodeRecursion::Continue), Expr::BinaryExpr(BinaryExpr { left, right, .. }) => { - vec![left.as_ref(), right.as_ref()] + (left, right).apply_ref_elements(f) } Expr::Like(Like { expr, pattern, .. }) | Expr::SimilarTo(Like { expr, pattern, .. }) => { - vec![expr.as_ref(), pattern.as_ref()] + (expr, pattern).apply_ref_elements(f) } Expr::Between(Between { - expr, low, high, .. - }) => vec![expr.as_ref(), low.as_ref(), high.as_ref()], - Expr::Case(case) => { - let mut expr_vec = vec![]; - if let Some(expr) = case.expr.as_ref() { - expr_vec.push(expr.as_ref()); - }; - for (when, then) in case.when_then_expr.iter() { - expr_vec.push(when.as_ref()); - expr_vec.push(then.as_ref()); - } - if let Some(else_expr) = case.else_expr.as_ref() { - expr_vec.push(else_expr.as_ref()); - } - expr_vec - } - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - => { - let mut expr_vec = args.iter().collect::>(); - if let Some(f) = filter { - expr_vec.push(f.as_ref()); - } - if let Some(order_by) = order_by { - expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); - } - expr_vec - } + expr, low, high, .. + }) => (expr, low, high).apply_ref_elements(f), + Expr::Case(Case { expr, when_then_expr, else_expr }) => + (expr, when_then_expr, else_expr).apply_ref_elements(f), + Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => + (args, filter, order_by).apply_ref_elements(f), Expr::WindowFunction(WindowFunction { - args, - partition_by, - order_by, - .. - }) => { - let mut expr_vec = args.iter().collect::>(); - expr_vec.extend(partition_by); - expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); - expr_vec + args, + partition_by, + order_by, + .. + }) => { + (args, partition_by, order_by).apply_ref_elements(f) } Expr::InList(InList { expr, list, .. }) => { - let mut expr_vec = vec![expr.as_ref()]; - expr_vec.extend(list); - expr_vec + (expr, list).apply_ref_elements(f) } - }; - - children.into_iter().apply_until_stop(f) + } } /// Maps each child of `self` using the provided closure `f`. @@ -148,137 +120,103 @@ impl TreeNode for Expr { | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) | Expr::Literal(_) => Transformed::no(self), - Expr::Unnest(Unnest { expr, .. }) => transform_box(expr, &mut f)? - .update_data(|be| Expr::Unnest(Unnest::new_boxed(be))), + Expr::Unnest(Unnest { expr, .. }) => expr + .map_elements(f)? + .update_data(|expr| Expr::Unnest(Unnest { expr })), Expr::Alias(Alias { expr, relation, name, - }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e, relation, name))), + }) => f(*expr)?.update_data(|e| e.alias_qualified(relation, name)), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => transform_box(expr, &mut f)?.update_data(|be| { + }) => expr.map_elements(f)?.update_data(|be| { Expr::InSubquery(InSubquery::new(be, subquery, negated)) }), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - map_until_stop_and_collect!( - transform_box(left, &mut f), - right, - transform_box(right, &mut f) - )? + Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, right) + .map_elements(f)? .update_data(|(new_left, new_right)| { Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) - }) - } + }), Expr::Like(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - pattern, - transform_box(pattern, &mut f) - )? - .update_data(|(new_expr, new_pattern)| { - Expr::Like(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), + }) => { + (expr, pattern) + .map_elements(f)? + .update_data(|(new_expr, new_pattern)| { + Expr::Like(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }) + } Expr::SimilarTo(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - pattern, - transform_box(pattern, &mut f) - )? - .update_data(|(new_expr, new_pattern)| { - Expr::SimilarTo(Like::new( - negated, - new_expr, - new_pattern, - escape_char, - case_insensitive, - )) - }), - Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not), - Expr::IsNotNull(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotNull) - } - Expr::IsNull(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsNull), - Expr::IsTrue(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsTrue), - Expr::IsFalse(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsFalse) - } - Expr::IsUnknown(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsUnknown) - } - Expr::IsNotTrue(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue) - } - Expr::IsNotFalse(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse) + }) => { + (expr, pattern) + .map_elements(f)? + .update_data(|(new_expr, new_pattern)| { + Expr::SimilarTo(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }) } + Expr::Not(expr) => expr.map_elements(f)?.update_data(Expr::Not), + Expr::IsNotNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNotNull), + Expr::IsNull(expr) => expr.map_elements(f)?.update_data(Expr::IsNull), + Expr::IsTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsTrue), + Expr::IsFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsFalse), + Expr::IsUnknown(expr) => expr.map_elements(f)?.update_data(Expr::IsUnknown), + Expr::IsNotTrue(expr) => expr.map_elements(f)?.update_data(Expr::IsNotTrue), + Expr::IsNotFalse(expr) => expr.map_elements(f)?.update_data(Expr::IsNotFalse), Expr::IsNotUnknown(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown) - } - Expr::Negative(expr) => { - transform_box(expr, &mut f)?.update_data(Expr::Negative) + expr.map_elements(f)?.update_data(Expr::IsNotUnknown) } + Expr::Negative(expr) => expr.map_elements(f)?.update_data(Expr::Negative), Expr::Between(Between { expr, negated, low, high, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - low, - transform_box(low, &mut f), - high, - transform_box(high, &mut f) - )? - .update_data(|(new_expr, new_low, new_high)| { - Expr::Between(Between::new(new_expr, negated, new_low, new_high)) - }), + }) => (expr, low, high).map_elements(f)?.update_data( + |(new_expr, new_low, new_high)| { + Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + }, + ), Expr::Case(Case { expr, when_then_expr, else_expr, - }) => map_until_stop_and_collect!( - transform_option_box(expr, &mut f), - when_then_expr, - when_then_expr - .into_iter() - .map_until_stop_and_collect(|(when, then)| { - map_until_stop_and_collect!( - transform_box(when, &mut f), - then, - transform_box(then, &mut f) - ) - }), - else_expr, - transform_option_box(else_expr, &mut f) - )? - .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { - Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) - }), - Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? + }) => (expr, when_then_expr, else_expr) + .map_elements(f)? + .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { + Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + }), + Expr::Cast(Cast { expr, data_type }) => expr + .map_elements(f)? .update_data(|be| Expr::Cast(Cast::new(be, data_type))), - Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? + Expr::TryCast(TryCast { expr, data_type }) => expr + .map_elements(f)? .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), Expr::ScalarFunction(ScalarFunction { func, args }) => { - transform_vec(args, &mut f)?.map_data(|new_args| { + args.map_elements(f)?.map_data(|new_args| { Ok(Expr::ScalarFunction(ScalarFunction::new_udf( func, new_args, ))) @@ -291,22 +229,17 @@ impl TreeNode for Expr { order_by, window_frame, null_treatment, - }) => map_until_stop_and_collect!( - transform_vec(args, &mut f), - partition_by, - transform_vec(partition_by, &mut f), - order_by, - transform_sort_vec(order_by, &mut f) - )? - .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new(fun, new_args)) - .partition_by(new_partition_by) - .order_by(new_order_by) - .window_frame(window_frame) - .null_treatment(null_treatment) - .build() - .unwrap() - }), + }) => (args, partition_by, order_by).map_elements(f)?.update_data( + |(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() + }, + ), Expr::AggregateFunction(AggregateFunction { args, func, @@ -314,31 +247,27 @@ impl TreeNode for Expr { filter, order_by, null_treatment, - }) => map_until_stop_and_collect!( - transform_vec(args, &mut f), - filter, - transform_option_box(filter, &mut f), - order_by, - transform_sort_option_vec(order_by, &mut f) - )? - .map_data(|(new_args, new_filter, new_order_by)| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - func, - new_args, - distinct, - new_filter, - new_order_by, - null_treatment, - ))) - })?, + }) => (args, filter, order_by).map_elements(f)?.map_data( + |(new_args, new_filter, new_order_by)| { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func, + new_args, + distinct, + new_filter, + new_order_by, + null_treatment, + ))) + }, + )?, Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? + GroupingSet::Rollup(exprs) => exprs + .map_elements(f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), - GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)? + GroupingSet::Cube(exprs) => exprs + .map_elements(f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs - .into_iter() - .map_until_stop_and_collect(|exprs| transform_vec(exprs, &mut f))? + .map_elements(f)? .update_data(|new_lists_of_exprs| { Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) }), @@ -347,70 +276,11 @@ impl TreeNode for Expr { expr, list, negated, - }) => map_until_stop_and_collect!( - transform_box(expr, &mut f), - list, - transform_vec(list, &mut f) - )? - .update_data(|(new_expr, new_list)| { - Expr::InList(InList::new(new_expr, new_list, negated)) - }), + }) => (expr, list) + .map_elements(f)? + .update_data(|(new_expr, new_list)| { + Expr::InList(InList::new(new_expr, new_list, negated)) + }), }) } } - -/// Transforms a boxed expression by applying the provided closure `f`. -fn transform_box Result>>( - be: Box, - f: &mut F, -) -> Result>> { - Ok(f(*be)?.update_data(Box::new)) -} - -/// Transforms an optional boxed expression by applying the provided closure `f`. -fn transform_option_box Result>>( - obe: Option>, - f: &mut F, -) -> Result>>> { - obe.map_or(Ok(Transformed::no(None)), |be| { - Ok(transform_box(be, f)?.update_data(Some)) - }) -} - -/// &mut transform a Option<`Vec` of `Expr`s> -pub fn transform_option_vec Result>>( - ove: Option>, - f: &mut F, -) -> Result>>> { - ove.map_or(Ok(Transformed::no(None)), |ve| { - Ok(transform_vec(ve, f)?.update_data(Some)) - }) -} - -/// &mut transform a `Vec` of `Expr`s -fn transform_vec Result>>( - ve: Vec, - f: &mut F, -) -> Result>> { - ve.into_iter().map_until_stop_and_collect(f) -} - -/// Transforms an optional vector of sort expressions by applying the provided closure `f`. -pub fn transform_sort_option_vec Result>>( - sorts_option: Option>, - f: &mut F, -) -> Result>>> { - sorts_option.map_or(Ok(Transformed::no(None)), |sorts| { - Ok(transform_sort_vec(sorts, f)?.update_data(Some)) - }) -} - -/// Transforms an vector of sort expressions by applying the provided closure `f`. -pub fn transform_sort_vec Result>>( - sorts: Vec, - f: &mut F, -) -> Result>> { - sorts.into_iter().map_until_stop_and_collect(|s| { - Ok(f(s.expr)?.update_data(|e| Sort { expr: e, ..s })) - }) -} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index c22ee244fe28..6f7c5d379260 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -29,7 +29,7 @@ use crate::{ }; use datafusion_expr_common::signature::{Signature, TypeSignature}; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -958,7 +958,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( /// Can this data type be used in hash join equal conditions?? /// Data types here come from function 'equal_rows', if more data types are supported -/// in equal_rows(hash join), add those data types here to generate join logical plan. +/// in create_hashes, add those data types here to generate join logical plan. pub fn can_hash(data_type: &DataType) -> bool { match data_type { DataType::Null => true, @@ -971,31 +971,38 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::UInt16 => true, DataType::UInt32 => true, DataType::UInt64 => true, + DataType::Float16 => true, DataType::Float32 => true, DataType::Float64 => true, - DataType::Timestamp(time_unit, _) => match time_unit { - TimeUnit::Second => true, - TimeUnit::Millisecond => true, - TimeUnit::Microsecond => true, - TimeUnit::Nanosecond => true, - }, + DataType::Decimal128(_, _) => true, + DataType::Decimal256(_, _) => true, + DataType::Timestamp(_, _) => true, DataType::Utf8 => true, DataType::LargeUtf8 => true, DataType::Utf8View => true, - DataType::Decimal128(_, _) => true, + DataType::Binary => true, + DataType::LargeBinary => true, + DataType::BinaryView => true, DataType::Date32 => true, DataType::Date64 => true, + DataType::Time32(_) => true, + DataType::Time64(_) => true, + DataType::Duration(_) => true, + DataType::Interval(_) => true, DataType::FixedSizeBinary(_) => true, - DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - DataType::is_dictionary_key_type(key_type) + DataType::Dictionary(key_type, value_type) => { + DataType::is_dictionary_key_type(key_type) && can_hash(value_type) } - DataType::List(_) => true, - DataType::LargeList(_) => true, - DataType::FixedSizeList(_, _) => true, + DataType::List(value_type) => can_hash(value_type.data_type()), + DataType::LargeList(value_type) => can_hash(value_type.data_type()), + DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()), + DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()), DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())), - _ => false, + + DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::Union(_, _) + | DataType::RunEndEncoded(_, _) => false, } } @@ -1403,6 +1410,7 @@ mod tests { test::function_stub::max_udaf, test::function_stub::min_udaf, test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, }; + use arrow::datatypes::{UnionFields, UnionMode}; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { @@ -1805,4 +1813,21 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn test_can_hash() { + let union_fields: UnionFields = [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + ] + .into_iter() + .collect(); + + let union_type = DataType::Union(union_fields, UnionMode::Sparse); + assert!(!can_hash(&union_type)); + + let list_union_type = + DataType::List(Arc::new(Field::new("my_union", union_type, true))); + assert!(!can_hash(&list_union_type)); + } } diff --git a/datafusion/functions-aggregate-common/Cargo.toml b/datafusion/functions-aggregate-common/Cargo.toml index 9b299c1a11d7..664746808fb4 100644 --- a/datafusion/functions-aggregate-common/Cargo.toml +++ b/datafusion/functions-aggregate-common/Cargo.toml @@ -43,3 +43,10 @@ datafusion-common = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } rand = { workspace = true } + +[dev-dependencies] +criterion = "0.5" + +[[bench]] +harness = false +name = "accumulate" diff --git a/datafusion/functions-aggregate-common/benches/accumulate.rs b/datafusion/functions-aggregate-common/benches/accumulate.rs new file mode 100644 index 000000000000..f422f8a2a7bf --- /dev/null +++ b/datafusion/functions-aggregate-common/benches/accumulate.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray, Int64Array}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; + +fn generate_group_indices(len: usize) -> Vec { + (0..len).collect() +} + +fn generate_values(len: usize, has_null: bool) -> ArrayRef { + if has_null { + let values = (0..len) + .map(|i| if i % 7 == 0 { None } else { Some(i as i64) }) + .collect::>(); + Arc::new(Int64Array::from(values)) + } else { + let values = (0..len).map(|i| Some(i as i64)).collect::>(); + Arc::new(Int64Array::from(values)) + } +} + +fn generate_filter(len: usize) -> Option { + let values = (0..len) + .map(|i| { + if i % 7 == 0 { + None + } else if i % 5 == 0 { + Some(false) + } else { + Some(true) + } + }) + .collect::>(); + Some(BooleanArray::from(values)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let len = 500_000; + let group_indices = generate_group_indices(len); + let rows_count = group_indices.len(); + let values = generate_values(len, true); + let opt_filter = generate_filter(len); + let mut counts: Vec = vec![0; rows_count]; + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + opt_filter.as_ref(), + |group_index| { + counts[group_index] += 1; + }, + ); + + c.bench_function("Handle both nulls and filter", |b| { + b.iter(|| { + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + opt_filter.as_ref(), + |group_index| { + counts[group_index] += 1; + }, + ); + }) + }); + + c.bench_function("Handle nulls only", |b| { + b.iter(|| { + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + None, + |group_index| { + counts[group_index] += 1; + }, + ); + }) + }); + + let values = generate_values(len, false); + c.bench_function("Handle filter only", |b| { + b.iter(|| { + accumulate_indices( + &group_indices, + values.logical_nulls().as_ref(), + opt_filter.as_ref(), + |group_index| { + counts[group_index] += 1; + }, + ); + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 3efd348937ed..ac4d0e75535e 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -395,19 +395,41 @@ pub fn accumulate_indices( } } (None, Some(filter)) => { - assert_eq!(filter.len(), group_indices.len()); - // The performance with a filter could be improved by - // iterating over the filter in chunks, rather than a single - // iterator. TODO file a ticket - let iter = group_indices.iter().zip(filter.iter()); - for (&group_index, filter_value) in iter { - if let Some(true) = filter_value { - index_fn(group_index) - } - } + debug_assert_eq!(filter.len(), group_indices.len()); + let group_indices_chunks = group_indices.chunks_exact(64); + let bit_chunks = filter.values().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + + group_indices_chunks.zip(bit_chunks.iter()).for_each( + |(group_index_chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real vale + let is_valid = (mask & index_mask) != 0; + if is_valid { + index_fn(group_index); + } + index_mask <<= 1; + }) + }, + ); + + // handle any remaining bits (after the initial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .enumerate() + .for_each(|(i, &group_index)| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + index_fn(group_index) + } + }); } (Some(valids), None) => { - assert_eq!(valids.len(), group_indices.len()); + debug_assert_eq!(valids.len(), group_indices.len()); // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum // iterate over in chunks of 64 bits for more efficient null checking let group_indices_chunks = group_indices.chunks_exact(64); @@ -444,20 +466,44 @@ pub fn accumulate_indices( } (Some(valids), Some(filter)) => { - assert_eq!(filter.len(), group_indices.len()); - assert_eq!(valids.len(), group_indices.len()); - // The performance with a filter could likely be improved by - // iterating over the filter in chunks, rather than using - // iterators. TODO file a ticket - filter + debug_assert_eq!(filter.len(), group_indices.len()); + debug_assert_eq!(valids.len(), group_indices.len()); + + let group_indices_chunks = group_indices.chunks_exact(64); + let valid_bit_chunks = valids.inner().bit_chunks(); + let filter_bit_chunks = filter.values().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + + group_indices_chunks + .zip(valid_bit_chunks.iter()) + .zip(filter_bit_chunks.iter()) + .for_each(|((group_index_chunk, valid_mask), filter_mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().for_each(|&group_index| { + // valid bit was set, real vale + let is_valid = (valid_mask & filter_mask & index_mask) != 0; + if is_valid { + index_fn(group_index); + } + index_mask <<= 1; + }) + }); + + // handle any remaining bits (after the initial 64) + let remainder_valid_bits = valid_bit_chunks.remainder_bits(); + let remainder_filter_bits = filter_bit_chunks.remainder_bits(); + group_indices_remainder .iter() - .zip(group_indices.iter()) - .zip(valids.iter()) - .for_each(|((filter_value, &group_index), is_valid)| { - if let (Some(true), true) = (filter_value, is_valid) { + .enumerate() + .for_each(|(i, &group_index)| { + let is_valid = + remainder_valid_bits & remainder_filter_bits & (1 << i) != 0; + if is_valid { index_fn(group_index) } - }) + }); } } } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 52181372698f..8fdd702b5b7c 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -467,7 +467,8 @@ impl GroupsAccumulator for CountGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&BooleanArray>, + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + _opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "one argument to merge_batch"); @@ -480,22 +481,11 @@ impl GroupsAccumulator for CountGroupsAccumulator { // Adds the counts with the partial counts self.counts.resize(total_num_groups, 0); - match opt_filter { - Some(filter) => filter - .iter() - .zip(group_indices.iter()) - .zip(partial_counts.iter()) - .for_each(|((filter_value, &group_index), partial_count)| { - if let Some(true) = filter_value { - self.counts[group_index] += partial_count; - } - }), - None => group_indices.iter().zip(partial_counts.iter()).for_each( - |(&group_index, partial_count)| { - self.counts[group_index] += partial_count; - }, - ), - } + group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ); Ok(()) } diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index 8daa85a5cc83..55d4181a96df 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -460,7 +460,7 @@ impl VarianceGroupsAccumulator { counts: &UInt64Array, means: &Float64Array, m2s: &Float64Array, - opt_filter: Option<&BooleanArray>, + _opt_filter: Option<&BooleanArray>, mut value_fn: F, ) where F: FnMut(usize, u64, f64, f64) + Send, @@ -469,33 +469,14 @@ impl VarianceGroupsAccumulator { assert_eq!(means.null_count(), 0); assert_eq!(m2s.null_count(), 0); - match opt_filter { - None => { - group_indices - .iter() - .zip(counts.values().iter()) - .zip(means.values().iter()) - .zip(m2s.values().iter()) - .for_each(|(((&group_index, &count), &mean), &m2)| { - value_fn(group_index, count, mean, m2); - }); - } - Some(filter) => { - group_indices - .iter() - .zip(counts.values().iter()) - .zip(means.values().iter()) - .zip(m2s.values().iter()) - .zip(filter.iter()) - .for_each( - |((((&group_index, &count), &mean), &m2), filter_value)| { - if let Some(true) = filter_value { - value_fn(group_index, count, mean, m2); - } - }, - ); - } - } + group_indices + .iter() + .zip(counts.values().iter()) + .zip(means.values().iter()) + .zip(m2s.values().iter()) + .for_each(|(((&group_index, &count), &mean), &m2)| { + value_fn(group_index, count, mean, m2); + }); } pub fn variance( @@ -554,7 +535,8 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&BooleanArray>, + // Since aggregate filter should be applied in partial stage, in final stage there should be no filter + _opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 3, "two arguments to merge_batch"); @@ -569,7 +551,7 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { partial_counts, partial_means, partial_m2s, - opt_filter, + None, |group_index, partial_count, partial_mean, partial_m2| { if partial_count == 0 { return; diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index de67b0ae3874..c84b6f010968 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -26,7 +26,7 @@ use arrow_array::{ new_null_array, Array, ArrayRef, GenericListArray, NullArray, OffsetSizeTrait, }; use arrow_buffer::OffsetBuffer; -use arrow_schema::DataType::{LargeList, List, Null}; +use arrow_schema::DataType::{List, Null}; use arrow_schema::{DataType, Field}; use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; use datafusion_expr::binary::{ @@ -198,7 +198,6 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { let array = new_null_array(&DataType::Int64, length); Ok(Arc::new(array_into_list_array_nullable(array))) } - LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), } } diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index 0f18fd47b4cf..0e8ff1cd3192 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -64,11 +64,11 @@ impl ScalarUDFImpl for Md5Func { fn return_type(&self, arg_types: &[DataType]) -> Result { use DataType::*; Ok(match &arg_types[0] { - LargeUtf8 | LargeBinary => LargeUtf8, + LargeUtf8 | LargeBinary => Utf8, Utf8View | Utf8 | Binary => Utf8, Null => Null, Dictionary(_, t) => match **t { - LargeUtf8 | LargeBinary => LargeUtf8, + LargeUtf8 | LargeBinary => Utf8, Utf8 | Binary => Utf8, Null => Null, _ => { diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index b659e477f67e..1519c54dbf68 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -39,7 +39,7 @@ use datafusion_expr::{ use crate::optimize_projections::required_indices::RequiredIndicies; use crate::utils::NamePreserver; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; /// Optimizer rule to prune unnecessary columns from intermediate schemas @@ -484,7 +484,7 @@ fn merge_consecutive_projections(proj: Projection) -> Result(&self, statistics: &S) -> Result> { let mut builder = BoolVecBuilder::new(statistics.num_containers()); @@ -653,7 +650,7 @@ impl PruningPredicate { // this is only used by `parquet` feature right now #[allow(dead_code)] - pub(crate) fn required_columns(&self) -> &RequiredColumns { + pub fn required_columns(&self) -> &RequiredColumns { &self.required_columns } @@ -762,7 +759,7 @@ fn is_always_true(expr: &Arc) -> bool { /// Handles creating references to the min/max statistics /// for columns as well as recording which statistics are needed #[derive(Debug, Default, Clone)] -pub(crate) struct RequiredColumns { +pub struct RequiredColumns { /// The statistics required to evaluate this predicate: /// * The unqualified column in the input schema /// * Statistics type (e.g. Min or Max or Null_Count) @@ -786,7 +783,7 @@ impl RequiredColumns { /// * `true` returns None #[allow(dead_code)] // this fn is only used by `parquet` feature right now, thus the `allow(dead_code)` - pub(crate) fn single_column(&self) -> Option<&phys_expr::Column> { + pub fn single_column(&self) -> Option<&phys_expr::Column> { if self.columns.windows(2).all(|w| { // check if all columns are the same (ignoring statistics and field) let c1 = &w[0].0; @@ -1664,15 +1661,14 @@ mod tests { use std::ops::{Not, Rem}; use super::*; - use crate::assert_batches_eq; - use crate::logical_expr::{col, lit}; + use datafusion_common::assert_batches_eq; + use datafusion_expr::{col, lit}; use arrow::array::Decimal128Array; use arrow::{ - array::{BinaryArray, Int32Array, Int64Array, StringArray}, + array::{BinaryArray, Int32Array, Int64Array, StringArray, UInt64Array}, datatypes::TimeUnit, }; - use arrow_array::UInt64Array; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; use datafusion_functions_nested::expr_fn::{array_has, make_array}; @@ -3536,7 +3532,7 @@ mod tests { // more complex case with unknown column let input = known_expression.clone().and(input.clone()); let expected = phys_expr::BinaryExpr::new( - known_expression_transformed.clone(), + Arc::::clone(&known_expression_transformed), Operator::And, logical2physical(&lit(42), &schema), ); @@ -3552,7 +3548,7 @@ mod tests { // more complex case with unknown expression let input = known_expression.and(input); let expected = phys_expr::BinaryExpr::new( - known_expression_transformed.clone(), + Arc::::clone(&known_expression_transformed), Operator::And, logical2physical(&lit(42), &schema), ); @@ -4038,7 +4034,7 @@ mod tests { ) { println!("Pruning with expr: {}", expr); let expr = logical2physical(&expr, schema); - let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let p = PruningPredicate::try_new(expr, Arc::::clone(schema)).unwrap(); let result = p.prune(statistics).unwrap(); assert_eq!(result, expected); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index a816203b6812..ae528daad53c 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -18,7 +18,13 @@ //! [`GroupValues`] trait for storing and interning group keys use arrow::record_batch::RecordBatch; +use arrow_array::types::{ + Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; use arrow_array::{downcast_primitive, ArrayRef}; +use arrow_schema::TimeUnit; use arrow_schema::{DataType, SchemaRef}; use datafusion_common::Result; @@ -142,6 +148,28 @@ pub(crate) fn new_group_values( } match d { + DataType::Date32 => { + downcast_helper!(Date32Type, d); + } + DataType::Date64 => { + downcast_helper!(Date64Type, d); + } + DataType::Time32(t) => match t { + TimeUnit::Second => downcast_helper!(Time32SecondType, d), + TimeUnit::Millisecond => downcast_helper!(Time32MillisecondType, d), + _ => {} + }, + DataType::Time64(t) => match t { + TimeUnit::Microsecond => downcast_helper!(Time64MicrosecondType, d), + TimeUnit::Nanosecond => downcast_helper!(Time64NanosecondType, d), + _ => {} + }, + DataType::Timestamp(t, _) => match t { + TimeUnit::Second => downcast_helper!(TimestampSecondType, d), + TimeUnit::Millisecond => downcast_helper!(TimestampMillisecondType, d), + TimeUnit::Microsecond => downcast_helper!(TimestampMicrosecondType, d), + TimeUnit::Nanosecond => downcast_helper!(TimestampNanosecondType, d), + }, DataType::Utf8 => { return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); } diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 83b0f9d77369..10b00cf74fdb 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -32,12 +32,14 @@ use ahash::RandomState; use arrow::compute::cast; use arrow::datatypes::{ BinaryViewType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, - Int32Type, Int64Type, Int8Type, StringViewType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, + Int32Type, Int64Type, Int8Type, StringViewType, Time32MillisecondType, + Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow::record_batch::RecordBatch; use arrow_array::{Array, ArrayRef}; -use arrow_schema::{DataType, Schema, SchemaRef}; +use arrow_schema::{DataType, Schema, SchemaRef, TimeUnit}; use datafusion_common::hash_utils::create_hashes; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; @@ -913,6 +915,38 @@ impl GroupValues for GroupValuesColumn { } &DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type), &DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type), + &DataType::Time32(t) => match t { + TimeUnit::Second => { + instantiate_primitive!(v, nullable, Time32SecondType) + } + TimeUnit::Millisecond => { + instantiate_primitive!(v, nullable, Time32MillisecondType) + } + _ => {} + }, + &DataType::Time64(t) => match t { + TimeUnit::Microsecond => { + instantiate_primitive!(v, nullable, Time64MicrosecondType) + } + TimeUnit::Nanosecond => { + instantiate_primitive!(v, nullable, Time64NanosecondType) + } + _ => {} + }, + &DataType::Timestamp(t, _) => match t { + TimeUnit::Second => { + instantiate_primitive!(v, nullable, TimestampSecondType) + } + TimeUnit::Millisecond => { + instantiate_primitive!(v, nullable, TimestampMillisecondType) + } + TimeUnit::Microsecond => { + instantiate_primitive!(v, nullable, TimestampMicrosecondType) + } + TimeUnit::Nanosecond => { + instantiate_primitive!(v, nullable, TimestampNanosecondType) + } + }, &DataType::Utf8 => { let b = ByteGroupValueBuilder::::new(OutputType::Utf8); v.push(Box::new(b) as _) @@ -1125,6 +1159,8 @@ fn supported_type(data_type: &DataType) -> bool { | DataType::LargeBinary | DataType::Date32 | DataType::Date64 + | DataType::Time32(_) + | DataType::Timestamp(_, _) | DataType::Utf8View | DataType::BinaryView ) diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index de0ae2e07dd2..8e0f0a3d6507 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -27,6 +27,7 @@ use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use hashbrown::raw::RawTable; +use log::debug; use std::mem::size_of; use std::sync::Arc; @@ -80,6 +81,9 @@ pub struct GroupValuesRows { impl GroupValuesRows { pub fn try_new(schema: SchemaRef) -> Result { + // Print a debugging message, so it is clear when the (slower) fallback + // GroupValuesRows is used. + debug!("Creating GroupValuesRows for schema: {}", schema); let row_converter = RowConverter::new( schema .fields() diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 0fa9f206f13d..965adbb8c780 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -859,14 +859,13 @@ impl GroupedHashAggregateStream { )?; } _ => { + if opt_filter.is_some() { + return internal_err!("aggregate filter should be applied in partial stage, there should be no filter in final stage"); + } + // if aggregation is over intermediate states, // use merge - acc.merge_batch( - values, - group_indices, - opt_filter, - total_num_groups, - )?; + acc.merge_batch(values, group_indices, None, total_num_groups)?; } } } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index e9f17ddebabc..2e97334493dd 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -932,7 +932,6 @@ impl ExecutionPlan for SortExec { context.session_config().batch_size(), context.runtime_env(), &self.metrics_set, - partition, )?; Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 27bb3b2b36b9..0f722ec143ff 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -95,9 +95,7 @@ pub struct TopK { impl TopK { /// Create a new [`TopK`] that stores the top `k` values, as /// defined by the sort expressions in `expr`. - // TODO: make a builder or some other nicer API to avoid the - // clippy warning - #[allow(clippy::too_many_arguments)] + // TODO: make a builder or some other nicer API pub fn try_new( partition_id: usize, schema: SchemaRef, @@ -106,7 +104,6 @@ impl TopK { batch_size: usize, runtime: Arc, metrics: &ExecutionPlanMetricsSet, - partition: usize, ) -> Result { let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) .register(&runtime.memory_pool); @@ -133,7 +130,7 @@ impl TopK { Ok(Self { schema: Arc::clone(&schema), - metrics: TopKMetrics::new(metrics, partition), + metrics: TopKMetrics::new(metrics, partition_id), reservation, batch_size, expr, diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 06288a1f7041..0615e6738a1f 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -36,7 +36,7 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{cast, is_not_null, kernels, sum}; use arrow::datatypes::{DataType, Int64Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_array::{Int64Array, Scalar, StructArray}; +use arrow_array::{new_null_array, Int64Array, Scalar, StructArray}; use arrow_ord::cmp::lt; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, HashMap, HashSet, Result, UnnestOptions, @@ -453,16 +453,36 @@ fn list_unnest_at_level( // Create the take indices array for other columns let take_indices = create_take_indicies(unnested_length, total_length); - - // Dimension of arrays in batch is untouched, but the values are repeated - // as the side effect of unnesting - let ret = repeat_arrs_from_indices(batch, &take_indices)?; unnested_temp_arrays .into_iter() .zip(list_unnest_specs.iter()) .for_each(|(flatten_arr, unnesting)| { temp_unnested_arrs.insert(*unnesting, flatten_arr); }); + + let repeat_mask: Vec = batch + .iter() + .enumerate() + .map(|(i, _)| { + // Check if the column is needed in future levels (levels below the current one) + let needed_in_future_levels = list_type_unnests.iter().any(|unnesting| { + unnesting.index_in_input_schema == i && unnesting.depth < level_to_unnest + }); + + // Check if the column is involved in unnesting at any level + let is_involved_in_unnesting = list_type_unnests + .iter() + .any(|unnesting| unnesting.index_in_input_schema == i); + + // Repeat columns needed in future levels or not unnested. + needed_in_future_levels || !is_involved_in_unnesting + }) + .collect(); + + // Dimension of arrays in batch is untouched, but the values are repeated + // as the side effect of unnesting + let ret = repeat_arrs_from_indices(batch, &take_indices, &repeat_mask)?; + Ok((ret, total_length)) } struct UnnestingResult { @@ -859,8 +879,11 @@ fn create_take_indicies( builder.finish() } -/// Create the batch given an arrays and a `indices` array -/// that is used by the take kernel to copy values. +/// Create a batch of arrays based on an input `batch` and a `indices` array. +/// The `indices` array is used by the take kernel to repeat values in the arrays +/// that are marked with `true` in the `repeat_mask`. Arrays marked with `false` +/// in the `repeat_mask` will be replaced with arrays filled with nulls of the +/// appropriate length. /// /// For example if we have the following batch: /// @@ -890,14 +913,35 @@ fn create_take_indicies( /// c2: 'a', 'b', 'c', 'c', 'c', null, 'd', 'd' /// ``` /// +/// The `repeat_mask` determines whether an array's values are repeated or replaced with nulls. +/// For example, if the `repeat_mask` is: +/// +/// ```ignore +/// [true, false] +/// ``` +/// +/// The final batch will look like: +/// +/// ```ignore +/// c1: 1, null, 2, 3, 4, null, 5, 6 // Repeated using `indices` +/// c2: null, null, null, null, null, null, null, null // Replaced with nulls +/// fn repeat_arrs_from_indices( batch: &[ArrayRef], indices: &PrimitiveArray, + repeat_mask: &[bool], ) -> Result>> { batch .iter() - .map(|arr| Ok(kernels::take::take(arr, indices, None)?)) - .collect::>() + .zip(repeat_mask.iter()) + .map(|(arr, &repeat)| { + if repeat { + Ok(kernels::take::take(arr, indices, None)?) + } else { + Ok(new_null_array(arr.data_type(), arr.len())) + } + }) + .collect() } #[cfg(test)] diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index a6f7c4fd1100..f1f28258f9bd 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -180,25 +180,7 @@ impl Unparser<'_> { }) } Expr::Cast(Cast { expr, data_type }) => { - let inner_expr = self.expr_to_sql_inner(expr)?; - match data_type { - DataType::Dictionary(_, _) => match inner_expr { - // Dictionary values don't need to be cast to other types when rewritten back to sql - ast::Expr::Value(_) => Ok(inner_expr), - _ => Ok(ast::Expr::Cast { - kind: ast::CastKind::Cast, - expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, - format: None, - }), - }, - _ => Ok(ast::Expr::Cast { - kind: ast::CastKind::Cast, - expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, - format: None, - }), - } + Ok(self.cast_to_sql(expr, data_type)?) } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), @@ -901,6 +883,31 @@ impl Unparser<'_> { }) } + // Explicit type cast on ast::Expr::Value is not needed by underlying engine for certain types + // For example: CAST(Utf8("binary_value") AS Binary) and CAST(Utf8("dictionary_value") AS Dictionary) + fn cast_to_sql(&self, expr: &Expr, data_type: &DataType) -> Result { + let inner_expr = self.expr_to_sql_inner(expr)?; + match inner_expr { + ast::Expr::Value(_) => match data_type { + DataType::Dictionary(_, _) | DataType::Binary | DataType::BinaryView => { + Ok(inner_expr) + } + _ => Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }), + }, + _ => Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }), + } + } + /// DataFusion ScalarValues sometimes require a ast::Expr to construct. /// For example ScalarValue::Date32(d) corresponds to the ast::Expr CAST('datestr' as DATE) fn scalar_to_sql(&self, v: &ScalarValue) -> Result { @@ -1451,9 +1458,7 @@ impl Unparser<'_> { } DataType::Utf8 => Ok(self.dialect.utf8_cast_dtype()), DataType::LargeUtf8 => Ok(self.dialect.large_utf8_cast_dtype()), - DataType::Utf8View => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } + DataType::Utf8View => Ok(self.dialect.utf8_cast_dtype()), DataType::List(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } @@ -1513,7 +1518,7 @@ mod tests { use datafusion_common::TableReference; use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ - case, col, cube, exists, grouping_set, interval_datetime_lit, + case, cast, col, cube, exists, grouping_set, interval_datetime_lit, interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, WindowFunctionDefinition, @@ -2237,6 +2242,39 @@ mod tests { } } + #[test] + fn test_cast_value_to_binary_expr() { + let tests = [ + ( + Expr::Cast(Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + "blah".to_string(), + )))), + data_type: DataType::Binary, + }), + "'blah'", + ), + ( + Expr::Cast(Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + "blah".to_string(), + )))), + data_type: DataType::BinaryView, + }), + "'blah'", + ), + ]; + for (value, expected) in tests { + let dialect = CustomDialectBuilder::new().build(); + let unparser = Unparser::new(&dialect); + + let ast = unparser.expr_to_sql(&value).expect("to be unparsed"); + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } + #[test] fn custom_dialect_use_char_for_utf8_cast() -> Result<()> { let default_dialect = CustomDialectBuilder::default().build(); @@ -2500,4 +2538,50 @@ mod tests { } Ok(()) } + + #[test] + fn test_utf8_view_to_sql() -> Result<()> { + let dialect = CustomDialectBuilder::new() + .with_utf8_cast_dtype(ast::DataType::Char(None)) + .build(); + let unparser = Unparser::new(&dialect); + + let ast_dtype = unparser.arrow_dtype_to_ast_dtype(&DataType::Utf8View)?; + + assert_eq!(ast_dtype, ast::DataType::Char(None)); + + let expr = cast(col("a"), DataType::Utf8View); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = r#"CAST(a AS CHAR)"#.to_string(); + + assert_eq!(actual, expected); + + let expr = col("a").eq(lit(ScalarValue::Utf8View(Some("hello".to_string())))); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = r#"(a = 'hello')"#.to_string(); + + assert_eq!(actual, expected); + + let expr = col("a").is_not_null(); + + let ast = unparser.expr_to_sql(&expr)?; + let actual = format!("{}", ast); + let expected = r#"a IS NOT NULL"#.to_string(); + + assert_eq!(actual, expected); + + let expr = col("a").is_null(); + + let ast = unparser.expr_to_sql(&expr)?; + let actual = format!("{}", ast); + let expected = r#"a IS NULL"#.to_string(); + + assert_eq!(actual, expected); + + Ok(()) + } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 433c456855a3..81e47ed939f2 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -876,7 +876,16 @@ impl Unparser<'_> { constraint: ast::JoinConstraint, ) -> Result { Ok(match join_type { - JoinType::Inner => ast::JoinOperator::Inner(constraint), + JoinType::Inner => match &constraint { + ast::JoinConstraint::On(_) + | ast::JoinConstraint::Using(_) + | ast::JoinConstraint::Natural => ast::JoinOperator::Inner(constraint), + ast::JoinConstraint::None => { + // Inner joins with no conditions or filters are not valid SQL in most systems, + // return a CROSS JOIN instead + ast::JoinOperator::CrossJoin + } + }, JoinType::Left => ast::JoinOperator::LeftOuter(constraint), JoinType::Right => ast::JoinOperator::RightOuter(constraint), JoinType::Full => ast::JoinOperator::FullOuter(constraint), diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index 6b3b999ba04b..68af121a4117 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -18,11 +18,12 @@ use std::{collections::HashSet, sync::Arc}; use arrow_schema::Schema; +use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Column, HashMap, Result, TableReference, }; -use datafusion_expr::{expr::Alias, tree_node::transform_sort_vec}; +use datafusion_expr::expr::Alias; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; @@ -83,17 +84,18 @@ pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result /// Rewrite sort expressions that have a UNION plan as their input to remove the table reference. fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { - let sort_exprs = transform_sort_vec(exprs, &mut |expr| { - expr.transform_up(|expr| { - if let Expr::Column(mut col) = expr { - col.relation = None; - Ok(Transformed::yes(Expr::Column(col))) - } else { - Ok(Transformed::no(expr)) - } + let sort_exprs = exprs + .map_elements(&mut |expr: Expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } + }) }) - }) - .data()?; + .data()?; Ok(sort_exprs) } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 4f43d7333dd1..f9d97cdc74af 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -296,7 +296,7 @@ fn roundtrip_statement_with_dialect() -> Result<()> { TestStatementWithDialect { sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", expected: - "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` CROSS JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", parser_dialect: Box::new(MySqlDialect {}), unparser_dialect: Box::new(UnparserMySqlDialect {}), }, @@ -1253,3 +1253,17 @@ fn test_unnest_to_sql() { r#"SELECT UNNEST([1, 2, 2, 5, NULL]) AS u1"#, ); } + +#[test] +fn test_join_with_no_conditions() { + sql_round_trip( + GenericDialect {}, + "SELECT * FROM j1 JOIN j2", + "SELECT * FROM j1 CROSS JOIN j2", + ); + sql_round_trip( + GenericDialect {}, + "SELECT * FROM j1 CROSS JOIN j2", + "SELECT * FROM j1 CROSS JOIN j2", + ); +} diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index 81682558d0a9..ed2b9c49715e 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -51,7 +51,7 @@ object_store = { workspace = true } postgres-protocol = { version = "0.6.4", optional = true } postgres-types = { version = "0.2.4", optional = true } rust_decimal = { version = "1.27.0" } -sqllogictest = "0.22.0" +sqllogictest = "0.23.0" sqlparser = { workspace = true } tempfile = { workspace = true } thiserror = "2.0.0" diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index 477f225443e2..2466303c32a9 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -106,6 +106,8 @@ impl TestContext { let example_udf = create_example_udf(); test_ctx.ctx.register_udf(example_udf); register_partition_table(&mut test_ctx).await; + info!("Registering table with many types"); + register_table_with_many_types(test_ctx.session_ctx()).await; } "metadata.slt" => { info!("Registering metadata table tables"); @@ -251,8 +253,11 @@ pub async fn register_table_with_many_types(ctx: &SessionContext) { .unwrap(); ctx.register_catalog("my_catalog", Arc::new(catalog)); - ctx.register_table("my_catalog.my_schema.t2", table_with_many_types()) - .unwrap(); + ctx.register_table( + "my_catalog.my_schema.table_with_many_types", + table_with_many_types(), + ) + .unwrap(); } pub async fn register_table_with_map(ctx: &SessionContext) { diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 1e60699a1f65..e6676d683f91 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1155,6 +1155,18 @@ select column1, column5 from arrays_values_without_nulls; [21, 22, 23, 24, 25, 26, 27, 28, 29, 30] [6, 7] [31, 32, 33, 34, 35, 26, 37, 38, 39, 40] [8, 9] +# make array with arrays of different types +query ? +select make_array(make_array(1), arrow_cast(make_array(-1), 'LargeList(Int8)')) +---- +[[1], [-1]] + +query T +select arrow_typeof(make_array(make_array(1), arrow_cast(make_array(-1), 'LargeList(Int8)'))); +---- +List(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + + query ??? select make_array(column1), make_array(column1, column5), diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index bc974a57b2db..5d8c4dfd05b4 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5272,6 +5272,201 @@ drop view t statement ok drop table source; +# Test multi group by int + Date32 +statement ok +create table source as values +(1, '2020-01-01'), +(1, '2020-01-01'), +(2, '2020-01-02'), +(2, '2020-01-03'), +(3, '2020-01-04'), +(3, '2020-01-04'), +(2, '2020-01-03'), +(null, null), +(null, '2020-01-01'), +(null, null), +(null, '2020-01-01'), +(2, '2020-01-02'), +(2, '2020-01-02'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Date32') as b from source; + +query IDI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 2020-01-01 2 +1 NULL 1 +2 2020-01-02 3 +2 2020-01-03 2 +3 2020-01-04 2 +NULL 2020-01-01 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + +# Test multi group by int + Date64 +statement ok +create table source as values +(1, '2020-01-01'), +(1, '2020-01-01'), +(2, '2020-01-02'), +(2, '2020-01-03'), +(3, '2020-01-04'), +(3, '2020-01-04'), +(2, '2020-01-03'), +(null, null), +(null, '2020-01-01'), +(null, null), +(null, '2020-01-01'), +(2, '2020-01-02'), +(2, '2020-01-02'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Date64') as b from source; + +query IDI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 2020-01-01T00:00:00 2 +1 NULL 1 +2 2020-01-02T00:00:00 3 +2 2020-01-03T00:00:00 2 +3 2020-01-04T00:00:00 2 +NULL 2020-01-01T00:00:00 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + +# Test multi group by int + Time32 +statement ok +create table source as values +(1, '12:34:56'), +(1, '12:34:56'), +(2, '13:00:00'), +(2, '14:15:00'), +(3, '23:59:59'), +(3, '23:59:59'), +(2, '14:15:00'), +(null, null), +(null, '12:00:00'), +(null, null), +(null, '12:00:00'), +(2, '13:00:00'), +(2, '13:00:00'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Time32(Second)') as b from source; + +query IDI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 12:34:56 2 +1 NULL 1 +2 13:00:00 3 +2 14:15:00 2 +3 23:59:59 2 +NULL 12:00:00 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + +# Test multi group by int + Time64 +statement ok +create table source as values +(1, '12:34:56.123456'), +(1, '12:34:56.123456'), +(2, '13:00:00.000001'), +(2, '14:15:00.999999'), +(3, '23:59:59.500000'), +(3, '23:59:59.500000'), +(2, '14:15:00.999999'), +(null, null), +(null, '12:00:00.000000'), +(null, null), +(null, '12:00:00.000000'), +(2, '13:00:00.000001'), +(2, '13:00:00.000001'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Time64(Microsecond)') as b from source; + +query IDI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 12:34:56.123456 2 +1 NULL 1 +2 13:00:00.000001 3 +2 14:15:00.999999 2 +3 23:59:59.500 2 +NULL 12:00:00 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + +# Test multi group by int + Timestamp +statement ok +create table source as values +(1, '2020-01-01 12:34:56'), +(1, '2020-01-01 12:34:56'), +(2, '2020-01-02 13:00:00'), +(2, '2020-01-03 14:15:00'), +(3, '2020-01-04 23:59:59'), +(3, '2020-01-04 23:59:59'), +(2, '2020-01-03 14:15:00'), +(null, null), +(null, '2020-01-01 12:00:00'), +(null, null), +(null, '2020-01-01 12:00:00'), +(2, '2020-01-02 13:00:00'), +(2, '2020-01-02 13:00:00'), +(1, null) +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Timestamp(Nanosecond, None)') as b from source; + +query IPI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 2020-01-01T12:34:56 2 +1 NULL 1 +2 2020-01-02T13:00:00 3 +2 2020-01-03T14:15:00 2 +3 2020-01-04T23:59:59 2 +NULL 2020-01-01T12:00:00 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; + # Test whether min, max accumulator produces NaN result when input is NaN. # See https://github.com/apache/datafusion/issues/13415 for rationale statement ok @@ -5287,3 +5482,4 @@ query RR SELECT max(input_table.x), min(input_table.x) from input_table GROUP BY input_table."row"; ---- NaN NaN + diff --git a/datafusion/sqllogictest/test_files/information_schema_columns.slt b/datafusion/sqllogictest/test_files/information_schema_columns.slt index 7cf845c16d73..d348a764fa85 100644 --- a/datafusion/sqllogictest/test_files/information_schema_columns.slt +++ b/datafusion/sqllogictest/test_files/information_schema_columns.slt @@ -37,17 +37,17 @@ query TTTTITTTIIIIIIT rowsort SELECT * from information_schema.columns; ---- my_catalog my_schema t1 i 0 NULL YES Int32 NULL NULL 32 2 NULL NULL NULL -my_catalog my_schema t2 binary_col 4 NULL NO Binary NULL 2147483647 NULL NULL NULL NULL NULL -my_catalog my_schema t2 float64_col 1 NULL YES Float64 NULL NULL 24 2 NULL NULL NULL -my_catalog my_schema t2 int32_col 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL -my_catalog my_schema t2 large_binary_col 5 NULL NO LargeBinary NULL 9223372036854775807 NULL NULL NULL NULL NULL -my_catalog my_schema t2 large_utf8_col 3 NULL NO LargeUtf8 NULL 9223372036854775807 NULL NULL NULL NULL NULL -my_catalog my_schema t2 timestamp_nanos 6 NULL NO Timestamp(Nanosecond, None) NULL NULL NULL NULL NULL NULL NULL -my_catalog my_schema t2 utf8_col 2 NULL YES Utf8 NULL 2147483647 NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types binary_col 4 NULL NO Binary NULL 2147483647 NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types float64_col 1 NULL YES Float64 NULL NULL 24 2 NULL NULL NULL +my_catalog my_schema table_with_many_types int32_col 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL +my_catalog my_schema table_with_many_types large_binary_col 5 NULL NO LargeBinary NULL 9223372036854775807 NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types large_utf8_col 3 NULL NO LargeUtf8 NULL 9223372036854775807 NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types timestamp_nanos 6 NULL NO Timestamp(Nanosecond, None) NULL NULL NULL NULL NULL NULL NULL +my_catalog my_schema table_with_many_types utf8_col 2 NULL YES Utf8 NULL 2147483647 NULL NULL NULL NULL NULL # Cleanup statement ok drop table t1 statement ok -drop table t2 +drop table table_with_many_types diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index d45dbc7ee1ae..e636e93007a4 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -4292,3 +4292,24 @@ query T select * from table1 as t1 natural join table1_stringview as t2; ---- foo + +query TT +EXPLAIN SELECT count(*) +FROM my_catalog.my_schema.table_with_many_types AS l +JOIN my_catalog.my_schema.table_with_many_types AS r ON l.binary_col = r.binary_col +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] +02)--Projection: +03)----Inner Join: l.binary_col = r.binary_col +04)------SubqueryAlias: l +05)--------TableScan: my_catalog.my_schema.table_with_many_types projection=[binary_col] +06)------SubqueryAlias: r +07)--------TableScan: my_catalog.my_schema.table_with_many_types projection=[binary_col] +physical_plan +01)AggregateExec: mode=Single, gby=[], aggr=[count(*)] +02)--ProjectionExec: expr=[] +03)----CoalesceBatchesExec: target_batch_size=3 +04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)--------MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index 493da64063bc..145081f91a30 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -901,7 +901,7 @@ SELECT '\' LIKE '\\', '\\' LIKE '\\' ---- -false true false true false false true +true false false true false true false # if "%%" in the pattern was simplified to "%", the pattern semantics would change query BBBBB @@ -1002,7 +1002,7 @@ NULL \%abc NULL \ NULL NULL \ (empty) false \ \ true -\ \\ false +\ \\ true \ \\\ false \ \\\\ false \ a false @@ -1010,10 +1010,10 @@ NULL \%abc NULL \ \\a false \ % true \ \% false -\ \\% false +\ \\% true \ %% true \ \%% false -\ \\%% false +\ \\%% true \ _ true \ \_ false \ \\_ false @@ -1028,21 +1028,21 @@ NULL \%abc NULL \\ NULL NULL \\ (empty) false \\ \ false -\\ \\ true -\\ \\\ false -\\ \\\\ false +\\ \\ false +\\ \\\ true +\\ \\\\ true \\ a false \\ \a false \\ \\a false \\ % true \\ \% false -\\ \\% false +\\ \\% true \\ %% true \\ \%% false -\\ \\%% false +\\ \\%% true \\ _ false \\ \_ false -\\ \\_ false +\\ \\_ true \\ __ true \\ \__ false \\ \\__ false @@ -1055,23 +1055,23 @@ NULL \%abc NULL \\\ (empty) false \\\ \ false \\\ \\ false -\\\ \\\ true +\\\ \\\ false \\\ \\\\ false \\\ a false \\\ \a false \\\ \\a false \\\ % true \\\ \% false -\\\ \\% false +\\\ \\% true \\\ %% true \\\ \%% false -\\\ \\%% false +\\\ \\%% true \\\ _ false \\\ \_ false \\\ \\_ false \\\ __ false \\\ \__ false -\\\ \\__ false +\\\ \\__ true \\\ abc false \\\ a_c false \\\ a\_c false @@ -1082,16 +1082,16 @@ NULL \%abc NULL \\\\ \ false \\\\ \\ false \\\\ \\\ false -\\\\ \\\\ true +\\\\ \\\\ false \\\\ a false \\\\ \a false \\\\ \\a false \\\\ % true \\\\ \% false -\\\\ \\% false +\\\\ \\% true \\\\ %% true \\\\ \%% false -\\\\ \\%% false +\\\\ \\%% true \\\\ _ false \\\\ \_ false \\\\ \\_ false @@ -1110,7 +1110,7 @@ a \\ false a \\\ false a \\\\ false a a true -a \a false +a \a true a \\a false a % true a \% false @@ -1136,17 +1136,17 @@ a \%abc false \a \\\ false \a \\\\ false \a a false -\a \a true -\a \\a false +\a \a false +\a \\a true \a % true \a \% false -\a \\% false +\a \\% true \a %% true \a \%% false -\a \\%% false +\a \\%% true \a _ false \a \_ false -\a \\_ false +\a \\_ true \a __ true \a \__ false \a \\__ false @@ -1163,19 +1163,19 @@ a \%abc false \\a \\\\ false \\a a false \\a \a false -\\a \\a true +\\a \\a false \\a % true \\a \% false -\\a \\% false +\\a \\% true \\a %% true \\a \%% false -\\a \\%% false +\\a \\%% true \\a _ false \\a \_ false \\a \\_ false \\a __ false \\a \__ false -\\a \\__ false +\\a \\__ true \\a abc false \\a a_c false \\a a\_c false @@ -1224,7 +1224,7 @@ a \%abc false \% \\%% true \% _ false \% \_ false -\% \\_ false +\% \\_ true \% __ true \% \__ false \% \\__ false @@ -1244,16 +1244,16 @@ a \%abc false \\% \\a false \\% % true \\% \% false -\\% \\% false +\\% \\% true \\% %% true \\% \%% false -\\% \\%% false +\\% \\%% true \\% _ false \\% \_ false \\% \\_ false \\% __ false \\% \__ false -\\% \\__ false +\\% \\__ true \\% abc false \\% a_c false \\% a\_c false @@ -1296,7 +1296,7 @@ a \%abc false \%% \\a false \%% % true \%% \% false -\%% \\% false +\%% \\% true \%% %% true \%% \%% false \%% \\%% true @@ -1305,7 +1305,7 @@ a \%abc false \%% \\_ false \%% __ false \%% \__ false -\%% \\__ false +\%% \\__ true \%% abc false \%% a_c false \%% a\_c false @@ -1322,10 +1322,10 @@ a \%abc false \\%% \\a false \\%% % true \\%% \% false -\\%% \\% false +\\%% \\% true \\%% %% true \\%% \%% false -\\%% \\%% false +\\%% \\%% true \\%% _ false \\%% \_ false \\%% \\_ false @@ -1374,10 +1374,10 @@ _ \%abc false \_ \\a false \_ % true \_ \% false -\_ \\% false +\_ \\% true \_ %% true \_ \%% false -\_ \\%% false +\_ \\%% true \_ _ false \_ \_ false \_ \\_ true @@ -1400,16 +1400,16 @@ _ \%abc false \\_ \\a false \\_ % true \\_ \% false -\\_ \\% false +\\_ \\% true \\_ %% true \\_ \%% false -\\_ \\%% false +\\_ \\%% true \\_ _ false \\_ \_ false \\_ \\_ false \\_ __ false \\_ \__ false -\\_ \\__ false +\\_ \\__ true \\_ abc false \\_ a_c false \\_ a\_c false @@ -1452,10 +1452,10 @@ __ \%abc false \__ \\a false \__ % true \__ \% false -\__ \\% false +\__ \\% true \__ %% true \__ \%% false -\__ \\%% false +\__ \\%% true \__ _ false \__ \_ false \__ \\_ false @@ -1478,10 +1478,10 @@ __ \%abc false \\__ \\a false \\__ % true \\__ \% false -\\__ \\% false +\\__ \\% true \\__ %% true \\__ \%% false -\\__ \\%% false +\\__ \\%% true \\__ _ false \\__ \_ false \\__ \\_ false @@ -1608,7 +1608,7 @@ a\_c \%abc false \%abc \\a false \%abc % true \%abc \% false -\%abc \\% false +\%abc \\% true \%abc %% true \%abc \%% false \%abc \\%% true diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part index f781b9dc33ca..c42a9384c5d0 100644 --- a/datafusion/sqllogictest/test_files/string/string_query.slt.part +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -1373,3 +1373,112 @@ p percent NULL pan Tadeusz ma iść w kąt pan Tadeusz ma iść w kąt NULL _ _ NULL (empty) (empty) NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test md5 +# -------------------------------------- + +query T +select md5(ascii_1) from test_basic_operator; +---- +8aae3a73a9a43ee6b04dfd986fe9d136 +76515af83bcb9d6336fe42dba18e716d +84fc7720d5e7bf07115d91762843b8ad +e0c4c75d58916b22a41b6ea9bc46231f +354f047ba64552895b016bbdd60ab174 +d41d8cd98f00b204e9800998ecf8427e +0bcef9c45bd8a48eda1b26eb0c61c869 +b14a7b8059d9c055954c92674ce60032 +NULL +NULL + +# -------------------------------------- +# Test sha244 +# -------------------------------------- + +query ? +select sha224(ascii_1) from test_basic_operator; +---- +abd8be3961e5dbe324bc67f9a0211d5f7d81e556baadaff6218e4bfa +87a20c95932524a54a0263a621fe791a5d5fbc0e40242b59732d6bf5 +8dd0c8021fe87bbc1c0701bd3130e27a639dcd93083c3f1989ffdf26 +8f6caa44143a080541f083bb762107ce12224b271bfa8b36ece002ab +951336d101e034714ba1ca0535688f0300613e235814ed938cd25115 +d14a028c2a3a2bc9476102bb288234c415a2b01f828ea62ac5b3e42f +fda2a4d4c5fb67cfd7fc817f59b543ae42f650aa4abd79934ca5ac55 +d365e3c7512c311d0df0528a850e6c827cbe508d13235fa91b545389 +NULL +NULL + +# -------------------------------------- +# Test sha256 +# -------------------------------------- + +query ? +select sha256(ascii_1) from test_basic_operator; +---- +c10873196eb1124ed74461c20a67094e395f2310f6305607b9694ee6b1ee8b43 +ec792d2e89af0d5b05c88ee1e5fe041ce2db94f84c3aabac4f7cfe20f00cd032 +053e9c5f1a29bea66ff896d7a8f217bf380b8e3973e7f13c1acbe14ef7fc947e +d8071166bbe6131a0acaf86019eeeca31c87ee4fda23b80eda0d094dbffee521 +fd86717aca41c558c78c19ab2b50691179a57ba5200bc7e3317be70efd4043ad +e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 +bbf3f11cb5b43e700273a78d12de55e4a7eab741ed2abf13787a4d2dc832b8ec +d2e2adf7177b7a8afddbc12d1634cf23ea1a71020f6a1308070a16400fb68fde +NULL +NULL + +# -------------------------------------- +# Test sha384 +# -------------------------------------- + +query ? +select sha384(ascii_1) from test_basic_operator; +---- +33a2a749758403660d131256e08647f52e4efba74840e7ad55c77012ade611ec0dc815ab3fa777e98710d43f3345222b +7b525a4147696421c6119df0e983ee3d9ebcfa13b3e1dce2fb308f91863e236fde55b56b89936908999332f5a453845c +359ee4b366b1965e9ceb0bd529edcdc08c33b0348aa4cc2cf4114c7f18069d53f6a798482626393c46ed340995c34b4e +fe417fcff1b9b8cdbc4fba45fedcd882ccbeef438497647052809fd73f43bcf1a6214f543a91e7183d56c6ae8e7cb30e +7791b34dcc841235a8a074052bc12aa7090c0d72f09ec41b1521a67fa09b026a9c02d159b42428d7b528aa5ff7598fd4 +38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b +bba987e661a4158451c5e9870fe91f483064574a0d7485caef40f48d7846579859c7dddebd418cbc99ccaa1ebd3619ea +586b0fd9f8ec935c69a7dceb5560742f368962833023906d30fe1cf49c96ea6d22cea8c2b63cd18e7af08fbf9e47c3f9 +NULL +NULL + + +# -------------------------------------- +# Test sha512 +# -------------------------------------- + +query ? +select sha512(ascii_1) from test_basic_operator; +---- +93262eb44d649a02a83b78889fd813ce819759daabcee2ac433f1ea7feef44f521ac0eba5b5359d47c7a7146afbe064b55134a63ac713c0fcc4c48e11eed7109 +f02c73afb1e433d6cc7e9137bb4ed40791e8c6e7877ae26e7a1edc4ce98a945a61bdf883d985adbc03d74d67ac18d4981529be5f4f53a35ff7fcd3e9814592d7 +2f25e277902f07a4c5cdb54485487b50bae3acdd615cd5551f71f4e3d97077fbccfbf0c85f88d6766d132069a343b732c6e81080a2c3ed59caff0c6947f4c57a +cafc51edc3a949179a74a805be8d0c7991bfc849b01f773f4bcd5e7dbe51b6d71d65921d8025d375d501af6a1c1026ab76cd7f4811b91bb4544f7dcbb710fa1f +2f845edf0e9c9728fae627d4678dc8c35c9a7f22809d355aa5ddf96d9ca3539973ac7ff96bfc6720ce6a973f93b716e265ad719ee38a85e44d9316ac1b6c89a4 +cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e +91972aa34055bca20ddb643b9f817a547e5d4ad49b7ff16a7f828a8d72c4cb4a5679cff4da00f9fb6b2833de7eb3480b3b4a7c7c7b85a39028de55acaf2d8812 +bbbe7f2559c7953d281fba7f25258063dbc8a55c5b9fdfcd334ecd64a8d7d8980c6f6ee0457bf496bcff747991f741446f1814222678dfa7457f1ad3a6f848b3 +NULL +NULL + +# -------------------------------------- +# Test DIGEST +# -------------------------------------- + +query ? +select DIGEST(ascii_1, 'sha256') from test_basic_operator; +---- +c10873196eb1124ed74461c20a67094e395f2310f6305607b9694ee6b1ee8b43 +ec792d2e89af0d5b05c88ee1e5fe041ce2db94f84c3aabac4f7cfe20f00cd032 +053e9c5f1a29bea66ff896d7a8f217bf380b8e3973e7f13c1acbe14ef7fc947e +d8071166bbe6131a0acaf86019eeeca31c87ee4fda23b80eda0d094dbffee521 +fd86717aca41c558c78c19ab2b50691179a57ba5200bc7e3317be70efd4043ad +e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 +bbf3f11cb5b43e700273a78d12de55e4a7eab741ed2abf13787a4d2dc832b8ec +d2e2adf7177b7a8afddbc12d1634cf23ea1a71020f6a1308070a16400fb68fde +NULL +NULL \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 5a08f3f5447a..aa41cbb8119e 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -39,8 +39,19 @@ drop table test_source # TODO: Revisit this issue after upgrading to the arrow-rs version that includes apache/arrow-rs#6671. # see issue https://github.com/apache/datafusion/issues/13329 -query error DataFusion error: Arrow error: Compute error: bit_length not supported for Utf8View +query IIII select bit_length(ascii_1), bit_length(ascii_2), bit_length(unicode_1), bit_length(unicode_2) from test_basic_operator; +---- +48 8 144 32 +72 72 176 176 +56 8 240 64 +88 88 104 256 +56 24 216 288 +0 8 0 0 +8 16 0 0 +8 16 0 0 +NULL 8 NULL NULL +NULL 8 NULL 32 # # common test for string-like functions and operators diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index fb7afdda2ea8..b5e82f613a46 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -761,3 +761,19 @@ SELECT NULL WHERE FALSE; ---- 0.5 1 + +# Test Union of List Types. Issue: https://github.com/apache/datafusion/issues/12291 +query error DataFusion error: type_coercion\ncaused by\nError during planning: Incompatible inputs for Union: Previous inputs were of type List(.*), but got incompatible type List(.*) on column 'x' +SELECT make_array(2) x UNION ALL SELECT make_array(now()) x; + +query ? +select make_array(arrow_cast(2, 'UInt8')) x UNION ALL SELECT make_array(arrow_cast(-2, 'Int8')) x; +---- +[-2] +[2] + +query ? +select make_array(make_array(1)) x UNION ALL SELECT make_array(arrow_cast(make_array(-1), 'LargeList(Int8)')) x; +---- +[[-1]] +[[1]] diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 8ebed5b25ca9..2e1b8b87cc42 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -853,3 +853,9 @@ select unnest(u.column5), j.* except(column2, column3) from unnest_table u join 1 2 1 3 4 2 NULL NULL 4 + +## Issue: https://github.com/apache/datafusion/issues/13237 +query I +select count(*) from (select unnest(range(0, 100000)) id) t inner join (select unnest(range(0, 100000)) id) t1 on t.id = t1.id; +---- +100000 diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 192fe26d6cef..61cdf3e91e3c 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -34,6 +34,7 @@ workspace = true [dependencies] arrow-buffer = { workspace = true } async-recursion = "1.0" +async-trait = { workspace = true } chrono = { workspace = true } datafusion = { workspace = true, default-features = true } itertools = { workspace = true } diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index a6f7c033f9d0..1389cac75b99 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -64,10 +64,10 @@ //! let plan = df.into_optimized_plan()?; //! //! // Convert the plan into a substrait (protobuf) Plan -//! let substrait_plan = logical_plan::producer::to_substrait_plan(&plan, &ctx)?; +//! let substrait_plan = logical_plan::producer::to_substrait_plan(&plan, &ctx.state())?; //! //! // Receive a substrait protobuf from somewhere, and turn it into a LogicalPlan -//! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?; +//! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx.state(), &substrait_plan).await?; //! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?; //! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); //! # Ok(()) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 1cce228527ec..77e9eb81f546 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -26,7 +26,7 @@ use datafusion::common::{ not_impl_err, plan_datafusion_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, }; -use datafusion::execution::FunctionRegistry; +use datafusion::datasource::provider_as_source; use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ @@ -56,7 +56,6 @@ use crate::variation_const::{ use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::arrow::temporal_conversions::NANOSECONDS; use datafusion::common::scalar::ScalarStructBuilder; -use datafusion::dataframe::DataFrame; use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ @@ -66,9 +65,7 @@ use datafusion::logical_expr::{ use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ - error::Result, - logical_expr::utils::split_conjunction, - prelude::{Column, SessionContext}, + error::Result, logical_expr::utils::split_conjunction, prelude::Column, scalar::ScalarValue, }; use std::collections::HashSet; @@ -102,6 +99,8 @@ use substrait::proto::{ }; use substrait::proto::{ExtendedExpression, FunctionArgument, SortField}; +use super::state::SubstraitPlanningState; + // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which // is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone // results in correct points on the timeline, and we pick UTC as a reasonable default. @@ -203,15 +202,15 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( async fn union_rels( rels: &[Rel], - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &Extensions, is_all: bool, ) -> Result { let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(ctx, &rels[0], extensions).await?, + from_substrait_rel(state, &rels[0], extensions).await?, )); for input in &rels[1..] { - let rel_plan = from_substrait_rel(ctx, input, extensions).await?; + let rel_plan = from_substrait_rel(state, input, extensions).await?; union_builder = if is_all { union_builder?.union(rel_plan) @@ -224,16 +223,16 @@ async fn union_rels( async fn intersect_rels( rels: &[Rel], - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::intersect( rel, - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, is_all, )? } @@ -243,16 +242,16 @@ async fn intersect_rels( async fn except_rels( rels: &[Rel], - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::except( rel, - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, is_all, )? } @@ -262,7 +261,7 @@ async fn except_rels( /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, plan: &Plan, ) -> Result { // Register function extension @@ -277,10 +276,10 @@ pub async fn from_substrait_plan( match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(ctx, rel, &extensions).await?) + Ok(from_substrait_rel(state, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; + let plan = from_substrait_rel(state, root.input.as_ref().unwrap(), &extensions).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -341,7 +340,7 @@ pub struct ExprContainer { /// between systems. This is often useful for scenarios like pushdown where filter /// expressions need to be sent to remote systems. pub async fn from_substrait_extended_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extended_expr: &ExtendedExpression, ) -> Result { // Register function extension @@ -370,7 +369,7 @@ pub async fn from_substrait_extended_expr( } }?; let expr = - from_substrait_rex(ctx, scalar_expr, &input_schema, &extensions).await?; + from_substrait_rex(state, scalar_expr, &input_schema, &extensions).await?; let (output_type, expected_nullability) = expr.data_type_and_nullable(&input_schema)?; let output_field = Field::new("", output_type, expected_nullability); @@ -561,7 +560,7 @@ fn make_renamed_schema( #[allow(deprecated)] #[async_recursion] pub async fn from_substrait_rel( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, rel: &Rel, extensions: &Extensions, ) -> Result { @@ -569,7 +568,7 @@ pub async fn from_substrait_rel( Some(RelType::Project(p)) => { if let Some(input) = p.input.as_ref() { let mut input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let original_schema = input.schema().clone(); @@ -587,9 +586,13 @@ pub async fn from_substrait_rel( let mut explicit_exprs: Vec = vec![]; for expr in &p.expressions { - let e = - from_substrait_rex(ctx, expr, input.clone().schema(), extensions) - .await?; + let e = from_substrait_rex( + state, + expr, + input.clone().schema(), + extensions, + ) + .await?; // if the expression is WindowFunction, wrap in a Window relation if let Expr::WindowFunction(_) = &e { // Adding the same expression here and in the project below @@ -617,11 +620,11 @@ pub async fn from_substrait_rel( Some(RelType::Filter(filter)) => { if let Some(input) = filter.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); if let Some(condition) = filter.condition.as_ref() { let expr = - from_substrait_rex(ctx, condition, input.schema(), extensions) + from_substrait_rex(state, condition, input.schema(), extensions) .await?; input.filter(expr)?.build() } else { @@ -634,7 +637,7 @@ pub async fn from_substrait_rel( Some(RelType::Fetch(fetch)) => { if let Some(input) = fetch.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let offset = fetch.offset as usize; // -1 means that ALL records should be returned @@ -651,10 +654,10 @@ pub async fn from_substrait_rel( Some(RelType::Sort(sort)) => { if let Some(input) = sort.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let sorts = - from_substrait_sorts(ctx, &sort.sorts, input.schema(), extensions) + from_substrait_sorts(state, &sort.sorts, input.schema(), extensions) .await?; input.sort(sorts)?.build() } else { @@ -664,13 +667,13 @@ pub async fn from_substrait_rel( Some(RelType::Aggregate(agg)) => { if let Some(input) = agg.input.as_ref() { let input = LogicalPlanBuilder::from( - from_substrait_rel(ctx, input, extensions).await?, + from_substrait_rel(state, input, extensions).await?, ); let mut ref_group_exprs = vec![]; for e in &agg.grouping_expressions { let x = - from_substrait_rex(ctx, e, input.schema(), extensions).await?; + from_substrait_rex(state, e, input.schema(), extensions).await?; ref_group_exprs.push(x); } @@ -681,7 +684,7 @@ pub async fn from_substrait_rel( 1 => { group_exprs.extend_from_slice( &from_substrait_grouping( - ctx, + state, &agg.groupings[0], &ref_group_exprs, input.schema(), @@ -694,7 +697,7 @@ pub async fn from_substrait_rel( let mut grouping_sets = vec![]; for grouping in &agg.groupings { let grouping_set = from_substrait_grouping( - ctx, + state, grouping, &ref_group_exprs, input.schema(), @@ -716,7 +719,7 @@ pub async fn from_substrait_rel( for m in &agg.measures { let filter = match &m.filter { Some(fil) => Some(Box::new( - from_substrait_rex(ctx, fil, input.schema(), extensions) + from_substrait_rex(state, fil, input.schema(), extensions) .await?, )), None => None, @@ -739,7 +742,7 @@ pub async fn from_substrait_rel( let order_by = if !f.sorts.is_empty() { Some( from_substrait_sorts( - ctx, + state, &f.sorts, input.schema(), extensions, @@ -751,7 +754,7 @@ pub async fn from_substrait_rel( }; from_substrait_agg_func( - ctx, + state, f, input.schema(), extensions, @@ -780,10 +783,12 @@ pub async fn from_substrait_rel( } let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - from_substrait_rel(ctx, join.left.as_ref().unwrap(), extensions).await?, + from_substrait_rel(state, join.left.as_ref().unwrap(), extensions) + .await?, ); let right = LogicalPlanBuilder::from( - from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?, + from_substrait_rel(state, join.right.as_ref().unwrap(), extensions) + .await?, ); let (left, right) = requalify_sides_if_needed(left, right)?; @@ -796,7 +801,7 @@ pub async fn from_substrait_rel( // Otherwise, build join with only the filter, without join keys match &join.expression.as_ref() { Some(expr) => { - let on = from_substrait_rex(ctx, expr, &in_join_schema, extensions) + let on = from_substrait_rex(state, expr, &in_join_schema, extensions) .await?; // The join expression can contain both equal and non-equal ops. // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. @@ -831,26 +836,44 @@ pub async fn from_substrait_rel( } Some(RelType::Cross(cross)) => { let left = LogicalPlanBuilder::from( - from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, + from_substrait_rel(state, cross.left.as_ref().unwrap(), extensions) + .await?, ); let right = LogicalPlanBuilder::from( - from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) + from_substrait_rel(state, cross.right.as_ref().unwrap(), extensions) .await?, ); let (left, right) = requalify_sides_if_needed(left, right)?; left.cross_join(right.build()?)?.build() } Some(RelType::Read(read)) => { - fn read_with_schema( - df: DataFrame, + async fn read_with_schema( + state: &dyn SubstraitPlanningState, + table_ref: TableReference, schema: DFSchema, projection: &Option, ) -> Result { - ensure_schema_compatability(df.schema().to_owned(), schema.clone())?; + let schema = schema.replace_qualifier(table_ref.clone()); + + let plan = { + let provider = match state.table(&table_ref).await? { + Some(ref provider) => Arc::clone(provider), + _ => return plan_err!("No table named '{table_ref}'"), + }; + + LogicalPlanBuilder::scan( + table_ref, + provider_as_source(Arc::clone(&provider)), + None, + )? + .build()? + }; + + ensure_schema_compatability(plan.schema(), schema.clone())?; let schema = apply_masking(schema, projection)?; - apply_projection(df, schema) + apply_projection(plan, schema) } let named_struct = read.base_schema.as_ref().ok_or_else(|| { @@ -879,12 +902,13 @@ pub async fn from_substrait_rel( }, }; - let t = ctx.table(table_reference.clone()).await?; - - let substrait_schema = - substrait_schema.replace_qualifier(table_reference); - - read_with_schema(t, substrait_schema, &read.projection) + read_with_schema( + state, + table_reference, + substrait_schema, + &read.projection, + ) + .await } Some(ReadType::VirtualTable(vt)) => { if vt.values.is_empty() { @@ -960,12 +984,14 @@ pub async fn from_substrait_rel( let name = filename.unwrap(); // directly use unwrap here since we could determine it is a valid one let table_reference = TableReference::Bare { table: name.into() }; - let t = ctx.table(table_reference.clone()).await?; - - let substrait_schema = - substrait_schema.replace_qualifier(table_reference); - read_with_schema(t, substrait_schema, &read.projection) + read_with_schema( + state, + table_reference, + substrait_schema, + &read.projection, + ) + .await } _ => { not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) @@ -979,31 +1005,31 @@ pub async fn from_substrait_rel( } else { match set_op { set_rel::SetOp::UnionAll => { - union_rels(&set.inputs, ctx, extensions, true).await + union_rels(&set.inputs, state, extensions, true).await } set_rel::SetOp::UnionDistinct => { - union_rels(&set.inputs, ctx, extensions, false).await + union_rels(&set.inputs, state, extensions, false).await } set_rel::SetOp::IntersectionPrimary => { LogicalPlanBuilder::intersect( - from_substrait_rel(ctx, &set.inputs[0], extensions) + from_substrait_rel(state, &set.inputs[0], extensions) .await?, - union_rels(&set.inputs[1..], ctx, extensions, true) + union_rels(&set.inputs[1..], state, extensions, true) .await?, false, ) } set_rel::SetOp::IntersectionMultiset => { - intersect_rels(&set.inputs, ctx, extensions, false).await + intersect_rels(&set.inputs, state, extensions, false).await } set_rel::SetOp::IntersectionMultisetAll => { - intersect_rels(&set.inputs, ctx, extensions, true).await + intersect_rels(&set.inputs, state, extensions, true).await } set_rel::SetOp::MinusPrimary => { - except_rels(&set.inputs, ctx, extensions, false).await + except_rels(&set.inputs, state, extensions, false).await } set_rel::SetOp::MinusPrimaryAll => { - except_rels(&set.inputs, ctx, extensions, true).await + except_rels(&set.inputs, state, extensions, true).await } _ => not_impl_err!("Unsupported set operator: {set_op:?}"), } @@ -1015,8 +1041,7 @@ pub async fn from_substrait_rel( let Some(ext_detail) = &extension.detail else { return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); }; - let plan = ctx - .state() + let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; Ok(LogicalPlan::Extension(Extension { node: plan })) @@ -1025,8 +1050,7 @@ pub async fn from_substrait_rel( let Some(ext_detail) = &extension.detail else { return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; - let plan = ctx - .state() + let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let Some(input_rel) = &extension.input else { @@ -1034,7 +1058,7 @@ pub async fn from_substrait_rel( "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" ); }; - let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; + let input_plan = from_substrait_rel(state, input_rel, extensions).await?; let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; Ok(LogicalPlan::Extension(Extension { node: plan })) @@ -1043,13 +1067,12 @@ pub async fn from_substrait_rel( let Some(ext_detail) = &extension.detail else { return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); }; - let plan = ctx - .state() + let plan = state .serializer_registry() .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; let mut inputs = Vec::with_capacity(extension.inputs.len()); for input in &extension.inputs { - let input_plan = from_substrait_rel(ctx, input, extensions).await?; + let input_plan = from_substrait_rel(state, input, extensions).await?; inputs.push(input_plan); } let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; @@ -1059,7 +1082,7 @@ pub async fn from_substrait_rel( let Some(input) = exchange.input.as_ref() else { return substrait_err!("Unexpected empty input in ExchangeRel"); }; - let input = Arc::new(from_substrait_rel(ctx, input, extensions).await?); + let input = Arc::new(from_substrait_rel(state, input, extensions).await?); let Some(exchange_kind) = &exchange.exchange_kind else { return substrait_err!("Unexpected empty input in ExchangeRel"); @@ -1237,7 +1260,7 @@ impl NameTracker { /// DataFusion schema may have MORE fields, but not the other way around. /// 2. All fields are compatible. See [`ensure_field_compatability`] for details fn ensure_schema_compatability( - table_schema: DFSchema, + table_schema: &DFSchema, substrait_schema: DFSchema, ) -> Result<()> { substrait_schema @@ -1253,16 +1276,19 @@ fn ensure_schema_compatability( /// This function returns a DataFrame with fields adjusted if necessary in the event that the /// Substrait schema is a subset of the DataFusion schema. -fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result { - let df_schema = table.schema().to_owned(); - - let t = table.into_unoptimized_plan(); +fn apply_projection( + plan: LogicalPlan, + substrait_schema: DFSchema, +) -> Result { + let df_schema = plan.schema(); if df_schema.logically_equivalent_names_and_types(&substrait_schema) { - return Ok(t); + return Ok(plan); } - match t { + let df_schema = df_schema.to_owned(); + + match plan { LogicalPlan::TableScan(mut scan) => { let column_indices: Vec = substrait_schema .strip_qualifiers() @@ -1389,7 +1415,7 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, substrait_sorts: &Vec, input_schema: &DFSchema, extensions: &Extensions, @@ -1397,7 +1423,7 @@ pub async fn from_substrait_sorts( let mut sorts: Vec = vec![]; for s in substrait_sorts { let expr = - from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) + from_substrait_rex(state, s.expr.as_ref().unwrap(), input_schema, extensions) .await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { @@ -1439,14 +1465,15 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, exprs: &Vec, input_schema: &DFSchema, extensions: &Extensions, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; + let expression = + from_substrait_rex(state, expr, input_schema, extensions).await?; expressions.push(expression); } Ok(expressions) @@ -1454,7 +1481,7 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substrait_func_args( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, arguments: &Vec, input_schema: &DFSchema, extensions: &Extensions, @@ -1463,7 +1490,7 @@ pub async fn from_substrait_func_args( for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(ctx, e, input_schema, extensions).await + from_substrait_rex(state, e, input_schema, extensions).await } _ => not_impl_err!("Function argument non-Value type not supported"), }; @@ -1474,7 +1501,7 @@ pub async fn from_substrait_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, f: &AggregateFunction, input_schema: &DFSchema, extensions: &Extensions, @@ -1483,7 +1510,7 @@ pub async fn from_substrait_agg_func( distinct: bool, ) -> Result> { let args = - from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; + from_substrait_func_args(state, &f.arguments, input_schema, extensions).await?; let Some(function_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( @@ -1494,7 +1521,7 @@ pub async fn from_substrait_agg_func( let function_name = substrait_fun_name(function_name); // try udaf first, then built-in aggr fn. - if let Ok(fun) = ctx.udaf(function_name) { + if let Ok(fun) = state.udaf(function_name) { // deal with situation that count(*) got no arguments let args = if fun.name() == "count" && args.is_empty() { vec![Expr::Literal(ScalarValue::Int64(Some(1)))] @@ -1517,7 +1544,7 @@ pub async fn from_substrait_agg_func( /// Convert Substrait Rex to DataFusion Expr #[async_recursion] pub async fn from_substrait_rex( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, e: &Expression, input_schema: &DFSchema, extensions: &Extensions, @@ -1528,11 +1555,11 @@ pub async fn from_substrait_rex( let substrait_list = s.options.as_ref(); Ok(Expr::InList(InList { expr: Box::new( - from_substrait_rex(ctx, substrait_expr, input_schema, extensions) + from_substrait_rex(state, substrait_expr, input_schema, extensions) .await?, ), list: from_substrait_rex_vec( - ctx, + state, substrait_list, input_schema, extensions, @@ -1555,7 +1582,7 @@ pub async fn from_substrait_rex( if if_expr.then.is_none() { expr = Some(Box::new( from_substrait_rex( - ctx, + state, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -1568,7 +1595,7 @@ pub async fn from_substrait_rex( when_then_expr.push(( Box::new( from_substrait_rex( - ctx, + state, if_expr.r#if.as_ref().unwrap(), input_schema, extensions, @@ -1577,7 +1604,7 @@ pub async fn from_substrait_rex( ), Box::new( from_substrait_rex( - ctx, + state, if_expr.then.as_ref().unwrap(), input_schema, extensions, @@ -1589,7 +1616,7 @@ pub async fn from_substrait_rex( // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(ctx, e, input_schema, extensions).await?, + from_substrait_rex(state, e, input_schema, extensions).await?, )), None => None, }; @@ -1609,12 +1636,12 @@ pub async fn from_substrait_rex( let fn_name = substrait_fun_name(fn_name); let args = - from_substrait_func_args(ctx, &f.arguments, input_schema, extensions) + from_substrait_func_args(state, &f.arguments, input_schema, extensions) .await?; // try to first match the requested function into registered udfs, then built-in ops // and finally built-in expressions - if let Some(func) = ctx.state().scalar_functions().get(fn_name) { + if let Ok(func) = state.udf(fn_name) { Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( func.to_owned(), args, @@ -1644,7 +1671,7 @@ pub async fn from_substrait_rex( Ok(combined_expr) } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { - builder.build(ctx, f, input_schema, extensions).await + builder.build(state, f, input_schema, extensions).await } else { not_impl_err!("Unsupported function name: {fn_name:?}") } @@ -1657,7 +1684,7 @@ pub async fn from_substrait_rex( Some(output_type) => Ok(Expr::Cast(Cast::new( Box::new( from_substrait_rex( - ctx, + state, cast.as_ref().input.as_ref().unwrap().as_ref(), input_schema, extensions, @@ -1679,9 +1706,9 @@ pub async fn from_substrait_rex( let fn_name = substrait_fun_name(fn_name); // check udwf first, then udaf, then built-in window and aggregate functions - let fun = if let Ok(udwf) = ctx.udwf(fn_name) { + let fun = if let Ok(udwf) = state.udwf(fn_name) { Ok(WindowFunctionDefinition::WindowUDF(udwf)) - } else if let Ok(udaf) = ctx.udaf(fn_name) { + } else if let Ok(udaf) = state.udaf(fn_name) { Ok(WindowFunctionDefinition::AggregateUDF(udaf)) } else { not_impl_err!( @@ -1692,7 +1719,7 @@ pub async fn from_substrait_rex( }?; let order_by = - from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) + from_substrait_sorts(state, &window.sorts, input_schema, extensions) .await?; let bound_units = @@ -1715,14 +1742,14 @@ pub async fn from_substrait_rex( Ok(Expr::WindowFunction(expr::WindowFunction { fun, args: from_substrait_func_args( - ctx, + state, &window.arguments, input_schema, extensions, ) .await?, partition_by: from_substrait_rex_vec( - ctx, + state, &window.partitions, input_schema, extensions, @@ -1747,13 +1774,13 @@ pub async fn from_substrait_rex( let haystack_expr = &in_predicate.haystack; if let Some(haystack_expr) = haystack_expr { let haystack_expr = - from_substrait_rel(ctx, haystack_expr, extensions) + from_substrait_rel(state, haystack_expr, extensions) .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); Ok(Expr::InSubquery(InSubquery { expr: Box::new( from_substrait_rex( - ctx, + state, needle_expr, input_schema, extensions, @@ -1773,7 +1800,7 @@ pub async fn from_substrait_rex( } SubqueryType::Scalar(query) => { let plan = from_substrait_rel( - ctx, + state, &(query.input.clone()).unwrap_or_default(), extensions, ) @@ -1790,7 +1817,7 @@ pub async fn from_substrait_rex( PredicateOp::Exists => { let relation = &predicate.tuples; let plan = from_substrait_rel( - ctx, + state, &relation.clone().unwrap_or_default(), extensions, ) @@ -2772,7 +2799,7 @@ fn from_substrait_null( #[allow(deprecated)] async fn from_substrait_grouping( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, grouping: &Grouping, expressions: &[Expr], input_schema: &DFSchemaRef, @@ -2781,7 +2808,7 @@ async fn from_substrait_grouping( let mut group_exprs = vec![]; if !grouping.grouping_expressions.is_empty() { for e in &grouping.grouping_expressions { - let expr = from_substrait_rex(ctx, e, input_schema, extensions).await?; + let expr = from_substrait_rex(state, e, input_schema, extensions).await?; group_exprs.push(expr); } return Ok(group_exprs); @@ -2834,23 +2861,29 @@ impl BuiltinExprBuilder { pub async fn build( self, - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, f: &ScalarFunction, input_schema: &DFSchema, extensions: &Extensions, ) -> Result { match self.expr_name.as_str() { "like" => { - Self::build_like_expr(ctx, false, f, input_schema, extensions).await + Self::build_like_expr(state, false, f, input_schema, extensions).await } "ilike" => { - Self::build_like_expr(ctx, true, f, input_schema, extensions).await + Self::build_like_expr(state, true, f, input_schema, extensions).await } "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { - Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) - .await + Self::build_unary_expr( + state, + &self.expr_name, + f, + input_schema, + extensions, + ) + .await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -2859,7 +2892,7 @@ impl BuiltinExprBuilder { } async fn build_unary_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, @@ -2872,7 +2905,7 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; let arg = - from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; + from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; let arg = Box::new(arg); let expr = match fn_name { @@ -2893,7 +2926,7 @@ impl BuiltinExprBuilder { } async fn build_like_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, @@ -2908,12 +2941,13 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let expr = - from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; + from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = - from_substrait_rex(ctx, pattern_substrait, input_schema, extensions).await?; + from_substrait_rex(state, pattern_substrait, input_schema, extensions) + .await?; // Default case: escape character is Literal(Utf8(None)) let escape_char = if f.arguments.len() == 3 { @@ -2922,9 +2956,13 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let escape_char_expr = - from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) - .await?; + let escape_char_expr = from_substrait_rex( + state, + escape_char_substrait, + input_schema, + extensions, + ) + .await?; match escape_char_expr { Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { diff --git a/datafusion/substrait/src/logical_plan/mod.rs b/datafusion/substrait/src/logical_plan/mod.rs index 6f8b8e493f52..9e2fa9fa49de 100644 --- a/datafusion/substrait/src/logical_plan/mod.rs +++ b/datafusion/substrait/src/logical_plan/mod.rs @@ -17,3 +17,4 @@ pub mod consumer; pub mod producer; +pub mod state; diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 4d864e4334ce..29019dfd74f3 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -29,7 +29,7 @@ use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, error::{DataFusionError, Result}, logical_expr::{WindowFrame, WindowFrameBound}, - prelude::{JoinType, SessionContext}, + prelude::JoinType, scalar::ScalarValue, }; @@ -100,8 +100,13 @@ use substrait::{ version, }; +use super::state::SubstraitPlanningState; + /// Convert DataFusion LogicalPlan to Substrait Plan -pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result> { +pub fn to_substrait_plan( + plan: &LogicalPlan, + state: &dyn SubstraitPlanningState, +) -> Result> { let mut extensions = Extensions::default(); // Parse relation nodes // Generate PlanRel(s) @@ -113,7 +118,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result Result Result> { let mut extensions = Extensions::default(); @@ -152,7 +157,7 @@ pub fn to_substrait_extended_expr( .iter() .map(|(expr, field)| { let substrait_expr = to_substrait_rex( - ctx, + state, expr, schema, /*col_ref_offset=*/ 0, @@ -183,7 +188,7 @@ pub fn to_substrait_extended_expr( #[allow(deprecated)] pub fn to_substrait_rel( plan: &LogicalPlan, - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, extensions: &mut Extensions, ) -> Result> { match plan { @@ -284,7 +289,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions)) + .map(|e| to_substrait_rex(state, e, p.input.schema(), 0, extensions)) .collect::>>()?; let emit_kind = create_project_remapping( @@ -300,16 +305,16 @@ pub fn to_substrait_rel( Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { common: Some(common), - input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?), + input: Some(to_substrait_rel(p.input.as_ref(), state, extensions)?), expressions, advanced_extension: None, }))), })) } LogicalPlan::Filter(filter) => { - let input = to_substrait_rel(filter.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(filter.input.as_ref(), state, extensions)?; let filter_expr = to_substrait_rex( - ctx, + state, &filter.predicate, filter.input.schema(), 0, @@ -325,7 +330,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Limit(limit) => { - let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(limit.input.as_ref(), state, extensions)?; let FetchType::Literal(fetch) = limit.get_fetch_type()? else { return not_impl_err!("Non-literal limit fetch"); }; @@ -344,11 +349,11 @@ pub fn to_substrait_rel( })) } LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(sort.input.as_ref(), state, extensions)?; let sort_fields = sort .expr .iter() - .map(|e| substrait_sort_field(ctx, e, sort.input.schema(), extensions)) + .map(|e| substrait_sort_field(state, e, sort.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -360,9 +365,9 @@ pub fn to_substrait_rel( })) } LogicalPlan::Aggregate(agg) => { - let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(agg.input.as_ref(), state, extensions)?; let (grouping_expressions, groupings) = to_substrait_groupings( - ctx, + state, &agg.group_expr, agg.input.schema(), extensions, @@ -370,7 +375,9 @@ pub fn to_substrait_rel( let measures = agg .aggr_expr .iter() - .map(|e| to_substrait_agg_measure(ctx, e, agg.input.schema(), extensions)) + .map(|e| { + to_substrait_agg_measure(state, e, agg.input.schema(), extensions) + }) .collect::>>()?; Ok(Box::new(Rel { @@ -386,7 +393,7 @@ pub fn to_substrait_rel( } LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(plan.as_ref(), state, extensions)?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) @@ -407,8 +414,8 @@ pub fn to_substrait_rel( })) } LogicalPlan::Join(join) => { - let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?; - let right = to_substrait_rel(join.right.as_ref(), ctx, extensions)?; + let left = to_substrait_rel(join.left.as_ref(), state, extensions)?; + let right = to_substrait_rel(join.right.as_ref(), state, extensions)?; let join_type = to_substrait_jointype(join.join_type); // we only support basic joins so return an error for anything not yet supported match join.join_constraint { @@ -421,7 +428,7 @@ pub fn to_substrait_rel( let in_join_schema = join.left.schema().join(join.right.schema())?; let join_filter = match &join.filter { Some(filter) => Some(to_substrait_rex( - ctx, + state, filter, &Arc::new(in_join_schema), 0, @@ -438,7 +445,7 @@ pub fn to_substrait_rel( Operator::Eq }; let join_on = to_substrait_join_expr( - ctx, + state, &join.on, eq_op, join.left.schema(), @@ -479,13 +486,13 @@ pub fn to_substrait_rel( LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait - to_substrait_rel(alias.input.as_ref(), ctx, extensions) + to_substrait_rel(alias.input.as_ref(), state, extensions) } LogicalPlan::Union(union) => { let input_rels = union .inputs .iter() - .map(|input| to_substrait_rel(input.as_ref(), ctx, extensions)) + .map(|input| to_substrait_rel(input.as_ref(), state, extensions)) .collect::>>()? .into_iter() .map(|ptr| *ptr) @@ -500,7 +507,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Window(window) => { - let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(window.input.as_ref(), state, extensions)?; // create a field reference for each input field let mut expressions = (0..window.input.schema().fields().len()) @@ -510,7 +517,7 @@ pub fn to_substrait_rel( // process and add each window function expression for expr in &window.window_expr { expressions.push(to_substrait_rex( - ctx, + state, expr, window.input.schema(), 0, @@ -539,7 +546,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Repartition(repartition) => { - let input = to_substrait_rel(repartition.input.as_ref(), ctx, extensions)?; + let input = to_substrait_rel(repartition.input.as_ref(), state, extensions)?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, Partitioning::Hash(_, num) => num, @@ -585,8 +592,7 @@ pub fn to_substrait_rel( })) } LogicalPlan::Extension(extension_plan) => { - let extension_bytes = ctx - .state() + let extension_bytes = state .serializer_registry() .serialize_logical_plan(extension_plan.node.as_ref())?; let detail = ProtoAny { @@ -597,7 +603,7 @@ pub fn to_substrait_rel( .node .inputs() .into_iter() - .map(|plan| to_substrait_rel(plan, ctx, extensions)) + .map(|plan| to_substrait_rel(plan, state, extensions)) .collect::>>()?; let rel_type = match inputs_rel.len() { 0 => RelType::ExtensionLeaf(ExtensionLeafRel { @@ -687,7 +693,7 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { } fn to_substrait_join_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, join_conditions: &Vec<(Expr, Expr)>, eq_op: Operator, left_schema: &DFSchemaRef, @@ -698,10 +704,10 @@ fn to_substrait_join_expr( let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(ctx, left, left_schema, 0, extensions)?; + let l = to_substrait_rex(state, left, left_schema, 0, extensions)?; // Parse right let r = to_substrait_rex( - ctx, + state, right, right_schema, left_schema.fields().len(), // offset to return the correct index @@ -770,7 +776,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { #[allow(deprecated)] pub fn parse_flat_grouping_exprs( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, exprs: &[Expr], schema: &DFSchemaRef, extensions: &mut Extensions, @@ -780,7 +786,7 @@ pub fn parse_flat_grouping_exprs( let mut grouping_expressions = vec![]; for e in exprs { - let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?; + let rex = to_substrait_rex(state, e, schema, 0, extensions)?; grouping_expressions.push(rex.clone()); ref_group_exprs.push(rex); expression_references.push((ref_group_exprs.len() - 1) as u32); @@ -792,7 +798,7 @@ pub fn parse_flat_grouping_exprs( } pub fn to_substrait_groupings( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, exprs: &[Expr], schema: &DFSchemaRef, extensions: &mut Extensions, @@ -808,7 +814,7 @@ pub fn to_substrait_groupings( .iter() .map(|set| { parse_flat_grouping_exprs( - ctx, + state, set, schema, extensions, @@ -826,7 +832,7 @@ pub fn to_substrait_groupings( .rev() .map(|set| { parse_flat_grouping_exprs( - ctx, + state, set, schema, extensions, @@ -837,7 +843,7 @@ pub fn to_substrait_groupings( } }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, + state, exprs, schema, extensions, @@ -845,7 +851,7 @@ pub fn to_substrait_groupings( )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( - ctx, + state, exprs, schema, extensions, @@ -857,7 +863,7 @@ pub fn to_substrait_groupings( #[allow(deprecated)] pub fn to_substrait_agg_measure( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, expr: &Expr, schema: &DFSchemaRef, extensions: &mut Extensions, @@ -865,13 +871,13 @@ pub fn to_substrait_agg_measure( match expr { Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(state, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(state, arg, schema, 0, extensions)?)) }); } let function_anchor = extensions.register_function(func.name().to_string()); Ok(Measure { @@ -889,14 +895,14 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), + Some(f) => Some(to_substrait_rex(state, f, schema, 0, extensions)?), None => None } }) } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(ctx, expr, schema, extensions) + to_substrait_agg_measure(state, expr, schema, extensions) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -908,7 +914,7 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, @@ -920,7 +926,7 @@ fn to_substrait_sort_field( (false, false) => SortDirection::DescNullsLast, }; Ok(SortField { - expr: Some(to_substrait_rex(ctx, &sort.expr, schema, 0, extensions)?), + expr: Some(to_substrait_rex(state, &sort.expr, schema, 0, extensions)?), sort_kind: Some(SortKind::Direction(sort_kind.into())), }) } @@ -977,7 +983,7 @@ pub fn make_binary_op_scalar_func( /// * `extensions` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, @@ -991,10 +997,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extensions)) + .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, extensions)) .collect::>>()?; let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -1026,7 +1032,7 @@ pub fn to_substrait_rex( for arg in &fun.args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( - ctx, + state, arg, schema, col_ref_offset, @@ -1055,11 +1061,11 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -1083,11 +1089,11 @@ pub fn to_substrait_rex( } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -1115,8 +1121,8 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extensions)?; - let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extensions)?; + let l = to_substrait_rex(state, left, schema, col_ref_offset, extensions)?; + let r = to_substrait_rex(state, right, schema, col_ref_offset, extensions)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) } @@ -1131,7 +1137,7 @@ pub fn to_substrait_rex( // Base expression exists ifs.push(IfClause { r#if: Some(to_substrait_rex( - ctx, + state, e, schema, col_ref_offset, @@ -1144,14 +1150,14 @@ pub fn to_substrait_rex( for (r#if, then) in when_then_expr { ifs.push(IfClause { r#if: Some(to_substrait_rex( - ctx, + state, r#if, schema, col_ref_offset, extensions, )?), then: Some(to_substrait_rex( - ctx, + state, then, schema, col_ref_offset, @@ -1163,7 +1169,7 @@ pub fn to_substrait_rex( // Parse outer `else` let r#else: Option> = match else_expr { Some(e) => Some(Box::new(to_substrait_rex( - ctx, + state, e, schema, col_ref_offset, @@ -1182,7 +1188,7 @@ pub fn to_substrait_rex( substrait::proto::expression::Cast { r#type: Some(to_substrait_type(data_type, true)?), input: Some(Box::new(to_substrait_rex( - ctx, + state, expr, schema, col_ref_offset, @@ -1195,7 +1201,7 @@ pub fn to_substrait_rex( } Expr::Literal(value) => to_substrait_literal_expr(value, extensions), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions) + to_substrait_rex(state, expr, schema, col_ref_offset, extensions) } Expr::WindowFunction(WindowFunction { fun, @@ -1212,7 +1218,7 @@ pub fn to_substrait_rex( for arg in args { arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex( - ctx, + state, arg, schema, col_ref_offset, @@ -1223,12 +1229,12 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extensions)) + .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, extensions)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(ctx, e, schema, extensions)) + .map(|e| substrait_sort_field(state, e, schema, extensions)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1249,7 +1255,7 @@ pub fn to_substrait_rex( escape_char, case_insensitive, }) => make_substrait_like_expr( - ctx, + state, *case_insensitive, *negated, expr, @@ -1265,10 +1271,10 @@ pub fn to_substrait_rex( negated, }) => { let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), ctx, extensions)?; + to_substrait_rel(subquery.subquery.as_ref(), state, extensions)?; let substrait_subquery = Expression { rex_type: Some(RexType::Subquery(Box::new(Subquery { @@ -1301,7 +1307,7 @@ pub fn to_substrait_rex( } } Expr::Not(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "not", arg, schema, @@ -1309,7 +1315,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_null", arg, schema, @@ -1317,7 +1323,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_null", arg, schema, @@ -1325,7 +1331,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_true", arg, schema, @@ -1333,7 +1339,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_false", arg, schema, @@ -1341,7 +1347,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_unknown", arg, schema, @@ -1349,7 +1355,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_true", arg, schema, @@ -1357,7 +1363,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_false", arg, schema, @@ -1365,7 +1371,7 @@ pub fn to_substrait_rex( extensions, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "is_not_unknown", arg, schema, @@ -1373,7 +1379,7 @@ pub fn to_substrait_rex( extensions, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( - ctx, + state, "negate", arg, schema, @@ -1674,7 +1680,7 @@ fn make_substrait_window_function( #[allow(deprecated)] #[allow(clippy::too_many_arguments)] fn make_substrait_like_expr( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, ignore_case: bool, negated: bool, expr: &Expr, @@ -1689,8 +1695,8 @@ fn make_substrait_like_expr( } else { extensions.register_function("like".to_string()) }; - let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; - let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extensions)?; + let expr = to_substrait_rex(state, expr, schema, col_ref_offset, extensions)?; + let pattern = to_substrait_rex(state, pattern, schema, col_ref_offset, extensions)?; let escape_char = to_substrait_literal_expr( &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), extensions, @@ -2088,7 +2094,7 @@ fn to_substrait_literal_expr( /// Util to generate substrait [RexType::ScalarFunction] with one argument fn to_substrait_unary_scalar_fn( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, fn_name: &str, arg: &Expr, schema: &DFSchemaRef, @@ -2096,7 +2102,8 @@ fn to_substrait_unary_scalar_fn( extensions: &mut Extensions, ) -> Result { let function_anchor = extensions.register_function(fn_name.to_string()); - let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extensions)?; + let substrait_expr = + to_substrait_rex(state, arg, schema, col_ref_offset, extensions)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -2137,7 +2144,7 @@ fn try_to_substrait_field_reference( } fn substrait_sort_field( - ctx: &SessionContext, + state: &dyn SubstraitPlanningState, sort: &Sort, schema: &DFSchemaRef, extensions: &mut Extensions, @@ -2147,7 +2154,7 @@ fn substrait_sort_field( asc, nulls_first, } = sort; - let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; + let e = to_substrait_rex(state, expr, schema, 0, extensions)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, @@ -2190,6 +2197,7 @@ mod test { use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::common::DFSchema; + use datafusion::execution::SessionStateBuilder; #[test] fn round_trip_literals() -> Result<()> { @@ -2433,15 +2441,15 @@ mod test { #[tokio::test] async fn extended_expressions() -> Result<()> { - let ctx = SessionContext::new(); + let state = SessionStateBuilder::default().build(); // One expression, empty input schema let expr = Expr::Literal(ScalarValue::Int32(Some(42))); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); let substrait = - to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx)?; - let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state)?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; assert_eq!(roundtrip_expr.input_schema, empty_schema); assert_eq!(roundtrip_expr.exprs.len(), 1); @@ -2463,9 +2471,9 @@ mod test { let substrait = to_substrait_extended_expr( &[(&expr1, &out1), (&expr2, &out2)], &input_schema, - &ctx, + &state, )?; - let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + let roundtrip_expr = from_substrait_extended_expr(&state, &substrait).await?; assert_eq!(roundtrip_expr.input_schema, input_schema); assert_eq!(roundtrip_expr.exprs.len(), 2); @@ -2485,14 +2493,14 @@ mod test { #[tokio::test] async fn invalid_extended_expression() { - let ctx = SessionContext::new(); + let state = SessionStateBuilder::default().build(); // Not ok if input schema is missing field referenced by expr let expr = Expr::Column("missing".into()); let field = Field::new("out", DataType::Int32, false); let empty_schema = DFSchemaRef::new(DFSchema::empty()); - let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx); + let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &state); assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); } diff --git a/datafusion/substrait/src/logical_plan/state.rs b/datafusion/substrait/src/logical_plan/state.rs new file mode 100644 index 000000000000..0bd749c1105d --- /dev/null +++ b/datafusion/substrait/src/logical_plan/state.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::{ + catalog::TableProvider, + error::{DataFusionError, Result}, + execution::{registry::SerializerRegistry, FunctionRegistry, SessionState}, + sql::TableReference, +}; + +/// This trait provides the context needed to transform a substrait plan into a +/// [`datafusion::logical_expr::LogicalPlan`] (via [`super::consumer::from_substrait_plan`]) +/// and back again into a substrait plan (via [`super::producer::to_substrait_plan`]). +/// +/// The context is declared as a trait to decouple the substrait plan encoder / +/// decoder from the [`SessionState`], potentially allowing users to define +/// their own slimmer context just for serializing and deserializing substrait. +/// +/// [`SessionState`] implements this trait. +#[async_trait] +pub trait SubstraitPlanningState: Sync + Send + FunctionRegistry { + /// Return [SerializerRegistry] for extensions + fn serializer_registry(&self) -> &Arc; + + async fn table( + &self, + reference: &TableReference, + ) -> Result>>; +} + +#[async_trait] +impl SubstraitPlanningState for SessionState { + fn serializer_registry(&self) -> &Arc { + self.serializer_registry() + } + + async fn table( + &self, + reference: &TableReference, + ) -> Result>, DataFusionError> { + let table = reference.table().to_string(); + let schema = self.schema_for_ref(reference.clone())?; + let table_provider = schema.table(&table).await?; + Ok(table_provider) + } +} diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index 6b81e33dfc37..4278671777fd 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -38,7 +38,7 @@ pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<() pub async fn serialize_bytes(sql: &str, ctx: &SessionContext) -> Result> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = producer::to_substrait_plan(&plan, ctx)?; + let proto = producer::to_substrait_plan(&plan, &ctx.state())?; let mut protobuf_out = Vec::::new(); proto.encode(&mut protobuf_out).map_err(|e| { diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs index bc38ef82977f..219f656bb471 100644 --- a/datafusion/substrait/tests/cases/consumer_integration.rs +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -41,7 +41,7 @@ mod tests { .expect("failed to parse json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; - let plan = from_substrait_plan(&ctx, &proto).await?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; Ok(format!("{}", plan)) } diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs index ac66177ed796..08537d0d110f 100644 --- a/datafusion/substrait/tests/cases/emit_kind_tests.rs +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -33,7 +33,7 @@ mod tests { "tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json", ); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); @@ -51,7 +51,7 @@ mod tests { "tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json", ); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); @@ -91,8 +91,8 @@ mod tests { \n TableScan: data" ); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; // note how the Projections are not flattened assert_eq!( format!("{}", plan2), @@ -115,8 +115,8 @@ mod tests { \n TableScan: data" ); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan1str = format!("{plan}"); let plan2str = format!("{plan2}"); diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs index b136b0af19c2..043808456176 100644 --- a/datafusion/substrait/tests/cases/function_test.rs +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -29,7 +29,7 @@ mod tests { async fn contains_function_test() -> Result<()> { let proto_plan = read_json("tests/testdata/contains_plan.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; let plan_str = format!("{}", plan); diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index f4e34af35d78..65f404bbda55 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -38,7 +38,7 @@ mod tests { let proto_plan = read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -63,7 +63,7 @@ mod tests { let proto_plan = read_json("tests/testdata/test_plans/select_window.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -82,7 +82,7 @@ mod tests { let proto_plan = read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json"); let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))"); diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index d4e2d48885ae..d03ab5182028 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -979,8 +979,8 @@ async fn extension_logical_plan() -> Result<()> { }), }); - let proto = to_substrait_plan(&ext_plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&ext_plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan1str = format!("{ext_plan}"); let plan2str = format!("{plan2}"); @@ -1081,8 +1081,8 @@ async fn roundtrip_repartition_roundrobin() -> Result<()> { partitioning_scheme: Partitioning::RoundRobinBatch(8), }); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; assert_eq!(format!("{plan}"), format!("{plan2}")); @@ -1098,8 +1098,8 @@ async fn roundtrip_repartition_hash() -> Result<()> { partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8), }); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; assert_eq!(format!("{plan}"), format!("{plan2}")); @@ -1199,8 +1199,8 @@ async fn assert_expected_plan_unoptimized( let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_unoptimized_plan(); - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; println!("{plan}"); println!("{plan2}"); @@ -1225,8 +1225,8 @@ async fn assert_expected_plan( let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; println!("{plan}"); @@ -1250,7 +1250,7 @@ async fn assert_expected_plan_substrait( ) -> Result<()> { let ctx = create_context().await?; - let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?; let plan = ctx.state().optimize(&plan)?; @@ -1265,7 +1265,7 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { let expected = ctx.sql(sql).await?.into_optimized_plan()?; - let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?; let plan = ctx.state().optimize(&plan)?; @@ -1280,8 +1280,8 @@ async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; // Format plan string and replace all None's with 0 @@ -1301,12 +1301,12 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { let ctx = create_context().await?; let df_a = ctx.sql(sql_with_alias).await?; - let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx)?; - let plan_with_alias = from_substrait_plan(&ctx, &proto_a).await?; + let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx.state())?; + let plan_with_alias = from_substrait_plan(&ctx.state(), &proto_a).await?; let df = ctx.sql(sql_no_alias).await?; - let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?; - let plan = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx.state())?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; println!("{plan_with_alias}"); println!("{plan}"); @@ -1323,8 +1323,8 @@ async fn roundtrip_logical_plan_with_ctx( plan: LogicalPlan, ctx: SessionContext, ) -> Result> { - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; + let proto = to_substrait_plan(&plan, &ctx.state())?; + let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; println!("{plan}"); diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index 54d55d1b6f10..e28c63312788 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -45,7 +45,7 @@ mod tests { // Read substrait plan from file let proto = serializer::deserialize(path).await?; // Check plan equality - let plan = from_substrait_plan(&ctx, &proto).await?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; let plan_str_ref = format!("{plan_ref}"); let plan_str = format!("{plan}"); assert_eq!(plan_str_ref, plan_str); @@ -60,7 +60,7 @@ mod tests { let ctx = create_context().await?; let table = provider_as_source(ctx.table_provider("data").await?); let table_scan = LogicalPlanBuilder::scan("data", table, None)?.build()?; - let convert_result = to_substrait_plan(&table_scan, &ctx); + let convert_result = to_substrait_plan(&table_scan, &ctx.state()); assert!(convert_result.is_ok()); Ok(()) @@ -78,7 +78,9 @@ mod tests { \n TableScan: data projection=[a, b]", ); - let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + let plan = to_substrait_plan(&datafusion_plan, &ctx.state())? + .as_ref() + .clone(); let relation = plan.relations.first().unwrap().rel_type.as_ref(); let root_rel = match relation { @@ -121,7 +123,9 @@ mod tests { \n TableScan: data projection=[a, b, c]", ); - let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + let plan = to_substrait_plan(&datafusion_plan, &ctx.state())? + .as_ref() + .clone(); let relation = plan.relations.first().unwrap().rel_type.as_ref(); let root_rel = match relation { diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs index 5ae586afe56f..c77bf1489f4e 100644 --- a/datafusion/substrait/tests/cases/substrait_validations.rs +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -65,7 +65,7 @@ mod tests { vec![("a", DataType::Int32, false), ("b", DataType::Int32, true)]; let ctx = generate_context_with_table("DATA", df_schema)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -86,7 +86,7 @@ mod tests { ("c", DataType::Int32, false), ]; let ctx = generate_context_with_table("DATA", df_schema)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -109,7 +109,7 @@ mod tests { ("b", DataType::Int32, false), ]; let ctx = generate_context_with_table("DATA", df_schema)?; - let plan = from_substrait_plan(&ctx, &proto_plan).await?; + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; assert_eq!( format!("{}", plan), @@ -128,7 +128,7 @@ mod tests { vec![("a", DataType::Int32, false), ("c", DataType::Int32, true)]; let ctx = generate_context_with_table("DATA", df_schema)?; - let res = from_substrait_plan(&ctx, &proto_plan).await; + let res = from_substrait_plan(&ctx.state(), &proto_plan).await; assert!(res.is_err()); Ok(()) } @@ -140,7 +140,7 @@ mod tests { let ctx = generate_context_with_table("DATA", vec![("a", DataType::Date32, true)])?; - let res = from_substrait_plan(&ctx, &proto_plan).await; + let res = from_substrait_plan(&ctx.state(), &proto_plan).await; assert!(res.is_err()); Ok(()) } diff --git a/docs/source/user-guide/concepts-readings-events.md b/docs/source/user-guide/concepts-readings-events.md index 092f8433d47b..135fbc47ad90 100644 --- a/docs/source/user-guide/concepts-readings-events.md +++ b/docs/source/user-guide/concepts-readings-events.md @@ -131,10 +131,11 @@ This is a list of DataFusion related blog posts, articles, and other resources. # 🌎 Community Events +- **2025-01-25** (Upcoming) [Amsterdam Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/12988) - **2025-01-15** (Upcoming) [Boston Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/13165) - **2024-12-18** (Upcoming) [Chicago Apache DataFusion Meetup](https://lu.ma/eq5myc5i) +- **2024-10-14** [Seattle Apache DataFusion Meetup](https://lu.ma/tnwl866b) - **2024-09-27** [Belgrade Apache DataFusion Meetup](https://lu.ma/tmwuz4lg), [recap](https://github.com/apache/datafusion/discussions/11431#discussioncomment-10832070), [slides](https://github.com/apache/datafusion/discussions/11431#discussioncomment-10826169), [recordings](https://www.youtube.com/watch?v=4huEsFFv6bQ&list=PLrhIfEjaw9ilQEczOQlHyMznabtVRptyX) - **2024-06-26** [New York City Apache DataFusion Meetup](https://lu.ma/2iwba0xm). [slides](https://docs.google.com/presentation/d/1dOLPAFPEMLhLv4NN6O9QSDIyyeiIySqAjky5cVgdWAE/edit#slide=id.g26bebde4fcc_3_7) - **2024-06-25** [San Francisco Bay Area Apache DataFusion Meetup](https://lu.ma/6bphole2). [slides](https://docs.google.com/presentation/d/1Oz2yGllrWBkNGyiRMLr8qXTt4vmvtJWuI_weGThaZak/edit#slide=id.g26bebde4fcc_3_7) - **2024-03-27** [Austin Apache DataFusion Meetup](https://github.com/apache/datafusion/discussions/8522). [slides](https://docs.google.com/presentation/d/1S51TK8waxHEJaxi_-uiSMrgQZ09m_hfaasPk5X5ExEY), [recording](https://www.youtube.com/watch?v=q1N3pH3tFw8) -- **2024-03-26** [Seattle Apache DataFusion Meetup](