Skip to content

Commit

Permalink
Allow aggregation without projection in Unparser
Browse files Browse the repository at this point in the history
  • Loading branch information
blaginin committed Nov 9, 2024
1 parent 0f584c8 commit 7fedaa5
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 56 deletions.
23 changes: 21 additions & 2 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ impl Unparser<'_> {
let body = self.select_to_sql_expr(plan, &mut query_builder)?;

let query = query_builder.unwrap().body(Box::new(body)).build()?;

Ok(ast::Statement::Query(Box::new(query)))
}

Expand Down Expand Up @@ -420,7 +419,27 @@ impl Unparser<'_> {
)
}
LogicalPlan::Aggregate(agg) => {
// Aggregate nodes are handled simultaneously with Projection nodes
// Aggregation can be already handled in the projection case
if !select.already_projected() {
// The query returns aggregate and group expressions. If that weren't the case,
// the aggregate would have been placed inside a projection, making the check above^ false
let exprs: Vec<_> = agg
.aggr_expr
.iter()
.chain(agg.group_expr.iter())
.map(|expr| self.select_item_to_sql(expr))
.collect::<Result<Vec<_>>>()?;
select.projection(exprs);

select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
vec![],
));
}

self.select_to_sql_recursively(
agg.input.as_ref(),
query,
Expand Down
136 changes: 82 additions & 54 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use std::vec;

use arrow_schema::*;
use datafusion_common::{DFSchema, Result, TableReference};
use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf};
use datafusion_expr::test::function_stub::{
count_udaf, max_udaf, min_udaf, sum, sum_udaf,
};
use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder};
use datafusion_functions::unicode;
use datafusion_functions_aggregate::grouping::grouping_udaf;
Expand Down Expand Up @@ -87,45 +89,45 @@ fn roundtrip_expr() {
#[test]
fn roundtrip_statement() -> Result<()> {
let tests: Vec<&str> = vec![
"select 1;",
"select 1 limit 0;",
"select ta.j1_id from j1 ta join (select 1 as j1_id) tb on ta.j1_id = tb.j1_id;",
"select ta.j1_id from j1 ta join (select 1 as j1_id) tb using (j1_id);",
"select ta.j1_id from j1 ta join (select 1 as j1_id) tb on ta.j1_id = tb.j1_id where ta.j1_id > 1;",
"select ta.j1_id from (select 1 as j1_id) ta;",
"select ta.j1_id from j1 ta;",
"select ta.j1_id from j1 ta order by ta.j1_id;",
"select * from j1 ta order by ta.j1_id, ta.j1_string desc;",
"select * from j1 limit 10;",
"select ta.j1_id from j1 ta where ta.j1_id > 1;",
"select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id);",
"select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);",
"select * from (select id, first_name from person)",
"select * from (select id, first_name from (select * from person))",
"select id, count(*) as cnt from (select id from person) group by id",
"select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from (select (id-1) as id from person) group by id",
"select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id))",
r#"select "First Name" from person_quoted_cols"#,
"select DISTINCT id FROM person",
"select DISTINCT on (id) id, first_name from person",
"select DISTINCT on (id) id, first_name from person order by id",
r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#,
"select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id",
"select id, count(*), first_name from person group by first_name, id",
"select id, sum(age), first_name from person group by first_name, id",
"select id, count(*), first_name
"select 1;",
"select 1 limit 0;",
"select ta.j1_id from j1 ta join (select 1 as j1_id) tb on ta.j1_id = tb.j1_id;",
"select ta.j1_id from j1 ta join (select 1 as j1_id) tb using (j1_id);",
"select ta.j1_id from j1 ta join (select 1 as j1_id) tb on ta.j1_id = tb.j1_id where ta.j1_id > 1;",
"select ta.j1_id from (select 1 as j1_id) ta;",
"select ta.j1_id from j1 ta;",
"select ta.j1_id from j1 ta order by ta.j1_id;",
"select * from j1 ta order by ta.j1_id, ta.j1_string desc;",
"select * from j1 limit 10;",
"select ta.j1_id from j1 ta where ta.j1_id > 1;",
"select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id);",
"select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);",
"select * from (select id, first_name from person)",
"select * from (select id, first_name from (select * from person))",
"select id, count(*) as cnt from (select id from person) group by id",
"select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from (select (id-1) as id from person) group by id",
"select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id))",
r#"select "First Name" from person_quoted_cols"#,
"select DISTINCT id FROM person",
"select DISTINCT on (id) id, first_name from person",
"select DISTINCT on (id) id, first_name from person order by id",
r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#,
"select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id",
"select id, count(*), first_name from person group by first_name, id",
"select id, sum(age), first_name from person group by first_name, id",
"select id, count(*), first_name
from person
where id!=3 and first_name=='test'
group by first_name, id
having count(*)>5 and count(*)<10
order by count(*)",
r#"select id, count("First Name") as count_first_name, "Last Name"
r#"select id, count("First Name") as count_first_name, "Last Name"
from person_quoted_cols
where id!=3 and "First Name"=='test'
group by "Last Name", id
having count_first_name>5 and count_first_name<10
order by count_first_name, "Last Name""#,
r#"select p.id, count("First Name") as count_first_name,
r#"select p.id, count("First Name") as count_first_name,
"Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person)
from (select id, "First Name", "Last Name" from person_quoted_cols) qp
inner join (select * from person) p
Expand All @@ -135,46 +137,46 @@ fn roundtrip_statement() -> Result<()> {
group by "Last Name", p.id
having count_first_name>5 and count_first_name<10
order by count_first_name, "Last Name""#,
r#"SELECT j1_string as string FROM j1
r#"SELECT j1_string as string FROM j1
UNION ALL
SELECT j2_string as string FROM j2"#,
r#"SELECT j1_string as string FROM j1
r#"SELECT j1_string as string FROM j1
UNION ALL
SELECT j2_string as string FROM j2
ORDER BY string DESC
LIMIT 10"#,
r#"SELECT col1, id FROM (
r#"SELECT col1, id FROM (
SELECT j1_string AS col1, j1_id AS id FROM j1
UNION ALL
SELECT j2_string AS col1, j2_id AS id FROM j2
UNION ALL
SELECT j3_string AS col1, j3_id AS id FROM j3
) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#,
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
first_name from person",
r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#,
"SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person",
"WITH t1 AS (SELECT j1_id AS id, j1_string name FROM j1), t2 AS (SELECT j2_id AS id, j2_string name FROM j2) SELECT * FROM t1 JOIN t2 USING (id, name)",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col",
r#"SELECT id, first_name,
"SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person",
"WITH t1 AS (SELECT j1_id AS id, j1_string name FROM j1), t2 AS (SELECT j2_id AS id, j2_string name FROM j2) SELECT * FROM t1 JOIN t2 USING (id, name)",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4",
"WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col",
r#"SELECT id, first_name,
SUM(id) AS total_sum,
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
FROM person JOIN orders ON person.id = orders.customer_id GROUP BY id, first_name"#,
r#"SELECT id, first_name,
r#"SELECT id, first_name,
SUM(id) AS total_sum,
SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum,
MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total
FROM (SELECT id, first_name from person) person JOIN (SELECT customer_id FROM orders) orders ON person.id = orders.customer_id GROUP BY id, first_name"#,
r#"SELECT id, first_name, last_name, customer_id, SUM(id) AS total_sum
r#"SELECT id, first_name, last_name, customer_id, SUM(id) AS total_sum
FROM person
JOIN orders ON person.id = orders.customer_id
GROUP BY ROLLUP(id, first_name, last_name, customer_id)"#,
r#"SELECT id, first_name, last_name,
r#"SELECT id, first_name, last_name,
SUM(id) AS total_sum,
COUNT(*) AS total_count,
SUM(id) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total
Expand Down Expand Up @@ -265,46 +267,46 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
TestStatementWithDialect {
sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;",
expected:
// top projection sort gets derived into a subquery
// for MySQL, this subquery needs an alias
"SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
// top projection sort gets derived into a subquery
// for MySQL, this subquery needs an alias
"SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;",
expected:
// top projection sort still gets derived into a subquery in default dialect
// except for the default dialect, the subquery is left non-aliased
"SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10",
// top projection sort still gets derived into a subquery in default dialect
// except for the default dialect, the subquery is left non-aliased
"SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10",
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;",
expected:
"SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
"SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select j1_id from (select 1 as j1_id);",
expected:
"SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`",
"SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select * from (select * from j1 limit 10);",
expected:
"SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`",
"SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
TestStatementWithDialect {
sql: "select ta.j1_id from j1 ta order by j1_id limit 10;",
expected:
"SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10",
"SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10",
parser_dialect: Box::new(MySqlDialect {}),
unparser_dialect: Box::new(UnparserMySqlDialect {}),
},
Expand Down Expand Up @@ -563,6 +565,32 @@ Projection: unnest_placeholder(unnest_table.struct_col).field1, unnest_placehold
Ok(())
}

#[test]
fn test_aggregation_without_projection() -> Result<()> {
let schema = Schema::new(vec![
Field::new("name", DataType::Utf8, false),
Field::new("age", DataType::UInt8, false),
]);

let plan = LogicalPlanBuilder::from(
table_scan(Some("users"), &schema, Some(vec![0, 1]))?.build()?,
)
.aggregate(vec![col("name")], vec![sum(col("age"))])?
.build()?;

let unparser = Unparser::default();
let roundtrip_statement = unparser.plan_to_sql(&plan)?;

let actual = &roundtrip_statement.to_string();

assert_eq!(
actual,
r#"SELECT sum(users.age), users."name" FROM (SELECT users."name", users.age FROM users) GROUP BY users."name""#
);

Ok(())
}

#[test]
fn test_table_references_in_plan_to_sql() {
fn test(table_name: &str, expected_sql: &str, dialect: &impl UnparserDialect) {
Expand Down

0 comments on commit 7fedaa5

Please sign in to comment.