diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs index d5421c36cd73..484a7f2388f5 100644 --- a/datafusion/common/src/format.rs +++ b/datafusion/common/src/format.rs @@ -15,8 +15,15 @@ // specific language governing permissions and limitations // under the License. +use arrow::compute::CastOptions; use arrow::util::display::{DurationFormat, FormatOptions}; /// The default [`FormatOptions`] to use within DataFusion pub const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> = FormatOptions::new().with_duration_format(DurationFormat::Pretty); + +/// The default [`CastOptions`] to use within DataFusion +pub const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { + safe: false, + format_options: DEFAULT_FORMAT_OPTIONS, +}; diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index c845c81cb708..831edc078d6a 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -19,7 +19,9 @@ use arrow::array::ArrayRef; use arrow::array::NullArray; -use arrow::datatypes::DataType; +use arrow::compute::{kernels, CastOptions}; +use arrow::datatypes::{DataType, TimeUnit}; +use datafusion_common::format::DEFAULT_CAST_OPTIONS; use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; @@ -122,6 +124,42 @@ impl ColumnarValue { Ok(args) } + + /// Cast's this [ColumnarValue] to the specified `DataType` + pub fn cast_to( + &self, + cast_type: &DataType, + cast_options: Option<&CastOptions<'static>>, + ) -> Result { + let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS); + match self { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array( + kernels::cast::cast_with_options(array, cast_type, &cast_options)?, + )), + ColumnarValue::Scalar(scalar) => { + let scalar_array = + if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) { + if let ScalarValue::Float64(Some(float_ts)) = scalar { + ScalarValue::Int64(Some( + (float_ts * 1_000_000_000_f64).trunc() as i64, + )) + .to_array()? + } else { + scalar.to_array()? + } + } else { + scalar.to_array()? + }; + let cast_array = kernels::cast::cast_with_options( + &scalar_array, + cast_type, + &cast_options, + )?; + let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; + Ok(ColumnarValue::Scalar(cast_scalar)) + } + } + } } #[cfg(test)] diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index c22aa3db5994..dbbc94b662e3 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -55,7 +55,6 @@ chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-physical-expr = { workspace = true } hex = { version = "0.4", optional = true } itertools = { workspace = true } log = "0.4.20" diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 4bff2c90eada..65153333a13c 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -35,7 +35,6 @@ use itertools::Either; use datafusion_common::cast::as_generic_string_array; use datafusion_common::{exec_err, DataFusionError, Result, ScalarType, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use datafusion_physical_expr::expressions::cast_column; /// Error message if nanosecond conversion request beyond supported interval const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804"; @@ -144,13 +143,11 @@ impl ScalarUDFImpl for ToTimestampFunc { } match args[0].data_type() { - DataType::Int32 | DataType::Int64 => cast_column( - &cast_column(&args[0], &Timestamp(Second, None), None)?, - &Timestamp(Nanosecond, None), - None, - ), + DataType::Int32 | DataType::Int64 => args[0] + .cast_to(&Timestamp(Second, None), None)? + .cast_to(&Timestamp(Nanosecond, None), None), DataType::Null | DataType::Float64 | Timestamp(_, None) => { - cast_column(&args[0], &Timestamp(Nanosecond, None), None) + args[0].cast_to(&Timestamp(Nanosecond, None), None) } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp") @@ -201,7 +198,7 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { match args[0].data_type() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { - cast_column(&args[0], &Timestamp(Second, None), None) + args[0].cast_to(&Timestamp(Second, None), None) } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp_seconds") @@ -252,7 +249,7 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { match args[0].data_type() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { - cast_column(&args[0], &Timestamp(Millisecond, None), None) + args[0].cast_to(&Timestamp(Millisecond, None), None) } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp_millis") @@ -303,7 +300,7 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { match args[0].data_type() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { - cast_column(&args[0], &Timestamp(Microsecond, None), None) + args[0].cast_to(&Timestamp(Microsecond, None), None) } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp_micros") @@ -354,7 +351,7 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { match args[0].data_type() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { - cast_column(&args[0], &Timestamp(Nanosecond, None), None) + args[0].cast_to(&Timestamp(Nanosecond, None), None) } DataType::Utf8 => { to_timestamp_impl::(args, "to_timestamp_nanos") diff --git a/datafusion/physical-expr/src/datetime_expressions.rs b/datafusion/physical-expr/src/datetime_expressions.rs index 306c49d950f8..3b322ae2692f 100644 --- a/datafusion/physical-expr/src/datetime_expressions.rs +++ b/datafusion/physical-expr/src/datetime_expressions.rs @@ -53,8 +53,6 @@ use datafusion_common::cast::{ use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -use crate::expressions::cast_column; - /// Create an implementation of `now()` that always returns the /// specified timestamp. /// @@ -328,9 +326,9 @@ pub fn make_date(args: &[ColumnarValue]) -> Result { let is_scalar = len.is_none(); let array_size = if is_scalar { 1 } else { len.unwrap() }; - let years = cast_column(&args[0], &DataType::Int32, None)?; - let months = cast_column(&args[1], &DataType::Int32, None)?; - let days = cast_column(&args[2], &DataType::Int32, None)?; + let years = args[0].cast_to(&DataType::Int32, None)?; + let months = args[1].cast_to(&DataType::Int32, None)?; + let days = args[2].cast_to(&DataType::Int32, None)?; // since the epoch for the date32 datatype is the unix epoch // we need to subtract the unix epoch from the current date @@ -1154,7 +1152,7 @@ pub fn from_unixtime_invoke(args: &[ColumnarValue]) -> Result { match args[0].data_type() { DataType::Int64 => { - cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None) + args[0].cast_to(&DataType::Timestamp(TimeUnit::Second, None), None) } other => { exec_err!( diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 9125f73048cb..a3bff578cad4 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -24,11 +24,11 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use DataType::*; -use arrow::compute::{can_cast_types, kernels, CastOptions}; +use arrow::compute::{can_cast_types, CastOptions}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_common::{not_impl_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; @@ -120,7 +120,7 @@ impl PhysicalExpr for CastExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; - cast_column(&value, &self.cast_type, Some(&self.cast_options)) + value.cast_to(&self.cast_type, Some(&self.cast_options)) } fn children(&self) -> Vec> { @@ -182,43 +182,6 @@ impl PartialEq for CastExpr { } } -/// Internal cast function for casting ColumnarValue -> ColumnarValue for cast_type -pub fn cast_column( - value: &ColumnarValue, - cast_type: &DataType, - cast_options: Option<&CastOptions<'static>>, -) -> Result { - let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS); - match value { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array( - kernels::cast::cast_with_options(array, cast_type, &cast_options)?, - )), - ColumnarValue::Scalar(scalar) => { - let scalar_array = if cast_type - == &DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None) - { - if let ScalarValue::Float64(Some(float_ts)) = scalar { - ScalarValue::Int64( - Some((float_ts * 1_000_000_000_f64).trunc() as i64), - ) - .to_array()? - } else { - scalar.to_array()? - } - } else { - scalar.to_array()? - }; - let cast_array = kernels::cast::cast_with_options( - &scalar_array, - cast_type, - &cast_options, - )?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) - } - } -} - /// Return a PhysicalExpression representing `expr` casted to /// `cast_type`, if any casting is needed. /// diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index ec20345569c2..f9896bafca15 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -80,7 +80,7 @@ pub use crate::PhysicalSortExpr; pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; -pub use cast::{cast, cast_column, cast_with_options, CastExpr}; +pub use cast::{cast, cast_with_options, CastExpr}; pub use column::{col, Column, UnKnownColumn}; pub use get_indexed_field::{GetFieldAccessExpr, GetIndexedFieldExpr}; pub use in_list::{in_list, InListExpr};