diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index eaa1ae1cca36..ee19a8c9ddb1 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -24,9 +24,7 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::zip::zip; -use arrow::compute::{ - and, and_not, is_null, not, nullif, or, or_kleene, prep_null_mask_filter, -}; +use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; @@ -405,7 +403,9 @@ impl CaseExpr { } fn expr_or_expr(&self, batch: &RecordBatch) -> Result { - // evaluate condition on batch + let return_type = self.data_type(&batch.schema())?; + + // evalute when condition on batch let when_value = self.when_then_expr[0].0.evaluate(batch)?; let when_value = when_value.into_array(batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|e| { @@ -415,22 +415,28 @@ impl CaseExpr { ) })?; + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + let then_value = self.when_then_expr[0] .1 - .evaluate_selection(batch, when_value)?; - let then_value = then_value.into_array(batch.num_rows())?; - - let remainder = or_kleene(¬(when_value)?, &is_null(when_value)?)?; + .evaluate_selection(batch, &when_value)? + .into_array(batch.num_rows())?; - let else_value = self - .else_expr - .as_ref() - .unwrap() + // evaluate else expression on the values not covered by when_value + let remainder = not(&when_value)?; + let e = self.else_expr.as_ref().unwrap(); + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); + let else_ = expr .evaluate_selection(batch, &remainder)? .into_array(batch.num_rows())?; - let current_value = zip(&remainder, &else_value, &then_value)?; - Ok(ColumnarValue::Array(current_value)) + Ok(ColumnarValue::Array(zip(&remainder, &else_, &then_value)?)) } }