Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/main' into comet-parquet-exec
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Nov 11, 2024
2 parents 16033d9 + 712658e commit ad46821
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 166 deletions.
112 changes: 110 additions & 2 deletions native/spark-expr/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use arrow::{
};
use arrow_array::builder::StringBuilder;
use arrow_array::{DictionaryArray, StringArray, StructArray};
use arrow_schema::{DataType, Schema};
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::{
cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue,
};
Expand Down Expand Up @@ -714,6 +714,14 @@ fn cast_array(
(DataType::Struct(_), DataType::Utf8) => {
Ok(casts_struct_to_string(array.as_struct(), &timezone)?)
}
(DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct(
array.as_struct(),
from_type,
to_type,
eval_mode,
timezone,
allow_incompat,
)?),
_ if is_datafusion_spark_compatible(from_type, to_type, allow_incompat) => {
// use DataFusion cast only when we know that it is compatible with Spark
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
Expand Down Expand Up @@ -811,6 +819,35 @@ fn is_datafusion_spark_compatible(
}
}

/// Cast between struct types based on logic in
/// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`.
fn cast_struct_to_struct(
array: &StructArray,
from_type: &DataType,
to_type: &DataType,
eval_mode: EvalMode,
timezone: String,
allow_incompat: bool,
) -> DataFusionResult<ArrayRef> {
match (from_type, to_type) {
(DataType::Struct(_), DataType::Struct(to_fields)) => {
let mut cast_fields: Vec<(Arc<Field>, ArrayRef)> = Vec::with_capacity(to_fields.len());
for i in 0..to_fields.len() {
let cast_field = cast_array(
Arc::clone(array.column(i)),
to_fields[i].data_type(),
eval_mode,
timezone.clone(),
allow_incompat,
)?;
cast_fields.push((Arc::clone(&to_fields[i]), cast_field));
}
Ok(Arc::new(StructArray::from(cast_fields)))
}
_ => unreachable!(),
}
}

fn casts_struct_to_string(array: &StructArray, timezone: &str) -> DataFusionResult<ArrayRef> {
// cast each field to a string
let string_arrays: Vec<ArrayRef> = array
Expand Down Expand Up @@ -1929,7 +1966,7 @@ fn trim_end(s: &str) -> &str {
mod tests {
use arrow::datatypes::TimestampMicrosecondType;
use arrow_array::StringArray;
use arrow_schema::{Field, TimeUnit};
use arrow_schema::{Field, Fields, TimeUnit};
use std::str::FromStr;

use super::*;
Expand Down Expand Up @@ -2336,4 +2373,75 @@ mod tests {
assert_eq!(r#"{4, d}"#, string_array.value(3));
assert_eq!(r#"{5, e}"#, string_array.value(4));
}

#[test]
fn test_cast_struct_to_struct() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
Some(2),
None,
Some(4),
Some(5),
]));
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
let c: ArrayRef = Arc::new(StructArray::from(vec![
(Arc::new(Field::new("a", DataType::Int32, true)), a),
(Arc::new(Field::new("b", DataType::Utf8, true)), b),
]));
// change type of "a" from Int32 to Utf8
let fields = Fields::from(vec![
Field::new("a", DataType::Utf8, true),
Field::new("b", DataType::Utf8, true),
]);
let cast_array = spark_cast(
ColumnarValue::Array(c),
&DataType::Struct(fields),
EvalMode::Legacy,
"UTC",
false,
)
.unwrap();
if let ColumnarValue::Array(cast_array) = cast_array {
assert_eq!(5, cast_array.len());
let a = cast_array.as_struct().column(0).as_string::<i32>();
assert_eq!("1", a.value(0));
} else {
unreachable!()
}
}

#[test]
fn test_cast_struct_to_struct_drop_column() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Some(1),
Some(2),
None,
Some(4),
Some(5),
]));
let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
let c: ArrayRef = Arc::new(StructArray::from(vec![
(Arc::new(Field::new("a", DataType::Int32, true)), a),
(Arc::new(Field::new("b", DataType::Utf8, true)), b),
]));
// change type of "a" from Int32 to Utf8 and drop "b"
let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]);
let cast_array = spark_cast(
ColumnarValue::Array(c),
&DataType::Struct(fields),
EvalMode::Legacy,
"UTC",
false,
)
.unwrap();
if let ColumnarValue::Array(cast_array) = cast_array {
assert_eq!(5, cast_array.len());
let struct_array = cast_array.as_struct();
assert_eq!(1, struct_array.columns().len());
let a = struct_array.column(0).as_string::<i32>();
assert_eq!("1", a.value(0));
} else {
unreachable!()
}
}
}
10 changes: 10 additions & 0 deletions spark/src/main/scala/org/apache/comet/expressions/CometCast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ object CometCast {
canCastFromFloat(toType)
case (DataTypes.DoubleType, _) =>
canCastFromDouble(toType)
case (from_struct: StructType, to_struct: StructType) =>
from_struct.fields.zip(to_struct.fields).foreach { case (a, b) =>
isSupported(a.dataType, b.dataType, timeZoneId, evalMode) match {
case Compatible(_) =>
// all good
case other =>
return other
}
}
Compatible()
case _ => Unsupported
}
}
Expand Down
Loading

0 comments on commit ad46821

Please sign in to comment.