Skip to content

Commit

Permalink
feat: Optimize CASE expression for "column or null" use case (#11534)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Jul 19, 2024
1 parent 5f0dfbb commit 28fa74b
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 14 deletions.
Binary file added datafusion/core/example.parquet
Binary file not shown.
41 changes: 37 additions & 4 deletions datafusion/physical-expr/benches/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,29 @@ fn criterion_benchmark(c: &mut Criterion) {
// create input data
let mut c1 = Int32Builder::new();
let mut c2 = StringBuilder::new();
let mut c3 = StringBuilder::new();
for i in 0..1000 {
c1.append_value(i);
if i % 7 == 0 {
c2.append_null();
} else {
c2.append_value(&format!("string {i}"));
}
if i % 9 == 0 {
c3.append_null();
} else {
c3.append_value(&format!("other string {i}"));
}
}
let c1 = Arc::new(c1.finish());
let c2 = Arc::new(c2.finish());
let c3 = Arc::new(c3.finish());
let schema = Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
Field::new("c3", DataType::Utf8, true),
]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap();

// use same predicate for all benchmarks
let predicate = Arc::new(BinaryExpr::new(
Expand All @@ -63,7 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) {
make_lit_i32(500),
));

// CASE WHEN expr THEN 1 ELSE 0 END
// CASE WHEN c1 <= 500 THEN 1 ELSE 0 END
c.bench_function("case_when: scalar or scalar", |b| {
let expr = Arc::new(
CaseExpr::try_new(
Expand All @@ -76,13 +84,38 @@ fn criterion_benchmark(c: &mut Criterion) {
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});

// CASE WHEN expr THEN col ELSE null END
// CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END
c.bench_function("case_when: column or null", |b| {
let expr = Arc::new(
CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});

// CASE WHEN c1 <= 500 THEN c2 ELSE c3 END
c.bench_function("case_when: expr or expr", |b| {
let expr = Arc::new(
CaseExpr::try_new(
None,
vec![(predicate.clone(), make_col("c2", 1))],
Some(Arc::new(Literal::new(ScalarValue::Utf8(None)))),
Some(make_col("c3", 2)),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});

// CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END
c.bench_function("case_when: CASE expr", |b| {
let expr = Arc::new(
CaseExpr::try_new(
Some(make_col("c1", 0)),
vec![
(make_lit_i32(1), make_col("c2", 1)),
(make_lit_i32(2), make_col("c3", 2)),
],
None,
)
.unwrap(),
);
Expand Down
161 changes: 152 additions & 9 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,33 @@ use datafusion_common::cast::as_boolean_array;
use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;

use datafusion_physical_expr_common::expressions::column::Column;
use datafusion_physical_expr_common::expressions::Literal;
use itertools::Itertools;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);

#[derive(Debug, Hash)]
enum EvalMethod {
/// CASE WHEN condition THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
NoExpression,
/// CASE expression
/// WHEN value THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
WithExpression,
/// This is a specialization for a specific use case where we can take a fast path
/// for expressions that are infallible and can be cheaply computed for the entire
/// record batch rather than just for the rows where the predicate is true.
///
/// CASE WHEN condition THEN column [ELSE NULL] END
InfallibleExprOrNull,
}

/// The CASE expression is similar to a series of nested if/else and there are two forms that
/// can be used. The first form consists of a series of boolean "when" expressions with
/// corresponding "then" expressions, and an optional "else" expression.
Expand All @@ -61,6 +84,8 @@ pub struct CaseExpr {
when_then_expr: Vec<WhenThen>,
/// Optional "else" expression
else_expr: Option<Arc<dyn PhysicalExpr>>,
/// Evaluation method to use
eval_method: EvalMethod,
}

impl std::fmt::Display for CaseExpr {
Expand All @@ -79,20 +104,51 @@ impl std::fmt::Display for CaseExpr {
}
}

/// This is a specialization for a specific use case where we can take a fast path
/// for expressions that are infallible and can be cheaply computed for the entire
/// record batch rather than just for the rows where the predicate is true. For now,
/// this is limited to use with Column expressions but could potentially be used for other
/// expressions in the future
fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
expr.as_any().is::<Column>()
}

impl CaseExpr {
/// Create a new CASE WHEN expression
pub fn try_new(
expr: Option<Arc<dyn PhysicalExpr>>,
when_then_expr: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Result<Self> {
// normalize null literals to None in the else_expr (this already happens
// during SQL planning, but not necessarily for other use cases)
let else_expr = match &else_expr {
Some(e) => match e.as_any().downcast_ref::<Literal>() {
Some(lit) if lit.value().is_null() => None,
_ => else_expr,
},
_ => else_expr,
};

if when_then_expr.is_empty() {
exec_err!("There must be at least one WHEN clause")
} else {
let eval_method = if expr.is_some() {
EvalMethod::WithExpression
} else if when_then_expr.len() == 1
&& is_cheap_and_infallible(&(when_then_expr[0].1))
&& else_expr.is_none()
{
EvalMethod::InfallibleExprOrNull
} else {
EvalMethod::NoExpression
};

Ok(Self {
expr,
when_then_expr,
else_expr,
eval_method,
})
}
}
Expand Down Expand Up @@ -256,6 +312,38 @@ impl CaseExpr {

Ok(ColumnarValue::Array(current_value))
}

/// This function evaluates the specialized case of:
///
/// CASE WHEN condition THEN column
/// [ELSE NULL]
/// END
///
/// Note that this function is only safe to use for "then" expressions
/// that are infallible because the expression will be evaluated for all
/// rows in the input batch.
fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let when_expr = &self.when_then_expr[0].0;
let then_expr = &self.when_then_expr[0].1;
if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? {
let bit_mask = bit_mask
.as_any()
.downcast_ref::<BooleanArray>()
.expect("predicate should evaluate to a boolean array");
// invert the bitmask
let bit_mask = not(bit_mask)?;
match then_expr.evaluate(batch)? {
ColumnarValue::Array(array) => {
Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
}
ColumnarValue::Scalar(_) => {
internal_err!("expression did not evaluate to an array")
}
}
} else {
internal_err!("predicate did not evaluate to an array")
}
}
}

impl PhysicalExpr for CaseExpr {
Expand Down Expand Up @@ -303,14 +391,21 @@ impl PhysicalExpr for CaseExpr {
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
if self.expr.is_some() {
// this use case evaluates "expr" and then compares the values with the "when"
// values
self.case_when_with_expr(batch)
} else {
// The "when" conditions all evaluate to boolean in this use case and can be
// arbitrary expressions
self.case_when_no_expr(batch)
match self.eval_method {
EvalMethod::WithExpression => {
// this use case evaluates "expr" and then compares the values with the "when"
// values
self.case_when_with_expr(batch)
}
EvalMethod::NoExpression => {
// The "when" conditions all evaluate to boolean in this use case and can be
// arbitrary expressions
self.case_when_no_expr(batch)
}
EvalMethod::InfallibleExprOrNull => {
// Specialization for CASE WHEN expr THEN column [ELSE NULL] END
self.case_column_or_null(batch)
}
}
}

Expand Down Expand Up @@ -409,7 +504,7 @@ pub fn case(
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::{binary, cast, col, lit};
use crate::expressions::{binary, cast, col, lit, BinaryExpr};

use arrow::buffer::Buffer;
use arrow::datatypes::DataType::Float64;
Expand All @@ -419,6 +514,7 @@ mod tests {
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_expr::type_coercion::binary::comparison_coercion;
use datafusion_expr::Operator;
use datafusion_physical_expr_common::expressions::Literal;

#[test]
fn case_with_expr() -> Result<()> {
Expand Down Expand Up @@ -998,6 +1094,53 @@ mod tests {
Ok(())
}

#[test]
fn test_column_or_null_specialization() -> Result<()> {
// create input data
let mut c1 = Int32Builder::new();
let mut c2 = StringBuilder::new();
for i in 0..1000 {
c1.append_value(i);
if i % 7 == 0 {
c2.append_null();
} else {
c2.append_value(&format!("string {i}"));
}
}
let c1 = Arc::new(c1.finish());
let c2 = Arc::new(c2.finish());
let schema = Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();

// CaseWhenExprOrNull should produce same results as CaseExpr
let predicate = Arc::new(BinaryExpr::new(
make_col("c1", 0),
Operator::LtEq,
make_lit_i32(250),
));
let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
match expr.evaluate(&batch)? {
ColumnarValue::Array(array) => {
assert_eq!(1000, array.len());
assert_eq!(785, array.null_count());
}
_ => unreachable!(),
}
Ok(())
}

fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
}

fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
}

fn generate_case_when_with_type_coercion(
expr: Option<Arc<dyn PhysicalExpr>>,
when_thens: Vec<WhenThen>,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ In order to run the sqllogictests running against a previously running Postgres
PG_COMPAT=true PG_URI="postgresql://[email protected]/postgres" cargo test --features=postgres --test sqllogictests
```

The environemnt variables:
The environment variables:

1. `PG_COMPAT` instructs sqllogictest to run against Postgres (not DataFusion)
2. `PG_URI` contains a `libpq` style connection string, whose format is described in
Expand Down
52 changes: 52 additions & 0 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# create test data
statement ok
create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6);

# CASE WHEN with condition
query T
SELECT CASE a WHEN 1 THEN 'one' WHEN 3 THEN 'three' ELSE '?' END FROM foo
----
one
three
?

# CASE WHEN with no condition
query I
SELECT CASE WHEN a > 2 THEN a ELSE b END FROM foo
----
2
3
5

# column or explicit null
query I
SELECT CASE WHEN a > 2 THEN b ELSE null END FROM foo
----
NULL
4
6

# column or implicit null
query I
SELECT CASE WHEN a > 2 THEN b END FROM foo
----
NULL
4
6

0 comments on commit 28fa74b

Please sign in to comment.