diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index fe65af0a1486..48a5e2f9a07c 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -995,16 +995,19 @@ fn project_with_column_index( .enumerate() .map(|(i, e)| match e { Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { - e.unalias().alias(schema.field(i).name()) + Ok(e.unalias().alias(schema.field(i).name())) } Expr::Column(Column { relation: _, ref name, - }) if name != schema.field(i).name() => e.alias(schema.field(i).name()), - Expr::Alias { .. } | Expr::Column { .. } => e, - _ => e.alias(schema.field(i).name()), + }) if name != schema.field(i).name() => Ok(e.alias(schema.field(i).name())), + Expr::Alias { .. } | Expr::Column { .. } => Ok(e), + Expr::Wildcard { .. } => { + plan_err!("Wildcard should be expanded before type coercion") + } + _ => Ok(e.alias(schema.field(i).name())), }) - .collect::>(); + .collect::>>()?; Projection::try_new_with_schema(alias_expr, input, schema) .map(LogicalPlan::Projection) @@ -1018,6 +1021,10 @@ mod test { use arrow::datatypes::DataType::Utf8; use arrow::datatypes::{DataType, Field, TimeUnit}; + use crate::analyzer::type_coercion::{ + coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + }; + use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; @@ -1032,11 +1039,6 @@ mod test { }; use datafusion_functions_aggregate::average::AvgAccumulator; - use crate::analyzer::type_coercion::{ - coerce_case_expression, TypeCoercion, TypeCoercionRewriter, - }; - use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq}; - fn empty() -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 29fac5cc3dec..b9073f5ac881 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -22,11 +22,13 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{plan_err, Result}; +use datafusion_common::{assert_contains, plan_err, Result}; +use datafusion_expr::sqlparser::dialect::PostgreSqlDialect; use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; use datafusion_functions_aggregate::average::avg_udaf; use datafusion_functions_aggregate::count::count_udaf; +use datafusion_optimizer::analyzer::type_coercion::TypeCoercionRewriter; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; @@ -387,6 +389,32 @@ fn select_correlated_predicate_subquery_with_uppercase_ident() { assert_eq!(expected, format!("{plan}")); } +// The test should return an error +// because the wildcard didn't be expanded before type coercion +#[test] +fn test_union_coercion_with_wildcard() -> Result<()> { + let dialect = PostgreSqlDialect {}; + let context_provider = MyContextProvider::default(); + let sql = "select * from (SELECT col_int32, col_uint32 FROM test) union all select * from(SELECT col_uint32, col_int32 FROM test)"; + let statements = Parser::parse_sql(&dialect, sql)?; + let sql_to_rel = SqlToRel::new(&context_provider); + let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; + + if let LogicalPlan::Union(union) = logical_plan { + let err = TypeCoercionRewriter::coerce_union(union) + .err() + .unwrap() + .to_string(); + assert_contains!( + err, + "Error during planning: Wildcard should be expanded before type coercion" + ); + } else { + panic!("Expected Union plan"); + } + Ok(()) +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...