diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index cfc28f2c499f..02934a004d6f 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -120,6 +120,12 @@ pub trait Dialect: Send + Sync { true } + /// Whether the dialect requires a table alias for any subquery in the FROM clause + /// This affects behavior when deriving logical plans for Sort, Limit, etc. + fn requires_derived_table_alias(&self) -> bool { + false + } + /// Allows the dialect to override scalar function unparsing if the dialect has specific rules. /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is /// a custom implementation for the function. @@ -300,6 +306,10 @@ impl Dialect for MySqlDialect { ast::DataType::Datetime(None) } + fn requires_derived_table_alias(&self) -> bool { + true + } + fn scalar_function_to_sql_overrides( &self, unparser: &Unparser, @@ -362,6 +372,7 @@ pub struct CustomDialect { timestamp_tz_cast_dtype: ast::DataType, date32_cast_dtype: sqlparser::ast::DataType, supports_column_alias_in_table_alias: bool, + requires_derived_table_alias: bool, } impl Default for CustomDialect { @@ -384,6 +395,7 @@ impl Default for CustomDialect { ), date32_cast_dtype: sqlparser::ast::DataType::Date, supports_column_alias_in_table_alias: true, + requires_derived_table_alias: false, } } } @@ -472,6 +484,10 @@ impl Dialect for CustomDialect { Ok(None) } + + fn requires_derived_table_alias(&self) -> bool { + self.requires_derived_table_alias + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -503,6 +519,7 @@ pub struct CustomDialectBuilder { timestamp_tz_cast_dtype: ast::DataType, date32_cast_dtype: ast::DataType, supports_column_alias_in_table_alias: bool, + requires_derived_table_alias: bool, } impl Default for CustomDialectBuilder { @@ -531,6 +548,7 @@ impl CustomDialectBuilder { ), date32_cast_dtype: sqlparser::ast::DataType::Date, supports_column_alias_in_table_alias: true, + requires_derived_table_alias: false, } } @@ -551,6 +569,7 @@ impl CustomDialectBuilder { date32_cast_dtype: self.date32_cast_dtype, supports_column_alias_in_table_alias: self .supports_column_alias_in_table_alias, + requires_derived_table_alias: self.requires_derived_table_alias, } } @@ -653,4 +672,12 @@ impl CustomDialectBuilder { self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias; self } + + pub fn with_requires_derived_table_alias( + mut self, + requires_derived_table_alias: bool, + ) -> Self { + self.requires_derived_table_alias = requires_derived_table_alias; + self + } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index c22400f1faa1..8e70654d8d6f 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -222,9 +222,14 @@ impl Unparser<'_> { Ok(()) } - fn derive(&self, plan: &LogicalPlan, relation: &mut RelationBuilder) -> Result<()> { + fn derive( + &self, + plan: &LogicalPlan, + relation: &mut RelationBuilder, + alias: Option, + ) -> Result<()> { let mut derived_builder = DerivedRelationBuilder::default(); - derived_builder.lateral(false).alias(None).subquery({ + derived_builder.lateral(false).alias(alias).subquery({ let inner_statement = self.plan_to_sql(plan)?; if let ast::Statement::Query(inner_query) = inner_statement { inner_query @@ -239,6 +244,23 @@ impl Unparser<'_> { Ok(()) } + fn derive_with_dialect_alias( + &self, + alias: &str, + plan: &LogicalPlan, + relation: &mut RelationBuilder, + ) -> Result<()> { + if self.dialect.requires_derived_table_alias() { + self.derive( + plan, + relation, + Some(self.new_table_alias(alias.to_string(), vec![])), + ) + } else { + self.derive(plan, relation, None) + } + } + fn select_to_sql_recursively( &self, plan: &LogicalPlan, @@ -284,7 +306,11 @@ impl Unparser<'_> { // Projection can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_projection", + plan, + relation, + ); } self.reconstruct_select_statement(plan, p, select)?; self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) @@ -311,8 +337,13 @@ impl Unparser<'_> { LogicalPlan::Limit(limit) => { // Limit can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_limit", + plan, + relation, + ); } + if let Some(fetch) = limit.fetch { let Some(query) = query.as_mut() else { return internal_err!( @@ -350,7 +381,11 @@ impl Unparser<'_> { LogicalPlan::Sort(sort) => { // Sort can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_sort", + plan, + relation, + ); } let Some(query_ref) = query else { return internal_err!( @@ -396,7 +431,11 @@ impl Unparser<'_> { LogicalPlan::Distinct(distinct) => { // Distinct can be top-level plan for derived table if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_distinct", + plan, + relation, + ); } let (select_distinct, input) = match distinct { Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()), @@ -559,7 +598,11 @@ impl Unparser<'_> { // Covers cases where the UNION is a subquery and the projection is at the top level if select.already_projected() { - return self.derive(plan, relation); + return self.derive_with_dialect_alias( + "derived_union", + plan, + relation, + ); } let input_exprs: Vec = union diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 2a3c5b5f6b2b..0de74e050553 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -261,6 +261,45 @@ fn roundtrip_statement_with_dialect() -> Result<()> { unparser_dialect: Box, } let tests: Vec = vec![ + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + expected: + // top projection sort gets derived into a subquery + // for MySQL, this subquery needs an alias + "SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + expected: + // top projection sort still gets derived into a subquery in default dialect + // except for the default dialect, the subquery is left non-aliased + "SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", + expected: + "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select 1 as j1_id);", + expected: + "SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select * from (select * from j1 limit 10);", + expected: + "SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, TestStatementWithDialect { sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", expected: