Skip to content

Commit

Permalink
fix(cubesql): Avoid panics during filter rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
mcheshkov committed Jan 31, 2025
1 parent 57bcbc4 commit 910b324
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 20 deletions.
62 changes: 42 additions & 20 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2807,7 +2807,7 @@ impl FilterRules {
ScalarValue::TimestampNanosecond(_, _)
| ScalarValue::Date32(_)
| ScalarValue::Date64(_) => {
if let Some(timestamp) =
if let Ok(Some(timestamp)) =
Self::scalar_to_native_datetime(&literal)
{
let value = format_iso_timestamp(timestamp);
Expand Down Expand Up @@ -2842,7 +2842,10 @@ impl FilterRules {
continue;
}
}
x => panic!("Unsupported filter scalar: {:?}", x),
x => {
log::trace!("Unsupported filter scalar: {x:?}");
continue;
}
};

subst.insert(
Expand Down Expand Up @@ -3442,6 +3445,7 @@ impl FilterRules {
}

// Transform ?expr IN (?literal) to ?expr = ?literal
// TODO it's incorrect: inner expr can be null, or can be non-literal (and domain in not clear)
fn transform_filter_in_to_equal(
&self,
negated_var: &'static str,
Expand Down Expand Up @@ -3501,7 +3505,10 @@ impl FilterRules {
let values = list
.into_iter()
.map(|literal| FilterRules::scalar_to_value(literal))
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>();
let Ok(values) = values else {
return false;
};

if let Some((member_name, cube)) = Self::filter_member_name(
egraph,
Expand Down Expand Up @@ -3552,8 +3559,8 @@ impl FilterRules {
}
}

fn scalar_to_value(literal: &ScalarValue) -> String {
match literal {
fn scalar_to_value(literal: &ScalarValue) -> Result<String, &'static str> {
Ok(match literal {

Check warning on line 3563 in rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs#L3563

Added line #L3563 was not covered by tests
ScalarValue::Utf8(Some(value)) => value.to_string(),
ScalarValue::Int64(Some(value)) => value.to_string(),
ScalarValue::Boolean(Some(value)) => value.to_string(),
Expand All @@ -3564,18 +3571,24 @@ impl FilterRules {
ScalarValue::TimestampNanosecond(_, _)
| ScalarValue::Date32(_)
| ScalarValue::Date64(_) => {
if let Some(timestamp) = Self::scalar_to_native_datetime(literal) {
return format_iso_timestamp(timestamp);
if let Some(timestamp) = Self::scalar_to_native_datetime(literal)? {
format_iso_timestamp(timestamp)
} else {
log::trace!("Unsupported filter scalar: {literal:?}");
return Err("Unsupported filter scalar");

Check warning on line 3578 in rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs#L3577-L3578

Added lines #L3577 - L3578 were not covered by tests
}

panic!("Unsupported filter scalar: {:?}", literal);
}
x => panic!("Unsupported filter scalar: {:?}", x),
}
x => {
log::trace!("Unsupported filter scalar: {x:?}");
return Err("Unsupported filter scalar");
}
})
}

fn scalar_to_native_datetime(literal: &ScalarValue) -> Option<NaiveDateTime> {
match literal {
fn scalar_to_native_datetime(
literal: &ScalarValue,
) -> Result<Option<NaiveDateTime>, &'static str> {
Ok(match literal {
ScalarValue::TimestampNanosecond(_, _)
| ScalarValue::Date32(_)
| ScalarValue::Date64(_) => {
Expand All @@ -3589,13 +3602,17 @@ impl FilterRules {
} else if let Some(array) = array.as_any().downcast_ref::<Date64Array>() {
array.value_as_datetime(0)
} else {
panic!("Unexpected array type: {:?}", array.data_type())
log::trace!("Unexpected array type: {:?}", array.data_type());
return Err("Unexpected array type");

Check warning on line 3606 in rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs#L3605-L3606

Added lines #L3605 - L3606 were not covered by tests
};

timestamp
}
_ => panic!("Unsupported filter scalar: {:?}", literal),
}
x => {
log::trace!("Unsupported filter scalar: {x:?}");
return Err("Unsupported filter scalar");

Check warning on line 3613 in rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs#L3611-L3613

Added lines #L3611 - L3613 were not covered by tests
}
})
}

fn transform_is_null(
Expand Down Expand Up @@ -3865,10 +3882,15 @@ impl FilterRules {
Some(MemberType::Time) => (),
_ => continue,
}
let values = vec![
FilterRules::scalar_to_value(&low),
FilterRules::scalar_to_value(&high),
];

let Ok(low) = FilterRules::scalar_to_value(&low) else {
return false;

Check warning on line 3887 in rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs#L3887

Added line #L3887 was not covered by tests
};
let Ok(high) = FilterRules::scalar_to_value(&high) else {
return false;

Check warning on line 3890 in rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs#L3890

Added line #L3890 was not covered by tests
};

let values = vec![low, high];

subst.insert(
filter_member_var,
Expand Down
73 changes: 73 additions & 0 deletions rust/cubesql/cubesql/src/compile/test/test_filters.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use cubeclient::models::{V1LoadRequestQuery, V1LoadRequestQueryFilterItem};
use datafusion::physical_plan::displayable;
use pretty_assertions::assert_eq;

use crate::compile::{
Expand Down Expand Up @@ -60,3 +61,75 @@ GROUP BY
}
);
}

#[tokio::test]
async fn test_filter_dim_in_null() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let query_plan = convert_select_to_query_plan(
// language=PostgreSQL
r#"
SELECT
dim_str0
FROM
MultiTypeCube
WHERE dim_str1 IN (NULL)
"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);

// For now this tests only that query is rewritable
// TODO support this as "notSet" filter

assert!(query_plan
.as_logical_plan()
.find_cube_scan_wrapped_sql()
.wrapped_sql
.sql
.contains(r#"\"expr\":\"${MultiTypeCube.dim_str1} IN (NULL)\""#));
}

#[tokio::test]
async fn test_filter_superset_is_null() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let query_plan = convert_select_to_query_plan(
// language=PostgreSQL
r#"
SELECT dim_str0 FROM MultiTypeCube WHERE (dim_str1 IS NULL OR dim_str1 IN (NULL) AND (1<>1))
"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);

// For now this tests only that query is rewritable
// TODO support this as "notSet" filter

assert!(query_plan
.as_logical_plan()
.find_cube_scan_wrapped_sql()
.wrapped_sql
.sql
.contains(r#"\"expr\":\"(${MultiTypeCube.dim_str1} IS NULL OR (${MultiTypeCube.dim_str1} IN (NULL) AND FALSE))\""#));
}

0 comments on commit 910b324

Please sign in to comment.