Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move cast_column to ColumnarValue::cast_to #1

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions datafusion/common/src/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use arrow::compute::CastOptions;
use arrow::util::display::{DurationFormat, FormatOptions};

/// The default [`FormatOptions`] to use within DataFusion
pub const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> =
FormatOptions::new().with_duration_format(DurationFormat::Pretty);

/// The default [`CastOptions`] to use within DataFusion
pub const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions {
safe: false,
format_options: DEFAULT_FORMAT_OPTIONS,
};
40 changes: 39 additions & 1 deletion datafusion/expr/src/columnar_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

use arrow::array::ArrayRef;
use arrow::array::NullArray;
use arrow::datatypes::DataType;
use arrow::compute::{kernels, CastOptions};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::format::DEFAULT_CAST_OPTIONS;
use datafusion_common::{internal_err, Result, ScalarValue};
use std::sync::Arc;

Expand Down Expand Up @@ -122,6 +124,42 @@ impl ColumnarValue {

Ok(args)
}

/// Cast's this [ColumnarValue] to the specified `DataType`
pub fn cast_to(
&self,
cast_type: &DataType,
cast_options: Option<&CastOptions<'static>>,
) -> Result<ColumnarValue> {
let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS);
match self {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(
kernels::cast::cast_with_options(array, cast_type, &cast_options)?,
)),
ColumnarValue::Scalar(scalar) => {
let scalar_array =
if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) {
if let ScalarValue::Float64(Some(float_ts)) = scalar {
ScalarValue::Int64(Some(
(float_ts * 1_000_000_000_f64).trunc() as i64,
))
.to_array()?
} else {
scalar.to_array()?
}
} else {
scalar.to_array()?
};
let cast_array = kernels::cast::cast_with_options(
&scalar_array,
cast_type,
&cast_options,
)?;
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
}
}
}

#[cfg(test)]
Expand Down
1 change: 0 additions & 1 deletion datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ chrono = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-physical-expr = { workspace = true }
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale for this PR is removing this line

hex = { version = "0.4", optional = true }
itertools = { workspace = true }
log = "0.4.20"
Expand Down
19 changes: 8 additions & 11 deletions datafusion/functions/src/datetime/to_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use itertools::Either;
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::{exec_err, DataFusionError, Result, ScalarType, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion_physical_expr::expressions::cast_column;

/// Error message if nanosecond conversion request beyond supported interval
const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804";
Expand Down Expand Up @@ -144,13 +143,11 @@ impl ScalarUDFImpl for ToTimestampFunc {
}

match args[0].data_type() {
DataType::Int32 | DataType::Int64 => cast_column(
&cast_column(&args[0], &Timestamp(Second, None), None)?,
&Timestamp(Nanosecond, None),
None,
),
DataType::Int32 | DataType::Int64 => args[0]
.cast_to(&Timestamp(Second, None), None)?
.cast_to(&Timestamp(Nanosecond, None), None),
DataType::Null | DataType::Float64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Nanosecond, None), None)
args[0].cast_to(&Timestamp(Nanosecond, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampNanosecondType>(args, "to_timestamp")
Expand Down Expand Up @@ -201,7 +198,7 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc {

match args[0].data_type() {
DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Second, None), None)
args[0].cast_to(&Timestamp(Second, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampSecondType>(args, "to_timestamp_seconds")
Expand Down Expand Up @@ -252,7 +249,7 @@ impl ScalarUDFImpl for ToTimestampMillisFunc {

match args[0].data_type() {
DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Millisecond, None), None)
args[0].cast_to(&Timestamp(Millisecond, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampMillisecondType>(args, "to_timestamp_millis")
Expand Down Expand Up @@ -303,7 +300,7 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc {

match args[0].data_type() {
DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Microsecond, None), None)
args[0].cast_to(&Timestamp(Microsecond, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampMicrosecondType>(args, "to_timestamp_micros")
Expand Down Expand Up @@ -354,7 +351,7 @@ impl ScalarUDFImpl for ToTimestampNanosFunc {

match args[0].data_type() {
DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Nanosecond, None), None)
args[0].cast_to(&Timestamp(Nanosecond, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampNanosecondType>(args, "to_timestamp_nanos")
Expand Down
10 changes: 4 additions & 6 deletions datafusion/physical-expr/src/datetime_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ use datafusion_common::cast::{
use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;

use crate::expressions::cast_column;

/// Create an implementation of `now()` that always returns the
/// specified timestamp.
///
Expand Down Expand Up @@ -328,9 +326,9 @@ pub fn make_date(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let is_scalar = len.is_none();
let array_size = if is_scalar { 1 } else { len.unwrap() };

let years = cast_column(&args[0], &DataType::Int32, None)?;
let months = cast_column(&args[1], &DataType::Int32, None)?;
let days = cast_column(&args[2], &DataType::Int32, None)?;
let years = args[0].cast_to(&DataType::Int32, None)?;
let months = args[1].cast_to(&DataType::Int32, None)?;
let days = args[2].cast_to(&DataType::Int32, None)?;

// since the epoch for the date32 datatype is the unix epoch
// we need to subtract the unix epoch from the current date
Expand Down Expand Up @@ -1154,7 +1152,7 @@ pub fn from_unixtime_invoke(args: &[ColumnarValue]) -> Result<ColumnarValue> {

match args[0].data_type() {
DataType::Int64 => {
cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None)
args[0].cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)
}
other => {
exec_err!(
Expand Down
43 changes: 3 additions & 40 deletions datafusion/physical-expr/src/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ use std::hash::{Hash, Hasher};
use std::sync::Arc;
use DataType::*;

use arrow::compute::{can_cast_types, kernels, CastOptions};
use arrow::compute::{can_cast_types, CastOptions};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::ColumnarValue;

Expand Down Expand Up @@ -120,7 +120,7 @@ impl PhysicalExpr for CastExpr {

fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let value = self.expr.evaluate(batch)?;
cast_column(&value, &self.cast_type, Some(&self.cast_options))
value.cast_to(&self.cast_type, Some(&self.cast_options))
}

fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand Down Expand Up @@ -182,43 +182,6 @@ impl PartialEq<dyn Any> for CastExpr {
}
}

/// Internal cast function for casting ColumnarValue -> ColumnarValue for cast_type
pub fn cast_column(
value: &ColumnarValue,
cast_type: &DataType,
cast_options: Option<&CastOptions<'static>>,
) -> Result<ColumnarValue> {
let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS);
match value {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(
kernels::cast::cast_with_options(array, cast_type, &cast_options)?,
)),
ColumnarValue::Scalar(scalar) => {
let scalar_array = if cast_type
== &DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None)
{
if let ScalarValue::Float64(Some(float_ts)) = scalar {
ScalarValue::Int64(
Some((float_ts * 1_000_000_000_f64).trunc() as i64),
)
.to_array()?
} else {
scalar.to_array()?
}
} else {
scalar.to_array()?
};
let cast_array = kernels::cast::cast_with_options(
&scalar_array,
cast_type,
&cast_options,
)?;
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
}
}

/// Return a PhysicalExpression representing `expr` casted to
/// `cast_type`, if any casting is needed.
///
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub use crate::PhysicalSortExpr;

pub use binary::{binary, BinaryExpr};
pub use case::{case, CaseExpr};
pub use cast::{cast, cast_column, cast_with_options, CastExpr};
pub use cast::{cast, cast_with_options, CastExpr};
pub use column::{col, Column, UnKnownColumn};
pub use get_indexed_field::{GetFieldAccessExpr, GetIndexedFieldExpr};
pub use in_list::{in_list, InListExpr};
Expand Down