Skip to content

Commit

Permalink
feat(core): avoid to generate unnamed subquery when planning calculat…
Browse files Browse the repository at this point in the history
…ed fields (#969)
  • Loading branch information
goldmedal authored Dec 5, 2024
1 parent 45a1ab6 commit 6c8557a
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 46 deletions.
3 changes: 3 additions & 0 deletions ibis-server/tests/routers/v3/connector/postgres/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def postgres(request) -> PostgresContainer:
pd.read_parquet(file_path("resource/tpch/data/orders.parquet")).to_sql(
"orders", engine, index=False
)
pd.read_parquet(file_path("resource/tpch/data/customer.parquet")).to_sql(
"customer", engine, index=False
)
request.addfinalizer(pg.stop)
return pg

Expand Down
98 changes: 94 additions & 4 deletions ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
},
{
"name": "o_totalprice",
"type": "text",
"isHidden": True,
},
{
"name": "o_totalprice_double",
"type": "double",
"expression": "cast(o_totalprice as double)",
},
{"name": "o_orderdate", "type": "date"},
{
Expand Down Expand Up @@ -57,6 +63,33 @@
],
"primaryKey": "o_orderkey",
},
{
"name": "customer",
"tableReference": {
"schema": "public",
"table": "customer",
},
"columns": [
{"name": "c_custkey", "type": "integer"},
{"name": "c_name", "type": "varchar"},
{"name": "orders", "type": "orders", "relationship": "orders_customer"},
{
"name": "sum_totalprice",
"type": "double",
"isCalculated": True,
"expression": "sum(orders.o_totalprice_double)",
},
],
"primaryKey": "c_custkey",
},
],
"relationships": [
{
"name": "orders_customer",
"models": ["orders", "customer"],
"joinType": "many_to_one",
"condition": "orders.o_custkey = customer.c_custkey",
}
],
}

Expand All @@ -79,25 +112,25 @@ def test_query(manifest_str, connection_info):
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
assert len(result["columns"]) == 10
assert len(result["data"]) == 1
assert result["data"][0] == [
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000 UTC",
"2024-01-16 04:00:00.000000 UTC", # utc-5
"2024-07-16 03:00:00.000000 UTC", # utc-4
172799.49,
"1_370",
370,
"1996-01-02",
1,
"O",
"172799.49",
]
assert result["dtypes"] == {
"o_orderkey": "int32",
"o_custkey": "int32",
"o_orderstatus": "object",
"o_totalprice": "object",
"o_totalprice_double": "float64",
"o_orderdate": "object",
"order_cust_key": "object",
"timestamp": "object",
Expand All @@ -117,7 +150,7 @@ def test_query_with_connection_url(manifest_str, connection_url):
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
assert len(result["columns"]) == 10
assert len(result["data"]) == 1
assert result["data"][0][0] == "2024-01-01 23:59:59.000000"
assert result["dtypes"] is not None
Expand Down Expand Up @@ -227,3 +260,60 @@ def test_query_with_dry_run_and_invalid_sql(manifest_str, connection_info):
)
assert response.status_code == 422
assert response.text is not None

def test_query_to_many_calculation(manifest_str, connection_info):
response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT sum_totalprice FROM wren.public.customer limit 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == 1
assert len(result["data"]) == 1
assert result["dtypes"] == {"sum_totalprice": "float64"}

response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT sum_totalprice FROM wren.public.customer where c_name = 'Customer#000000001' limit 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == 1
assert len(result["data"]) == 1
assert result["dtypes"] == {"sum_totalprice": "float64"}

response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT c_name, sum_totalprice FROM wren.public.customer limit 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == 2
assert len(result["data"]) == 1
assert result["dtypes"] == {"c_name": "object", "sum_totalprice": "float64"}

response = client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT c_custkey, sum_totalprice FROM wren.public.customer limit 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == 2
assert len(result["data"]) == 1
assert result["dtypes"] == {"c_custkey": "int32", "sum_totalprice": "float64"}
2 changes: 1 addition & 1 deletion ibis-server/tools/mdl_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
for model in mdl["models"]:
for column in model["columns"]:
# ignore hidden columns
if column.get("isHidden"):
if column.get("isHidden") or column.get("relationship") is not None:
continue
sql = f"select \"{column['name']}\" from \"{model['name']}\""
try:
Expand Down
56 changes: 43 additions & 13 deletions wren-core/core/src/logical_plan/analyze/model_generation.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
use std::fmt::Debug;
use std::sync::Arc;

use datafusion::common::alias::AliasGenerator;
use datafusion::common::config::ConfigOptions;
use datafusion::common::tree_node::{Transformed, TransformedResult};
use datafusion::common::{plan_err, Result};
use datafusion::logical_expr::{col, ident, Extension, UserDefinedLogicalNodeCore};
use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
use datafusion::optimizer::analyzer::AnalyzerRule;
use datafusion::physical_plan::internal_err;
use datafusion::sql::TableReference;

use crate::logical_plan::analyze::plan::{
CalculationPlanNode, ModelPlanNode, ModelSourceNode, PartialModelPlanNode,
};
use crate::logical_plan::utils::create_remote_table_source;
use crate::logical_plan::utils::{
create_remote_table_source, eliminate_ambiguous_columns, rebase_column,
};
use crate::mdl::manifest::Model;
use crate::mdl::utils::quoted;
use crate::mdl::{AnalyzedWrenMDL, SessionStateRef};
Expand All @@ -35,24 +39,37 @@ impl ModelGenerationRule {
&self,
plan: LogicalPlan,
) -> Result<Transformed<LogicalPlan>> {
let alias_generator = AliasGenerator::default();
match plan {
LogicalPlan::Extension(extension) => {
if let Some(model_plan) =
extension.node.as_any().downcast_ref::<ModelPlanNode>()
{
let source_plan = model_plan.relation_chain.clone().plan(
let (source_plan, alias) = model_plan.relation_chain.clone().plan(
ModelGenerationRule::new(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
),
&alias_generator,
)?;

let projections = if let Some(alias) = alias {
model_plan
.required_exprs
.iter()
.map(|expr| rebase_column(expr, &alias).unwrap())
.collect()
} else {
model_plan.required_exprs.clone()
};
let projections = eliminate_ambiguous_columns(projections);
let result = match source_plan {
Some(plan) => {
if model_plan.required_exprs.is_empty() {
plan
} else {
LogicalPlanBuilder::from(plan)
.project(model_plan.required_exprs.clone())?
.project(projections)?
.build()?
}
}
Expand Down Expand Up @@ -119,23 +136,35 @@ impl ModelGenerationRule {
.as_any()
.downcast_ref::<CalculationPlanNode>(
) {
let source_plan = calculation_plan.relation_chain.clone().plan(
ModelGenerationRule::new(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
),
)?;
let (source_plan, plan_alias) =
calculation_plan.relation_chain.clone().plan(
ModelGenerationRule::new(
Arc::clone(&self.analyzed_wren_mdl),
Arc::clone(&self.session_state),
),
&alias_generator,
)?;

let plan_alias = if let Some(alias) = plan_alias {
alias
} else {
return internal_err!("calculation plan should have an alias");
};

if let Expr::Alias(alias) = calculation_plan.measures[0].clone() {
let measure: Expr = *alias.expr.clone();
let rebased_measure = rebase_column(&measure, &plan_alias)?;
let name = alias.name.clone();
let ident = ident(measure.to_string()).alias(name);
let project = vec![calculation_plan.dimensions[0].clone(), ident];
let ident =
ident(rebased_measure.to_string()).alias(name.clone());
let rebased_dimension =
rebase_column(&calculation_plan.dimensions[0], &plan_alias)?;
let project = vec![rebased_dimension.clone(), ident];
let result = match source_plan {
Some(plan) => LogicalPlanBuilder::from(plan)
.aggregate(
calculation_plan.dimensions.clone(),
vec![measure],
vec![rebased_dimension],
vec![rebased_measure],
)?
.project(project)?
.build()?,
Expand Down Expand Up @@ -169,6 +198,7 @@ impl ModelGenerationRule {
.iter()
.map(|f| col(datafusion::common::Column::from((None, f))))
.collect();
let projection = eliminate_ambiguous_columns(projection);
let alias = LogicalPlanBuilder::from(source_plan)
.project(projection)?
.alias(quoted(&partial_model.model_node.plan_name))?
Expand Down
25 changes: 25 additions & 0 deletions wren-core/core/src/logical_plan/analyze/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,31 @@ impl ModelPlanNodeBuilder {
let Some(source) = self.directed_graph.node_weight(start) else {
return internal_err!("Dataset not found");
};

// insert the primary key to the required fields for join with the calculation
let keys = self
.model_required_fields
.keys()
.cloned()
.collect::<Vec<_>>();
for model in keys {
let Some(pk_column) = self
.analyzed_wren_mdl
.wren_mdl()
.get_model(model.table())
.and_then(|m| m.primary_key().and_then(|pk| m.get_column(pk)))
else {
debug!("Primary key not found for model {}", model);
continue;
};
self.model_required_fields
.entry(model.clone())
.or_default()
.insert(OrdExpr::new(Expr::Column(Column::from_qualified_name(
format!("{}.{}", quoted(model.table()), quoted(pk_column.name()),),
))));
}

let mut source_required_fields: Vec<Expr> = self
.model_required_fields
.get(&model_ref)
Expand Down
Loading

0 comments on commit 6c8557a

Please sign in to comment.