Skip to content

Commit

Permalink
eliminate cse during convert substrait expression to actions dag
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Dec 19, 2024
1 parent aff865b commit bda2f7a
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
81 changes: 79 additions & 2 deletions cpp-ch/local-engine/Parser/ExpressionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,23 @@ std::pair<DB::DataTypePtr, DB::Field> LiteralParser::parse(const substrait::Expr
return std::make_pair(std::move(type), std::move(field));
}

const static std::string REUSE_COMMON_SUBEXPRESSION_CONF = "reuse_cse_in_expression_parser";
bool ExpressionParser::reuseCSE() const
{
return context->queryContext()->getConfigRef().getBool(REUSE_COMMON_SUBEXPRESSION_CONF, true);
}

const DB::ActionsDAG::Node *
ExpressionParser::addConstColumn(DB::ActionsDAG & actions_dag, const DB::DataTypePtr type, const DB::Field & field) const
{
String name = toString(field).substr(0, 10);
name = getUniqueName(name);
return &actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1, field), type, name));
LOG_ERROR(getLogger("ExpressionParser"), "xxx add new const col: {}", name);
const auto * res_node = &actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1, field), type, name));
if (reuseCSE())
if (const auto * exists_node = findOneStructureEqualNode(res_node, actions_dag))
res_node = exists_node;
return res_node;
}


Expand Down Expand Up @@ -613,7 +624,18 @@ const DB::ActionsDAG::Node * ExpressionParser::toFunctionNode(
std::string args_name = join(args, ',');
result_name = ch_function_name + "(" + args_name + ")";
}
return &actions_dag.addFunction(function_builder, args, result_name);
const auto * res_node = &actions_dag.addFunction(function_builder, args, result_name);
if (reuseCSE())
{
if (const auto * exists_node = findOneStructureEqualNode(res_node, actions_dag))
{
if (result_name_.empty() || result_name == exists_node->result_name)
res_node = exists_node;
else
res_node = &actions_dag.addAlias(*exists_node, result_name);
}
}
return res_node;
}

std::atomic<UInt64> ExpressionParser::unique_name_counter = 0;
Expand Down Expand Up @@ -828,4 +850,59 @@ ExpressionParser::parseJsonTuple(const substrait::Expression_ScalarFunction & fu
}
return res_nodes;
}

bool ExpressionParser::areEqualNodes(const DB::ActionsDAG::Node * a, const ActionsDAG::Node * b)
{
if (a->type != b->type || !a->result_type->equals(*(b->result_type)) || a->children.size() != b->children.size())
return false;

switch (a->type)
{
case DB::ActionsDAG::ActionType::INPUT: {
if (a->result_name != b->result_name)
return false;
break;
}
case DB::ActionsDAG::ActionType::ALIAS: {
if (a->result_name != b->result_name)
return false;
break;
}
case DB::ActionsDAG::ActionType::COLUMN: {
if (a->column->compareAt(0, 0, *(b->column), 1) != 0)
return false;
break;
}
case DB::ActionsDAG::ActionType::ARRAY_JOIN: {
break;
}
case DB::ActionsDAG::ActionType::FUNCTION: {
if (a->function_base->getName() != b->function_base->getName())
return false;
break;
}
default:
break;
}

for (size_t i = 0; i < a->children.size(); ++i)
if (!areEqualNodes(a->children[i], b->children[i]))
return false;

return true;
}

// since each new node is added at the end of ActionsDAG::nodes, we expect to find the previous node and the new node will be dropped later.
const DB::ActionsDAG::Node *
ExpressionParser::findOneStructureEqualNode(const DB::ActionsDAG::Node * node_, const DB::ActionsDAG & actions_dag) const
{
for (const auto & node : actions_dag.getNodes())
{
if (node_ == &node)
continue;
if (areEqualNodes(node_, &node))
return &node;
}
return nullptr;
}
}
5 changes: 5 additions & 0 deletions cpp-ch/local-engine/Parser/ExpressionParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,16 @@ class ExpressionParser
static std::atomic<UInt64> unique_name_counter;
std::shared_ptr<const ParserContext> context;

bool reuseCSE() const;

DB::ActionsDAG::NodeRawConstPtrs
parseArrayJoin(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool position) const;
DB::ActionsDAG::NodeRawConstPtrs parseArrayJoinArguments(
const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool position, bool & is_map) const;

DB::ActionsDAG::NodeRawConstPtrs parseJsonTuple(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag) const;

static bool areEqualNodes(const DB::ActionsDAG::Node * a, const DB::ActionsDAG::Node * b);
const DB::ActionsDAG::Node * findOneStructureEqualNode(const DB::ActionsDAG::Node * node_, const DB::ActionsDAG & actions_dag) const;
};
}

0 comments on commit bda2f7a

Please sign in to comment.