From 9b492c6a5e168171f14d4e985fcb43b535c2e872 Mon Sep 17 00:00:00 2001 From: Sergei Grebnov Date: Sun, 6 Oct 2024 04:12:50 -0700 Subject: [PATCH] Improve `round` scalar function unparsing for Postgres (#12744) * Postgres: enforce required `NUMERIC` type for `round` scalar function (#34) Includes initial support for dialects to override scalar functions unparsing * Document scalar_function_to_sql_overrides fn --- datafusion/sql/src/unparser/dialect.rs | 119 ++++++++++++++- datafusion/sql/src/unparser/expr.rs | 198 +++++++++---------------- datafusion/sql/src/unparser/utils.rs | 82 +++++++++- 3 files changed, 273 insertions(+), 126 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index d8a4fb254264..609e6f2240e1 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -18,12 +18,17 @@ use std::sync::Arc; use arrow_schema::TimeUnit; +use datafusion_expr::Expr; use regex::Regex; use sqlparser::{ - ast::{self, Ident, ObjectName, TimezoneInfo}, + ast::{self, Function, Ident, ObjectName, TimezoneInfo}, keywords::ALL_KEYWORDS, }; +use datafusion_common::Result; + +use super::{utils::date_part_to_sql, Unparser}; + /// `Dialect` to use for Unparsing /// /// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`) @@ -108,6 +113,18 @@ pub trait Dialect: Send + Sync { fn supports_column_alias_in_table_alias(&self) -> bool { true } + + /// Allows the dialect to override scalar function unparsing if the dialect has specific rules. + /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is + /// a custom implementation for the function. + fn scalar_function_to_sql_overrides( + &self, + _unparser: &Unparser, + _func_name: &str, + _args: &[Expr], + ) -> Result> { + Ok(None) + } } /// `IntervalStyle` to use for unparsing @@ -171,6 +188,67 @@ impl Dialect for PostgreSqlDialect { fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { sqlparser::ast::DataType::DoublePrecision } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "round" { + return Ok(Some( + self.round_to_sql_enforce_numeric(unparser, func_name, args)?, + )); + } + + Ok(None) + } +} + +impl PostgreSqlDialect { + fn round_to_sql_enforce_numeric( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result { + let mut args = unparser.function_args_to_sql(args)?; + + // Enforce the first argument to be Numeric + if let Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) = + args.first_mut() + { + if let ast::Expr::Cast { data_type, .. } = expr { + // Don't create an additional cast wrapper if we can update the existing one + *data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None); + } else { + // Wrap the expression in a new cast + *expr = ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(expr.clone()), + data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None), + format: None, + }; + } + } + + Ok(ast::Expr::Function(Function { + name: ast::ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } } pub struct MySqlDialect {} @@ -211,6 +289,19 @@ impl Dialect for MySqlDialect { ) -> ast::DataType { ast::DataType::Datetime(None) } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } pub struct SqliteDialect {} @@ -231,6 +322,19 @@ impl Dialect for SqliteDialect { fn supports_column_alias_in_table_alias(&self) -> bool { false } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } pub struct CustomDialect { @@ -339,6 +443,19 @@ impl Dialect for CustomDialect { fn supports_column_alias_in_table_alias(&self) -> bool { self.supports_column_alias_in_table_alias } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index b924268a7657..537ac2274424 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,16 +15,15 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::ScalarUDF; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, - ObjectName, TimezoneInfo, UnaryOperator, + self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, + TimezoneInfo, UnaryOperator, }; use std::sync::Arc; use std::vec; -use super::dialect::{DateFieldExtractStyle, IntervalStyle}; +use super::dialect::IntervalStyle; use super::Unparser; use arrow::datatypes::{Decimal128Type, Decimal256Type, DecimalType}; use arrow::util::display::array_value_to_string; @@ -116,47 +115,14 @@ impl Unparser<'_> { Expr::ScalarFunction(ScalarFunction { func, args }) => { let func_name = func.name(); - if let Some(expr) = - self.scalar_function_to_sql_overrides(func_name, func, args) + if let Some(expr) = self + .dialect + .scalar_function_to_sql_overrides(self, func_name, args)? { return Ok(expr); } - let args = args - .iter() - .map(|e| { - if matches!( - e, - Expr::Wildcard { - qualifier: None, - .. - } - ) { - Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) - } else { - self.expr_to_sql_inner(e).map(|e| { - FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) - }) - } - }) - .collect::>>()?; - - Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { - value: func_name.to_string(), - quote_style: None, - }]), - args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: None, - args, - clauses: vec![], - }), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - parameters: ast::FunctionArguments::None, - })) + self.scalar_function_to_sql(func_name, args) } Expr::Between(Between { expr, @@ -508,6 +474,30 @@ impl Unparser<'_> { } } + pub fn scalar_function_to_sql( + &self, + func_name: &str, + args: &[Expr], + ) -> Result { + let args = self.function_args_to_sql(args)?; + Ok(ast::Expr::Function(Function { + name: ast::ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } + pub fn sort_to_sql(&self, sort: &Sort) -> Result { let Sort { expr, @@ -530,87 +520,6 @@ impl Unparser<'_> { }) } - fn scalar_function_to_sql_overrides( - &self, - func_name: &str, - _func: &Arc, - args: &[Expr], - ) -> Option { - if func_name.to_lowercase() == "date_part" { - match (self.dialect.date_field_extract_style(), args.len()) { - (DateFieldExtractStyle::Extract, 2) => { - let date_expr = self.expr_to_sql(&args[1]).ok()?; - - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] { - let field = match field.to_lowercase().as_str() { - "year" => ast::DateTimeField::Year, - "month" => ast::DateTimeField::Month, - "day" => ast::DateTimeField::Day, - "hour" => ast::DateTimeField::Hour, - "minute" => ast::DateTimeField::Minute, - "second" => ast::DateTimeField::Second, - _ => return None, - }; - - return Some(ast::Expr::Extract { - field, - expr: Box::new(date_expr), - syntax: ast::ExtractSyntax::From, - }); - } - } - (DateFieldExtractStyle::Strftime, 2) => { - let column = self.expr_to_sql(&args[1]).ok()?; - - if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] { - let field = match field.to_lowercase().as_str() { - "year" => "%Y", - "month" => "%m", - "day" => "%d", - "hour" => "%H", - "minute" => "%M", - "second" => "%S", - _ => return None, - }; - - return Some(ast::Expr::Function(ast::Function { - name: ast::ObjectName(vec![ast::Ident { - value: "strftime".to_string(), - quote_style: None, - }]), - args: ast::FunctionArguments::List( - ast::FunctionArgumentList { - duplicate_treatment: None, - args: vec![ - ast::FunctionArg::Unnamed( - ast::FunctionArgExpr::Expr(ast::Expr::Value( - ast::Value::SingleQuotedString( - field.to_string(), - ), - )), - ), - ast::FunctionArg::Unnamed( - ast::FunctionArgExpr::Expr(column), - ), - ], - clauses: vec![], - }, - ), - filter: None, - null_treatment: None, - over: None, - within_group: vec![], - parameters: ast::FunctionArguments::None, - })); - } - } - _ => {} // no overrides for DateFieldExtractStyle::DatePart, because it's already a date_part - } - } - - None - } - fn ast_type_for_date64_in_cast(&self) -> ast::DataType { if self.dialect.use_timestamp_for_date64() { ast::DataType::Timestamp(None, ast::TimezoneInfo::None) @@ -665,7 +574,10 @@ impl Unparser<'_> { } } - fn function_args_to_sql(&self, args: &[Expr]) -> Result> { + pub(crate) fn function_args_to_sql( + &self, + args: &[Expr], + ) -> Result> { args.iter() .map(|e| { if matches!( @@ -1554,7 +1466,10 @@ mod tests { use datafusion_functions_aggregate::expr_fn::sum; use datafusion_functions_window::row_number::row_number_udwf; - use crate::unparser::dialect::{CustomDialect, CustomDialectBuilder}; + use crate::unparser::dialect::{ + CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect, + PostgreSqlDialect, + }; use super::*; @@ -2428,4 +2343,39 @@ mod tests { assert_eq!(actual, expected); } } + + #[test] + fn test_round_scalar_fn_to_expr() -> Result<()> { + let default_dialect: Arc = Arc::new( + CustomDialectBuilder::new() + .with_identifier_quote_style('"') + .build(), + ); + let postgres_dialect: Arc = Arc::new(PostgreSqlDialect {}); + + for (dialect, identifier) in + [(default_dialect, "DOUBLE"), (postgres_dialect, "NUMERIC")] + { + let unparser = Unparser::new(dialect.as_ref()); + let expr = Expr::ScalarFunction(ScalarFunction { + func: Arc::new(ScalarUDF::from( + datafusion_functions::math::round::RoundFunc::new(), + )), + args: vec![ + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Float64, + }), + Expr::Literal(ScalarValue::Int64(Some(2))), + ], + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"round(CAST("a" AS {identifier}), 2)"#); + + assert_eq!(actual, expected); + } + Ok(()) + } } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 0059aba25738..8b2530a7499b 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -18,11 +18,14 @@ use datafusion_common::{ internal_err, tree_node::{Transformed, TreeNode}, - Column, DataFusionError, Result, + Column, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window, }; +use sqlparser::ast; + +use super::{dialect::DateFieldExtractStyle, Unparser}; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -187,3 +190,80 @@ fn find_window_expr<'a>( .flat_map(|w| w.window_expr.iter()) .find(|expr| expr.schema_name().to_string() == column_name) } + +/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style. +pub(crate) fn date_part_to_sql( + unparser: &Unparser, + style: DateFieldExtractStyle, + date_part_args: &[Expr], +) -> Result> { + match (style, date_part_args.len()) { + (DateFieldExtractStyle::Extract, 2) => { + let date_expr = unparser.expr_to_sql(&date_part_args[1])?; + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => ast::DateTimeField::Year, + "month" => ast::DateTimeField::Month, + "day" => ast::DateTimeField::Day, + "hour" => ast::DateTimeField::Hour, + "minute" => ast::DateTimeField::Minute, + "second" => ast::DateTimeField::Second, + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Extract { + field, + expr: Box::new(date_expr), + syntax: ast::ExtractSyntax::From, + })); + } + } + (DateFieldExtractStyle::Strftime, 2) => { + let column = unparser.expr_to_sql(&date_part_args[1])?; + + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => "%Y", + "month" => "%m", + "day" => "%d", + "hour" => "%H", + "minute" => "%M", + "second" => "%S", + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident { + value: "strftime".to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + ast::Expr::Value(ast::Value::SingleQuotedString( + field.to_string(), + )), + )), + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)), + ], + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + }))); + } + } + (DateFieldExtractStyle::DatePart, _) => { + return Ok(Some( + unparser.scalar_function_to_sql("date_part", date_part_args)?, + )); + } + _ => {} + }; + + Ok(None) +}