From bda2f7a2e79e69cff44bea3f456704a0f118140c Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 19 Dec 2024 16:46:26 +0800 Subject: [PATCH] eliminate cse during convert substrait expression to actions dag --- .../local-engine/Parser/ExpressionParser.cpp | 81 ++++++++++++++++++- cpp-ch/local-engine/Parser/ExpressionParser.h | 5 ++ 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.cpp b/cpp-ch/local-engine/Parser/ExpressionParser.cpp index e7de46483eda0..e436c1a9dd0a9 100644 --- a/cpp-ch/local-engine/Parser/ExpressionParser.cpp +++ b/cpp-ch/local-engine/Parser/ExpressionParser.cpp @@ -254,12 +254,23 @@ std::pair LiteralParser::parse(const substrait::Expr return std::make_pair(std::move(type), std::move(field)); } +const static std::string REUSE_COMMON_SUBEXPRESSION_CONF = "reuse_cse_in_expression_parser"; +bool ExpressionParser::reuseCSE() const +{ + return context->queryContext()->getConfigRef().getBool(REUSE_COMMON_SUBEXPRESSION_CONF, true); +} + const DB::ActionsDAG::Node * ExpressionParser::addConstColumn(DB::ActionsDAG & actions_dag, const DB::DataTypePtr type, const DB::Field & field) const { String name = toString(field).substr(0, 10); name = getUniqueName(name); - return &actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1, field), type, name)); + LOG_ERROR(getLogger("ExpressionParser"), "xxx add new const col: {}", name); + const auto * res_node = &actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1, field), type, name)); + if (reuseCSE()) + if (const auto * exists_node = findOneStructureEqualNode(res_node, actions_dag)) + res_node = exists_node; + return res_node; } @@ -613,7 +624,18 @@ const DB::ActionsDAG::Node * ExpressionParser::toFunctionNode( std::string args_name = join(args, ','); result_name = ch_function_name + "(" + args_name + ")"; } - return &actions_dag.addFunction(function_builder, args, result_name); + const auto * res_node = &actions_dag.addFunction(function_builder, args, result_name); + if (reuseCSE()) + { + if (const auto * exists_node = findOneStructureEqualNode(res_node, actions_dag)) + { + if (result_name_.empty() || result_name == exists_node->result_name) + res_node = exists_node; + else + res_node = &actions_dag.addAlias(*exists_node, result_name); + } + } + return res_node; } std::atomic ExpressionParser::unique_name_counter = 0; @@ -828,4 +850,59 @@ ExpressionParser::parseJsonTuple(const substrait::Expression_ScalarFunction & fu } return res_nodes; } + +bool ExpressionParser::areEqualNodes(const DB::ActionsDAG::Node * a, const ActionsDAG::Node * b) +{ + if (a->type != b->type || !a->result_type->equals(*(b->result_type)) || a->children.size() != b->children.size()) + return false; + + switch (a->type) + { + case DB::ActionsDAG::ActionType::INPUT: { + if (a->result_name != b->result_name) + return false; + break; + } + case DB::ActionsDAG::ActionType::ALIAS: { + if (a->result_name != b->result_name) + return false; + break; + } + case DB::ActionsDAG::ActionType::COLUMN: { + if (a->column->compareAt(0, 0, *(b->column), 1) != 0) + return false; + break; + } + case DB::ActionsDAG::ActionType::ARRAY_JOIN: { + break; + } + case DB::ActionsDAG::ActionType::FUNCTION: { + if (a->function_base->getName() != b->function_base->getName()) + return false; + break; + } + default: + break; + } + + for (size_t i = 0; i < a->children.size(); ++i) + if (!areEqualNodes(a->children[i], b->children[i])) + return false; + + return true; +} + +// since each new node is added at the end of ActionsDAG::nodes, we expect to find the previous node and the new node will be dropped later. +const DB::ActionsDAG::Node * +ExpressionParser::findOneStructureEqualNode(const DB::ActionsDAG::Node * node_, const DB::ActionsDAG & actions_dag) const +{ + for (const auto & node : actions_dag.getNodes()) + { + if (node_ == &node) + continue; + if (areEqualNodes(node_, &node)) + return &node; + } + return nullptr; +} } diff --git a/cpp-ch/local-engine/Parser/ExpressionParser.h b/cpp-ch/local-engine/Parser/ExpressionParser.h index 06a80d756e3f5..1d0188b47d5d9 100644 --- a/cpp-ch/local-engine/Parser/ExpressionParser.h +++ b/cpp-ch/local-engine/Parser/ExpressionParser.h @@ -77,11 +77,16 @@ class ExpressionParser static std::atomic unique_name_counter; std::shared_ptr context; + bool reuseCSE() const; + DB::ActionsDAG::NodeRawConstPtrs parseArrayJoin(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool position) const; DB::ActionsDAG::NodeRawConstPtrs parseArrayJoinArguments( const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool position, bool & is_map) const; DB::ActionsDAG::NodeRawConstPtrs parseJsonTuple(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag) const; + + static bool areEqualNodes(const DB::ActionsDAG::Node * a, const DB::ActionsDAG::Node * b); + const DB::ActionsDAG::Node * findOneStructureEqualNode(const DB::ActionsDAG::Node * node_, const DB::ActionsDAG & actions_dag) const; }; }