Skip to content

Commit

Permalink
Merge pull request #1 from alamb/alamb/move_cast
Browse files Browse the repository at this point in the history
Move `cast_column` to `ColumnarValue::cast_to`
  • Loading branch information
Omega359 authored Mar 1, 2024
2 parents 10b0a8a + d5f9615 commit e29e62d
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 60 deletions.
7 changes: 7 additions & 0 deletions datafusion/common/src/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use arrow::compute::CastOptions;
use arrow::util::display::{DurationFormat, FormatOptions};

/// The default [`FormatOptions`] to use within DataFusion
pub const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> =
FormatOptions::new().with_duration_format(DurationFormat::Pretty);

/// The default [`CastOptions`] to use within DataFusion
pub const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions {
safe: false,
format_options: DEFAULT_FORMAT_OPTIONS,
};
40 changes: 39 additions & 1 deletion datafusion/expr/src/columnar_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

use arrow::array::ArrayRef;
use arrow::array::NullArray;
use arrow::datatypes::DataType;
use arrow::compute::{kernels, CastOptions};
use arrow::datatypes::{DataType, TimeUnit};
use datafusion_common::format::DEFAULT_CAST_OPTIONS;
use datafusion_common::{internal_err, Result, ScalarValue};
use std::sync::Arc;

Expand Down Expand Up @@ -122,6 +124,42 @@ impl ColumnarValue {

Ok(args)
}

/// Cast's this [ColumnarValue] to the specified `DataType`
pub fn cast_to(
&self,
cast_type: &DataType,
cast_options: Option<&CastOptions<'static>>,
) -> Result<ColumnarValue> {
let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS);
match self {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(
kernels::cast::cast_with_options(array, cast_type, &cast_options)?,
)),
ColumnarValue::Scalar(scalar) => {
let scalar_array =
if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) {
if let ScalarValue::Float64(Some(float_ts)) = scalar {
ScalarValue::Int64(Some(
(float_ts * 1_000_000_000_f64).trunc() as i64,
))
.to_array()?
} else {
scalar.to_array()?
}
} else {
scalar.to_array()?
};
let cast_array = kernels::cast::cast_with_options(
&scalar_array,
cast_type,
&cast_options,
)?;
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
}
}
}

#[cfg(test)]
Expand Down
1 change: 0 additions & 1 deletion datafusion/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ chrono = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-physical-expr = { workspace = true }
hex = { version = "0.4", optional = true }
itertools = { workspace = true }
log = "0.4.20"
Expand Down
19 changes: 8 additions & 11 deletions datafusion/functions/src/datetime/to_timestamp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ use itertools::Either;
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::{exec_err, DataFusionError, Result, ScalarType, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion_physical_expr::expressions::cast_column;

/// Error message if nanosecond conversion request beyond supported interval
const ERR_NANOSECONDS_NOT_SUPPORTED: &str = "The dates that can be represented as nanoseconds have to be between 1677-09-21T00:12:44.0 and 2262-04-11T23:47:16.854775804";
Expand Down Expand Up @@ -144,13 +143,11 @@ impl ScalarUDFImpl for ToTimestampFunc {
}

match args[0].data_type() {
DataType::Int32 | DataType::Int64 => cast_column(
&cast_column(&args[0], &Timestamp(Second, None), None)?,
&Timestamp(Nanosecond, None),
None,
),
DataType::Int32 | DataType::Int64 => args[0]
.cast_to(&Timestamp(Second, None), None)?
.cast_to(&Timestamp(Nanosecond, None), None),
DataType::Null | DataType::Float64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Nanosecond, None), None)
args[0].cast_to(&Timestamp(Nanosecond, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampNanosecondType>(args, "to_timestamp")
Expand Down Expand Up @@ -201,7 +198,7 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc {

match args[0].data_type() {
DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Second, None), None)
args[0].cast_to(&Timestamp(Second, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampSecondType>(args, "to_timestamp_seconds")
Expand Down Expand Up @@ -252,7 +249,7 @@ impl ScalarUDFImpl for ToTimestampMillisFunc {

match args[0].data_type() {
DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Millisecond, None), None)
args[0].cast_to(&Timestamp(Millisecond, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampMillisecondType>(args, "to_timestamp_millis")
Expand Down Expand Up @@ -303,7 +300,7 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc {

match args[0].data_type() {
DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Microsecond, None), None)
args[0].cast_to(&Timestamp(Microsecond, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampMicrosecondType>(args, "to_timestamp_micros")
Expand Down Expand Up @@ -354,7 +351,7 @@ impl ScalarUDFImpl for ToTimestampNanosFunc {

match args[0].data_type() {
DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => {
cast_column(&args[0], &Timestamp(Nanosecond, None), None)
args[0].cast_to(&Timestamp(Nanosecond, None), None)
}
DataType::Utf8 => {
to_timestamp_impl::<TimestampNanosecondType>(args, "to_timestamp_nanos")
Expand Down
10 changes: 4 additions & 6 deletions datafusion/physical-expr/src/datetime_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ use datafusion_common::cast::{
use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;

use crate::expressions::cast_column;

/// Create an implementation of `now()` that always returns the
/// specified timestamp.
///
Expand Down Expand Up @@ -328,9 +326,9 @@ pub fn make_date(args: &[ColumnarValue]) -> Result<ColumnarValue> {
let is_scalar = len.is_none();
let array_size = if is_scalar { 1 } else { len.unwrap() };

let years = cast_column(&args[0], &DataType::Int32, None)?;
let months = cast_column(&args[1], &DataType::Int32, None)?;
let days = cast_column(&args[2], &DataType::Int32, None)?;
let years = args[0].cast_to(&DataType::Int32, None)?;
let months = args[1].cast_to(&DataType::Int32, None)?;
let days = args[2].cast_to(&DataType::Int32, None)?;

// since the epoch for the date32 datatype is the unix epoch
// we need to subtract the unix epoch from the current date
Expand Down Expand Up @@ -1154,7 +1152,7 @@ pub fn from_unixtime_invoke(args: &[ColumnarValue]) -> Result<ColumnarValue> {

match args[0].data_type() {
DataType::Int64 => {
cast_column(&args[0], &DataType::Timestamp(TimeUnit::Second, None), None)
args[0].cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)
}
other => {
exec_err!(
Expand Down
43 changes: 3 additions & 40 deletions datafusion/physical-expr/src/expressions/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ use std::hash::{Hash, Hasher};
use std::sync::Arc;
use DataType::*;

use arrow::compute::{can_cast_types, kernels, CastOptions};
use arrow::compute::{can_cast_types, CastOptions};
use arrow::datatypes::{DataType, Schema};
use arrow::record_batch::RecordBatch;
use datafusion_common::format::DEFAULT_FORMAT_OPTIONS;
use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::ColumnarValue;

Expand Down Expand Up @@ -120,7 +120,7 @@ impl PhysicalExpr for CastExpr {

fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let value = self.expr.evaluate(batch)?;
cast_column(&value, &self.cast_type, Some(&self.cast_options))
value.cast_to(&self.cast_type, Some(&self.cast_options))
}

fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
Expand Down Expand Up @@ -182,43 +182,6 @@ impl PartialEq<dyn Any> for CastExpr {
}
}

/// Internal cast function for casting ColumnarValue -> ColumnarValue for cast_type
pub fn cast_column(
value: &ColumnarValue,
cast_type: &DataType,
cast_options: Option<&CastOptions<'static>>,
) -> Result<ColumnarValue> {
let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS);
match value {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(
kernels::cast::cast_with_options(array, cast_type, &cast_options)?,
)),
ColumnarValue::Scalar(scalar) => {
let scalar_array = if cast_type
== &DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, None)
{
if let ScalarValue::Float64(Some(float_ts)) = scalar {
ScalarValue::Int64(
Some((float_ts * 1_000_000_000_f64).trunc() as i64),
)
.to_array()?
} else {
scalar.to_array()?
}
} else {
scalar.to_array()?
};
let cast_array = kernels::cast::cast_with_options(
&scalar_array,
cast_type,
&cast_options,
)?;
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
Ok(ColumnarValue::Scalar(cast_scalar))
}
}
}

/// Return a PhysicalExpression representing `expr` casted to
/// `cast_type`, if any casting is needed.
///
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub use crate::PhysicalSortExpr;

pub use binary::{binary, BinaryExpr};
pub use case::{case, CaseExpr};
pub use cast::{cast, cast_column, cast_with_options, CastExpr};
pub use cast::{cast, cast_with_options, CastExpr};
pub use column::{col, Column, UnKnownColumn};
pub use get_indexed_field::{GetFieldAccessExpr, GetIndexedFieldExpr};
pub use in_list::{in_list, InListExpr};
Expand Down

0 comments on commit e29e62d

Please sign in to comment.