From e25ab2e7adc197ada9285e18f4e35544fec4e3fe Mon Sep 17 00:00:00 2001 From: "shuai.xu" Date: Fri, 21 Jun 2024 18:10:16 +0800 Subject: [PATCH] [GLUTEN-4451] [CH] fix header maybe changed by FilterTransform (#6166) What changes were proposed in this pull request? Rollback header if changed in FilterTransform (Fixes: #4451) How was this patch tested? This patch was tested by integration tests. --- ...enClickHouseTPCHSaltNullParquetSuite.scala | 50 +++++++++++++++++++ .../local-engine/Parser/FilterRelParser.cpp | 7 ++- .../Parser/SerializedPlanParser.cpp | 13 +++++ .../Parser/SerializedPlanParser.h | 1 + 4 files changed, 70 insertions(+), 1 deletion(-) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 1d3bbec848bc..5040153320fc 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2638,5 +2638,55 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr spark.sql("drop table test_tbl_5910_0") spark.sql("drop table test_tbl_5910_1") } + + test("GLUTEN-4451: Fix schema may be changed by filter") { + val create_sql = + """ + |create table if not exists test_tbl_4451( + | month_day string, + | month_dif int, + | is_month_new string, + | country string, + | os string, + | mr bigint + |) using parquet + |PARTITIONED BY ( + | day string, + | app_name string) + |""".stripMargin + val insert_sql1 = + "INSERT into test_tbl_4451 partition (day='2024-06-01', app_name='abc') " + + "values('2024-06-01', 0, '1', 'CN', 'iOS', 100)" + val insert_sql2 = + "INSERT into test_tbl_4451 partition (day='2024-06-01', app_name='abc') " + + "values('2024-06-01', 0, '1', 'CN', 'iOS', 50)" + val insert_sql3 = + "INSERT into test_tbl_4451 partition (day='2024-06-01', app_name='abc') " + + "values('2024-06-01', 1, '1', 'CN', 'iOS', 80)" + spark.sql(create_sql) + spark.sql(insert_sql1) + spark.sql(insert_sql2) + spark.sql(insert_sql3) + val select_sql = + """ + |SELECT * FROM ( + | SELECT + | month_day, + | country, + | if(os = 'ALite','Android',os) AS os, + | is_month_new, + | nvl(sum(if(month_dif = 0, mr, 0)),0) AS `month0_n`, + | nvl(sum(if(month_dif = 1, mr, 0)) / sum(if(month_dif = 0, mr, 0)),0) AS `month1_rate`, + | '2024-06-18' as day, + | app_name + | FROM test_tbl_4451 + | GROUP BY month_day,country,if(os = 'ALite','Android',os),is_month_new,app_name + |) tt + |WHERE month0_n > 0 AND month1_rate <= 1 AND os IN ('all','Android','iOS') + | AND app_name IS NOT NULL + |""".stripMargin + compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) + spark.sql("drop table test_tbl_4451") + } } // scalastyle:on line.size.limit diff --git a/cpp-ch/local-engine/Parser/FilterRelParser.cpp b/cpp-ch/local-engine/Parser/FilterRelParser.cpp index 4c71cc3126af..e0098f747c2a 100644 --- a/cpp-ch/local-engine/Parser/FilterRelParser.cpp +++ b/cpp-ch/local-engine/Parser/FilterRelParser.cpp @@ -59,7 +59,12 @@ DB::QueryPlanPtr FilterRelParser::parse(DB::QueryPlanPtr query_plan, const subst filter_step->setStepDescription("WHERE"); steps.emplace_back(filter_step.get()); query_plan->addStep(std::move(filter_step)); - + + // header maybe changed, need to rollback it + if (!blocksHaveEqualStructure(input_header, query_plan->getCurrentDataStream().header)) { + steps.emplace_back(getPlanParser()->addRollbackFilterHeaderStep(query_plan, input_header)); + } + // remove nullable auto * remove_null_step = getPlanParser()->addRemoveNullableStep(*query_plan, non_nullable_columns); if (remove_null_step) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 5f2c9cc33150..40e01e3052a3 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -335,6 +335,19 @@ IQueryPlanStep * SerializedPlanParser::addRemoveNullableStep(QueryPlan & plan, c return step_ptr; } +IQueryPlanStep * SerializedPlanParser::addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header) +{ + auto convert_actions_dag = ActionsDAG::makeConvertingActions( + query_plan->getCurrentDataStream().header.getColumnsWithTypeAndName(), + input_header.getColumnsWithTypeAndName(), + ActionsDAG::MatchColumnsMode::Name); + auto expression_step = std::make_unique(query_plan->getCurrentDataStream(), convert_actions_dag); + expression_step->setStepDescription("Generator for rollback filter"); + auto * step_ptr = expression_step.get(); + query_plan->addStep(std::move(expression_step)); + return step_ptr; +} + DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type) { return wrapNullableType(nullable == substrait::Type_Nullability_NULLABILITY_NULLABLE, nested_type); diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index ccd5c0fdc4c8..45ff5a20b5ae 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -299,6 +299,7 @@ class SerializedPlanParser static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); + IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header); static ContextMutablePtr global_context; static Context::ConfigurationPtr config;