From 6c8557a561ef34cd7e98102da535ff23a0911de4 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 5 Dec 2024 15:48:34 +0800 Subject: [PATCH] feat(core): avoid to generate unnamed subquery when planning calculated fields (#969) --- .../routers/v3/connector/postgres/conftest.py | 3 + .../v3/connector/postgres/test_query.py | 98 ++++++++++++++++++- ibis-server/tools/mdl_validation.py | 2 +- .../logical_plan/analyze/model_generation.rs | 56 ++++++++--- .../core/src/logical_plan/analyze/plan.rs | 25 +++++ .../logical_plan/analyze/relation_chain.rs | 91 +++++++++++++---- wren-core/core/src/logical_plan/utils.rs | 48 +++++++-- wren-core/core/src/mdl/mod.rs | 51 +++++++++- 8 files changed, 328 insertions(+), 46 deletions(-) diff --git a/ibis-server/tests/routers/v3/connector/postgres/conftest.py b/ibis-server/tests/routers/v3/connector/postgres/conftest.py index 0b5b3d116..fa277fc4f 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/conftest.py +++ b/ibis-server/tests/routers/v3/connector/postgres/conftest.py @@ -26,6 +26,9 @@ def postgres(request) -> PostgresContainer: pd.read_parquet(file_path("resource/tpch/data/orders.parquet")).to_sql( "orders", engine, index=False ) + pd.read_parquet(file_path("resource/tpch/data/customer.parquet")).to_sql( + "customer", engine, index=False + ) request.addfinalizer(pg.stop) return pg diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_query.py b/ibis-server/tests/routers/v3/connector/postgres/test_query.py index 685472fec..d156b236d 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_query.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_query.py @@ -26,7 +26,13 @@ }, { "name": "o_totalprice", + "type": "text", + "isHidden": True, + }, + { + "name": "o_totalprice_double", "type": "double", + "expression": "cast(o_totalprice as double)", }, {"name": "o_orderdate", "type": "date"}, { @@ -57,6 +63,33 @@ ], "primaryKey": "o_orderkey", }, + { + "name": "customer", + "tableReference": { + "schema": "public", + "table": "customer", + }, + "columns": [ + {"name": "c_custkey", "type": "integer"}, + {"name": "c_name", "type": "varchar"}, + {"name": "orders", "type": "orders", "relationship": "orders_customer"}, + { + "name": "sum_totalprice", + "type": "double", + "isCalculated": True, + "expression": "sum(orders.o_totalprice_double)", + }, + ], + "primaryKey": "c_custkey", + }, + ], + "relationships": [ + { + "name": "orders_customer", + "models": ["orders", "customer"], + "joinType": "many_to_one", + "condition": "orders.o_custkey = customer.c_custkey", + } ], } @@ -79,25 +112,25 @@ def test_query(manifest_str, connection_info): ) assert response.status_code == 200 result = response.json() - assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["columns"]) == 10 assert len(result["data"]) == 1 assert result["data"][0] == [ "2024-01-01 23:59:59.000000", "2024-01-01 23:59:59.000000 UTC", "2024-01-16 04:00:00.000000 UTC", # utc-5 "2024-07-16 03:00:00.000000 UTC", # utc-4 + 172799.49, "1_370", 370, "1996-01-02", 1, "O", - "172799.49", ] assert result["dtypes"] == { "o_orderkey": "int32", "o_custkey": "int32", "o_orderstatus": "object", - "o_totalprice": "object", + "o_totalprice_double": "float64", "o_orderdate": "object", "order_cust_key": "object", "timestamp": "object", @@ -117,7 +150,7 @@ def test_query_with_connection_url(manifest_str, connection_url): ) assert response.status_code == 200 result = response.json() - assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["columns"]) == 10 assert len(result["data"]) == 1 assert result["data"][0][0] == "2024-01-01 23:59:59.000000" assert result["dtypes"] is not None @@ -227,3 +260,60 @@ def test_query_with_dry_run_and_invalid_sql(manifest_str, connection_info): ) assert response.status_code == 422 assert response.text is not None + + def test_query_to_many_calculation(manifest_str, connection_info): + response = client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT sum_totalprice FROM wren.public.customer limit 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == 1 + assert len(result["data"]) == 1 + assert result["dtypes"] == {"sum_totalprice": "float64"} + + response = client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT sum_totalprice FROM wren.public.customer where c_name = 'Customer#000000001' limit 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == 1 + assert len(result["data"]) == 1 + assert result["dtypes"] == {"sum_totalprice": "float64"} + + response = client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT c_name, sum_totalprice FROM wren.public.customer limit 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == 2 + assert len(result["data"]) == 1 + assert result["dtypes"] == {"c_name": "object", "sum_totalprice": "float64"} + + response = client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT c_custkey, sum_totalprice FROM wren.public.customer limit 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == 2 + assert len(result["data"]) == 1 + assert result["dtypes"] == {"c_custkey": "int32", "sum_totalprice": "float64"} diff --git a/ibis-server/tools/mdl_validation.py b/ibis-server/tools/mdl_validation.py index aa708515e..8b9f593f1 100644 --- a/ibis-server/tools/mdl_validation.py +++ b/ibis-server/tools/mdl_validation.py @@ -35,7 +35,7 @@ for model in mdl["models"]: for column in model["columns"]: # ignore hidden columns - if column.get("isHidden"): + if column.get("isHidden") or column.get("relationship") is not None: continue sql = f"select \"{column['name']}\" from \"{model['name']}\"" try: diff --git a/wren-core/core/src/logical_plan/analyze/model_generation.rs b/wren-core/core/src/logical_plan/analyze/model_generation.rs index 603e0137a..ffd933766 100644 --- a/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -1,18 +1,22 @@ use std::fmt::Debug; use std::sync::Arc; +use datafusion::common::alias::AliasGenerator; use datafusion::common::config::ConfigOptions; use datafusion::common::tree_node::{Transformed, TransformedResult}; use datafusion::common::{plan_err, Result}; use datafusion::logical_expr::{col, ident, Extension, UserDefinedLogicalNodeCore}; use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion::optimizer::analyzer::AnalyzerRule; +use datafusion::physical_plan::internal_err; use datafusion::sql::TableReference; use crate::logical_plan::analyze::plan::{ CalculationPlanNode, ModelPlanNode, ModelSourceNode, PartialModelPlanNode, }; -use crate::logical_plan::utils::create_remote_table_source; +use crate::logical_plan::utils::{ + create_remote_table_source, eliminate_ambiguous_columns, rebase_column, +}; use crate::mdl::manifest::Model; use crate::mdl::utils::quoted; use crate::mdl::{AnalyzedWrenMDL, SessionStateRef}; @@ -35,24 +39,37 @@ impl ModelGenerationRule { &self, plan: LogicalPlan, ) -> Result> { + let alias_generator = AliasGenerator::default(); match plan { LogicalPlan::Extension(extension) => { if let Some(model_plan) = extension.node.as_any().downcast_ref::() { - let source_plan = model_plan.relation_chain.clone().plan( + let (source_plan, alias) = model_plan.relation_chain.clone().plan( ModelGenerationRule::new( Arc::clone(&self.analyzed_wren_mdl), Arc::clone(&self.session_state), ), + &alias_generator, )?; + + let projections = if let Some(alias) = alias { + model_plan + .required_exprs + .iter() + .map(|expr| rebase_column(expr, &alias).unwrap()) + .collect() + } else { + model_plan.required_exprs.clone() + }; + let projections = eliminate_ambiguous_columns(projections); let result = match source_plan { Some(plan) => { if model_plan.required_exprs.is_empty() { plan } else { LogicalPlanBuilder::from(plan) - .project(model_plan.required_exprs.clone())? + .project(projections)? .build()? } } @@ -119,23 +136,35 @@ impl ModelGenerationRule { .as_any() .downcast_ref::( ) { - let source_plan = calculation_plan.relation_chain.clone().plan( - ModelGenerationRule::new( - Arc::clone(&self.analyzed_wren_mdl), - Arc::clone(&self.session_state), - ), - )?; + let (source_plan, plan_alias) = + calculation_plan.relation_chain.clone().plan( + ModelGenerationRule::new( + Arc::clone(&self.analyzed_wren_mdl), + Arc::clone(&self.session_state), + ), + &alias_generator, + )?; + + let plan_alias = if let Some(alias) = plan_alias { + alias + } else { + return internal_err!("calculation plan should have an alias"); + }; if let Expr::Alias(alias) = calculation_plan.measures[0].clone() { let measure: Expr = *alias.expr.clone(); + let rebased_measure = rebase_column(&measure, &plan_alias)?; let name = alias.name.clone(); - let ident = ident(measure.to_string()).alias(name); - let project = vec![calculation_plan.dimensions[0].clone(), ident]; + let ident = + ident(rebased_measure.to_string()).alias(name.clone()); + let rebased_dimension = + rebase_column(&calculation_plan.dimensions[0], &plan_alias)?; + let project = vec![rebased_dimension.clone(), ident]; let result = match source_plan { Some(plan) => LogicalPlanBuilder::from(plan) .aggregate( - calculation_plan.dimensions.clone(), - vec![measure], + vec![rebased_dimension], + vec![rebased_measure], )? .project(project)? .build()?, @@ -169,6 +198,7 @@ impl ModelGenerationRule { .iter() .map(|f| col(datafusion::common::Column::from((None, f)))) .collect(); + let projection = eliminate_ambiguous_columns(projection); let alias = LogicalPlanBuilder::from(source_plan) .project(projection)? .alias(quoted(&partial_model.model_node.plan_name))? diff --git a/wren-core/core/src/logical_plan/analyze/plan.rs b/wren-core/core/src/logical_plan/analyze/plan.rs index a114fc768..22c1d446b 100644 --- a/wren-core/core/src/logical_plan/analyze/plan.rs +++ b/wren-core/core/src/logical_plan/analyze/plan.rs @@ -254,6 +254,31 @@ impl ModelPlanNodeBuilder { let Some(source) = self.directed_graph.node_weight(start) else { return internal_err!("Dataset not found"); }; + + // insert the primary key to the required fields for join with the calculation + let keys = self + .model_required_fields + .keys() + .cloned() + .collect::>(); + for model in keys { + let Some(pk_column) = self + .analyzed_wren_mdl + .wren_mdl() + .get_model(model.table()) + .and_then(|m| m.primary_key().and_then(|pk| m.get_column(pk))) + else { + debug!("Primary key not found for model {}", model); + continue; + }; + self.model_required_fields + .entry(model.clone()) + .or_default() + .insert(OrdExpr::new(Expr::Column(Column::from_qualified_name( + format!("{}.{}", quoted(model.table()), quoted(pk_column.name()),), + )))); + } + let mut source_required_fields: Vec = self .model_required_fields .get(&model_ref) diff --git a/wren-core/core/src/logical_plan/analyze/relation_chain.rs b/wren-core/core/src/logical_plan/analyze/relation_chain.rs index 8abeb594f..99d603080 100644 --- a/wren-core/core/src/logical_plan/analyze/relation_chain.rs +++ b/wren-core/core/src/logical_plan/analyze/relation_chain.rs @@ -3,23 +3,31 @@ use crate::logical_plan::analyze::plan::{ CalculationPlanNode, ModelPlanNode, ModelSourceNode, OrdExpr, PartialModelPlanNode, }; use crate::logical_plan::analyze::relation_chain::RelationChain::Start; -use crate::logical_plan::utils::create_schema; -use crate::mdl; +use crate::logical_plan::utils::{ + create_schema, eliminate_ambiguous_columns, rebase_column, +}; use crate::mdl::lineage::DatasetLink; use crate::mdl::manifest::JoinType; use crate::mdl::utils::{qualify_name_from_column_name, quoted}; use crate::mdl::Dataset; use crate::mdl::{AnalyzedWrenMDL, SessionStateRef}; +use crate::{mdl, DataFusionError}; +use datafusion::common::alias::AliasGenerator; use datafusion::common::TableReference; -use datafusion::common::{internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef}; +use datafusion::common::{ + internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef, Result, +}; use datafusion::logical_expr::{ - col, Expr, Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNodeCore, + col, Expr, Extension, LogicalPlan, LogicalPlanBuilder, SubqueryAlias, + UserDefinedLogicalNodeCore, }; use petgraph::graph::NodeIndex; use petgraph::Graph; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; +const ALIAS: &str = "__relation_"; + /// RelationChain is a chain of models that are connected by the relationship. /// The chain is used to generate the join plan for the model. /// The physical layout will be looked like: @@ -36,7 +44,7 @@ impl RelationChain { required_fields: Vec, analyzed_wren_mdl: Arc, session_state_ref: SessionStateRef, - ) -> datafusion::common::Result { + ) -> Result { match dataset { Dataset::Model(source_model) => { Ok(Start(LogicalPlan::Extension(Extension { @@ -63,7 +71,7 @@ impl RelationChain { model_required_fields: &HashMap>, analyzed_wren_mdl: Arc, session_state_ref: SessionStateRef, - ) -> datafusion::common::Result { + ) -> Result { let mut relation_chain = source; for next in iter { @@ -129,18 +137,52 @@ impl RelationChain { pub(crate) fn plan( &mut self, rule: ModelGenerationRule, - ) -> datafusion::common::Result> { + alias_generator: &AliasGenerator, + ) -> Result<(Option, Option)> { match self { RelationChain::Chain(plan, _, condition, ref mut next) => { let left = rule.generate_model_internal(plan.clone())?.data; + let left_alias = if let LogicalPlan::SubqueryAlias(SubqueryAlias { + alias, + .. + }) = &left + { + alias.table() + } else { + return internal_err!( + "model plan should be wrapped in a subquery alias" + ); + }; + + let (Some(right), right_alias) = next.plan(rule, alias_generator)? else { + return plan_err!("Nil relation chain"); + }; + let join_keys: Vec = mdl::utils::collect_identifiers(condition)? .iter() .map(|c| col(qualify_name_from_column_name(c))) .collect(); + + // The right key should be rebased if the right table has a generated alias + let join_keys = join_keys + .into_iter() + .map(|expr| match expr { + Expr::Column(c) => { + if c.relation + .clone() + .map(|r| r.table() != left_alias) + .unwrap_or(false) + { + if let Some(right_alias) = &right_alias { + return rebase_column(&Expr::Column(c), right_alias); + } + } + Ok::<_, DataFusionError>(Expr::Column(c)) + } + _ => Ok::<_, DataFusionError>(expr), + }) + .collect::>>()?; let join_condition = join_keys[0].clone().eq(join_keys[1].clone()); - let Some(right) = next.plan(rule)? else { - return plan_err!("Nil relation chain"); - }; let mut required_exprs = BTreeSet::new(); // collect the output calculated fields match plan { @@ -242,19 +284,26 @@ impl RelationChain { .iter() .map(|expr| expr.expr.clone()) .collect(); - - Ok(Some( - LogicalPlanBuilder::from(left) - .join_on( - right, - datafusion::logical_expr::JoinType::Right, - vec![join_condition], - )? - .project(required_field)? - .build()?, + let required_field = eliminate_ambiguous_columns(required_field); + let alias = alias_generator.next(ALIAS); + Ok(( + Some( + LogicalPlanBuilder::from(left) + .join_on( + right, + datafusion::logical_expr::JoinType::Right, + vec![join_condition], + )? + .project(required_field)? + .alias(&alias)? + .build()?, + ), + Some(alias), )) } - Start(plan) => Ok(Some(rule.generate_model_internal(plan.clone())?.data)), + Start(plan) => { + Ok((Some(rule.generate_model_internal(plan.clone())?.data), None)) + } } } } diff --git a/wren-core/core/src/logical_plan/utils.rs b/wren-core/core/src/logical_plan/utils.rs index db5a78188..27b82dae3 100644 --- a/wren-core/core/src/logical_plan/utils.rs +++ b/wren-core/core/src/logical_plan/utils.rs @@ -1,16 +1,16 @@ use crate::mdl::lineage::DatasetLink; +use crate::mdl::manifest::Column; use crate::mdl::utils::quoted; -use crate::mdl::{ - manifest::{Column, Model}, - WrenMDL, -}; +use crate::mdl::{manifest::Model, WrenMDL}; use crate::mdl::{Dataset, SessionStateRef}; use datafusion::arrow::datatypes::{ DataType, Field, IntervalUnit, Schema, SchemaBuilder, SchemaRef, TimeUnit, }; use datafusion::catalog_common::TableReference; use datafusion::common::plan_err; -use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion::datasource::DefaultTableSource; use datafusion::error::Result; use datafusion::logical_expr::sqlparser::ast::ArrayElemTypeDef; @@ -21,7 +21,7 @@ use datafusion::sql::sqlparser::parser::Parser; use log::debug; use petgraph::dot::{Config, Dot}; use petgraph::Graph; -use std::collections::HashSet; +use std::collections::{BTreeMap, HashSet}; use std::{collections::HashMap, sync::Arc}; fn create_list_type(array_type: &str) -> Result { @@ -303,6 +303,42 @@ pub fn expr_to_columns( .map(|_| ()) } +/// Rebase the column reference to the new base reference +/// +/// e.g. `a.b` with base_reference `c` will be transformed to `c.b` +pub fn rebase_column(expr: &Expr, base_reference: &str) -> Result { + expr.clone() + .transform_down(|expr| { + if let Expr::Column(datafusion::common::Column { name, .. }) = expr { + let rewritten = Expr::Column(datafusion::common::Column::new( + Some(base_reference), + name, + )); + Ok(Transformed::yes(rewritten)) + } else { + Ok(Transformed::no(expr)) + } + }) + .data() +} + +/// Eliminate the ambiguous columns in the expressions. If there are columns with the same name, +/// only the first one will be kept. +pub fn eliminate_ambiguous_columns(expr: Vec) -> Vec { + let mut columns = BTreeMap::new(); + for e in expr { + match e { + Expr::Column(c) => { + columns.insert(c.name.clone(), Expr::Column(c)); + } + _ => { + columns.insert(e.clone().schema_name().to_string(), e); + } + } + } + columns.into_values().collect() +} + #[cfg(test)] mod test { use crate::logical_plan::utils::{ diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 7d5cf71bd..aa85f4070 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -514,6 +514,56 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_plan_calculation_without_unnamed_subquery() -> Result<()> { + let test_data: PathBuf = + [env!("CARGO_MANIFEST_DIR"), "tests", "data", "mdl.json"] + .iter() + .collect(); + let mdl_json = fs::read_to_string(test_data.as_path())?; + let mdl = match serde_json::from_str::(&mdl_json) { + Ok(mdl) => mdl, + Err(e) => return not_impl_err!("Failed to parse mdl json: {}", e), + }; + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let sql = "select totalcost from profile"; + let result = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + let expected = "SELECT profile.totalcost FROM (SELECT totalcost.totalcost FROM \ + (SELECT __relation__2.p_custkey AS p_custkey, sum(CAST(__relation__2.o_totalprice AS BIGINT)) AS totalcost \ + FROM (SELECT __relation__1.c_custkey, orders.o_custkey, orders.o_totalprice, __relation__1.p_custkey \ + FROM (SELECT orders.o_custkey AS o_custkey, orders.o_totalprice AS o_totalprice FROM orders) AS orders \ + RIGHT JOIN (SELECT customer.c_custkey, profile.p_custkey FROM (SELECT customer.c_custkey AS c_custkey FROM customer) AS customer \ + RIGHT JOIN (SELECT profile.p_custkey AS p_custkey FROM profile) AS profile ON customer.c_custkey = profile.p_custkey) AS __relation__1 \ + ON orders.o_custkey = __relation__1.c_custkey) AS __relation__2 GROUP BY __relation__2.p_custkey) AS totalcost) AS profile"; + assert_eq!(result, expected); + + let sql = "select totalcost from profile where p_sex = 'M'"; + let result = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + assert_eq!(result, "SELECT profile.totalcost FROM (SELECT __relation__1.p_sex, __relation__1.totalcost \ + FROM (SELECT totalcost.p_custkey, profile.p_sex, totalcost.totalcost FROM \ + (SELECT __relation__2.p_custkey AS p_custkey, sum(CAST(__relation__2.o_totalprice AS BIGINT)) AS totalcost FROM \ + (SELECT __relation__1.c_custkey, orders.o_custkey, orders.o_totalprice, __relation__1.p_custkey FROM \ + (SELECT orders.o_custkey AS o_custkey, orders.o_totalprice AS o_totalprice FROM orders) AS orders RIGHT JOIN \ + (SELECT customer.c_custkey, profile.p_custkey FROM (SELECT customer.c_custkey AS c_custkey FROM customer) AS customer \ + RIGHT JOIN (SELECT profile.p_custkey AS p_custkey FROM profile) AS profile ON customer.c_custkey = profile.p_custkey) AS __relation__1 \ + ON orders.o_custkey = __relation__1.c_custkey) AS __relation__2 GROUP BY __relation__2.p_custkey) AS totalcost RIGHT JOIN \ + (SELECT profile.p_custkey AS p_custkey, profile.p_sex AS p_sex FROM profile) AS profile \ + ON totalcost.p_custkey = profile.p_custkey) AS __relation__1) AS profile WHERE profile.p_sex = 'M'"); + Ok(()) + } + #[tokio::test] async fn test_uppercase_catalog_schema() -> Result<()> { let ctx = SessionContext::new(); @@ -558,7 +608,6 @@ mod test { .into_deserialize::() .filter_map(Result::ok) .collect::>(); - dbg!(&functions); let manifest = ManifestBuilder::new() .catalog("CTest") .schema("STest")