Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Optimize CASE expression for "column or null" use case #11534

Merged
merged 5 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added datafusion/core/example.parquet
Binary file not shown.
41 changes: 37 additions & 4 deletions datafusion/physical-expr/benches/case_when.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,29 @@ fn criterion_benchmark(c: &mut Criterion) {
// create input data
let mut c1 = Int32Builder::new();
let mut c2 = StringBuilder::new();
let mut c3 = StringBuilder::new();
for i in 0..1000 {
c1.append_value(i);
if i % 7 == 0 {
c2.append_null();
} else {
c2.append_value(&format!("string {i}"));
}
if i % 9 == 0 {
c3.append_null();
} else {
c3.append_value(&format!("other string {i}"));
}
}
let c1 = Arc::new(c1.finish());
let c2 = Arc::new(c2.finish());
let c3 = Arc::new(c3.finish());
let schema = Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
Field::new("c3", DataType::Utf8, true),
]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap();

// use same predicate for all benchmarks
let predicate = Arc::new(BinaryExpr::new(
Expand All @@ -63,7 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) {
make_lit_i32(500),
));

// CASE WHEN expr THEN 1 ELSE 0 END
// CASE WHEN c1 <= 500 THEN 1 ELSE 0 END
c.bench_function("case_when: scalar or scalar", |b| {
let expr = Arc::new(
CaseExpr::try_new(
Expand All @@ -76,13 +84,38 @@ fn criterion_benchmark(c: &mut Criterion) {
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});

// CASE WHEN expr THEN col ELSE null END
// CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END
c.bench_function("case_when: column or null", |b| {
let expr = Arc::new(
CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});

// CASE WHEN c1 <= 500 THEN c2 ELSE c3 END
c.bench_function("case_when: expr or expr", |b| {
let expr = Arc::new(
CaseExpr::try_new(
None,
vec![(predicate.clone(), make_col("c2", 1))],
Some(Arc::new(Literal::new(ScalarValue::Utf8(None)))),
Some(make_col("c3", 2)),
)
.unwrap(),
);
b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap()))
});

// CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END
c.bench_function("case_when: CASE expr", |b| {
let expr = Arc::new(
CaseExpr::try_new(
Some(make_col("c1", 0)),
vec![
(make_lit_i32(1), make_col("c2", 1)),
(make_lit_i32(2), make_col("c3", 2)),
],
None,
)
.unwrap(),
);
Expand Down
161 changes: 152 additions & 9 deletions datafusion/physical-expr/src/expressions/case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,33 @@ use datafusion_common::cast::as_boolean_array;
use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;

use datafusion_physical_expr_common::expressions::column::Column;
use datafusion_physical_expr_common::expressions::Literal;
use itertools::Itertools;

type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);

#[derive(Debug, Hash)]
enum EvalMethod {
/// CASE WHEN condition THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
NoExpression,
/// CASE expression
/// WHEN value THEN result
/// [WHEN ...]
/// [ELSE result]
/// END
WithExpression,
/// This is a specialization for a specific use case where we can take a fast path
/// for expressions that are infallible and can be cheaply computed for the entire
/// record batch rather than just for the rows where the predicate is true.
///
/// CASE WHEN condition THEN column [ELSE NULL] END
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'm wondering why this special format of case when is not optimized to if/else during query optimization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think DataFusion has an if/else expression. We have one in Comet that could be upstreamed.

InfallibleExprOrNull,
}

/// The CASE expression is similar to a series of nested if/else and there are two forms that
/// can be used. The first form consists of a series of boolean "when" expressions with
/// corresponding "then" expressions, and an optional "else" expression.
Expand All @@ -61,6 +84,8 @@ pub struct CaseExpr {
when_then_expr: Vec<WhenThen>,
/// Optional "else" expression
else_expr: Option<Arc<dyn PhysicalExpr>>,
/// Evaluation method to use
eval_method: EvalMethod,
}

impl std::fmt::Display for CaseExpr {
Expand All @@ -79,20 +104,51 @@ impl std::fmt::Display for CaseExpr {
}
}

/// This is a specialization for a specific use case where we can take a fast path
/// for expressions that are infallible and can be cheaply computed for the entire
/// record batch rather than just for the rows where the predicate is true. For now,
/// this is limited to use with Column expressions but could potentially be used for other
/// expressions in the future
fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
expr.as_any().is::<Column>()
}

impl CaseExpr {
/// Create a new CASE WHEN expression
pub fn try_new(
expr: Option<Arc<dyn PhysicalExpr>>,
when_then_expr: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Result<Self> {
// normalize null literals to None in the else_expr (this already happens
// during SQL planning, but not necessarily for other use cases)
let else_expr = match &else_expr {
Some(e) => match e.as_any().downcast_ref::<Literal>() {
Some(lit) if lit.value().is_null() => None,
_ => else_expr,
},
_ => else_expr,
};

if when_then_expr.is_empty() {
exec_err!("There must be at least one WHEN clause")
} else {
let eval_method = if expr.is_some() {
EvalMethod::WithExpression
} else if when_then_expr.len() == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this optimization and I think it would be valid for any CASE expression that had a NULL ELSE, not just column

So in other words, I think you could remove the when_then_expr[0].1.as_any().is::<Column>() check and this would still work fine

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, it is only safe to use this approach for expressions that are infallible. One of the main reasons that regular CaseExpr is expensive is that we need to only evaluate the "true" expression on rows where the predicate has evaluated to true.

I do think that this could be extended beyond just Column expressions though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a comment in the code to explain this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more comments to explain the rationale and what is safe or not

&& is_cheap_and_infallible(&(when_then_expr[0].1))
&& else_expr.is_none()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If else value is a null literal, is else_expr still none? Or is it a null literal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From SQL planning, it will be None, but it makes sense to add code here to normalize this for other use cases (other query engines that delegate to DataFusion). I have added this.

{
EvalMethod::InfallibleExprOrNull
} else {
EvalMethod::NoExpression
};

Ok(Self {
expr,
when_then_expr,
else_expr,
eval_method,
})
}
}
Expand Down Expand Up @@ -256,6 +312,38 @@ impl CaseExpr {

Ok(ColumnarValue::Array(current_value))
}

/// This function evaluates the specialized case of:
///
/// CASE WHEN condition THEN column
/// [ELSE NULL]
/// END
///
/// Note that this function is only safe to use for "then" expressions
/// that are infallible because the expression will be evaluated for all
/// rows in the input batch.
fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let when_expr = &self.when_then_expr[0].0;
let then_expr = &self.when_then_expr[0].1;
if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is evaluated to Scalar, maybe we still can work on it? I.e., we can convert the scalar value to an boolean array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am planning on adding a specialization for scalar values as well to avoid converting scalars to arrays

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, nm, your question is different. Yes, we could implement special handling for CASE WHEN [true|false] THEN but I'm not sure that is a real-world use case that is worth optimizing

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I forgot that in this optimization, the when expression could only be a Column.

let bit_mask = bit_mask
.as_any()
.downcast_ref::<BooleanArray>()
.expect("predicate should evaluate to a boolean array");
// invert the bitmask
let bit_mask = not(bit_mask)?;
Copy link
Contributor

@Dandandan Dandandan Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a further optimization I think the when predicate can be transformed to be not(condition) so it's possible to use bit_mask directly instead of not(bit_mask).
This makes it possible to optimize/simplify a condition like x=1 to x!=1 instead of not(x=1), saving the invert.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I will look at this suggestion as a follow on. Great idea.

match then_expr.evaluate(batch)? {
ColumnarValue::Array(array) => {
Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
}
ColumnarValue::Scalar(_) => {
internal_err!("expression did not evaluate to an array")
}
}
} else {
internal_err!("predicate did not evaluate to an array")
}
}
}

impl PhysicalExpr for CaseExpr {
Expand Down Expand Up @@ -303,14 +391,21 @@ impl PhysicalExpr for CaseExpr {
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
if self.expr.is_some() {
// this use case evaluates "expr" and then compares the values with the "when"
// values
self.case_when_with_expr(batch)
} else {
// The "when" conditions all evaluate to boolean in this use case and can be
// arbitrary expressions
self.case_when_no_expr(batch)
match self.eval_method {
EvalMethod::WithExpression => {
// this use case evaluates "expr" and then compares the values with the "when"
// values
self.case_when_with_expr(batch)
}
EvalMethod::NoExpression => {
// The "when" conditions all evaluate to boolean in this use case and can be
// arbitrary expressions
self.case_when_no_expr(batch)
}
EvalMethod::InfallibleExprOrNull => {
// Specialization for CASE WHEN expr THEN column [ELSE NULL] END
self.case_column_or_null(batch)
}
}
}

Expand Down Expand Up @@ -409,7 +504,7 @@ pub fn case(
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::{binary, cast, col, lit};
use crate::expressions::{binary, cast, col, lit, BinaryExpr};

use arrow::buffer::Buffer;
use arrow::datatypes::DataType::Float64;
Expand All @@ -419,6 +514,7 @@ mod tests {
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_expr::type_coercion::binary::comparison_coercion;
use datafusion_expr::Operator;
use datafusion_physical_expr_common::expressions::Literal;

#[test]
fn case_with_expr() -> Result<()> {
Expand Down Expand Up @@ -998,6 +1094,53 @@ mod tests {
Ok(())
}

#[test]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to make sure we had a slt level test to cover this as well,

Maybe in
https://github.com/apache/datafusion/blob/382bf4f3c7a730828684b9e4ce01369b89717e19/datafusion/sqllogictest/test_files/expr.slt

Or we could start adding a file just for CASE if we are about to spend a bunch of time optimizing it 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I willl add slt tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am planning on more CASE optimizations so will create a separate file

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the slt test. This is my first time using slt and it is very cool

fn test_column_or_null_specialization() -> Result<()> {
// create input data
let mut c1 = Int32Builder::new();
let mut c2 = StringBuilder::new();
for i in 0..1000 {
c1.append_value(i);
if i % 7 == 0 {
c2.append_null();
} else {
c2.append_value(&format!("string {i}"));
}
}
let c1 = Arc::new(c1.finish());
let c2 = Arc::new(c2.finish());
let schema = Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();

// CaseWhenExprOrNull should produce same results as CaseExpr
let predicate = Arc::new(BinaryExpr::new(
make_col("c1", 0),
Operator::LtEq,
make_lit_i32(250),
));
let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
match expr.evaluate(&batch)? {
ColumnarValue::Array(array) => {
assert_eq!(1000, array.len());
assert_eq!(785, array.null_count());
}
_ => unreachable!(),
}
Ok(())
}

fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
}

fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
}

fn generate_case_when_with_type_coercion(
expr: Option<Arc<dyn PhysicalExpr>>,
when_thens: Vec<WhenThen>,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ In order to run the sqllogictests running against a previously running Postgres
PG_COMPAT=true PG_URI="postgresql://[email protected]/postgres" cargo test --features=postgres --test sqllogictests
```

The environemnt variables:
The environment variables:

1. `PG_COMPAT` instructs sqllogictest to run against Postgres (not DataFusion)
2. `PG_URI` contains a `libpq` style connection string, whose format is described in
Expand Down
52 changes: 52 additions & 0 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# create test data
statement ok
create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6);

# CASE WHEN with condition
query T
SELECT CASE a WHEN 1 THEN 'one' WHEN 3 THEN 'three' ELSE '?' END FROM foo
----
one
three
?

# CASE WHEN with no condition
query I
SELECT CASE WHEN a > 2 THEN a ELSE b END FROM foo
----
2
3
5

# column or explicit null
query I
SELECT CASE WHEN a > 2 THEN b ELSE null END FROM foo
----
NULL
4
6

# column or implicit null
query I
SELECT CASE WHEN a > 2 THEN b END FROM foo
----
NULL
4
6