Skip to content

Commit

Permalink
feat(cubesql): Support [NOT] IN SQL push down
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Nov 13, 2023
1 parent 3e1a075 commit 5da7a43
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 2 deletions.
3 changes: 2 additions & 1 deletion packages/cubejs-schema-compiler/src/adapter/BaseQuery.js
Original file line number Diff line number Diff line change
Expand Up @@ -2479,7 +2479,8 @@ class BaseQuery {
binary: '({{ left }} {{ op }} {{ right }})',
sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}',
cast: 'CAST({{ expr }} AS {{ data_type }})',
window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})'
window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})',
in_list: '{{ expr }} {% if negated %}NOT {% endif %}IN ({{ in_exprs_concat }})',
},
quotes: {
identifiers: '"',
Expand Down
42 changes: 41 additions & 1 deletion rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,47 @@ impl CubeScanWrapperNode {
Ok((resulting_sql, sql_query))
}
// Expr::AggregateUDF { .. } => {}
// Expr::InList { .. } => {}
Expr::InList {
expr,
list,
negated,
} => {
let mut sql_query = sql_query;
let (sql_expr, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
*expr,
ungrouped_scan_node.clone(),
)
.await?;

Check warning on line 1445 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1445

Added line #L1445 was not covered by tests
sql_query = query;
let mut sql_in_exprs = Vec::new();
for expr in list {
let (sql, query) = Self::generate_sql_for_expr(
plan.clone(),
sql_query,
sql_generator.clone(),
expr,
ungrouped_scan_node.clone(),
)
.await?;

Check warning on line 1456 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1456

Added line #L1456 was not covered by tests
sql_query = query;
sql_in_exprs.push(sql);
}
Ok((
sql_generator
.get_sql_templates()
.in_list_expr(sql_expr, sql_in_exprs, negated)
.map_err(|e| {
DataFusionError::Internal(format!(

Check warning on line 1465 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L1464-L1465

Added lines #L1464 - L1465 were not covered by tests
"Can't generate SQL for in list expr: {}",
e
))
})?,
sql_query,
))
}
// Expr::Wildcard => {}
// Expr::QualifiedWildcard { .. } => {}
x => {
Expand Down
38 changes: 38 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19591,4 +19591,42 @@ ORDER BY \"COUNT(count)\" DESC"

Ok(())
}

#[tokio::test]
async fn test_inlist_expr() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_logger();

let query_plan = convert_select_to_query_plan(
"
SELECT
CASE
WHEN (customer_gender NOT IN ('1', '2', '3')) THEN customer_gender
ELSE '0'
END AS customer_gender
FROM KibanaSampleDataEcommerce AS k
GROUP BY 1
ORDER BY 1 DESC
"
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);

let logical_plan = query_plan.as_logical_plan();
assert!(logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql
.contains("NOT IN ("));
}
}
8 changes: 8 additions & 0 deletions rust/cubesql/cubesql/src/compile/rewrite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,14 @@ fn inlist_expr(expr: impl Display, list: impl Display, negated: impl Display) ->
format!("(InListExpr {} {} {})", expr, list, negated)
}

fn inlist_expr_list(left: impl Display, right: impl Display) -> String {
format!("(InListExprList {} {})", left, right)
}

fn inlist_expr_list_empty_tail() -> String {
format!("InListExprList")
}

fn between_expr(
expr: impl Display,
negated: impl Display,
Expand Down
156 changes: 156 additions & 0 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/in_list_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
use crate::{
compile::rewrite::{
analysis::LogicalPlanAnalysis, inlist_expr, inlist_expr_list, inlist_expr_list_empty_tail,
rewrite, rules::wrapper::WrapperRules, transforming_rewrite, wrapper_pullup_replacer,
wrapper_pushdown_replacer, LogicalPlanLanguage, WrapperPullupReplacerAliasToCube,
},
var, var_iter,
};
use egg::{EGraph, Rewrite, Subst};

impl WrapperRules {
pub fn in_list_expr_rules(
&self,
rules: &mut Vec<Rewrite<LogicalPlanLanguage, LogicalPlanAnalysis>>,
) {
rules.extend(vec![
rewrite(
"wrapper-push-down-in-list",
wrapper_pushdown_replacer(
inlist_expr("?expr", "?list", "?negated"),
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
inlist_expr(
wrapper_pushdown_replacer(
"?expr",
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
wrapper_pushdown_replacer(
"?list",
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
"?negated",
),
),
transforming_rewrite(
"wrapper-pull-up-in-list",
inlist_expr(
wrapper_pullup_replacer(
"?expr",
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
wrapper_pullup_replacer(
"?list",
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
"?negated",
),
wrapper_pullup_replacer(
inlist_expr("?expr", "?list", "?negated"),
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
self.transform_in_list_expr("?alias_to_cube"),
),
rewrite(
"wrapper-push-down-in-list-exprs",
wrapper_pushdown_replacer(
inlist_expr_list("?left", "?right"),
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
inlist_expr_list(
wrapper_pushdown_replacer(
"?left",
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
wrapper_pushdown_replacer(
"?right",
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
),
),
rewrite(
"wrapper-pull-up-in-list-exprs",
inlist_expr_list(
wrapper_pullup_replacer(
"?left",
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
wrapper_pullup_replacer(
"?right",
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
),
wrapper_pullup_replacer(
inlist_expr_list("?left", "?right"),
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
),
rewrite(
"wrapper-push-down-in-list-exprs-empty-tail",
wrapper_pushdown_replacer(
inlist_expr_list_empty_tail(),
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
wrapper_pullup_replacer(
inlist_expr_list_empty_tail(),
"?alias_to_cube",
"?ungrouped",
"?cube_members",
),
),
]);
}

fn transform_in_list_expr(
&self,
alias_to_cube_var: &'static str,
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
let alias_to_cube_var = var!(alias_to_cube_var);
let meta = self.cube_context.meta.clone();
move |egraph, subst| {
for alias_to_cube in var_iter!(
egraph[subst[alias_to_cube_var]],
WrapperPullupReplacerAliasToCube
)
.cloned()
{
if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) {
if sql_generator
.get_sql_templates()
.templates
.contains_key("expressions/in_list")
{
return true;
}
}
}
false

Check warning on line 153 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/in_list_expr.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/in_list_expr.rs#L153

Added line #L153 was not covered by tests
}
}
}
2 changes: 2 additions & 0 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod cast;
mod column;
mod cube_scan_wrapper;
mod extract;
mod in_list_expr;
mod is_null_expr;
mod limit;
mod literal;
Expand Down Expand Up @@ -60,6 +61,7 @@ impl RewriteRules for WrapperRules {
self.cast_rules(&mut rules);
self.column_rules(&mut rules);
self.literal_rules(&mut rules);
self.in_list_expr_rules(&mut rules);

rules
}
Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/compile/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ pub fn get_test_tenant_ctx() -> Arc<MetaContext> {
("expressions/cast".to_string(), "CAST({{ expr }} AS {{ data_type }})".to_string()),
("expressions/interval".to_string(), "INTERVAL '{{ interval }}'".to_string()),
("expressions/window_function".to_string(), "{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})".to_string()),
("expressions/in_list".to_string(), "{{ expr }} {% if negated %}NOT {% endif %}IN ({{ in_exprs_concat }})".to_string()),
("quotes/identifiers".to_string(), "\"".to_string()),
("quotes/escape".to_string(), "\"\"".to_string()),
("params/param".to_string(), "${{ param_index + 1 }}".to_string())
Expand Down
18 changes: 18 additions & 0 deletions rust/cubesql/cubesql/src/transport/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,24 @@ impl SqlTemplates {
)
}

pub fn in_list_expr(
&self,
expr: String,
in_exprs: Vec<String>,
negated: bool,
) -> Result<String, CubeError> {
let in_exprs_concat = in_exprs.join(", ");
self.render_template(
"expressions/in_list",
context! {
expr => expr,
in_exprs_concat => in_exprs_concat,
in_exprs => in_exprs,
negated => negated
},
)
}

pub fn param(&self, param_index: usize) -> Result<String, CubeError> {
self.render_template("params/param", context! { param_index => param_index })
}
Expand Down

0 comments on commit 5da7a43

Please sign in to comment.