diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 39d678fdadac..fcdb6825ce73 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -22,9 +22,8 @@ use datafusion_expr::planner::{ }; use sqlparser::ast::{ AccessExpr, BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, - DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, StructField, Subscript, - TrimWhereField, - Value, + DictionaryField, Expr as SQLExpr, ExprWithAlias as SQLExprWithAlias, MapEntry, + StructField, Subscript, TrimWhereField, Value, }; use datafusion_common::{ diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 628bcb2fbdcd..c6975891482f 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -430,7 +430,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(DataType::UInt64), SQLDataType::Float(_) => Ok(DataType::Float32), SQLDataType::Real | SQLDataType::Float4 => Ok(DataType::Float32), - SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64), + SQLDataType::Double(_) | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64), SQLDataType::Char(_) | SQLDataType::Text | SQLDataType::String(_) => Ok(DataType::Utf8), @@ -514,6 +514,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Regclass | SQLDataType::Custom(_, _) | SQLDataType::Array(_) + | SQLDataType::AnyType | SQLDataType::Enum(_, _) | SQLDataType::Set(_) | SQLDataType::MediumInt(_) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index dfd3a4fd76a2..43cdc8032c09 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -56,7 +56,7 @@ use datafusion_expr::{ }; use sqlparser::ast::{ self, BeginTransactionKind, NullsDistinctOption, ShowStatementIn, - ShowStatementOptions, SqliteOnConflict, + ShowStatementOptions, SqliteOnConflict, UpdateTableFromKind, }; use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, @@ -890,6 +890,10 @@ impl SqlToRel<'_, S> { if or.is_some() { plan_err!("ON conflict not supported")?; } + let from = from.map(|update| match update { + UpdateTableFromKind::BeforeSet(t) + | UpdateTableFromKind::AfterSet(t) => t, + }); self.update_to_plan(table, assignments, from, selection) } diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index e320a4510e46..6d77c01ea888 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -466,6 +466,7 @@ impl TableRelationBuilder { partitions: self.partitions.clone(), with_ordinality: false, json_path: None, + sample: None, }) } fn create_empty() -> Self { diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 3a44d7f0ec48..e673dfa50669 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -17,18 +17,18 @@ use std::sync::Arc; +use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser}; use arrow_schema::TimeUnit; use datafusion_common::Result; use datafusion_expr::Expr; use regex::Regex; +use sqlparser::ast::ExactNumberInfo; use sqlparser::tokenizer::Span; use sqlparser::{ ast::{self, BinaryOperator, Function, Ident, ObjectName, TimezoneInfo}, keywords::ALL_KEYWORDS, }; -use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser}; - /// `Dialect` to use for Unparsing /// /// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`) @@ -60,7 +60,7 @@ pub trait Dialect: Send + Sync { /// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE? /// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE fn float64_ast_dtype(&self) -> ast::DataType { - ast::DataType::Double + ast::DataType::Double(ExactNumberInfo::None) } /// The SQL type to use for Arrow Utf8 unparsing @@ -272,13 +272,13 @@ impl PostgreSqlDialect { { if let ast::Expr::Cast { data_type, .. } = expr { // Don't create an additional cast wrapper if we can update the existing one - *data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None); + *data_type = ast::DataType::Numeric(ExactNumberInfo::None); } else { // Wrap the expression in a new cast *expr = ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(expr.clone()), - data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None), + data_type: ast::DataType::Numeric(ExactNumberInfo::None), format: None, }; } @@ -469,7 +469,7 @@ impl Default for CustomDialect { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::SQLStandard, - float64_ast_dtype: ast::DataType::Double, + float64_ast_dtype: ast::DataType::Double(ExactNumberInfo::None), utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, @@ -650,7 +650,7 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::PostgresVerbose, - float64_ast_dtype: ast::DataType::Double, + float64_ast_dtype: ast::DataType::Double(ExactNumberInfo::None), utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 2b8e53c4243d..f401668e53ab 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1633,6 +1633,7 @@ mod tests { use datafusion_functions_nested::expr_fn::{array_element, make_array}; use datafusion_functions_nested::map::map; use datafusion_functions_window::row_number::row_number_udwf; + use sqlparser::ast::ExactNumberInfo; use crate::unparser::dialect::{ CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, @@ -2123,7 +2124,7 @@ mod tests { #[test] fn custom_dialect_float64_ast_dtype() -> Result<()> { for (float64_ast_dtype, identifier) in [ - (ast::DataType::Double, "DOUBLE"), + (ast::DataType::Double(ExactNumberInfo::None), "DOUBLE"), (ast::DataType::DoublePrecision, "DOUBLE PRECISION"), ] { let dialect = CustomDialectBuilder::new()