Skip to content

Commit

Permalink
fix: Dialect requires derived table alias (apache#12994)
Browse files Browse the repository at this point in the history
* fix: Dialect requires table alias (#46)

* fix: Add Dialect option for requiring table aliases

* feat: Add CustomDialectBuilder for requires_table_alias

* docs: Spelling

* refactor: rename requires_derived_table_alias

* refactor: rename requires_derived_table_alias

* review: Rewrite match to if, add another test case

* test: Update RHS expected

* test: Update tests with more cases
  • Loading branch information
peasee authored Oct 21, 2024
1 parent e9584bc commit b42d9b8
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 7 deletions.
27 changes: 27 additions & 0 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ pub trait Dialect: Send + Sync {
true
}

/// Whether the dialect requires a table alias for any subquery in the FROM clause
/// This affects behavior when deriving logical plans for Sort, Limit, etc.
fn requires_derived_table_alias(&self) -> bool {
false
}

/// Allows the dialect to override scalar function unparsing if the dialect has specific rules.
/// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is
/// a custom implementation for the function.
Expand Down Expand Up @@ -300,6 +306,10 @@ impl Dialect for MySqlDialect {
ast::DataType::Datetime(None)
}

fn requires_derived_table_alias(&self) -> bool {
true
}

fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
Expand Down Expand Up @@ -362,6 +372,7 @@ pub struct CustomDialect {
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: sqlparser::ast::DataType,
supports_column_alias_in_table_alias: bool,
requires_derived_table_alias: bool,
}

impl Default for CustomDialect {
Expand All @@ -384,6 +395,7 @@ impl Default for CustomDialect {
),
date32_cast_dtype: sqlparser::ast::DataType::Date,
supports_column_alias_in_table_alias: true,
requires_derived_table_alias: false,
}
}
}
Expand Down Expand Up @@ -472,6 +484,10 @@ impl Dialect for CustomDialect {

Ok(None)
}

fn requires_derived_table_alias(&self) -> bool {
self.requires_derived_table_alias
}
}

/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
Expand Down Expand Up @@ -503,6 +519,7 @@ pub struct CustomDialectBuilder {
timestamp_tz_cast_dtype: ast::DataType,
date32_cast_dtype: ast::DataType,
supports_column_alias_in_table_alias: bool,
requires_derived_table_alias: bool,
}

impl Default for CustomDialectBuilder {
Expand Down Expand Up @@ -531,6 +548,7 @@ impl CustomDialectBuilder {
),
date32_cast_dtype: sqlparser::ast::DataType::Date,
supports_column_alias_in_table_alias: true,
requires_derived_table_alias: false,
}
}

Expand All @@ -551,6 +569,7 @@ impl CustomDialectBuilder {
date32_cast_dtype: self.date32_cast_dtype,
supports_column_alias_in_table_alias: self
.supports_column_alias_in_table_alias,
requires_derived_table_alias: self.requires_derived_table_alias,
}
}

Expand Down Expand Up @@ -653,4 +672,12 @@ impl CustomDialectBuilder {
self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias;
self
}

pub fn with_requires_derived_table_alias(
mut self,
requires_derived_table_alias: bool,
) -> Self {
self.requires_derived_table_alias = requires_derived_table_alias;
self
}
}
57 changes: 50 additions & 7 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,14 @@ impl Unparser<'_> {
Ok(())
}

fn derive(&self, plan: &LogicalPlan, relation: &mut RelationBuilder) -> Result<()> {
fn derive(
&self,
plan: &LogicalPlan,
relation: &mut RelationBuilder,
alias: Option<ast::TableAlias>,
) -> Result<()> {
let mut derived_builder = DerivedRelationBuilder::default();
derived_builder.lateral(false).alias(None).subquery({
derived_builder.lateral(false).alias(alias).subquery({
let inner_statement = self.plan_to_sql(plan)?;
if let ast::Statement::Query(inner_query) = inner_statement {
inner_query
Expand All @@ -239,6 +244,23 @@ impl Unparser<'_> {
Ok(())
}

fn derive_with_dialect_alias(
&self,
alias: &str,
plan: &LogicalPlan,
relation: &mut RelationBuilder,
) -> Result<()> {
if self.dialect.requires_derived_table_alias() {
self.derive(
plan,
relation,
Some(self.new_table_alias(alias.to_string(), vec![])),
)
} else {
self.derive(plan, relation, None)
}
}

fn select_to_sql_recursively(
&self,
plan: &LogicalPlan,
Expand Down Expand Up @@ -284,7 +306,11 @@ impl Unparser<'_> {

// Projection can be top-level plan for derived table
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_projection",
plan,
relation,
);
}
self.reconstruct_select_statement(plan, p, select)?;
self.select_to_sql_recursively(p.input.as_ref(), query, select, relation)
Expand All @@ -311,8 +337,13 @@ impl Unparser<'_> {
LogicalPlan::Limit(limit) => {
// Limit can be top-level plan for derived table
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_limit",
plan,
relation,
);
}

if let Some(fetch) = limit.fetch {
let Some(query) = query.as_mut() else {
return internal_err!(
Expand Down Expand Up @@ -350,7 +381,11 @@ impl Unparser<'_> {
LogicalPlan::Sort(sort) => {
// Sort can be top-level plan for derived table
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_sort",
plan,
relation,
);
}
let Some(query_ref) = query else {
return internal_err!(
Expand Down Expand Up @@ -396,7 +431,11 @@ impl Unparser<'_> {
LogicalPlan::Distinct(distinct) => {
// Distinct can be top-level plan for derived table
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_distinct",
plan,
relation,
);
}
let (select_distinct, input) = match distinct {
Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()),
Expand Down Expand Up @@ -559,7 +598,11 @@ impl Unparser<'_> {

// Covers cases where the UNION is a subquery and the projection is at the top level
if select.already_projected() {
return self.derive(plan, relation);
return self.derive_with_dialect_alias(
"derived_union",
plan,
relation,
);
}

let input_exprs: Vec<SetExpr> = union
Expand Down
39 changes: 39 additions & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,45 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
unparser_dialect: Box<dyn UnparserDialect>,
}
let tests: Vec<TestStatementWithDialect> = vec![
TestStatementWithDialect {
sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;",
expected:
// top projection sort gets derived into a subquery
// for MySQL, this subquery needs an alias
"SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;",
expected:
// top projection sort still gets derived into a subquery in default dialect
// except for the default dialect, the subquery is left non-aliased
"SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10",
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;",
expected:
"SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id from (select 1 as j1_id);",
expected:
"SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select * from (select * from j1 limit 10);",
expected:
"SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
expected:
Expand Down

0 comments on commit b42d9b8

Please sign in to comment.