From 079f6219fed8eb8812b028257c031dfdfba96cb4 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 20 Dec 2024 07:40:48 -0600 Subject: [PATCH 1/6] replace CASE expressions in predicate pruning with boolean algebra (#13795) * replace CASE expressions in predicate pruning with boolean algebra * fix merge * update tests * add some more tests * add some more tests * remove duplicate test case * Update datafusion/physical-optimizer/src/pruning.rs * swap NOT for != * replace comments, update docstrings * fix example * update tests * update tests * Apply suggestions from code review Co-authored-by: Andrew Lamb * Update pruning.rs Co-authored-by: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> * Update pruning.rs Co-authored-by: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> --------- Co-authored-by: Andrew Lamb Co-authored-by: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> --- .../datasource/physical_plan/parquet/mod.rs | 2 +- datafusion/physical-optimizer/src/pruning.rs | 325 ++++++++++-------- .../test_files/parquet_filter_pushdown.slt | 8 +- .../test_files/repartition_scan.slt | 8 +- datafusion/sqllogictest/test_files/window.slt | 2 +- 5 files changed, 184 insertions(+), 161 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index cb79055ce301..7573e32f8652 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -2001,7 +2001,7 @@ mod tests { assert_contains!( &display, - "pruning_predicate=CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != bar OR bar != c1_max@1 END" + "pruning_predicate=c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != bar OR bar != c1_max@1)" ); assert_contains!(&display, r#"predicate=c1@0 != bar"#); diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs index 3cfb03b7205a..77fc76c35352 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -287,7 +287,12 @@ pub trait PruningStatistics { /// predicate can never possibly be true). The container can be pruned (skipped) /// entirely. /// -/// Note that in order to be correct, `PruningPredicate` must return false +/// While `PruningPredicate` will never return a `NULL` value, the +/// rewritten predicate (as returned by `build_predicate_expression` and used internally +/// by `PruningPredicate`) may evaluate to `NULL` when some of the min/max values +/// or null / row counts are not known. +/// +/// In order to be correct, `PruningPredicate` must return false /// **only** if it can determine that for all rows in the container, the /// predicate could never evaluate to `true` (always evaluates to either `NULL` /// or `false`). @@ -327,12 +332,12 @@ pub trait PruningStatistics { /// /// Original Predicate | Rewritten Predicate /// ------------------ | -------------------- -/// `x = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END` -/// `x < 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_max < 5 END` -/// `x = 5 AND y = 10` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END AND CASE WHEN y_null_count = y_row_count THEN false ELSE y_min <= 10 AND 10 <= y_max END` +/// `x = 5` | `x_null_count != x_row_count AND (x_min <= 5 AND 5 <= x_max)` +/// `x < 5` | `x_null_count != x_row_count THEN false (x_max < 5)` +/// `x = 5 AND y = 10` | `x_null_count != x_row_count AND (x_min <= 5 AND 5 <= x_max) AND y_null_count != y_row_count (y_min <= 10 AND 10 <= y_max)` /// `x IS NULL` | `x_null_count > 0` /// `x IS NOT NULL` | `x_null_count != row_count` -/// `CAST(x as int) = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int) END` +/// `CAST(x as int) = 5` | `x_null_count != x_row_count (CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int))` /// /// ## Predicate Evaluation /// The PruningPredicate works in two passes @@ -352,15 +357,9 @@ pub trait PruningStatistics { /// Given the predicate, `x = 5 AND y = 10`, the rewritten predicate would look like: /// /// ```sql -/// CASE -/// WHEN x_null_count = x_row_count THEN false -/// ELSE x_min <= 5 AND 5 <= x_max -/// END +/// x_null_count != x_row_count AND (x_min <= 5 AND 5 <= x_max) /// AND -/// CASE -/// WHEN y_null_count = y_row_count THEN false -/// ELSE y_min <= 10 AND 10 <= y_max -/// END +/// y_null_count != y_row_count AND (y_min <= 10 AND 10 <= y_max) /// ``` /// /// If we know that for a given container, `x` is between `1 and 100` and we know that @@ -381,16 +380,22 @@ pub trait PruningStatistics { /// When these statistics values are substituted in to the rewritten predicate and /// simplified, the result is `false`: /// -/// * `CASE WHEN null = null THEN false ELSE 1 <= 5 AND 5 <= 100 END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` -/// * `null = null` is `null` which is not true, so the `CASE` expression will use the `ELSE` clause -/// * `1 <= 5 AND 5 <= 100 AND 4 <= 10 AND 10 <= 7` -/// * `true AND true AND true AND false` +/// * `null != null AND (1 <= 5 AND 5 <= 100) AND null != null AND (4 <= 10 AND 10 <= 7)` +/// * `null = null` is `null` which is not true, so the AND moves on to the next clause +/// * `null and (1 <= 5 AND 5 <= 100) AND null AND (4 <= 10 AND 10 <= 7)` +/// * evaluating the clauses further we get: +/// * `null and true and null and false` +/// * `null and false` /// * `false` /// /// Returning `false` means the container can be pruned, which matches the /// intuition that `x = 5 AND y = 10` can’t be true for any row if all values of `y` /// are `7` or less. /// +/// Note that if we had ended up with `null AND true AND null AND true` the result +/// would have been `null`. +/// `null` is treated the same as`true`, because we can't prove that the predicate is `false.` +/// /// If, for some other container, we knew `y` was between the values `4` and /// `15`, then the rewritten predicate evaluates to `true` (verifying this is /// left as an exercise to the reader -- are you still here?), and the container @@ -405,15 +410,9 @@ pub trait PruningStatistics { /// look like the same as example 1: /// /// ```sql -/// CASE -/// WHEN x_null_count = x_row_count THEN false -/// ELSE x_min <= 5 AND 5 <= x_max -/// END +/// x_null_count != x_row_count AND (x_min <= 5 AND 5 <= x_max) /// AND -/// CASE -/// WHEN y_null_count = y_row_count THEN false -/// ELSE y_min <= 10 AND 10 <= y_max -/// END +/// y_null_count != y_row_count AND (y_min <= 10 AND 10 <= y_max) /// ``` /// /// If we know that for another given container, `x_min` is NULL and `x_max` is @@ -435,14 +434,13 @@ pub trait PruningStatistics { /// When these statistics values are substituted in to the rewritten predicate and /// simplified, the result is `false`: /// -/// * `CASE WHEN 100 = 100 THEN false ELSE null <= 5 AND 5 <= null END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` -/// * Since `100 = 100` is `true`, the `CASE` expression will use the `THEN` clause, i.e. `false` -/// * The other `CASE` expression will use the `ELSE` clause, i.e. `4 <= 10 AND 10 <= 7` -/// * `false AND true` +/// * `100 != 100 AND (null <= 5 AND 5 <= null) AND null = null AND (4 <= 10 AND 10 <= 7)` +/// * `false AND null AND null AND false` +/// * `false AND false` /// * `false` /// /// Returning `false` means the container can be pruned, which matches the -/// intuition that `x = 5 AND y = 10` can’t be true for all values in `x` +/// intuition that `x = 5 AND y = 10` can’t be true because all values in `x` /// are known to be NULL. /// /// # Related Work @@ -1603,13 +1601,15 @@ fn build_statistics_expr( ); } }; - let statistics_expr = wrap_case_expr(statistics_expr, expr_builder)?; + let statistics_expr = wrap_null_count_check_expr(statistics_expr, expr_builder)?; Ok(statistics_expr) } -/// Wrap the statistics expression in a case expression. -/// This is necessary to handle the case where the column is known -/// to be all nulls. +/// Wrap the statistics expression in a check that skips the expression if the column is all nulls. +/// This is important not only as an optimization but also because statistics may not be +/// accurate for columns that are all nulls. +/// For example, for an `int` column `x` with all nulls, the min/max/null_count statistics +/// might be set to 0 and evaluating `x = 0` would incorrectly include the column. /// /// For example: /// @@ -1618,33 +1618,29 @@ fn build_statistics_expr( /// will become /// /// ```sql -/// CASE -/// WHEN x_null_count = x_row_count THEN false -/// ELSE x_min <= 10 AND 10 <= x_max -/// END +/// x_null_count != x_row_count AND (x_min <= 10 AND 10 <= x_max) /// ```` /// /// If the column is known to be all nulls, then the expression /// `x_null_count = x_row_count` will be true, which will cause the -/// case expression to return false. Therefore, prune out the container. -fn wrap_case_expr( +/// boolean expression to return false. Therefore, prune out the container. +fn wrap_null_count_check_expr( statistics_expr: Arc, expr_builder: &mut PruningExpressionBuilder, ) -> Result> { - // x_null_count = x_row_count - let when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new( + // x_null_count != x_row_count + let not_when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new( expr_builder.null_count_column_expr()?, - Operator::Eq, + Operator::NotEq, expr_builder.row_count_column_expr()?, )); - let then = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false)))); - - // CASE WHEN x_null_count = x_row_count THEN false ELSE END - Ok(Arc::new(phys_expr::CaseExpr::try_new( - None, - vec![(when_null_count_eq_row_count, then)], - Some(statistics_expr), - )?)) + + // (x_null_count != x_row_count) AND () + Ok(Arc::new(phys_expr::BinaryExpr::new( + not_when_null_count_eq_row_count, + Operator::And, + statistics_expr, + ))) } #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -2052,6 +2048,110 @@ mod tests { } } + #[test] + fn prune_all_rows_null_counts() { + // if null_count = row_count then we should prune the container for i = 0 + // regardless of the statistics + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let statistics = TestStatistics::new().with( + "i", + ContainerStats::new_i32( + vec![Some(0)], // min + vec![Some(0)], // max + ) + .with_null_counts(vec![Some(1)]) + .with_row_counts(vec![Some(1)]), + ); + let expected_ret = &[false]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + + // this should be true even if the container stats are missing + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let container_stats = ContainerStats { + min: Some(Arc::new(Int32Array::from(vec![None]))), + max: Some(Arc::new(Int32Array::from(vec![None]))), + null_counts: Some(Arc::new(UInt64Array::from(vec![Some(1)]))), + row_counts: Some(Arc::new(UInt64Array::from(vec![Some(1)]))), + ..ContainerStats::default() + }; + let statistics = TestStatistics::new().with("i", container_stats); + let expected_ret = &[false]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + + // If the null counts themselves are missing we should be able to fall back to the stats + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let container_stats = ContainerStats { + min: Some(Arc::new(Int32Array::from(vec![Some(0)]))), + max: Some(Arc::new(Int32Array::from(vec![Some(0)]))), + null_counts: Some(Arc::new(UInt64Array::from(vec![None]))), + row_counts: Some(Arc::new(UInt64Array::from(vec![Some(1)]))), + ..ContainerStats::default() + }; + let statistics = TestStatistics::new().with("i", container_stats); + let expected_ret = &[true]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + let expected_ret = &[false]; + prune_with_expr(col("i").gt(lit(0)), &schema, &statistics, expected_ret); + + // Same for the row counts + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let container_stats = ContainerStats { + min: Some(Arc::new(Int32Array::from(vec![Some(0)]))), + max: Some(Arc::new(Int32Array::from(vec![Some(0)]))), + null_counts: Some(Arc::new(UInt64Array::from(vec![Some(1)]))), + row_counts: Some(Arc::new(UInt64Array::from(vec![None]))), + ..ContainerStats::default() + }; + let statistics = TestStatistics::new().with("i", container_stats); + let expected_ret = &[true]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + let expected_ret = &[false]; + prune_with_expr(col("i").gt(lit(0)), &schema, &statistics, expected_ret); + } + + #[test] + fn prune_missing_statistics() { + // If the min or max stats are missing we should not prune + // (unless we know all rows are null, see `prune_all_rows_null_counts`) + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let container_stats = ContainerStats { + min: Some(Arc::new(Int32Array::from(vec![None, Some(0)]))), + max: Some(Arc::new(Int32Array::from(vec![Some(0), None]))), + null_counts: Some(Arc::new(UInt64Array::from(vec![Some(0), Some(0)]))), + row_counts: Some(Arc::new(UInt64Array::from(vec![Some(1), Some(1)]))), + ..ContainerStats::default() + }; + let statistics = TestStatistics::new().with("i", container_stats); + let expected_ret = &[true, true]; + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + let expected_ret = &[false, true]; + prune_with_expr(col("i").gt(lit(0)), &schema, &statistics, expected_ret); + let expected_ret = &[true, false]; + prune_with_expr(col("i").lt(lit(0)), &schema, &statistics, expected_ret); + } + + #[test] + fn prune_null_stats() { + // if null_count = row_count then we should prune the container for i = 0 + // regardless of the statistics + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + + let statistics = TestStatistics::new().with( + "i", + ContainerStats::new_i32( + vec![Some(0)], // min + vec![Some(0)], // max + ) + .with_null_counts(vec![Some(1)]) + .with_row_counts(vec![Some(1)]), + ); + + let expected_ret = &[false]; + + // i = 0 + prune_with_expr(col("i").eq(lit(0)), &schema, &statistics, expected_ret); + } + #[test] fn test_build_statistics_record_batch() { // Request a record batch with of s1_min, s2_max, s3_max, s3_min @@ -2233,7 +2333,8 @@ mod tests { #[test] fn row_group_predicate_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 END"; + let expected_expr = + "c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1"; // test column on the left let expr = col("c1").eq(lit(1)); @@ -2253,7 +2354,8 @@ mod tests { #[test] fn row_group_predicate_not_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != 1 OR 1 != c1_max@1 END"; + let expected_expr = + "c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 1 OR 1 != c1_max@1)"; // test column on the left let expr = col("c1").not_eq(lit(1)); @@ -2273,8 +2375,7 @@ mod tests { #[test] fn row_group_predicate_gt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = - "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 > 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_max@0 > 1"; // test column on the left let expr = col("c1").gt(lit(1)); @@ -2294,7 +2395,7 @@ mod tests { #[test] fn row_group_predicate_gt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 >= 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_max@0 >= 1"; // test column on the left let expr = col("c1").gt_eq(lit(1)); @@ -2313,8 +2414,7 @@ mod tests { #[test] fn row_group_predicate_lt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = - "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 < 1"; // test column on the left let expr = col("c1").lt(lit(1)); @@ -2334,7 +2434,7 @@ mod tests { #[test] fn row_group_predicate_lt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 <= 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 <= 1"; // test column on the left let expr = col("c1").lt_eq(lit(1)); @@ -2359,8 +2459,7 @@ mod tests { ]); // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); - let expected_expr = - "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 < 1"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2426,7 +2525,7 @@ mod tests { #[test] fn row_group_predicate_lt_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); - let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < true END"; + let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 < true"; // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated @@ -2449,20 +2548,11 @@ mod tests { let expr = col("c1") .lt(lit(1)) .and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3)))); - let expected_expr = "\ - CASE \ - WHEN c1_null_count@1 = c1_row_count@2 THEN false \ - ELSE c1_min@0 < 1 \ - END \ - AND (\ - CASE \ - WHEN c2_null_count@5 = c2_row_count@6 THEN false \ - ELSE c2_min@3 <= 2 AND 2 <= c2_max@4 \ - END \ - OR CASE \ - WHEN c2_null_count@5 = c2_row_count@6 THEN false \ - ELSE c2_min@3 <= 3 AND 3 <= c2_max@4 \ - END\ + let expected_expr = "c1_null_count@1 != c1_row_count@2 \ + AND c1_min@0 < 1 AND (\ + c2_null_count@5 != c2_row_count@6 \ + AND c2_min@3 <= 2 AND 2 <= c2_max@4 OR \ + c2_null_count@5 != c2_row_count@6 AND c2_min@3 <= 3 AND 3 <= c2_max@4\ )"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut required_columns); @@ -2554,18 +2644,7 @@ mod tests { vec![lit(1), lit(2), lit(3)], false, )); - let expected_expr = "CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 3 AND 3 <= c1_max@1 \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 3 AND 3 <= c1_max@1"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2601,19 +2680,7 @@ mod tests { vec![lit(1), lit(2), lit(3)], true, )); - let expected_expr = "\ - CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 != 1 OR 1 != c1_max@1 \ - END \ - AND CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 != 2 OR 2 != c1_max@1 \ - END \ - AND CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 != 3 OR 3 != c1_max@1 \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 1 OR 1 != c1_max@1) AND c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 2 OR 2 != c1_max@1) AND c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 3 OR 3 != c1_max@1)"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2659,24 +2726,7 @@ mod tests { // test c1 in(1, 2) and c2 BETWEEN 4 AND 5 let expr3 = expr1.and(expr2); - let expected_expr = "\ - (\ - CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ - END\ - ) AND CASE \ - WHEN c2_null_count@5 = c2_row_count@6 THEN false \ - ELSE c2_max@4 >= 4 \ - END \ - AND CASE \ - WHEN c2_null_count@5 = c2_row_count@6 THEN false \ - ELSE c2_min@7 <= 5 \ - END"; + let expected_expr = "(c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 2 AND 2 <= c1_max@1) AND c2_null_count@5 != c2_row_count@6 AND c2_max@4 >= 4 AND c2_null_count@5 != c2_row_count@6 AND c2_min@7 <= 5"; let predicate_expr = test_build_predicate_expression(&expr3, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2703,10 +2753,7 @@ mod tests { #[test] fn row_group_predicate_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; // test cast(c1 as int64) = 1 // test column on the left @@ -2721,10 +2768,8 @@ mod tests { test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); - let expected_expr = "CASE \ - WHEN c1_null_count@1 = c1_row_count@2 THEN false \ - ELSE TRY_CAST(c1_max@0 AS Int64) > 1 \ - END"; + let expected_expr = + "c1_null_count@1 != c1_row_count@2 AND TRY_CAST(c1_max@0 AS Int64) > 1"; // test column on the left let expr = @@ -2756,18 +2801,7 @@ mod tests { ], false, )); - let expected_expr = "CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) \ - END \ - OR CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64) \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2781,18 +2815,7 @@ mod tests { ], true, )); - let expected_expr = "CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64) \ - END \ - AND CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64) \ - END \ - AND CASE \ - WHEN c1_null_count@2 = c1_row_count@3 THEN false \ - ELSE CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64) \ - END"; + let expected_expr = "c1_null_count@2 != c1_row_count@3 AND (CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) AND c1_null_count@2 != c1_row_count@3 AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) AND c1_null_count@2 != c1_row_count@3 AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); diff --git a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt index 24ffb963bbe2..806886b07170 100644 --- a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt +++ b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt @@ -85,7 +85,7 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] 02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], predicate=b@1 > 2, pruning_predicate=CASE WHEN b_null_count@1 = b_row_count@2 THEN false ELSE b_max@0 > 2 END, required_guarantees=[] +03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], predicate=b@1 > 2, pruning_predicate=b_null_count@1 != b_row_count@2 AND b_max@0 > 2, required_guarantees=[] # When filter pushdown *is* enabled, ParquetExec can filter exactly, @@ -113,7 +113,7 @@ physical_plan 03)----CoalesceBatchesExec: target_batch_size=8192 04)------FilterExec: b@1 > 2, projection=[a@0] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 -06)----------ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], predicate=b@1 > 2, pruning_predicate=CASE WHEN b_null_count@1 = b_row_count@2 THEN false ELSE b_max@0 > 2 END, required_guarantees=[] +06)----------ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], predicate=b@1 > 2, pruning_predicate=b_null_count@1 != b_row_count@2 AND b_max@0 > 2, required_guarantees=[] # also test querying on columns that are not in all the files query T @@ -131,7 +131,7 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] 02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], predicate=b@1 > 2 AND a@0 IS NOT NULL, pruning_predicate=CASE WHEN b_null_count@1 = b_row_count@2 THEN false ELSE b_max@0 > 2 END AND a_null_count@4 != a_row_count@3, required_guarantees=[] +03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], predicate=b@1 > 2 AND a@0 IS NOT NULL, pruning_predicate=b_null_count@1 != b_row_count@2 AND b_max@0 > 2 AND a_null_count@4 != a_row_count@3, required_guarantees=[] query I @@ -148,7 +148,7 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [b@0 ASC NULLS LAST] 02)--SortExec: expr=[b@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[b], predicate=a@0 = bar, pruning_predicate=CASE WHEN a_null_count@2 = a_row_count@3 THEN false ELSE a_min@0 <= bar AND bar <= a_max@1 END, required_guarantees=[a in (bar)] +03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[b], predicate=a@0 = bar, pruning_predicate=a_null_count@2 != a_row_count@3 AND a_min@0 <= bar AND bar <= a_max@1, required_guarantees=[a in (bar)] ## cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 858e42106221..a1db84b87850 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != column1_row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != column1_row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -102,7 +102,7 @@ physical_plan 02)--SortExec: expr=[column1@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------FilterExec: column1@0 != 42 -05)--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..174], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:174..342, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..180], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:180..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +05)--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..174], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:174..342, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..180], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:180..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != column1_row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan 01)SortPreservingMergeExec: [column1@0 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: column1@0 != 42 -04)------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..171], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..175], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:175..351], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:171..342]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +04)------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..171], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..175], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:175..351], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:171..342]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_null_count@2 != column1_row_count@3 AND (column1_min@0 != 42 OR 42 != column1_max@1), required_guarantees=[column1 not in (42)] # Cleanup statement ok diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 188e2ae0915f..56f088dfd10f 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -5054,7 +5054,7 @@ select b, row_number() over (order by a) from (select TRUE as a, 1 as b); 1 1 # test window functions on boolean columns -query T +statement count 0 create table t1 (id int, bool_col boolean) as values (1, true), (2, false), From d7aeb1a2571708582422366896a6e9f749816afa Mon Sep 17 00:00:00 2001 From: Arttu Date: Fri, 20 Dec 2024 15:03:41 +0100 Subject: [PATCH 2/6] enable DF's nested_expressions feature by in datafusion-substrait tests to make them pass (#13857) fixes #13854 Co-authored-by: Arttu Voutilainen --- datafusion/substrait/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 09c6e0351ed3..5e056d040e4c 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -46,6 +46,7 @@ substrait = { version = "0.50", features = ["serde"] } url = { workspace = true } [dev-dependencies] +datafusion = { workspace = true, features = ["nested_expressions"] } datafusion-functions-aggregate = { workspace = true } serde_json = "1.0" tokio = { workspace = true } From 31acf452d92c790234de385fc76b7347e8b9e7c6 Mon Sep 17 00:00:00 2001 From: Dmitrii Blaginin Date: Fri, 20 Dec 2024 17:28:08 +0300 Subject: [PATCH 3/6] Add configurable normalization for configuration options and preserve case for S3 paths (#13576) * Do not normalize values * Fix tests & update docs * Prettier * Lowercase config params * Unify transform and parse * Fix tests * Rename `default_transform` and relax boundaries * Make `compression` case-insensitive * Comment to new line * Deprecate and ignore `enable_options_value_normalization` * Update datafusion/common/src/config.rs * fix typo --------- Co-authored-by: Oleks V --- datafusion-cli/Cargo.lock | 1 + datafusion-cli/src/object_storage.rs | 9 ++- datafusion/common/Cargo.toml | 1 + datafusion/common/src/config.rs | 80 +++++++++++++------ datafusion/core/src/datasource/stream.rs | 2 +- datafusion/core/tests/config_from_env.rs | 17 +++- datafusion/sql/src/planner.rs | 35 +------- datafusion/sql/src/statement.rs | 3 +- datafusion/sql/tests/sql_integration.rs | 69 +--------------- .../test_files/create_external_table.slt | 14 ++++ .../test_files/information_schema.slt | 4 +- .../sqllogictest/test_files/set_variable.slt | 8 +- docs/source/user-guide/configs.md | 2 +- 13 files changed, 104 insertions(+), 141 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index d33cbf396470..2ffc64114ef7 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1331,6 +1331,7 @@ dependencies = [ "hashbrown 0.14.5", "indexmap", "libc", + "log", "object_store", "parquet", "paste", diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index de66b60fe449..045c924e5037 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -472,12 +472,13 @@ mod tests { #[tokio::test] async fn s3_object_store_builder() -> Result<()> { - let access_key_id = "fake_access_key_id"; - let secret_access_key = "fake_secret_access_key"; + // "fake" is uppercase to ensure the values are not lowercased when parsed + let access_key_id = "FAKE_access_key_id"; + let secret_access_key = "FAKE_secret_access_key"; let region = "fake_us-east-2"; let endpoint = "endpoint33"; - let session_token = "fake_session_token"; - let location = "s3://bucket/path/file.parquet"; + let session_token = "FAKE_session_token"; + let location = "s3://bucket/path/FAKE/file.parquet"; let table_url = ListingTableUrl::parse(location)?; let scheme = table_url.scheme(); diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 82909404e455..a81ec724dd66 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -57,6 +57,7 @@ half = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } libc = "0.2.140" +log = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } paste = "1.0.15" diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 4948833b1f5f..6e64700bd2e0 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::collections::{BTreeMap, HashMap}; +use std::error::Error; use std::fmt::{self, Display}; use std::str::FromStr; @@ -29,7 +30,9 @@ use crate::{DataFusionError, Result}; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used -/// in the [`ConfigOptions`] configuration tree +/// in the [`ConfigOptions`] configuration tree. +/// +/// `transform` is used to normalize values before parsing. /// /// For example, /// @@ -38,7 +41,7 @@ use crate::{DataFusionError, Result}; /// /// Amazing config /// pub struct MyConfig { /// /// Field 1 doc -/// field1: String, default = "".to_string() +/// field1: String, transform = str::to_lowercase, default = "".to_string() /// /// /// Field 2 doc /// field2: usize, default = 232 @@ -67,9 +70,12 @@ use crate::{DataFusionError, Result}; /// fn set(&mut self, key: &str, value: &str) -> Result<()> { /// let (key, rem) = key.split_once('.').unwrap_or((key, "")); /// match key { -/// "field1" => self.field1.set(rem, value), -/// "field2" => self.field2.set(rem, value), -/// "field3" => self.field3.set(rem, value), +/// "field1" => { +/// let value = str::to_lowercase(value); +/// self.field1.set(rem, value.as_ref()) +/// }, +/// "field2" => self.field2.set(rem, value.as_ref()), +/// "field3" => self.field3.set(rem, value.as_ref()), /// _ => _internal_err!( /// "Config value \"{}\" not found on MyConfig", /// key @@ -102,7 +108,6 @@ use crate::{DataFusionError, Result}; /// ``` /// /// NB: Misplaced commas may result in nonsensical errors -/// #[macro_export] macro_rules! config_namespace { ( @@ -110,7 +115,7 @@ macro_rules! config_namespace { $vis:vis struct $struct_name:ident { $( $(#[doc = $d:tt])* - $field_vis:vis $field_name:ident : $field_type:ty, default = $default:expr + $field_vis:vis $field_name:ident : $field_type:ty, $(warn = $warn: expr,)? $(transform = $transform:expr,)? default = $default:expr )*$(,)* } ) => { @@ -127,9 +132,14 @@ macro_rules! config_namespace { impl ConfigField for $struct_name { fn set(&mut self, key: &str, value: &str) -> Result<()> { let (key, rem) = key.split_once('.').unwrap_or((key, "")); + match key { $( - stringify!($field_name) => self.$field_name.set(rem, value), + stringify!($field_name) => { + $(let value = $transform(value);)? + $(log::warn!($warn);)? + self.$field_name.set(rem, value.as_ref()) + }, )* _ => return _config_err!( "Config value \"{}\" not found on {}", key, stringify!($struct_name) @@ -211,12 +221,15 @@ config_namespace! { /// When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) pub enable_ident_normalization: bool, default = true - /// When set to true, SQL parser will normalize options value (convert value to lowercase) - pub enable_options_value_normalization: bool, default = true + /// When set to true, SQL parser will normalize options value (convert value to lowercase). + /// Note that this option is ignored and will be removed in the future. All case-insensitive values + /// are normalized automatically. + pub enable_options_value_normalization: bool, warn = "`enable_options_value_normalization` is deprecated and ignored", default = false /// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, /// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. pub dialect: String, default = "generic".to_string() + // no need to lowercase because `sqlparser::dialect_from_str`] is case-insensitive /// If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but /// ignore the length. If false, error if a `VARCHAR` with a length is @@ -431,7 +444,7 @@ config_namespace! { /// /// Note that this default setting is not the same as /// the default parquet writer setting. - pub compression: Option, default = Some("zstd(3)".into()) + pub compression: Option, transform = str::to_lowercase, default = Some("zstd(3)".into()) /// (writing) Sets if dictionary encoding is enabled. If NULL, uses /// default parquet writer setting @@ -444,7 +457,7 @@ config_namespace! { /// Valid values are: "none", "chunk", and "page" /// These values are not case sensitive. If NULL, uses /// default parquet writer setting - pub statistics_enabled: Option, default = Some("page".into()) + pub statistics_enabled: Option, transform = str::to_lowercase, default = Some("page".into()) /// (writing) Sets max statistics size for any column. If NULL, uses /// default parquet writer setting @@ -470,7 +483,7 @@ config_namespace! { /// delta_byte_array, rle_dictionary, and byte_stream_split. /// These values are not case sensitive. If NULL, uses /// default parquet writer setting - pub encoding: Option, default = None + pub encoding: Option, transform = str::to_lowercase, default = None /// (writing) Use any available bloom filters when reading parquet files pub bloom_filter_on_read: bool, default = true @@ -971,21 +984,37 @@ impl ConfigField for Option { } } +fn default_transform(input: &str) -> Result +where + T: FromStr, + ::Err: Sync + Send + Error + 'static, +{ + input.parse().map_err(|e| { + DataFusionError::Context( + format!( + "Error parsing '{}' as {}", + input, + std::any::type_name::() + ), + Box::new(DataFusionError::External(Box::new(e))), + ) + }) +} + #[macro_export] macro_rules! config_field { ($t:ty) => { + config_field!($t, value => default_transform(value)?); + }; + + ($t:ty, $arg:ident => $transform:expr) => { impl ConfigField for $t { fn visit(&self, v: &mut V, key: &str, description: &'static str) { v.some(key, self, description) } - fn set(&mut self, _: &str, value: &str) -> Result<()> { - *self = value.parse().map_err(|e| { - DataFusionError::Context( - format!(concat!("Error parsing {} as ", stringify!($t),), value), - Box::new(DataFusionError::External(Box::new(e))), - ) - })?; + fn set(&mut self, _: &str, $arg: &str) -> Result<()> { + *self = $transform; Ok(()) } } @@ -993,7 +1022,7 @@ macro_rules! config_field { } config_field!(String); -config_field!(bool); +config_field!(bool, value => default_transform(value.to_lowercase().as_str())?); config_field!(usize); config_field!(f64); config_field!(u64); @@ -1508,7 +1537,7 @@ macro_rules! config_namespace_with_hashmap { $vis:vis struct $struct_name:ident { $( $(#[doc = $d:tt])* - $field_vis:vis $field_name:ident : $field_type:ty, default = $default:expr + $field_vis:vis $field_name:ident : $field_type:ty, $(transform = $transform:expr,)? default = $default:expr )*$(,)* } ) => { @@ -1527,7 +1556,10 @@ macro_rules! config_namespace_with_hashmap { let (key, rem) = key.split_once('.').unwrap_or((key, "")); match key { $( - stringify!($field_name) => self.$field_name.set(rem, value), + stringify!($field_name) => { + $(let value = $transform(value);)? + self.$field_name.set(rem, value.as_ref()) + }, )* _ => _config_err!( "Config value \"{}\" not found on {}", key, stringify!($struct_name) @@ -1606,7 +1638,7 @@ config_namespace_with_hashmap! { /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case-sensitive. If NULL, uses /// default parquet options - pub compression: Option, default = None + pub compression: Option, transform = str::to_lowercase, default = None /// Sets if statistics are enabled for the column /// Valid values are: "none", "chunk", and "page" diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index d8fad5b6cd37..2cea37fe17e2 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -62,7 +62,7 @@ impl TableProviderFactory for StreamTableFactory { let header = if let Ok(opt) = cmd .options .get("format.has_header") - .map(|has_header| bool::from_str(has_header)) + .map(|has_header| bool::from_str(has_header.to_lowercase().as_str())) .transpose() { opt.unwrap_or(false) diff --git a/datafusion/core/tests/config_from_env.rs b/datafusion/core/tests/config_from_env.rs index a5a5a4524e60..976597c8a9ac 100644 --- a/datafusion/core/tests/config_from_env.rs +++ b/datafusion/core/tests/config_from_env.rs @@ -22,10 +22,19 @@ use std::env; fn from_env() { // Note: these must be a single test to avoid interference from concurrent execution let env_key = "DATAFUSION_OPTIMIZER_FILTER_NULL_JOIN_KEYS"; - env::set_var(env_key, "true"); - let config = ConfigOptions::from_env().unwrap(); + // valid testing in different cases + for bool_option in ["true", "TRUE", "True", "tRUe"] { + env::set_var(env_key, bool_option); + let config = ConfigOptions::from_env().unwrap(); + env::remove_var(env_key); + assert!(config.optimizer.filter_null_join_keys); + } + + // invalid testing + env::set_var(env_key, "ttruee"); + let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); + assert_eq!(err, "Error parsing 'ttruee' as bool\ncaused by\nExternal error: provided string was not `true` or `false`"); env::remove_var(env_key); - assert!(config.optimizer.filter_null_join_keys); let env_key = "DATAFUSION_EXECUTION_BATCH_SIZE"; @@ -37,7 +46,7 @@ fn from_env() { // for invalid testing env::set_var(env_key, "abc"); let err = ConfigOptions::from_env().unwrap_err().strip_backtrace(); - assert_eq!(err, "Error parsing abc as usize\ncaused by\nExternal error: invalid digit found in string"); + assert_eq!(err, "Error parsing 'abc' as usize\ncaused by\nExternal error: invalid digit found in string"); env::remove_var(env_key); let config = ConfigOptions::from_env().unwrap(); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 59fa4ca5f1f6..2d0ba8f8d994 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -24,10 +24,10 @@ use arrow_schema::*; use datafusion_common::{ field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; +use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; -use sqlparser::ast::{TimezoneInfo, Value}; use datafusion_common::TableReference; use datafusion_common::{ @@ -38,7 +38,7 @@ use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::utils::find_column_exprs; use datafusion_expr::{col, Expr}; -use crate::utils::{make_decimal_type, value_to_string}; +use crate::utils::make_decimal_type; pub use datafusion_expr::planner::ContextProvider; /// SQL parser options @@ -56,7 +56,7 @@ impl Default for ParserOptions { parse_float_as_decimal: false, enable_ident_normalization: true, support_varchar_with_length: true, - enable_options_value_normalization: true, + enable_options_value_normalization: false, } } } @@ -87,32 +87,6 @@ impl IdentNormalizer { } } -/// Value Normalizer -#[derive(Debug)] -pub struct ValueNormalizer { - normalize: bool, -} - -impl Default for ValueNormalizer { - fn default() -> Self { - Self { normalize: true } - } -} - -impl ValueNormalizer { - pub fn new(normalize: bool) -> Self { - Self { normalize } - } - - pub fn normalize(&self, value: Value) -> Option { - match (value_to_string(&value), self.normalize) { - (Some(s), true) => Some(s.to_ascii_lowercase()), - (Some(s), false) => Some(s), - (None, _) => None, - } - } -} - /// Struct to store the states used by the Planner. The Planner will leverage the states to resolve /// CTEs, Views, subqueries and PREPARE statements. The states include /// Common Table Expression (CTE) provided with WITH clause and @@ -254,7 +228,6 @@ pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, pub(crate) ident_normalizer: IdentNormalizer, - pub(crate) value_normalizer: ValueNormalizer, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -266,13 +239,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Create a new query planner pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { let ident_normalize = options.enable_ident_normalization; - let options_value_normalize = options.enable_options_value_normalization; SqlToRel { context_provider, options, ident_normalizer: IdentNormalizer::new(ident_normalize), - value_normalizer: ValueNormalizer::new(options_value_normalize), } } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 38695f98b5fe..f750afbc4a53 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1386,8 +1386,7 @@ impl SqlToRel<'_, S> { return plan_err!("Option {key} is specified multiple times"); } - let Some(value_string) = self.value_normalizer.normalize(value.clone()) - else { + let Some(value_string) = crate::utils::value_to_string(&value) else { return plan_err!("Unsupported Value {}", value); }; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 9363d16c9fc9..786f72741282 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -29,11 +29,10 @@ use datafusion_common::{ }; use datafusion_expr::{ col, - dml::CopyTo, logical_plan::{LogicalPlan, Prepare}, test::function_stub::sum_udaf, - ColumnarValue, CreateExternalTable, CreateIndex, DdlStatement, ScalarUDF, - ScalarUDFImpl, Signature, Statement, Volatility, + ColumnarValue, CreateIndex, DdlStatement, ScalarUDF, ScalarUDFImpl, Signature, + Statement, Volatility, }; use datafusion_functions::{string, unicode}; use datafusion_sql::{ @@ -161,70 +160,6 @@ fn parse_ident_normalization() { } } -#[test] -fn test_parse_options_value_normalization() { - let test_data = [ - ( - "CREATE EXTERNAL TABLE test OPTIONS ('location' 'LoCaTiOn') STORED AS PARQUET LOCATION 'fake_location'", - "CreateExternalTable: Bare { table: \"test\" }", - HashMap::from([("format.location", "LoCaTiOn")]), - false, - ), - ( - "CREATE EXTERNAL TABLE test OPTIONS ('location' 'LoCaTiOn') STORED AS PARQUET LOCATION 'fake_location'", - "CreateExternalTable: Bare { table: \"test\" }", - HashMap::from([("format.location", "location")]), - true, - ), - ( - "COPY test TO 'fake_location' STORED AS PARQUET OPTIONS ('location' 'LoCaTiOn')", - "CopyTo: format=csv output_url=fake_location options: (format.location LoCaTiOn)\n TableScan: test", - HashMap::from([("format.location", "LoCaTiOn")]), - false, - ), - ( - "COPY test TO 'fake_location' STORED AS PARQUET OPTIONS ('location' 'LoCaTiOn')", - "CopyTo: format=csv output_url=fake_location options: (format.location location)\n TableScan: test", - HashMap::from([("format.location", "location")]), - true, - ), - ]; - - for (sql, expected_plan, expected_options, enable_options_value_normalization) in - test_data - { - let plan = logical_plan_with_options( - sql, - ParserOptions { - parse_float_as_decimal: false, - enable_ident_normalization: false, - support_varchar_with_length: false, - enable_options_value_normalization, - }, - ); - if let Ok(plan) = plan { - assert_eq!(expected_plan, format!("{plan}")); - - match plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable( - CreateExternalTable { options, .. }, - )) - | LogicalPlan::Copy(CopyTo { options, .. }) => { - expected_options.iter().for_each(|(k, v)| { - assert_eq!(Some(&v.to_string()), options.get(*k)); - }); - } - _ => panic!( - "Expected Ddl(CreateExternalTable) or Copy(CopyTo) but got {:?}", - plan - ), - } - } else { - assert_eq!(expected_plan, plan.unwrap_err().strip_backtrace()); - } - } -} - #[test] fn select_no_relation() { quick_test( diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index ed001cf9f84c..6a63ea1cd3e4 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -226,6 +226,20 @@ OPTIONS ( has_header false, compression gzip); +# Verify that some options are case insensitive +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS region ( + r_regionkey BIGINT, + r_name VARCHAR, + r_comment VARCHAR, + r_rev VARCHAR, +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' +OPTIONS ( + format.delimiter '|', + has_header FALSE, + compression GZIP); + + # Create an external parquet table and infer schema to order by # query should succeed diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 4d51a61c8a52..1f6b5f9852ec 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -258,7 +258,7 @@ datafusion.optimizer.skip_failed_rules false datafusion.optimizer.top_down_join_key_reordering true datafusion.sql_parser.dialect generic datafusion.sql_parser.enable_ident_normalization true -datafusion.sql_parser.enable_options_value_normalization true +datafusion.sql_parser.enable_options_value_normalization false datafusion.sql_parser.parse_float_as_decimal false datafusion.sql_parser.support_varchar_with_length true @@ -351,7 +351,7 @@ datafusion.optimizer.skip_failed_rules false When set to true, the logical plan datafusion.optimizer.top_down_join_key_reordering true When set to true, the physical plan optimizer will run a top down process to reorder the join keys datafusion.sql_parser.dialect generic Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. datafusion.sql_parser.enable_ident_normalization true When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) -datafusion.sql_parser.enable_options_value_normalization true When set to true, SQL parser will normalize options value (convert value to lowercase) +datafusion.sql_parser.enable_options_value_normalization false When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. datafusion.sql_parser.parse_float_as_decimal false When set to true, SQL parser will parse float as decimal type datafusion.sql_parser.support_varchar_with_length true If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. diff --git a/datafusion/sqllogictest/test_files/set_variable.slt b/datafusion/sqllogictest/test_files/set_variable.slt index 6f19c9f4d42f..bb4ac920d032 100644 --- a/datafusion/sqllogictest/test_files/set_variable.slt +++ b/datafusion/sqllogictest/test_files/set_variable.slt @@ -93,10 +93,10 @@ datafusion.execution.coalesce_batches false statement ok set datafusion.catalog.information_schema = true -statement error DataFusion error: Error parsing 1 as bool +statement error DataFusion error: Error parsing '1' as bool SET datafusion.execution.coalesce_batches to 1 -statement error DataFusion error: Error parsing abc as bool +statement error DataFusion error: Error parsing 'abc' as bool SET datafusion.execution.coalesce_batches to abc # set u64 variable @@ -132,10 +132,10 @@ datafusion.execution.batch_size 2 statement ok set datafusion.catalog.information_schema = true -statement error DataFusion error: Error parsing -1 as usize +statement error DataFusion error: Error parsing '-1' as usize SET datafusion.execution.batch_size to -1 -statement error DataFusion error: Error parsing abc as usize +statement error DataFusion error: Error parsing 'abc' as usize SET datafusion.execution.batch_size to abc statement error External error: invalid digit found in string diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 6a49fda668a9..77433c85cb66 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -122,6 +122,6 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.explain.show_schema | false | When set to true, the explain statement will print schema information | | datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | | datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | -| datafusion.sql_parser.enable_options_value_normalization | true | When set to true, SQL parser will normalize options value (convert value to lowercase) | +| datafusion.sql_parser.enable_options_value_normalization | false | When set to true, SQL parser will normalize options value (convert value to lowercase). Note that this option is ignored and will be removed in the future. All case-insensitive values are normalized automatically. | | datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. | | datafusion.sql_parser.support_varchar_with_length | true | If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. | From 87b77bba4fb47980245c19b8dd289a3ca83cf7e5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 20 Dec 2024 09:32:48 -0500 Subject: [PATCH 4/6] Improve`Signature` and `comparison_coercion` documentation (#13840) * Improve Signature documentation more * Apply suggestions from code review Co-authored-by: Piotr Findeisen --------- Co-authored-by: Piotr Findeisen --- datafusion/expr-common/src/signature.rs | 33 ++++++++++++------- .../expr-common/src/type_coercion/binary.rs | 22 ++++++++++++- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index b5d25b4338c7..77ba1858e35b 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -103,9 +103,13 @@ pub enum TypeSignature { /// A function such as `concat` is `Variadic(vec![DataType::Utf8, /// DataType::LargeUtf8])` Variadic(Vec), - /// The acceptable signature and coercions rules to coerce arguments to this - /// signature are special for this function. If this signature is specified, - /// DataFusion will call `ScalarUDFImpl::coerce_types` to prepare argument types. + /// The acceptable signature and coercions rules are special for this + /// function. + /// + /// If this signature is specified, + /// DataFusion will call [`ScalarUDFImpl::coerce_types`] to prepare argument types. + /// + /// [`ScalarUDFImpl::coerce_types`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/trait.ScalarUDFImpl.html#method.coerce_types UserDefined, /// One or more arguments with arbitrary types VariadicAny, @@ -123,24 +127,29 @@ pub enum TypeSignature { /// One or more arguments belonging to the [`TypeSignatureClass`], in order. /// /// For example, `Coercible(vec![logical_float64()])` accepts - /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]` + /// arguments like `vec![Int32]` or `vec![Float32]` /// since i32 and f32 can be cast to f64 /// /// For functions that take no arguments (e.g. `random()`) see [`TypeSignature::Nullary`]. Coercible(Vec), - /// One or more arguments that can be "compared" + /// One or more arguments coercible to a single, comparable type. + /// + /// Each argument will be coerced to a single type using the + /// coercion rules described in [`comparison_coercion_numeric`]. + /// + /// # Examples + /// + /// If the `nullif(1, 2)` function is called with `i32` and `i64` arguments + /// the types will both be coerced to `i64` before the function is invoked. /// - /// Each argument will be coerced to a single type based on comparison rules. - /// For example a function called with `i32` and `i64` has coerced type `Int64` so - /// each argument will be coerced to `Int64` before the function is invoked. + /// If the `nullif('1', 2)` function is called with `Utf8` and `i64` arguments + /// the types will both be coerced to `Utf8` before the function is invoked. /// /// Note: - /// - If compares with numeric and string, numeric is preferred for numeric string cases. For example, `nullif('2', 1)` has coerced types `Int64`. - /// - If the result is Null, it will be coerced to String (Utf8View). - /// - See [`comparison_coercion`] for more details. /// - For functions that take no arguments (e.g. `random()` see [`TypeSignature::Nullary`]). + /// - If all arguments have type [`DataType::Null`], they are coerced to `Utf8` /// - /// [`comparison_coercion`]: crate::type_coercion::binary::comparison_coercion + /// [`comparison_coercion_numeric`]: crate::type_coercion::binary::comparison_coercion_numeric Comparable(usize), /// One or more arguments of arbitrary types. /// diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 49c1ccff3814..c775d3131692 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -625,6 +625,19 @@ pub fn try_type_union_resolution_with_struct( /// data type. However, users can write queries where the two arguments are /// different data types. In such cases, the data types are automatically cast /// (coerced) to a single data type to pass to the kernels. +/// +/// # Numeric comparisons +/// +/// When comparing numeric values, the lower precision type is coerced to the +/// higher precision type to avoid losing data. For example when comparing +/// `Int32` to `Int64` the coerced type is `Int64` so the `Int32` argument will +/// be cast. +/// +/// # Numeric / String comparisons +/// +/// When comparing numeric values and strings, both values will be coerced to +/// strings. For example when comparing `'2' > 1`, the arguments will be +/// coerced to `Utf8` for comparison pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { if lhs_type == rhs_type { // same type => equality is possible @@ -642,7 +655,14 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option 1` if `1` is an `Int32`, the arguments +/// will be coerced to `Int32`. pub fn comparison_coercion_numeric( lhs_type: &DataType, rhs_type: &DataType, From 74480ac58ee5658a275b1e8b0ebd9764d0e48844 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Fri, 20 Dec 2024 22:35:53 +0800 Subject: [PATCH 5/6] feat: support normalized expr in CSE (#13315) * feat: support normalized expr in CSE * feat: support normalize_eq in cse optimization * feat: support cumulative binary expr result in normalize_eq --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/cse.rs | 150 +++++-- datafusion/expr/src/expr.rs | 389 +++++++++++++++++- datafusion/expr/src/logical_plan/plan.rs | 20 + .../optimizer/src/common_subexpr_eliminate.rs | 263 +++++++++++- 4 files changed, 790 insertions(+), 32 deletions(-) diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index ab02915858cd..f64571b8471e 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -50,12 +50,33 @@ impl HashNode for Arc { } } +/// The `Normalizeable` trait defines a method to determine whether a node can be normalized. +/// +/// Normalization is the process of converting a node into a canonical form that can be used +/// to compare nodes for equality. This is useful in optimizations like Common Subexpression Elimination (CSE), +/// where semantically equivalent nodes (e.g., `a + b` and `b + a`) should be treated as equal. +pub trait Normalizeable { + fn can_normalize(&self) -> bool; +} + +/// The `NormalizeEq` trait extends `Eq` and `Normalizeable` to provide a method for comparing +/// normlized nodes in optimizations like Common Subexpression Elimination (CSE). +/// +/// The `normalize_eq` method ensures that two nodes that are semantically equivalent (after normalization) +/// are considered equal in CSE optimization, even if their original forms differ. +/// +/// This trait allows for equality comparisons between nodes with equivalent semantics, regardless of their +/// internal representations. +pub trait NormalizeEq: Eq + Normalizeable { + fn normalize_eq(&self, other: &Self) -> bool; +} + /// Identifier that represents a [`TreeNode`] tree. /// /// This identifier is designed to be efficient and "hash", "accumulate", "equal" and /// "have no collision (as low as possible)" -#[derive(Debug, Eq, PartialEq)] -struct Identifier<'n, N> { +#[derive(Debug, Eq)] +struct Identifier<'n, N: NormalizeEq> { // Hash of `node` built up incrementally during the first, visiting traversal. // Its value is not necessarily equal to default hash of the node. E.g. it is not // equal to `expr.hash()` if the node is `Expr`. @@ -63,20 +84,29 @@ struct Identifier<'n, N> { node: &'n N, } -impl Clone for Identifier<'_, N> { +impl Clone for Identifier<'_, N> { fn clone(&self) -> Self { *self } } -impl Copy for Identifier<'_, N> {} +impl Copy for Identifier<'_, N> {} -impl Hash for Identifier<'_, N> { +impl Hash for Identifier<'_, N> { fn hash(&self, state: &mut H) { state.write_u64(self.hash); } } -impl<'n, N: HashNode> Identifier<'n, N> { +impl PartialEq for Identifier<'_, N> { + fn eq(&self, other: &Self) -> bool { + self.hash == other.hash && self.node.normalize_eq(other.node) + } +} + +impl<'n, N> Identifier<'n, N> +where + N: HashNode + NormalizeEq, +{ fn new(node: &'n N, random_state: &RandomState) -> Self { let mut hasher = random_state.build_hasher(); node.hash_node(&mut hasher); @@ -213,7 +243,11 @@ pub enum FoundCommonNodes { /// /// A [`TreeNode`] without any children (column, literal etc.) will not have identifier /// because they should not be recognized as common subtree. -struct CSEVisitor<'a, 'n, N, C: CSEController> { +struct CSEVisitor<'a, 'n, N, C> +where + N: NormalizeEq, + C: CSEController, +{ /// statistics of [`TreeNode`]s node_stats: &'a mut NodeStats<'n, N>, @@ -244,7 +278,10 @@ struct CSEVisitor<'a, 'n, N, C: CSEController> { } /// Record item that used when traversing a [`TreeNode`] tree. -enum VisitRecord<'n, N> { +enum VisitRecord<'n, N> +where + N: NormalizeEq, +{ /// Marks the beginning of [`TreeNode`]. It contains: /// - The post-order index assigned during the first, visiting traversal. EnterMark(usize), @@ -258,7 +295,11 @@ enum VisitRecord<'n, N> { NodeItem(Identifier<'n, N>, bool), } -impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, N, C> { +impl<'n, N, C> CSEVisitor<'_, 'n, N, C> +where + N: TreeNode + HashNode + NormalizeEq, + C: CSEController, +{ /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before /// it. Returns a tuple that contains: /// - The pre-order index of the [`TreeNode`] we marked. @@ -271,17 +312,26 @@ impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, /// information up from children to parents via `visit_stack` during the first, /// visiting traversal and no need to test the expression's validity beforehand with /// an extra traversal). - fn pop_enter_mark(&mut self) -> (usize, Option>, bool) { - let mut node_id = None; + fn pop_enter_mark( + &mut self, + can_normalize: bool, + ) -> (usize, Option>, bool) { + let mut node_ids: Vec> = vec![]; let mut is_valid = true; while let Some(item) = self.visit_stack.pop() { match item { VisitRecord::EnterMark(down_index) => { + if can_normalize { + node_ids.sort_by_key(|i| i.hash); + } + let node_id = node_ids + .into_iter() + .fold(None, |accum, item| Some(item.combine(accum))); return (down_index, node_id, is_valid); } VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => { - node_id = Some(sub_node_id.combine(node_id)); + node_ids.push(sub_node_id); is_valid &= sub_node_is_valid; } } @@ -290,8 +340,10 @@ impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, } } -impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisitor<'n> - for CSEVisitor<'_, 'n, N, C> +impl<'n, N, C> TreeNodeVisitor<'n> for CSEVisitor<'_, 'n, N, C> +where + N: TreeNode + HashNode + NormalizeEq, + C: CSEController, { type Node = N; @@ -331,7 +383,8 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito } fn f_up(&mut self, node: &'n Self::Node) -> Result { - let (down_index, sub_node_id, sub_node_is_valid) = self.pop_enter_mark(); + let (down_index, sub_node_id, sub_node_is_valid) = + self.pop_enter_mark(node.can_normalize()); let node_id = Identifier::new(node, self.random_state).combine(sub_node_id); let is_valid = C::is_valid(node) && sub_node_is_valid; @@ -369,7 +422,11 @@ impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisito /// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the /// corresponding temporary [`TreeNode`], that column contains the evaluate result of /// replaced [`TreeNode`] tree. -struct CSERewriter<'a, 'n, N, C: CSEController> { +struct CSERewriter<'a, 'n, N, C> +where + N: NormalizeEq, + C: CSEController, +{ /// statistics of [`TreeNode`]s node_stats: &'a NodeStats<'n, N>, @@ -386,8 +443,10 @@ struct CSERewriter<'a, 'n, N, C: CSEController> { controller: &'a mut C, } -impl> TreeNodeRewriter - for CSERewriter<'_, '_, N, C> +impl TreeNodeRewriter for CSERewriter<'_, '_, N, C> +where + N: TreeNode + NormalizeEq, + C: CSEController, { type Node = N; @@ -408,13 +467,30 @@ impl> TreeNodeRewriter self.down_index += 1; } - let (node, alias) = - self.common_nodes.entry(node_id).or_insert_with(|| { - let node_alias = self.controller.generate_alias(); - (node, node_alias) - }); - - let rewritten = self.controller.rewrite(node, alias); + // We *must* replace all original nodes with same `node_id`, not just the first + // node which is inserted into the common_nodes. This is because nodes with the same + // `node_id` are semantically equivalent, but not exactly the same. + // + // For example, `a + 1` and `1 + a` are semantically equivalent but not identical. + // In this case, we should replace the common expression `1 + a` with a new variable + // (e.g., `__common_cse_1`). So, `a + 1` and `1 + a` would both be replaced by + // `__common_cse_1`. + // + // The final result would be: + // - `__common_cse_1 as a + 1` + // - `__common_cse_1 as 1 + a` + // + // This way, we can efficiently handle semantically equivalent expressions without + // incorrectly treating them as identical. + let rewritten = if let Some((_, alias)) = self.common_nodes.get(&node_id) + { + self.controller.rewrite(&node, alias) + } else { + let node_alias = self.controller.generate_alias(); + let rewritten = self.controller.rewrite(&node, &node_alias); + self.common_nodes.insert(node_id, (node, node_alias)); + rewritten + }; return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); } @@ -441,7 +517,11 @@ pub struct CSE> { controller: C, } -impl> CSE { +impl CSE +where + N: TreeNode + HashNode + Clone + NormalizeEq, + C: CSEController, +{ pub fn new(controller: C) -> Self { Self { random_state: RandomState::new(), @@ -557,6 +637,7 @@ impl> CSE ) -> Result> { let mut found_common = false; let mut node_stats = NodeStats::new(); + let id_arrays_list = nodes_list .iter() .map(|nodes| { @@ -596,7 +677,10 @@ impl> CSE #[cfg(test)] mod test { use crate::alias::AliasGenerator; - use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE}; + use crate::cse::{ + CSEController, HashNode, IdArray, Identifier, NodeStats, NormalizeEq, + Normalizeable, CSE, + }; use crate::tree_node::tests::TestTreeNode; use crate::Result; use std::collections::HashSet; @@ -662,6 +746,18 @@ mod test { } } + impl Normalizeable for TestTreeNode { + fn can_normalize(&self) -> bool { + false + } + } + + impl NormalizeEq for TestTreeNode { + fn normalize_eq(&self, other: &Self) -> bool { + self == other + } + } + #[test] fn id_array_visitor() -> Result<()> { let alias_generator = AliasGenerator::new(); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c495b5396f53..af54dad79d2e 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -30,7 +30,7 @@ use crate::Volatility; use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::cse::HashNode; +use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -1665,6 +1665,393 @@ impl Expr { } } +impl Normalizeable for Expr { + fn can_normalize(&self) -> bool { + #[allow(clippy::match_like_matches_macro)] + match self { + Expr::BinaryExpr(BinaryExpr { + op: + _op @ (Operator::Plus + | Operator::Multiply + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::Eq + | Operator::NotEq), + .. + }) => true, + _ => false, + } + } +} + +impl NormalizeEq for Expr { + fn normalize_eq(&self, other: &Self) -> bool { + match (self, other) { + ( + Expr::BinaryExpr(BinaryExpr { + left: self_left, + op: self_op, + right: self_right, + }), + Expr::BinaryExpr(BinaryExpr { + left: other_left, + op: other_op, + right: other_right, + }), + ) => { + if self_op != other_op { + return false; + } + + if matches!( + self_op, + Operator::Plus + | Operator::Multiply + | Operator::BitwiseAnd + | Operator::BitwiseOr + | Operator::BitwiseXor + | Operator::Eq + | Operator::NotEq + ) { + (self_left.normalize_eq(other_left) + && self_right.normalize_eq(other_right)) + || (self_left.normalize_eq(other_right) + && self_right.normalize_eq(other_left)) + } else { + self_left.normalize_eq(other_left) + && self_right.normalize_eq(other_right) + } + } + ( + Expr::Alias(Alias { + expr: self_expr, + relation: self_relation, + name: self_name, + }), + Expr::Alias(Alias { + expr: other_expr, + relation: other_relation, + name: other_name, + }), + ) => { + self_name == other_name + && self_relation == other_relation + && self_expr.normalize_eq(other_expr) + } + ( + Expr::Like(Like { + negated: self_negated, + expr: self_expr, + pattern: self_pattern, + escape_char: self_escape_char, + case_insensitive: self_case_insensitive, + }), + Expr::Like(Like { + negated: other_negated, + expr: other_expr, + pattern: other_pattern, + escape_char: other_escape_char, + case_insensitive: other_case_insensitive, + }), + ) + | ( + Expr::SimilarTo(Like { + negated: self_negated, + expr: self_expr, + pattern: self_pattern, + escape_char: self_escape_char, + case_insensitive: self_case_insensitive, + }), + Expr::SimilarTo(Like { + negated: other_negated, + expr: other_expr, + pattern: other_pattern, + escape_char: other_escape_char, + case_insensitive: other_case_insensitive, + }), + ) => { + self_negated == other_negated + && self_escape_char == other_escape_char + && self_case_insensitive == other_case_insensitive + && self_expr.normalize_eq(other_expr) + && self_pattern.normalize_eq(other_pattern) + } + (Expr::Not(self_expr), Expr::Not(other_expr)) + | (Expr::IsNull(self_expr), Expr::IsNull(other_expr)) + | (Expr::IsTrue(self_expr), Expr::IsTrue(other_expr)) + | (Expr::IsFalse(self_expr), Expr::IsFalse(other_expr)) + | (Expr::IsUnknown(self_expr), Expr::IsUnknown(other_expr)) + | (Expr::IsNotNull(self_expr), Expr::IsNotNull(other_expr)) + | (Expr::IsNotTrue(self_expr), Expr::IsNotTrue(other_expr)) + | (Expr::IsNotFalse(self_expr), Expr::IsNotFalse(other_expr)) + | (Expr::IsNotUnknown(self_expr), Expr::IsNotUnknown(other_expr)) + | (Expr::Negative(self_expr), Expr::Negative(other_expr)) + | ( + Expr::Unnest(Unnest { expr: self_expr }), + Expr::Unnest(Unnest { expr: other_expr }), + ) => self_expr.normalize_eq(other_expr), + ( + Expr::Between(Between { + expr: self_expr, + negated: self_negated, + low: self_low, + high: self_high, + }), + Expr::Between(Between { + expr: other_expr, + negated: other_negated, + low: other_low, + high: other_high, + }), + ) => { + self_negated == other_negated + && self_expr.normalize_eq(other_expr) + && self_low.normalize_eq(other_low) + && self_high.normalize_eq(other_high) + } + ( + Expr::Cast(Cast { + expr: self_expr, + data_type: self_data_type, + }), + Expr::Cast(Cast { + expr: other_expr, + data_type: other_data_type, + }), + ) + | ( + Expr::TryCast(TryCast { + expr: self_expr, + data_type: self_data_type, + }), + Expr::TryCast(TryCast { + expr: other_expr, + data_type: other_data_type, + }), + ) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr), + ( + Expr::ScalarFunction(ScalarFunction { + func: self_func, + args: self_args, + }), + Expr::ScalarFunction(ScalarFunction { + func: other_func, + args: other_args, + }), + ) => { + self_func.name() == other_func.name() + && self_args.len() == other_args.len() + && self_args + .iter() + .zip(other_args.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } + ( + Expr::AggregateFunction(AggregateFunction { + func: self_func, + args: self_args, + distinct: self_distinct, + filter: self_filter, + order_by: self_order_by, + null_treatment: self_null_treatment, + }), + Expr::AggregateFunction(AggregateFunction { + func: other_func, + args: other_args, + distinct: other_distinct, + filter: other_filter, + order_by: other_order_by, + null_treatment: other_null_treatment, + }), + ) => { + self_func.name() == other_func.name() + && self_distinct == other_distinct + && self_null_treatment == other_null_treatment + && self_args.len() == other_args.len() + && self_args + .iter() + .zip(other_args.iter()) + .all(|(a, b)| a.normalize_eq(b)) + && match (self_filter, other_filter) { + (Some(self_filter), Some(other_filter)) => { + self_filter.normalize_eq(other_filter) + } + (None, None) => true, + _ => false, + } + && match (self_order_by, other_order_by) { + (Some(self_order_by), Some(other_order_by)) => self_order_by + .iter() + .zip(other_order_by.iter()) + .all(|(a, b)| { + a.asc == b.asc + && a.nulls_first == b.nulls_first + && a.expr.normalize_eq(&b.expr) + }), + (None, None) => true, + _ => false, + } + } + ( + Expr::WindowFunction(WindowFunction { + fun: self_fun, + args: self_args, + partition_by: self_partition_by, + order_by: self_order_by, + window_frame: self_window_frame, + null_treatment: self_null_treatment, + }), + Expr::WindowFunction(WindowFunction { + fun: other_fun, + args: other_args, + partition_by: other_partition_by, + order_by: other_order_by, + window_frame: other_window_frame, + null_treatment: other_null_treatment, + }), + ) => { + self_fun.name() == other_fun.name() + && self_window_frame == other_window_frame + && self_null_treatment == other_null_treatment + && self_args.len() == other_args.len() + && self_args + .iter() + .zip(other_args.iter()) + .all(|(a, b)| a.normalize_eq(b)) + && self_partition_by + .iter() + .zip(other_partition_by.iter()) + .all(|(a, b)| a.normalize_eq(b)) + && self_order_by + .iter() + .zip(other_order_by.iter()) + .all(|(a, b)| { + a.asc == b.asc + && a.nulls_first == b.nulls_first + && a.expr.normalize_eq(&b.expr) + }) + } + ( + Expr::Exists(Exists { + subquery: self_subquery, + negated: self_negated, + }), + Expr::Exists(Exists { + subquery: other_subquery, + negated: other_negated, + }), + ) => { + self_negated == other_negated + && self_subquery.normalize_eq(other_subquery) + } + ( + Expr::InSubquery(InSubquery { + expr: self_expr, + subquery: self_subquery, + negated: self_negated, + }), + Expr::InSubquery(InSubquery { + expr: other_expr, + subquery: other_subquery, + negated: other_negated, + }), + ) => { + self_negated == other_negated + && self_expr.normalize_eq(other_expr) + && self_subquery.normalize_eq(other_subquery) + } + ( + Expr::ScalarSubquery(self_subquery), + Expr::ScalarSubquery(other_subquery), + ) => self_subquery.normalize_eq(other_subquery), + ( + Expr::GroupingSet(GroupingSet::Rollup(self_exprs)), + Expr::GroupingSet(GroupingSet::Rollup(other_exprs)), + ) + | ( + Expr::GroupingSet(GroupingSet::Cube(self_exprs)), + Expr::GroupingSet(GroupingSet::Cube(other_exprs)), + ) => { + self_exprs.len() == other_exprs.len() + && self_exprs + .iter() + .zip(other_exprs.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } + ( + Expr::GroupingSet(GroupingSet::GroupingSets(self_exprs)), + Expr::GroupingSet(GroupingSet::GroupingSets(other_exprs)), + ) => { + self_exprs.len() == other_exprs.len() + && self_exprs.iter().zip(other_exprs.iter()).all(|(a, b)| { + a.len() == b.len() + && a.iter().zip(b.iter()).all(|(x, y)| x.normalize_eq(y)) + }) + } + ( + Expr::InList(InList { + expr: self_expr, + list: self_list, + negated: self_negated, + }), + Expr::InList(InList { + expr: other_expr, + list: other_list, + negated: other_negated, + }), + ) => { + // TODO: normalize_eq for lists, for example `a IN (c1 + c3, c3)` is equal to `a IN (c3, c1 + c3)` + self_negated == other_negated + && self_expr.normalize_eq(other_expr) + && self_list.len() == other_list.len() + && self_list + .iter() + .zip(other_list.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } + ( + Expr::Case(Case { + expr: self_expr, + when_then_expr: self_when_then_expr, + else_expr: self_else_expr, + }), + Expr::Case(Case { + expr: other_expr, + when_then_expr: other_when_then_expr, + else_expr: other_else_expr, + }), + ) => { + // TODO: normalize_eq for when_then_expr + // for example `CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END` is equal to `CASE a WHEN 3 THEN 4 WHEN 1 THEN 2 ELSE 5 END` + self_when_then_expr.len() == other_when_then_expr.len() + && self_when_then_expr + .iter() + .zip(other_when_then_expr.iter()) + .all(|((self_when, self_then), (other_when, other_then))| { + self_when.normalize_eq(other_when) + && self_then.normalize_eq(other_then) + }) + && match (self_expr, other_expr) { + (Some(self_expr), Some(other_expr)) => { + self_expr.normalize_eq(other_expr) + } + (None, None) => true, + (_, _) => false, + } + && match (self_else_expr, other_else_expr) { + (Some(self_else_expr), Some(other_else_expr)) => { + self_else_expr.normalize_eq(other_else_expr) + } + (None, None) => true, + (_, _) => false, + } + } + (_, _) => self == other, + } + } +} + impl HashNode for Expr { /// As it is pretty easy to forget changing this method when `Expr` changes the /// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 31bf4c573444..6c2b923cf6ad 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -45,6 +45,7 @@ use crate::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::cse::{NormalizeEq, Normalizeable}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, }; @@ -3354,6 +3355,25 @@ pub struct Subquery { pub outer_ref_columns: Vec, } +impl Normalizeable for Subquery { + fn can_normalize(&self) -> bool { + false + } +} + +impl NormalizeEq for Subquery { + fn normalize_eq(&self, other: &Self) -> bool { + // TODO: may be implement NormalizeEq for LogicalPlan? + *self.subquery == *other.subquery + && self.outer_ref_columns.len() == other.outer_ref_columns.len() + && self + .outer_ref_columns + .iter() + .zip(other.outer_ref_columns.iter()) + .all(|(a, b)| a.normalize_eq(b)) + } +} + impl Subquery { pub fn try_from_expr(plan: &Expr) -> Result<&Subquery> { match plan { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 0ea2d24effbb..e7c9a198f3ad 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -795,8 +795,9 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, - ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility, + grouping_set, is_null, not, AccumulatorFactoryFunction, AggregateUDF, + ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; @@ -1054,8 +1055,9 @@ mod test { .project(vec![lit(1) + col("a"), col("a") + lit(1)])? .build()?; - let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ - \n TableScan: test"; + let expected = "Projection: __common_expr_1 AS Int32(1) + test.a, __common_expr_1 AS test.a + Int32(1)\ + \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; assert_optimized_plan_eq(expected, plan, None); @@ -1412,6 +1414,259 @@ mod test { Ok(()) } + #[test] + fn test_normalize_add_expression() -> Result<()> { + // a + b <=> b + a + let table_scan = test_table_scan()?; + let expr = ((col("a") + col("b")) * (col("b") + col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_multi_expression() -> Result<()> { + // a * b <=> b * a + let table_scan = test_table_scan()?; + let expr = ((col("a") * col("b")) + (col("b") * col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a * test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_bitset_and_expression() -> Result<()> { + // a & b <=> b & a + let table_scan = test_table_scan()?; + let expr = ((col("a") & col("b")) + (col("b") & col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a & test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_bitset_or_expression() -> Result<()> { + // a | b <=> b | a + let table_scan = test_table_scan()?; + let expr = ((col("a") | col("b")) + (col("b") | col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a | test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_bitset_xor_expression() -> Result<()> { + // a # b <=> b # a + let table_scan = test_table_scan()?; + let expr = ((col("a") ^ col("b")) + (col("b") ^ col("a"))).eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 + __common_expr_1 = Int32(30)\ + \n Projection: test.a BIT_XOR test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_eq_expression() -> Result<()> { + // a = b <=> b = a + let table_scan = test_table_scan()?; + let expr = (col("a").eq(col("b"))).and(col("b").eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 AND __common_expr_1\ + \n Projection: test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_ne_expression() -> Result<()> { + // a != b <=> b != a + let table_scan = test_table_scan()?; + let expr = (col("a").not_eq(col("b"))).and(col("b").not_eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 AND __common_expr_1\ + \n Projection: test.a != test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_normalize_complex_expression() -> Result<()> { + // case1: a + b * c <=> b * c + a + let table_scan = test_table_scan()?; + let expr = ((col("a") + col("b") * col("c")) - (col("b") * col("c") + col("a"))) + .eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 - __common_expr_1 = Int32(30)\ + \n Projection: test.a + test.b * test.c AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // ((c1 + c2 / c3) * c3 <=> c3 * (c2 / c3 + c1)) + let table_scan = test_table_scan()?; + let expr = (((col("a") + col("b") / col("c")) * col("c")) + / (col("c") * (col("b") / col("c") + col("a"))) + + col("a")) + .eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 / __common_expr_1 + test.a = Int32(30)\ + \n Projection: (test.a + test.b / test.c) * test.c AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // c2 / (c1 + c3) <=> c2 / (c3 + c1) + let table_scan = test_table_scan()?; + let expr = ((col("b") / (col("a") + col("c"))) + * (col("b") / (col("c") + col("a")))) + .eq(lit(30)); + let plan = LogicalPlanBuilder::from(table_scan).filter(expr)?.build()?; + let expected = "Projection: test.a, test.b, test.c\ + \n Filter: __common_expr_1 * __common_expr_1 = Int32(30)\ + \n Projection: test.b / (test.a + test.c) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[derive(Debug)] + pub struct TestUdf { + signature: Signature, + } + + impl TestUdf { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for TestUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "my_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _: &[ColumnarValue]) -> Result { + panic!("not implemented") + } + } + + #[test] + fn test_normalize_inner_binary_expression() -> Result<()> { + // Not(a == b) <=> Not(b == a) + let table_scan = test_table_scan()?; + let expr1 = not(col("a").eq(col("b"))); + let expr2 = not(col("b").eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS NOT test.a = test.b, __common_expr_1 AS NOT test.b = test.a\ + \n Projection: NOT test.a = test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // is_null(a == b) <=> is_null(b == a) + let table_scan = test_table_scan()?; + let expr1 = is_null(col("a").eq(col("b"))); + let expr2 = is_null(col("b").eq(col("a"))); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS test.a = test.b IS NULL, __common_expr_1 AS test.b = test.a IS NULL\ + \n Projection: test.a = test.b IS NULL AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // a + b between 0 and 10 <=> b + a between 0 and 10 + let table_scan = test_table_scan()?; + let expr1 = (col("a") + col("b")).between(lit(0), lit(10)); + let expr2 = (col("b") + col("a")).between(lit(0), lit(10)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS test.a + test.b BETWEEN Int32(0) AND Int32(10), __common_expr_1 AS test.b + test.a BETWEEN Int32(0) AND Int32(10)\ + \n Projection: test.a + test.b BETWEEN Int32(0) AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // c between a + b and 10 <=> c between b + a and 10 + let table_scan = test_table_scan()?; + let expr1 = col("c").between(col("a") + col("b"), lit(10)); + let expr2 = col("c").between(col("b") + col("a"), lit(10)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS test.c BETWEEN test.a + test.b AND Int32(10), __common_expr_1 AS test.c BETWEEN test.b + test.a AND Int32(10)\ + \n Projection: test.c BETWEEN test.a + test.b AND Int32(10) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + + // function call with argument <=> function call with argument + let udf = ScalarUDF::from(TestUdf::new()); + let table_scan = test_table_scan()?; + let expr1 = udf.call(vec![col("a") + col("b")]); + let expr2 = udf.call(vec![col("b") + col("a")]); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![expr1, expr2])? + .build()?; + let expected = "Projection: __common_expr_1 AS my_udf(test.a + test.b), __common_expr_1 AS my_udf(test.b + test.a)\ + \n Projection: my_udf(test.a + test.b) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + assert_optimized_plan_eq(expected, plan, None); + Ok(()) + } + /// returns a "random" function that is marked volatile (aka each invocation /// returns a different value) /// From 75202b587108cd8de6c702a84463ac0ca0ca4d4b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 20 Dec 2024 09:37:35 -0500 Subject: [PATCH 6/6] Upgrade to sqlparser `0.53.0` (#13767) * chore: Udpate to sqlparser 0.53.0 * Update for new sqlparser API * more api updates * Avoid serializing query to SQL string unless it is necessary * Box wildcard options * chore: update datafusion-cli Cargo.lock --- Cargo.toml | 2 +- datafusion-cli/Cargo.lock | 8 +- datafusion/common/src/utils/mod.rs | 5 +- .../user_defined_scalar_functions.rs | 10 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/expr_fn.rs | 8 +- .../src/analyzer/inline_table_scan.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 2 +- .../tests/cases/roundtrip_logical_plan.rs | 12 +- datafusion/sql/src/expr/function.rs | 24 +- datafusion/sql/src/expr/mod.rs | 8 +- datafusion/sql/src/parser.rs | 14 +- datafusion/sql/src/planner.rs | 14 +- datafusion/sql/src/select.rs | 1 + datafusion/sql/src/statement.rs | 207 ++++++++++++++---- datafusion/sql/src/unparser/ast.rs | 3 + datafusion/sql/src/unparser/dialect.rs | 6 +- datafusion/sql/src/unparser/expr.rs | 32 ++- datafusion/sql/src/unparser/plan.rs | 9 +- datafusion/sql/src/unparser/utils.rs | 12 +- .../test_files/information_schema.slt | 4 +- 21 files changed, 271 insertions(+), 120 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cc94b4292a50..b7c8c09a8537 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -147,7 +147,7 @@ recursive = "0.1.1" regex = "1.8" rstest = "0.23.0" serde_json = "1" -sqlparser = { version = "0.52.0", features = ["visitor"] } +sqlparser = { version = "0.53.0", features = ["visitor"] } tempfile = "3" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } url = "2.2" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 2ffc64114ef7..a435869dbece 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -3755,9 +3755,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.52.0" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a875d8cd437cc8a97e9aeaeea352ec9a19aea99c23e9effb17757291de80b08" +checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" dependencies = [ "log", "sqlparser_derive", @@ -3765,9 +3765,9 @@ dependencies = [ [[package]] name = "sqlparser_derive" -version = "0.2.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" +checksum = "da5fc6819faabb412da764b99d3b713bb55083c11e7e0c00144d386cd6a1939c" dependencies = [ "proc-macro2", "quote", diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 5e840f859400..29d33fec14ab 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -887,10 +887,10 @@ pub fn get_available_parallelism() -> usize { #[cfg(test)] mod tests { + use super::*; use crate::ScalarValue::Null; use arrow::array::Float64Array; - - use super::*; + use sqlparser::tokenizer::Span; #[test] fn test_bisect_linear_left_and_right() -> Result<()> { @@ -1118,6 +1118,7 @@ mod tests { let expected_parsed = vec![Ident { value: identifier.to_string(), quote_style, + span: Span::empty(), }]; assert_eq!( diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 4d2a536c4920..30b3c6e2bbeb 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -27,10 +27,6 @@ use arrow_array::{ Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, }; use arrow_schema::{DataType, Field, Schema}; -use parking_lot::Mutex; -use regex::Regex; -use sqlparser::ast::Ident; - use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; @@ -48,6 +44,10 @@ use datafusion_expr::{ Volatility, }; use datafusion_functions_nested::range::range_udf; +use parking_lot::Mutex; +use regex::Regex; +use sqlparser::ast::Ident; +use sqlparser::tokenizer::Span; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -1187,6 +1187,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( name: Some(Ident { value: "name".into(), quote_style: None, + span: Span::empty(), }), data_type: DataType::Utf8, default_expr: None, @@ -1196,6 +1197,7 @@ async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<( language: Some(Ident { value: "plrust".into(), quote_style: None, + span: Span::empty(), }), behavior: None, function_body: Some(lit(body)), diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index af54dad79d2e..c82572ebd5f1 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -313,7 +313,7 @@ pub enum Expr { /// plan into physical plan. Wildcard { qualifier: Option, - options: WildcardOptions, + options: Box, }, /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a44dd24039dc..a2de5e7b259f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -123,7 +123,7 @@ pub fn placeholder(id: impl Into) -> Expr { pub fn wildcard() -> Expr { Expr::Wildcard { qualifier: None, - options: WildcardOptions::default(), + options: Box::new(WildcardOptions::default()), } } @@ -131,7 +131,7 @@ pub fn wildcard() -> Expr { pub fn wildcard_with_options(options: WildcardOptions) -> Expr { Expr::Wildcard { qualifier: None, - options, + options: Box::new(options), } } @@ -148,7 +148,7 @@ pub fn wildcard_with_options(options: WildcardOptions) -> Expr { pub fn qualified_wildcard(qualifier: impl Into) -> Expr { Expr::Wildcard { qualifier: Some(qualifier.into()), - options: WildcardOptions::default(), + options: Box::new(WildcardOptions::default()), } } @@ -159,7 +159,7 @@ pub fn qualified_wildcard_with_options( ) -> Expr { Expr::Wildcard { qualifier: Some(qualifier.into()), - options, + options: Box::new(options), } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 342d85a915b4..95781b395f3c 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -23,8 +23,7 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; -use datafusion_expr::expr::WildcardOptions; -use datafusion_expr::{logical_plan::LogicalPlan, Expr, LogicalPlanBuilder}; +use datafusion_expr::{logical_plan::LogicalPlan, wildcard, Expr, LogicalPlanBuilder}; /// Analyzed rule that inlines TableScan that provide a [`LogicalPlan`] /// (DataFrame / ViewTable) @@ -93,10 +92,7 @@ fn generate_projection_expr( ))); } } else { - exprs.push(Expr::Wildcard { - qualifier: None, - options: WildcardOptions::default(), - }); + exprs.push(wildcard()); } Ok(exprs) } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 301efc42a7c4..6ab3e0c9096c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -513,7 +513,7 @@ pub fn parse_expr( let qualifier = qualifier.to_owned().map(|x| x.try_into()).transpose()?; Ok(Expr::Wildcard { qualifier, - options: WildcardOptions::default(), + options: Box::new(WildcardOptions::default()), }) } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index f793e96f612b..c0885ece08bc 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -67,7 +67,7 @@ use datafusion_common::{ use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - Unnest, WildcardOptions, + Unnest, }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ @@ -2061,10 +2061,7 @@ fn roundtrip_unnest() { #[test] fn roundtrip_wildcard() { - let test_expr = Expr::Wildcard { - qualifier: None, - options: WildcardOptions::default(), - }; + let test_expr = wildcard(); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2072,10 +2069,7 @@ fn roundtrip_wildcard() { #[test] fn roundtrip_qualified_wildcard() { - let test_expr = Expr::Wildcard { - qualifier: Some("foo".into()), - options: WildcardOptions::default(), - }; + let test_expr = qualified_wildcard("foo"); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 67fa23b86990..da1a4ba81f5a 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -22,11 +22,11 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; -use datafusion_expr::expr::WildcardOptions; use datafusion_expr::expr::{ScalarFunction, Unnest}; use datafusion_expr::planner::PlannerResult; use datafusion_expr::{ - expr, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition, + expr, qualified_wildcard, wildcard, Expr, ExprFunctionExt, ExprSchemable, + WindowFrame, WindowFunctionDefinition, }; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, @@ -169,6 +169,11 @@ impl FunctionArgs { "Calling {name}: SEPARATOR not supported in function arguments: {sep}" ) } + FunctionArgumentClause::JsonNullClause(jn) => { + return not_impl_err!( + "Calling {name}: JSON NULL clause not supported in function arguments: {jn}" + ) + } } } @@ -413,17 +418,11 @@ impl SqlToRel<'_, S> { name: _, arg: FunctionArgExpr::Wildcard, operator: _, - } => Ok(Expr::Wildcard { - qualifier: None, - options: WildcardOptions::default(), - }), + } => Ok(wildcard()), FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { self.sql_expr_to_logical_expr(arg, schema, planner_context) } - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expr::Wildcard { - qualifier: None, - options: WildcardOptions::default(), - }), + FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(wildcard()), FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(object_name)) => { let qualifier = self.object_name_to_table_reference(object_name)?; // Sanity check on qualifier with schema @@ -431,10 +430,7 @@ impl SqlToRel<'_, S> { if qualified_indices.is_empty() { return plan_err!("Invalid qualifier {qualifier}"); } - Ok(Expr::Wildcard { - qualifier: Some(qualifier), - options: WildcardOptions::default(), - }) + Ok(qualified_wildcard(qualifier)) } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index e8ec8d7b7d1c..a651d8fa5d35 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -593,13 +593,13 @@ impl SqlToRel<'_, S> { } not_impl_err!("AnyOp not supported by ExprPlanner: {binary_expr:?}") } - SQLExpr::Wildcard => Ok(Expr::Wildcard { + SQLExpr::Wildcard(_token) => Ok(Expr::Wildcard { qualifier: None, - options: WildcardOptions::default(), + options: Box::new(WildcardOptions::default()), }), - SQLExpr::QualifiedWildcard(object_name) => Ok(Expr::Wildcard { + SQLExpr::QualifiedWildcard(object_name, _token) => Ok(Expr::Wildcard { qualifier: Some(self.object_name_to_table_reference(object_name)?), - options: WildcardOptions::default(), + options: Box::new(WildcardOptions::default()), }), SQLExpr::Tuple(values) => self.parse_tuple(schema, planner_context, values), _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index efec6020641c..f185d65fa194 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -21,6 +21,7 @@ use std::collections::VecDeque; use std::fmt; use sqlparser::ast::ExprWithAlias; +use sqlparser::tokenizer::TokenWithSpan; use sqlparser::{ ast::{ ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query, @@ -28,7 +29,7 @@ use sqlparser::{ }, dialect::{keywords::Keyword, Dialect, GenericDialect}, parser::{Parser, ParserError}, - tokenizer::{Token, TokenWithLocation, Tokenizer, Word}, + tokenizer::{Token, Tokenizer, Word}, }; // Use `Parser::expected` instead, if possible @@ -338,7 +339,7 @@ impl<'a> DFParser<'a> { fn expected( &self, expected: &str, - found: TokenWithLocation, + found: TokenWithSpan, ) -> Result { parser_err!(format!("Expected {expected}, found: {found}")) } @@ -876,6 +877,7 @@ mod tests { use super::*; use sqlparser::ast::Expr::Identifier; use sqlparser::ast::{BinaryOperator, DataType, Expr, Ident}; + use sqlparser::tokenizer::Span; fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), ParserError> { let statements = DFParser::parse_sql(sql)?; @@ -911,6 +913,7 @@ mod tests { name: Ident { value: name.into(), quote_style: None, + span: Span::empty(), }, data_type, collation: None, @@ -1219,6 +1222,7 @@ mod tests { expr: Identifier(Ident { value: "c1".to_owned(), quote_style: None, + span: Span::empty(), }), asc, nulls_first, @@ -1250,6 +1254,7 @@ mod tests { expr: Identifier(Ident { value: "c1".to_owned(), quote_style: None, + span: Span::empty(), }), asc: Some(true), nulls_first: None, @@ -1259,6 +1264,7 @@ mod tests { expr: Identifier(Ident { value: "c2".to_owned(), quote_style: None, + span: Span::empty(), }), asc: Some(false), nulls_first: Some(true), @@ -1290,11 +1296,13 @@ mod tests { left: Box::new(Identifier(Ident { value: "c1".to_owned(), quote_style: None, + span: Span::empty(), })), op: BinaryOperator::Minus, right: Box::new(Identifier(Ident { value: "c2".to_owned(), quote_style: None, + span: Span::empty(), })), }, asc: Some(true), @@ -1335,11 +1343,13 @@ mod tests { left: Box::new(Identifier(Ident { value: "c1".to_owned(), quote_style: None, + span: Span::empty(), })), op: BinaryOperator::Minus, right: Box::new(Identifier(Ident { value: "c2".to_owned(), quote_style: None, + span: Span::empty(), })), }, asc: Some(true), diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 2d0ba8f8d994..d917a707ca20 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -310,7 +310,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan: LogicalPlan, alias: TableAlias, ) -> Result { - let plan = self.apply_expr_alias(plan, alias.columns)?; + let idents = alias.columns.into_iter().map(|c| c.name).collect(); + let plan = self.apply_expr_alias(plan, idents)?; LogicalPlanBuilder::from(plan) .alias(TableReference::bare( @@ -513,7 +514,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Regclass | SQLDataType::Custom(_, _) | SQLDataType::Array(_) - | SQLDataType::Enum(_) + | SQLDataType::Enum(_, _) | SQLDataType::Set(_) | SQLDataType::MediumInt(_) | SQLDataType::UnsignedMediumInt(_) @@ -557,6 +558,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Nullable(_) | SQLDataType::LowCardinality(_) | SQLDataType::Trigger + // MySQL datatypes + | SQLDataType::TinyBlob + | SQLDataType::MediumBlob + | SQLDataType::LongBlob + | SQLDataType::TinyText + | SQLDataType::MediumText + | SQLDataType::LongText + | SQLDataType::Bit(_) + |SQLDataType::BitVarying(_) => not_impl_err!( "Unsupported SQL type {sql_type:?}" ), diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index c2a9bac24e66..9a84c00a8044 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -655,6 +655,7 @@ impl SqlToRel<'_, S> { opt_rename, opt_replace: _opt_replace, opt_ilike: _opt_ilike, + wildcard_token: _wildcard_token, } = options; if opt_rename.is_some() { diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index f750afbc4a53..e264b9083cc0 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -33,7 +33,7 @@ use arrow_schema::{DataType, Fields}; use datafusion_common::error::_plan_err; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ - exec_err, not_impl_err, plan_datafusion_err, plan_err, schema_err, + exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, schema_err, unqualified_field_not_found, Column, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, ToDFSchema, @@ -54,13 +54,16 @@ use datafusion_expr::{ TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, }; -use sqlparser::ast::{self, SqliteOnConflict}; +use sqlparser::ast::{ + self, BeginTransactionKind, NullsDistinctOption, ShowStatementIn, + ShowStatementOptions, SqliteOnConflict, +}; use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, CreateTableOptions, Delete, DescribeAlias, Expr as SQLExpr, FromTable, Ident, Insert, ObjectName, ObjectType, OneOrManyWithParens, Query, SchemaName, SetExpr, - ShowCreateObject, ShowStatementFilter, Statement, TableConstraint, TableFactor, - TableWithJoins, TransactionMode, UnaryOperator, Value, + ShowCreateObject, Statement, TableConstraint, TableFactor, TableWithJoins, + TransactionMode, UnaryOperator, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -107,6 +110,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec SqlToRel<'_, S> { statement: Statement, planner_context: &mut PlannerContext, ) -> Result { - let sql = Some(statement.to_string()); match statement { Statement::ExplainTable { describe_alias: DescribeAlias::Describe, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' @@ -514,6 +517,35 @@ impl SqlToRel<'_, S> { return not_impl_err!("To not supported")?; } + // put the statement back together temporarily to get the SQL + // string representation + let stmt = Statement::CreateView { + or_replace, + materialized, + name, + columns, + query, + options: CreateTableOptions::None, + cluster_by, + comment, + with_no_schema_binding, + if_not_exists, + temporary, + to, + }; + let sql = stmt.to_string(); + let Statement::CreateView { + name, + columns, + query, + or_replace, + temporary, + .. + } = stmt + else { + return internal_err!("Unreachable code in create view"); + }; + let columns = columns .into_iter() .map(|view_column_def| { @@ -534,7 +566,7 @@ impl SqlToRel<'_, S> { name: self.object_name_to_table_reference(name)?, input: Arc::new(plan), or_replace, - definition: sql, + definition: Some(sql), temporary, }))) } @@ -685,19 +717,99 @@ impl SqlToRel<'_, S> { Statement::ShowTables { extended, full, - db_name, - filter, - // SHOW TABLES IN/FROM are equivalent, this field specifies which the user - // specified, but it doesn't affect the plan so ignore the field - clause: _, - } => self.show_tables_to_plan(extended, full, db_name, filter), + terse, + history, + external, + show_options, + } => { + // We only support the basic "SHOW TABLES" + // https://github.com/apache/datafusion/issues/3188 + if extended { + return not_impl_err!("SHOW TABLES EXTENDED not supported")?; + } + if full { + return not_impl_err!("SHOW FULL TABLES not supported")?; + } + if terse { + return not_impl_err!("SHOW TERSE TABLES not supported")?; + } + if history { + return not_impl_err!("SHOW TABLES HISTORY not supported")?; + } + if external { + return not_impl_err!("SHOW EXTERNAL TABLES not supported")?; + } + let ShowStatementOptions { + show_in, + starts_with, + limit, + limit_from, + filter_position, + } = show_options; + if show_in.is_some() { + return not_impl_err!("SHOW TABLES IN not supported")?; + } + if starts_with.is_some() { + return not_impl_err!("SHOW TABLES LIKE not supported")?; + } + if limit.is_some() { + return not_impl_err!("SHOW TABLES LIMIT not supported")?; + } + if limit_from.is_some() { + return not_impl_err!("SHOW TABLES LIMIT FROM not supported")?; + } + if filter_position.is_some() { + return not_impl_err!("SHOW TABLES FILTER not supported")?; + } + self.show_tables_to_plan() + } Statement::ShowColumns { extended, full, - table_name, - filter, - } => self.show_columns_to_plan(extended, full, table_name, filter), + show_options, + } => { + let ShowStatementOptions { + show_in, + starts_with, + limit, + limit_from, + filter_position, + } = show_options; + if starts_with.is_some() { + return not_impl_err!("SHOW COLUMNS LIKE not supported")?; + } + if limit.is_some() { + return not_impl_err!("SHOW COLUMNS LIMIT not supported")?; + } + if limit_from.is_some() { + return not_impl_err!("SHOW COLUMNS LIMIT FROM not supported")?; + } + if filter_position.is_some() { + return not_impl_err!( + "SHOW COLUMNS with WHERE or LIKE is not supported" + )?; + } + let Some(ShowStatementIn { + // specifies if the syntax was `SHOW COLUMNS IN` or `SHOW + // COLUMNS FROM` which is not different in DataFusion + clause: _, + parent_type, + parent_name, + }) = show_in + else { + return plan_err!("SHOW COLUMNS requires a table name"); + }; + + if let Some(parent_type) = parent_type { + return not_impl_err!("SHOW COLUMNS IN {parent_type} not supported"); + } + let Some(table_name) = parent_name else { + return plan_err!("SHOW COLUMNS requires a table name"); + }; + + self.show_columns_to_plan(extended, full, table_name) + } Statement::Insert(Insert { or, @@ -766,10 +878,14 @@ impl SqlToRel<'_, S> { from, selection, returning, + or, } => { if returning.is_some() { plan_err!("Update-returning clause not yet supported")?; } + if or.is_some() { + plan_err!("ON conflict not supported")?; + } self.update_to_plan(table, assignments, from, selection) } @@ -810,12 +926,14 @@ impl SqlToRel<'_, S> { modes, begin: false, modifier, + transaction, } => { if let Some(modifier) = modifier { return not_impl_err!( "Transaction modifier not supported: {modifier}" ); } + self.validate_transaction_kind(transaction)?; let isolation_level: ast::TransactionIsolationLevel = modes .iter() .filter_map(|m: &TransactionMode| match m { @@ -879,7 +997,7 @@ impl SqlToRel<'_, S> { }); Ok(LogicalPlan::Statement(statement)) } - Statement::CreateFunction { + Statement::CreateFunction(ast::CreateFunction { or_replace, temporary, name, @@ -889,7 +1007,7 @@ impl SqlToRel<'_, S> { behavior, language, .. - } => { + }) => { let return_type = match return_type { Some(t) => Some(self.convert_data_type(&t)?), None => None, @@ -1033,8 +1151,8 @@ impl SqlToRel<'_, S> { }, ))) } - _ => { - not_impl_err!("Unsupported SQL statement: {sql:?}") + stmt => { + not_impl_err!("Unsupported SQL statement: {stmt}") } } } @@ -1065,24 +1183,12 @@ impl SqlToRel<'_, S> { } /// Generate a logical plan from a "SHOW TABLES" query - fn show_tables_to_plan( - &self, - extended: bool, - full: bool, - db_name: Option, - filter: Option, - ) -> Result { + fn show_tables_to_plan(&self) -> Result { if self.has_table("information_schema", "tables") { - // We only support the basic "SHOW TABLES" - // https://github.com/apache/datafusion/issues/3188 - if db_name.is_some() || filter.is_some() || full || extended { - plan_err!("Unsupported parameters to SHOW TABLES") - } else { - let query = "SELECT * FROM information_schema.tables;"; - let mut rewrite = DFParser::parse_sql(query)?; - assert_eq!(rewrite.len(), 1); - self.statement_to_plan(rewrite.pop_front().unwrap()) // length of rewrite is 1 - } + let query = "SELECT * FROM information_schema.tables;"; + let mut rewrite = DFParser::parse_sql(query)?; + assert_eq!(rewrite.len(), 1); + self.statement_to_plan(rewrite.pop_front().unwrap()) // length of rewrite is 1 } else { plan_err!("SHOW TABLES is not supported unless information_schema is enabled") } @@ -1841,22 +1947,18 @@ impl SqlToRel<'_, S> { extended: bool, full: bool, sql_table_name: ObjectName, - filter: Option, ) -> Result { - if filter.is_some() { - return plan_err!("SHOW COLUMNS with WHERE or LIKE is not supported"); - } + // Figure out the where clause + let where_clause = object_name_to_qualifier( + &sql_table_name, + self.options.enable_ident_normalization, + ); if !self.has_table("information_schema", "columns") { return plan_err!( "SHOW COLUMNS is not supported unless information_schema is enabled" ); } - // Figure out the where clause - let where_clause = object_name_to_qualifier( - &sql_table_name, - self.options.enable_ident_normalization, - ); // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(sql_table_name)?; @@ -1916,4 +2018,19 @@ impl SqlToRel<'_, S> { .get_table_source(tables_reference) .is_ok() } + + fn validate_transaction_kind( + &self, + kind: Option, + ) -> Result<()> { + match kind { + // BEGIN + None => Ok(()), + // BEGIN TRANSACTION + Some(BeginTransactionKind::Transaction) => Ok(()), + Some(BeginTransactionKind::Work) => { + not_impl_err!("Transaction kind not supported: {kind:?}") + } + } + } } diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index ad0b5f16b283..345d16adef29 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -24,6 +24,7 @@ use core::fmt; use sqlparser::ast; +use sqlparser::ast::helpers::attached_token::AttachedToken; #[derive(Clone)] pub(super) struct QueryBuilder { @@ -268,6 +269,7 @@ impl SelectBuilder { connect_by: None, window_before_qualify: false, prewhere: None, + select_token: AttachedToken::empty(), }) } fn create_empty() -> Self { @@ -469,6 +471,7 @@ impl TableRelationBuilder { version: self.version.clone(), partitions: self.partitions.clone(), with_ordinality: false, + json_path: None, }) } fn create_empty() -> Self { diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index ae387d441fa2..a82687533e31 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -18,15 +18,15 @@ use std::sync::Arc; use arrow_schema::TimeUnit; +use datafusion_common::Result; use datafusion_expr::Expr; use regex::Regex; +use sqlparser::tokenizer::Span; use sqlparser::{ ast::{self, BinaryOperator, Function, Ident, ObjectName, TimezoneInfo}, keywords::ALL_KEYWORDS, }; -use datafusion_common::Result; - use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser}; /// `Dialect` to use for Unparsing @@ -288,6 +288,7 @@ impl PostgreSqlDialect { name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, + span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: None, @@ -299,6 +300,7 @@ impl PostgreSqlDialect { over: None, within_group: vec![], parameters: ast::FunctionArguments::None, + uses_odbc_syntax: false, })) } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 48dfc425ee63..f09de133b571 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -43,6 +43,8 @@ use datafusion_expr::{ expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; +use sqlparser::ast::helpers::attached_token::AttachedToken; +use sqlparser::tokenizer::Span; /// Convert a DataFusion [`Expr`] to [`ast::Expr`] /// @@ -233,6 +235,7 @@ impl Unparser<'_> { name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, + span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: None, @@ -244,6 +247,7 @@ impl Unparser<'_> { over, within_group: vec![], parameters: ast::FunctionArguments::None, + uses_odbc_syntax: false, })) } Expr::SimilarTo(Like { @@ -278,6 +282,7 @@ impl Unparser<'_> { name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, + span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: agg @@ -291,6 +296,7 @@ impl Unparser<'_> { over: None, within_group: vec![], parameters: ast::FunctionArguments::None, + uses_odbc_syntax: false, })) } Expr::ScalarSubquery(subq) => { @@ -404,12 +410,16 @@ impl Unparser<'_> { } // TODO: unparsing wildcard addition options Expr::Wildcard { qualifier, .. } => { + let attached_token = AttachedToken::empty(); if let Some(qualifier) = qualifier { let idents: Vec = qualifier.to_vec().into_iter().map(Ident::new).collect(); - Ok(ast::Expr::QualifiedWildcard(ObjectName(idents))) + Ok(ast::Expr::QualifiedWildcard( + ObjectName(idents), + attached_token, + )) } else { - Ok(ast::Expr::Wildcard) + Ok(ast::Expr::Wildcard(attached_token)) } } Expr::GroupingSet(grouping_set) => match grouping_set { @@ -480,6 +490,7 @@ impl Unparser<'_> { name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, + span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: None, @@ -491,6 +502,7 @@ impl Unparser<'_> { over: None, within_group: vec![], parameters: ast::FunctionArguments::None, + uses_odbc_syntax: false, })) } @@ -709,6 +721,7 @@ impl Unparser<'_> { Ident { value: ident, quote_style, + span: Span::empty(), } } @@ -716,6 +729,7 @@ impl Unparser<'_> { Ident { value: str, quote_style: None, + span: Span::empty(), } } @@ -1481,6 +1495,7 @@ impl Unparser<'_> { name: ObjectName(vec![Ident { value: "UNNEST".to_string(), quote_style: None, + span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: None, @@ -1492,6 +1507,7 @@ impl Unparser<'_> { over: None, within_group: vec![], parameters: ast::FunctionArguments::None, + uses_odbc_syntax: false, })) } @@ -1672,7 +1688,7 @@ mod tests { let dummy_logical_plan = table_scan(Some("t"), &dummy_schema, None)? .project(vec![Expr::Wildcard { qualifier: None, - options: WildcardOptions::default(), + options: Box::new(WildcardOptions::default()), }])? .filter(col("a").eq(lit(1)))? .build()?; @@ -1864,10 +1880,7 @@ mod tests { (sum(col("a")), r#"sum(a)"#), ( count_udaf() - .call(vec![Expr::Wildcard { - qualifier: None, - options: WildcardOptions::default(), - }]) + .call(vec![wildcard()]) .distinct() .build() .unwrap(), @@ -1875,10 +1888,7 @@ mod tests { ), ( count_udaf() - .call(vec![Expr::Wildcard { - qualifier: None, - options: WildcardOptions::default(), - }]) + .call(vec![wildcard()]) .filter(lit(true)) .build() .unwrap(), diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e9f9f486ea9a..f2d46a9f4cce 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -44,7 +44,7 @@ use datafusion_expr::{ expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, Unnest, }; -use sqlparser::ast::{self, Ident, SetExpr}; +use sqlparser::ast::{self, Ident, SetExpr, TableAliasColumnDef}; use std::sync::Arc; /// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`] @@ -1069,6 +1069,13 @@ impl Unparser<'_> { } fn new_table_alias(&self, alias: String, columns: Vec) -> ast::TableAlias { + let columns = columns + .into_iter() + .map(|ident| TableAliasColumnDef { + name: ident, + data_type: None, + }) + .collect(); ast::TableAlias { name: self.new_ident_quoted_if_needs(alias), columns, diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 354a68f60964..3a7fa5ddcabb 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -17,6 +17,10 @@ use std::{cmp::Ordering, sync::Arc, vec}; +use super::{ + dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle, + rewrite::TableAliasRewriter, Unparser, +}; use datafusion_common::{ internal_err, tree_node::{Transformed, TransformedResult, TreeNode}, @@ -29,11 +33,7 @@ use datafusion_expr::{ use indexmap::IndexSet; use sqlparser::ast; - -use super::{ - dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle, - rewrite::TableAliasRewriter, Unparser, -}; +use sqlparser::tokenizer::Span; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -426,6 +426,7 @@ pub(crate) fn date_part_to_sql( name: ast::ObjectName(vec![ast::Ident { value: "strftime".to_string(), quote_style: None, + span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { duplicate_treatment: None, @@ -444,6 +445,7 @@ pub(crate) fn date_part_to_sql( over: None, within_group: vec![], parameters: ast::FunctionArguments::None, + uses_odbc_syntax: false, }))); } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 1f6b5f9852ec..476a933c72b7 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -483,10 +483,10 @@ set datafusion.catalog.information_schema = true; statement ok CREATE TABLE t AS SELECT 1::int as i; -statement error Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported +statement error DataFusion error: This feature is not implemented: SHOW COLUMNS with WHERE or LIKE is not supported SHOW COLUMNS FROM t LIKE 'f'; -statement error Error during planning: SHOW COLUMNS with WHERE or LIKE is not supported +statement error DataFusion error: This feature is not implemented: SHOW COLUMNS with WHERE or LIKE is not supported SHOW COLUMNS FROM t WHERE column_name = 'bar'; query TTTTTT