diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index 5dcba3b47686..e1287c8b6d86 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -881,13 +881,25 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS test("Test transform_keys/transform_values") { val sql = """ + |select id, sort_array(map_entries(m1)), sort_array(map_entries(m2)) from( + |select id, first(m1) as m1, first(m2) as m2 from( |select + | id, | transform_keys(map_from_arrays(array(id+1, id+2, id+3), - | array(1, id+2, 3)), (k, v) -> k + 1), + | array(1, id+2, 3)), (k, v) -> k + 1) as m1, | transform_values(map_from_arrays(array(id+1, id+2, id+3), - | array(1, id+2, 3)), (k, v) -> v + 1) + | array(1, id+2, 3)), (k, v) -> v + 1) as m2 |from range(10) + |) group by id + |) order by id |""".stripMargin - compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + def checkProjects(df: DataFrame): Unit = { + val projects = collectWithSubqueries(df.queryExecution.executedPlan) { + case e: ProjectExecTransformer => e + } + assert(projects.size >= 1) + } + compareResultsAgainstVanillaSpark(sql, true, checkProjects, false) } } diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp b/cpp-ch/local-engine/Parser/ExpressionParser.cpp index 23dede13f27c..7d41c78a3db7 100644 --- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp +++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp @@ -258,25 +258,30 @@ std::pair LiteralParser::parse(const substrait::Expr } const static std::string REUSE_COMMON_SUBEXPRESSION_CONF = "reuse_cse_in_expression_parser"; + bool ExpressionParser::reuseCSE() const { return context->queryContext()->getConfigRef().getBool(REUSE_COMMON_SUBEXPRESSION_CONF, true); } -const DB::ActionsDAG::Node * +ExpressionParser::NodeRawConstPtr ExpressionParser::addConstColumn(DB::ActionsDAG & actions_dag, const DB::DataTypePtr type, const DB::Field & field) const { String name = toString(field).substr(0, 10); name = getUniqueName(name); const auto * res_node = &actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1, field), type, name)); if (reuseCSE()) - if (const auto * exists_node = findOneStructureEqualNode(res_node, actions_dag)) + { + // The new node, res_node will be remained in the ActionsDAG, but it will not affect the execution. + // And it will be remove once `ActionsDAG::removeUnusedActions` is called. + if (const auto * exists_node = findFirstStructureEqualNode(res_node, actions_dag)) res_node = exists_node; + } return res_node; } -const ActionsDAG::Node * ExpressionParser::parseExpression(ActionsDAG & actions_dag, const substrait::Expression & rel) const +ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG & actions_dag, const substrait::Expression & rel) const { switch (rel.rex_type_case()) { @@ -614,7 +619,7 @@ ExpressionParser::parseFunctionArguments(DB::ActionsDAG & actions_dag, const sub return parsed_args; } -const DB::ActionsDAG::Node * +ExpressionParser::NodeRawConstPtr ExpressionParser::parseFunction(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool add_to_output) const { auto function_signature = getFunctionNameInSignature(func); @@ -625,7 +630,7 @@ ExpressionParser::parseFunction(const substrait::Expression_ScalarFunction & fun return function_node; } -const DB::ActionsDAG::Node * ExpressionParser::toFunctionNode( +ExpressionParser::NodeRawConstPtr ExpressionParser::toFunctionNode( DB::ActionsDAG & actions_dag, const String & ch_function_name, const DB::ActionsDAG::NodeRawConstPtrs & args, @@ -641,7 +646,8 @@ const DB::ActionsDAG::Node * ExpressionParser::toFunctionNode( const auto * res_node = &actions_dag.addFunction(function_builder, args, result_name); if (reuseCSE()) { - if (const auto * exists_node = findOneStructureEqualNode(res_node, actions_dag)) + const auto * exists_node = findFirstStructureEqualNode(res_node, actions_dag); + if (exists_node) { if (result_name_.empty() || result_name == exists_node->result_name) res_node = exists_node; @@ -865,9 +871,46 @@ ExpressionParser::parseJsonTuple(const substrait::Expression_ScalarFunction & fu return res_nodes; } -bool ExpressionParser::areEqualNodes(const DB::ActionsDAG::Node * a, const ActionsDAG::Node * b) + +static bool isAllowedDataType(const DB::IDataType & data_type) { - if (a->type != b->type || !a->result_type->equals(*(b->result_type)) || a->children.size() != b->children.size()) + DB::WhichDataType which(data_type); + if (which.isNullable()) + { + const auto * null_type = typeid_cast(&data_type); + return isAllowedDataType(*(null_type->getNestedType())); + } + else if (which.isNumber() || which.isStringOrFixedString() || which.isDateOrDate32OrDateTimeOrDateTime64()) + return true; + else if (which.isArray()) + { + auto nested_type = typeid_cast(&data_type)->getNestedType(); + return isAllowedDataType(*nested_type); + } + else if (which.isTuple()) + { + const auto * tuple_type = typeid_cast(&data_type); + for (const auto & nested_type : tuple_type->getElements()) + if (!isAllowedDataType(*nested_type)) + return false; + return true; + } + else if (which.isMap()) + { + const auto * map_type = typeid_cast(&data_type); + return isAllowedDataType(*(map_type->getKeyType())) && isAllowedDataType(*(map_type->getValueType())); + } + + return false; +} + +bool ExpressionParser::areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b) +{ + if (a == b) + return true; + + if (a->type != b->type || !a->result_type->equals(*(b->result_type)) || a->children.size() != b->children.size() + || !a->isDeterministic() || !b->isDeterministic() || !isAllowedDataType(*(a->result_type))) return false; switch (a->type) @@ -900,7 +943,12 @@ bool ExpressionParser::areEqualNodes(const DB::ActionsDAG::Node * a, const Actio break; } default: { - LOG_DEBUG(getLogger("ExpressionParser"), "Unknow node type: {}|{}|{}", a->type, a->result_type->getName(), a->result_name); + LOG_WARNING( + getLogger("ExpressionParser"), + "Unknow node type. type:{}, data type:{}, result_name:{}", + a->type, + a->result_type->getName(), + a->result_name); return false; } } @@ -908,19 +956,40 @@ bool ExpressionParser::areEqualNodes(const DB::ActionsDAG::Node * a, const Actio for (size_t i = 0; i < a->children.size(); ++i) if (!areEqualNodes(a->children[i], b->children[i])) return false; + LOG_TEST( + getLogger("ExpressionParser"), + "Nodes are equal:\ntype:{},data type:{},name:{}\ntype:{},data type:{},name:{}", + a->type, + a->result_type->getName(), + a->result_name, + b->type, + b->result_type->getName(), + b->result_name); return true; } // since each new node is added at the end of ActionsDAG::nodes, we expect to find the previous node and the new node will be dropped later. -const DB::ActionsDAG::Node * -ExpressionParser::findOneStructureEqualNode(const DB::ActionsDAG::Node * node_, const DB::ActionsDAG & actions_dag) const +ExpressionParser::NodeRawConstPtr +ExpressionParser::findFirstStructureEqualNode(NodeRawConstPtr target, const DB::ActionsDAG & actions_dag) const { for (const auto & node : actions_dag.getNodes()) { - if (node_ == &node) + if (target == &node) continue; - if (areEqualNodes(node_, &node)) + + if (areEqualNodes(target, &node)) + { + LOG_TEST( + getLogger("ExpressionParser"), + "Two nodes are equal:\ntype:{},data type:{},name:{}\ntype:{},data type:{},name:{}", + target->type, + target->result_type->getName(), + target->result_name, + node.type, + node.result_type->getName(), + node.result_name); return &node; + } } return nullptr; } diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.h b/cpp-ch/local-engine/Parser/ExpressionParser.h index 1d0188b47d5d..1e4a48282ac9 100644 --- a/cpp-ch/local-engine/Parser/ExpressionParser.h +++ b/cpp-ch/local-engine/Parser/ExpressionParser.h @@ -40,26 +40,27 @@ class LiteralParser class ExpressionParser { public: + using NodeRawConstPtr = const DB::ActionsDAG::Node *; ExpressionParser(const std::shared_ptr & context_) : context(context_) { } ~ExpressionParser() = default; /// Append a counter-suffix to name String getUniqueName(const String & name) const; - const DB::ActionsDAG::Node * addConstColumn(DB::ActionsDAG & actions_dag, const DB::DataTypePtr type, const DB::Field & field) const; + NodeRawConstPtr addConstColumn(DB::ActionsDAG & actions_dag, const DB::DataTypePtr type, const DB::Field & field) const; /// Parse expr and add an expression node in actions_dag - const DB::ActionsDAG::Node * parseExpression(DB::ActionsDAG & actions_dag, const substrait::Expression & expr) const; + NodeRawConstPtr parseExpression(DB::ActionsDAG & actions_dag, const substrait::Expression & expr) const; /// Build an actions dag that contains expressions. header is used as input columns for the actions dag. DB::ActionsDAG expressionsToActionsDAG(const std::vector & expressions, const DB::Block & header) const; // Parse func's arguments into actions dag, and return the node ptrs. DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(DB::ActionsDAG & actions_dag, const substrait::Expression_ScalarFunction & func) const; - const DB::ActionsDAG::Node * + NodeRawConstPtr parseFunction(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool add_to_output = false) const; // Add a new function node into the actions dag - const DB::ActionsDAG::Node * toFunctionNode( + NodeRawConstPtr toFunctionNode( DB::ActionsDAG & actions_dag, const String & ch_function_name, const DB::ActionsDAG::NodeRawConstPtrs & args, @@ -86,7 +87,7 @@ class ExpressionParser DB::ActionsDAG::NodeRawConstPtrs parseJsonTuple(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag) const; - static bool areEqualNodes(const DB::ActionsDAG::Node * a, const DB::ActionsDAG::Node * b); - const DB::ActionsDAG::Node * findOneStructureEqualNode(const DB::ActionsDAG::Node * node_, const DB::ActionsDAG & actions_dag) const; + static bool areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b); + NodeRawConstPtr findFirstStructureEqualNode(NodeRawConstPtr target, const DB::ActionsDAG & actions_dag) const; }; }