Skip to content

Commit

Permalink
feat: IsNaN expression in Comet (#612)
Browse files Browse the repository at this point in the history
* Support IsNaN expression in Comet

* Document that IsNaN is supported

* Update spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Co-authored-by: Liang-Chi Hsieh <[email protected]>

* Fix whitespace

---------

Co-authored-by: Liang-Chi Hsieh <[email protected]>
  • Loading branch information
eejbyfeldt and viirya authored Jul 3, 2024
1 parent 68efa57 commit 8d07204
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 1 deletion.
52 changes: 51 additions & 1 deletion core/src/execution/datafusion/expressions/scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::{
},
datatypes::{validate_decimal_precision, Decimal128Type, Int64Type},
};
use arrow_array::{Array, ArrowNativeTypeOp, Decimal128Array};
use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array};
use arrow_schema::DataType;
use datafusion::{
execution::FunctionRegistry,
Expand Down Expand Up @@ -129,6 +129,10 @@ pub fn create_comet_physical_fun(
let func = Arc::new(spark_chr);
make_comet_scalar_udf!("chr", func, without data_type)
}
"isnan" => {
let func = Arc::new(spark_isnan);
make_comet_scalar_udf!("isnan", func, without data_type)
}
sha if sha2_functions.contains(&sha) => {
// Spark requires hex string as the result of sha2 functions, we have to wrap the
// result of digest functions as hex string
Expand Down Expand Up @@ -634,3 +638,49 @@ fn spark_decimal_div(
let result = result.with_data_type(DataType::Decimal128(p3, s3));
Ok(ColumnarValue::Array(Arc::new(result)))
}

fn spark_isnan(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue {
match is_nan.nulls() {
Some(nulls) => {
let is_not_null = nulls.inner();
ColumnarValue::Array(Arc::new(BooleanArray::new(
is_nan.values() & is_not_null,
None,
)))
}
None => ColumnarValue::Array(Arc::new(is_nan)),
}
}
let value = &args[0];
match value {
ColumnarValue::Array(array) => match array.data_type() {
DataType::Float64 => {
let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
DataType::Float32 => {
let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
let is_nan = BooleanArray::from_unary(array, |x| x.is_nan());
Ok(set_nulls_to_false(is_nan))
}
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function isnan",
other,
))),
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(
a.map(|x| x.is_nan()).unwrap_or(false),
)))),
_ => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function isnan",
value.data_type(),
))),
},
}
}
1 change: 1 addition & 0 deletions docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ The following Spark expressions are currently available. Any known compatibility
| Cos | |
| Exp | |
| Floor | |
| IsNaN | |
| Log | log(0) will produce `-Infinity` unlike Spark which returns `null` |
| Log2 | log2(0) will produce `-Infinity` unlike Spark which returns `null` |
| Log10 | log10(0) will produce `-Infinity` unlike Spark which returns `null` |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,13 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
None
}

case IsNaN(child) =>
val childExpr = exprToProtoInternal(child, inputs)
val optExpr =
scalarExprToProtoWithReturnType("isnan", BooleanType, childExpr)

optExprWithInfo(optExpr, expr, child)

case SortOrder(child, direction, nullOrdering, _) =>
val childExpr = exprToProtoInternal(child, inputs)

Expand Down
16 changes: 16 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1722,4 +1722,20 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("isnan") {
Seq("true", "false").foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary) {
withParquetTable(
Seq(Some(1.0), Some(Double.NaN), None).map(i => Tuple1(i)),
"tbl",
withDictionary = dictionary.toBoolean) {
checkSparkAnswerAndOperator("SELECT isnan(_1), isnan(cast(_1 as float)) FROM tbl")
// Use inside a nullable statement to make sure isnan has correct behavior for null input
checkSparkAnswerAndOperator(
"SELECT CASE WHEN (_1 > 0) THEN NULL ELSE isnan(_1) END FROM tbl")
}
}
}
}
}

0 comments on commit 8d07204

Please sign in to comment.