From 5da7a433584fe6af07efa205204bd31ac7a05ade Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Mon, 13 Nov 2023 23:27:10 +0400 Subject: [PATCH] feat(cubesql): Support `[NOT] IN` SQL push down --- .../src/adapter/BaseQuery.js | 3 +- .../cubesql/src/compile/engine/df/wrapper.rs | 42 ++++- rust/cubesql/cubesql/src/compile/mod.rs | 38 +++++ .../cubesql/src/compile/rewrite/mod.rs | 8 + .../rewrite/rules/wrapper/in_list_expr.rs | 156 ++++++++++++++++++ .../src/compile/rewrite/rules/wrapper/mod.rs | 2 + rust/cubesql/cubesql/src/compile/test/mod.rs | 1 + rust/cubesql/cubesql/src/transport/service.rs | 18 ++ 8 files changed, 266 insertions(+), 2 deletions(-) create mode 100644 rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/in_list_expr.rs diff --git a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js index 0c5b8879f943e..af695c58db687 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js @@ -2479,7 +2479,8 @@ class BaseQuery { binary: '({{ left }} {{ op }} {{ right }})', sort: '{{ expr }} {% if asc %}ASC{% else %}DESC{% endif %}{% if nulls_first %} NULLS FIRST{% endif %}', cast: 'CAST({{ expr }} AS {{ data_type }})', - window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})' + window_function: '{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})', + in_list: '{{ expr }} {% if negated %}NOT {% endif %}IN ({{ in_exprs_concat }})', }, quotes: { identifiers: '"', diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index b594c78d93eec..39bc3489fc080 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -1429,7 +1429,47 @@ impl CubeScanWrapperNode { Ok((resulting_sql, sql_query)) } // Expr::AggregateUDF { .. } => {} - // Expr::InList { .. } => {} + Expr::InList { + expr, + list, + negated, + } => { + let mut sql_query = sql_query; + let (sql_expr, query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + *expr, + ungrouped_scan_node.clone(), + ) + .await?; + sql_query = query; + let mut sql_in_exprs = Vec::new(); + for expr in list { + let (sql, query) = Self::generate_sql_for_expr( + plan.clone(), + sql_query, + sql_generator.clone(), + expr, + ungrouped_scan_node.clone(), + ) + .await?; + sql_query = query; + sql_in_exprs.push(sql); + } + Ok(( + sql_generator + .get_sql_templates() + .in_list_expr(sql_expr, sql_in_exprs, negated) + .map_err(|e| { + DataFusionError::Internal(format!( + "Can't generate SQL for in list expr: {}", + e + )) + })?, + sql_query, + )) + } // Expr::Wildcard => {} // Expr::QualifiedWildcard { .. } => {} x => { diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 354b4f6d5de43..c43ab358fe50d 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -19591,4 +19591,42 @@ ORDER BY \"COUNT(count)\" DESC" Ok(()) } + + #[tokio::test] + async fn test_inlist_expr() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_logger(); + + let query_plan = convert_select_to_query_plan( + " + SELECT + CASE + WHEN (customer_gender NOT IN ('1', '2', '3')) THEN customer_gender + ELSE '0' + END AS customer_gender + FROM KibanaSampleDataEcommerce AS k + GROUP BY 1 + ORDER BY 1 DESC + " + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + + let logical_plan = query_plan.as_logical_plan(); + assert!(logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql + .contains("NOT IN (")); + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs index 4a86a8bbb3aca..ec120fd17e769 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/mod.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/mod.rs @@ -892,6 +892,14 @@ fn inlist_expr(expr: impl Display, list: impl Display, negated: impl Display) -> format!("(InListExpr {} {} {})", expr, list, negated) } +fn inlist_expr_list(left: impl Display, right: impl Display) -> String { + format!("(InListExprList {} {})", left, right) +} + +fn inlist_expr_list_empty_tail() -> String { + format!("InListExprList") +} + fn between_expr( expr: impl Display, negated: impl Display, diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/in_list_expr.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/in_list_expr.rs new file mode 100644 index 0000000000000..1f387d389e0a8 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/in_list_expr.rs @@ -0,0 +1,156 @@ +use crate::{ + compile::rewrite::{ + analysis::LogicalPlanAnalysis, inlist_expr, inlist_expr_list, inlist_expr_list_empty_tail, + rewrite, rules::wrapper::WrapperRules, transforming_rewrite, wrapper_pullup_replacer, + wrapper_pushdown_replacer, LogicalPlanLanguage, WrapperPullupReplacerAliasToCube, + }, + var, var_iter, +}; +use egg::{EGraph, Rewrite, Subst}; + +impl WrapperRules { + pub fn in_list_expr_rules( + &self, + rules: &mut Vec>, + ) { + rules.extend(vec![ + rewrite( + "wrapper-push-down-in-list", + wrapper_pushdown_replacer( + inlist_expr("?expr", "?list", "?negated"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + inlist_expr( + wrapper_pushdown_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pushdown_replacer( + "?list", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + "?negated", + ), + ), + transforming_rewrite( + "wrapper-pull-up-in-list", + inlist_expr( + wrapper_pullup_replacer( + "?expr", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pullup_replacer( + "?list", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + "?negated", + ), + wrapper_pullup_replacer( + inlist_expr("?expr", "?list", "?negated"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + self.transform_in_list_expr("?alias_to_cube"), + ), + rewrite( + "wrapper-push-down-in-list-exprs", + wrapper_pushdown_replacer( + inlist_expr_list("?left", "?right"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + inlist_expr_list( + wrapper_pushdown_replacer( + "?left", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pushdown_replacer( + "?right", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + ), + ), + rewrite( + "wrapper-pull-up-in-list-exprs", + inlist_expr_list( + wrapper_pullup_replacer( + "?left", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pullup_replacer( + "?right", + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + ), + wrapper_pullup_replacer( + inlist_expr_list("?left", "?right"), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + ), + rewrite( + "wrapper-push-down-in-list-exprs-empty-tail", + wrapper_pushdown_replacer( + inlist_expr_list_empty_tail(), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + wrapper_pullup_replacer( + inlist_expr_list_empty_tail(), + "?alias_to_cube", + "?ungrouped", + "?cube_members", + ), + ), + ]); + } + + fn transform_in_list_expr( + &self, + alias_to_cube_var: &'static str, + ) -> impl Fn(&mut EGraph, &mut Subst) -> bool { + let alias_to_cube_var = var!(alias_to_cube_var); + let meta = self.cube_context.meta.clone(); + move |egraph, subst| { + for alias_to_cube in var_iter!( + egraph[subst[alias_to_cube_var]], + WrapperPullupReplacerAliasToCube + ) + .cloned() + { + if let Some(sql_generator) = meta.sql_generator_by_alias_to_cube(&alias_to_cube) { + if sql_generator + .get_sql_templates() + .templates + .contains_key("expressions/in_list") + { + return true; + } + } + } + false + } + } +} diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs index e3f469a4778f3..e7cff566f0116 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/mod.rs @@ -7,6 +7,7 @@ mod cast; mod column; mod cube_scan_wrapper; mod extract; +mod in_list_expr; mod is_null_expr; mod limit; mod literal; @@ -60,6 +61,7 @@ impl RewriteRules for WrapperRules { self.cast_rules(&mut rules); self.column_rules(&mut rules); self.literal_rules(&mut rules); + self.in_list_expr_rules(&mut rules); rules } diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index 094a4199148e5..6ddf25789b0f6 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -234,6 +234,7 @@ pub fn get_test_tenant_ctx() -> Arc { ("expressions/cast".to_string(), "CAST({{ expr }} AS {{ data_type }})".to_string()), ("expressions/interval".to_string(), "INTERVAL '{{ interval }}'".to_string()), ("expressions/window_function".to_string(), "{{ fun_call }} OVER ({% if partition_by %}PARTITION BY {{ partition_by }}{% if order_by %} {% endif %}{% endif %}{% if order_by %}ORDER BY {{ order_by }}{% endif %})".to_string()), + ("expressions/in_list".to_string(), "{{ expr }} {% if negated %}NOT {% endif %}IN ({{ in_exprs_concat }})".to_string()), ("quotes/identifiers".to_string(), "\"".to_string()), ("quotes/escape".to_string(), "\"\"".to_string()), ("params/param".to_string(), "${{ param_index + 1 }}".to_string()) diff --git a/rust/cubesql/cubesql/src/transport/service.rs b/rust/cubesql/cubesql/src/transport/service.rs index 4a655f8784c40..65bb4d335de3a 100644 --- a/rust/cubesql/cubesql/src/transport/service.rs +++ b/rust/cubesql/cubesql/src/transport/service.rs @@ -536,6 +536,24 @@ impl SqlTemplates { ) } + pub fn in_list_expr( + &self, + expr: String, + in_exprs: Vec, + negated: bool, + ) -> Result { + let in_exprs_concat = in_exprs.join(", "); + self.render_template( + "expressions/in_list", + context! { + expr => expr, + in_exprs_concat => in_exprs_concat, + in_exprs => in_exprs, + negated => negated + }, + ) + } + pub fn param(&self, param_index: usize) -> Result { self.render_template("params/param", context! { param_index => param_index }) }