Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Dec 27, 2024
1 parent 4068fbd commit 4f29471
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -863,13 +863,25 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS

test("Test transform_keys/transform_values") {
val sql = """
|select id, sort_array(map_entries(m1)), sort_array(map_entries(m2)) from(
|select id, first(m1) as m1, first(m2) as m2 from(
|select
| id,
| transform_keys(map_from_arrays(array(id+1, id+2, id+3),
| array(1, id+2, 3)), (k, v) -> k + 1),
| array(1, id+2, 3)), (k, v) -> k + 1) as m1,
| transform_values(map_from_arrays(array(id+1, id+2, id+3),
| array(1, id+2, 3)), (k, v) -> v + 1)
| array(1, id+2, 3)), (k, v) -> v + 1) as m2
|from range(10)
|) group by id
|) order by id
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

def checkProjects(df: DataFrame): Unit = {
val projects = collectWithSubqueries(df.queryExecution.executedPlan) {
case e: ProjectExecTransformer => e
}
assert(projects.size >= 1)
}
compareResultsAgainstVanillaSpark(sql, true, checkProjects, false)
}
}
61 changes: 48 additions & 13 deletions cpp-ch/local-engine/Parser/ExpressionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,25 +258,30 @@ std::pair<DB::DataTypePtr, DB::Field> LiteralParser::parse(const substrait::Expr
}

const static std::string REUSE_COMMON_SUBEXPRESSION_CONF = "reuse_cse_in_expression_parser";

bool ExpressionParser::reuseCSE() const
{
return context->queryContext()->getConfigRef().getBool(REUSE_COMMON_SUBEXPRESSION_CONF, true);
}

const DB::ActionsDAG::Node *
ExpressionParser::NodeRawConstPtr
ExpressionParser::addConstColumn(DB::ActionsDAG & actions_dag, const DB::DataTypePtr type, const DB::Field & field) const
{
String name = toString(field).substr(0, 10);
name = getUniqueName(name);
const auto * res_node = &actions_dag.addColumn(DB::ColumnWithTypeAndName(type->createColumnConst(1, field), type, name));
if (reuseCSE())
if (const auto * exists_node = findOneStructureEqualNode(res_node, actions_dag))
{
// The new node, res_node will be remained in the ActionsDAG, but it will not affect the execution.
// And it will be remove once `ActionsDAG::removeUnusedActions` is called.
if (const auto * exists_node = findFirstStructureEqualNode(res_node, actions_dag))
res_node = exists_node;
}
return res_node;
}


const ActionsDAG::Node * ExpressionParser::parseExpression(ActionsDAG & actions_dag, const substrait::Expression & rel) const
ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG & actions_dag, const substrait::Expression & rel) const
{
switch (rel.rex_type_case())
{
Expand Down Expand Up @@ -614,7 +619,7 @@ ExpressionParser::parseFunctionArguments(DB::ActionsDAG & actions_dag, const sub
return parsed_args;
}

const DB::ActionsDAG::Node *
ExpressionParser::NodeRawConstPtr
ExpressionParser::parseFunction(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool add_to_output) const
{
auto function_signature = getFunctionNameInSignature(func);
Expand All @@ -625,7 +630,7 @@ ExpressionParser::parseFunction(const substrait::Expression_ScalarFunction & fun
return function_node;
}

const DB::ActionsDAG::Node * ExpressionParser::toFunctionNode(
ExpressionParser::NodeRawConstPtr ExpressionParser::toFunctionNode(
DB::ActionsDAG & actions_dag,
const String & ch_function_name,
const DB::ActionsDAG::NodeRawConstPtrs & args,
Expand All @@ -641,7 +646,7 @@ const DB::ActionsDAG::Node * ExpressionParser::toFunctionNode(
const auto * res_node = &actions_dag.addFunction(function_builder, args, result_name);
if (reuseCSE())
{
if (const auto * exists_node = findOneStructureEqualNode(res_node, actions_dag))
if (const auto * exists_node = findFirstStructureEqualNode(res_node, actions_dag))
{
if (result_name_.empty() || result_name == exists_node->result_name)
res_node = exists_node;
Expand Down Expand Up @@ -865,9 +870,13 @@ ExpressionParser::parseJsonTuple(const substrait::Expression_ScalarFunction & fu
return res_nodes;
}

bool ExpressionParser::areEqualNodes(const DB::ActionsDAG::Node * a, const ActionsDAG::Node * b)
bool ExpressionParser::areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b)
{
if (a->type != b->type || !a->result_type->equals(*(b->result_type)) || a->children.size() != b->children.size())
if (a == b)
return true;

if (a->type != b->type || !a->result_type->equals(*(b->result_type)) || a->children.size() != b->children.size()
|| !a->isDeterministic() || !b->isDeterministic())
return false;

switch (a->type)
Expand Down Expand Up @@ -900,27 +909,53 @@ bool ExpressionParser::areEqualNodes(const DB::ActionsDAG::Node * a, const Actio
break;
}
default: {
LOG_DEBUG(getLogger("ExpressionParser"), "Unknow node type: {}|{}|{}", a->type, a->result_type->getName(), a->result_name);
LOG_WARNING(
getLogger("ExpressionParser"),
"Unknow node type. type:{}, data type:{}, result_name:{}",
a->type,
a->result_type->getName(),
a->result_name);
return false;
}
}

for (size_t i = 0; i < a->children.size(); ++i)
if (!areEqualNodes(a->children[i], b->children[i]))
return false;
LOG_TEST(
getLogger("ExpressionParser"),
"Nodes are equal:\ntype:{},data type:{},name:{}\ntype:{},data type:{},name:{}",
a->type,
a->result_type->getName(),
a->result_name,
b->type,
b->result_type->getName(),
b->result_name);
return true;
}

// since each new node is added at the end of ActionsDAG::nodes, we expect to find the previous node and the new node will be dropped later.
const DB::ActionsDAG::Node *
ExpressionParser::findOneStructureEqualNode(const DB::ActionsDAG::Node * node_, const DB::ActionsDAG & actions_dag) const
ExpressionParser::NodeRawConstPtr
ExpressionParser::findFirstStructureEqualNode(NodeRawConstPtr target, const DB::ActionsDAG & actions_dag) const
{
for (const auto & node : actions_dag.getNodes())
{
if (node_ == &node)
if (target == &node)
continue;
if (areEqualNodes(node_, &node))

if (areEqualNodes(target, &node))
{
LOG_TEST(
getLogger("ExpressionParser"),
"Two nodes are equal:\ntype:{},data type:{},name:{}\ntype:{},data type:{},name:{}",
target->type,
target->result_type->getName(),
target->result_name,
node.type,
node.result_type->getName(),
node.result_name);
return &node;
}
}
return nullptr;
}
Expand Down
13 changes: 7 additions & 6 deletions cpp-ch/local-engine/Parser/ExpressionParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,27 @@ class LiteralParser
class ExpressionParser
{
public:
using NodeRawConstPtr = const DB::ActionsDAG::Node *;
ExpressionParser(const std::shared_ptr<const ParserContext> & context_) : context(context_) { }
~ExpressionParser() = default;

/// Append a counter-suffix to name
String getUniqueName(const String & name) const;

const DB::ActionsDAG::Node * addConstColumn(DB::ActionsDAG & actions_dag, const DB::DataTypePtr type, const DB::Field & field) const;
NodeRawConstPtr addConstColumn(DB::ActionsDAG & actions_dag, const DB::DataTypePtr type, const DB::Field & field) const;

/// Parse expr and add an expression node in actions_dag
const DB::ActionsDAG::Node * parseExpression(DB::ActionsDAG & actions_dag, const substrait::Expression & expr) const;
NodeRawConstPtr parseExpression(DB::ActionsDAG & actions_dag, const substrait::Expression & expr) const;
/// Build an actions dag that contains expressions. header is used as input columns for the actions dag.
DB::ActionsDAG expressionsToActionsDAG(const std::vector<substrait::Expression> & expressions, const DB::Block & header) const;

// Parse func's arguments into actions dag, and return the node ptrs.
DB::ActionsDAG::NodeRawConstPtrs
parseFunctionArguments(DB::ActionsDAG & actions_dag, const substrait::Expression_ScalarFunction & func) const;
const DB::ActionsDAG::Node *
NodeRawConstPtr
parseFunction(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag, bool add_to_output = false) const;
// Add a new function node into the actions dag
const DB::ActionsDAG::Node * toFunctionNode(
NodeRawConstPtr toFunctionNode(
DB::ActionsDAG & actions_dag,
const String & ch_function_name,
const DB::ActionsDAG::NodeRawConstPtrs & args,
Expand All @@ -86,7 +87,7 @@ class ExpressionParser

DB::ActionsDAG::NodeRawConstPtrs parseJsonTuple(const substrait::Expression_ScalarFunction & func, DB::ActionsDAG & actions_dag) const;

static bool areEqualNodes(const DB::ActionsDAG::Node * a, const DB::ActionsDAG::Node * b);
const DB::ActionsDAG::Node * findOneStructureEqualNode(const DB::ActionsDAG::Node * node_, const DB::ActionsDAG & actions_dag) const;
static bool areEqualNodes(NodeRawConstPtr a, NodeRawConstPtr b);
NodeRawConstPtr findFirstStructureEqualNode(NodeRawConstPtr target, const DB::ActionsDAG & actions_dag) const;
};
}

0 comments on commit 4f29471

Please sign in to comment.