Skip to content

Commit

Permalink
high order array functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Jun 27, 2024
1 parent 32808dd commit d8ea7ee
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,20 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
CHGenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child)
}

/** Transform array filter to Substrait. */
override def genArrayFilterTransformer(
substraitExprName: String,
argument: ExpressionTransformer,
function: ExpressionTransformer,
expr: ArrayFilter): ExpressionTransformer = {
expr.function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
throw new GlutenNotSupportException(
"filter on array with lambda using index argument is not supported yet")
case _ => GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr)
}
}

override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate
Expand Down
10 changes: 10 additions & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,16 @@ NamesAndTypesList SerializedPlanParser::blockToNameAndTypeList(const Block & hea
return types;
}

std::optional<String> SerializedPlanParser::getFunctionSignatureName(UInt32 function_ref)
{
auto it = function_mapping.find(std::to_string(function_ref));
if (it == function_mapping.end())
return {};
auto function_signature = it->second;
auto pos = function_signature.find(':');
return function_signature.substr(0, pos);
}

std::string
SerializedPlanParser::getFunctionName(const std::string & function_signature, const substrait::Expression_ScalarFunction & function)
{
Expand Down
1 change: 1 addition & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class SerializedPlanParser
RelMetricPtr getMetric() { return metrics.empty() ? nullptr : metrics.at(0); }

static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function);
std::optional<std::string> getFunctionSignatureName(UInt32 function_ref);

IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set<String> & columns);
IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <Parser/FunctionParser.h>
#include <Common/Exception.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include <Common/CHUtil.h>
#include <DataTypes/DataTypeFunction.h>

namespace DB::ErrorCodes
{
extern const int LOGICAL_ERROR;
}

namespace local_engine
{
class ArrayFilter : public FunctionParser
{
public:
static constexpr auto name = "filter";
explicit ArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayFilter() override = default;

String getName() const override { return name; }

String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
return "arrayFilter";
}

const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func,
DB::ActionsDAGPtr & actions_dag) const
{
auto ch_func_name = getCHFunctionName(substrait_func);
auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag);
assert(parsed_args.size() == 2);
const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]});
return func_node;
}
};
static FunctionParserRegister<ArrayFilter> register_array_filter;

class ArrayTransform : public FunctionParser
{
public:
static constexpr auto name = "transform";
explicit ArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayTransform() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
return "arrayMap";
}
const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func,
DB::ActionsDAGPtr & actions_dag) const
{
auto ch_func_name = getCHFunctionName(substrait_func);
auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag);
assert(parsed_args.size() == 2);
const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]});
return func_node;
}
};
static FunctionParserRegister<ArrayTransform> register_array_map;

class ArrayAggregate : public FunctionParser
{
public:
static constexpr auto name = "aggregate";
explicit ArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayAggregate() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
return "arrayFold";
}
const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func,
DB::ActionsDAGPtr & actions_dag) const
{
auto ch_func_name = getCHFunctionName(substrait_func);
auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag);
assert(parsed_args.size() == 3);
const auto * function_type = typeid_cast<const DataTypeFunction *>(parsed_args[2]->result_type.get());
if (!function_type)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third argument of aggregate function must be a lambda function");
if (!parsed_args[1]->result_type->equals(*(function_type->getReturnType())))
{
parsed_args[1] = ActionsDAGUtil::convertNodeType(
actions_dag,
parsed_args[1],
function_type->getReturnType()->getName(),
parsed_args[1]->result_name);
}
const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[2], parsed_args[0], parsed_args[1]});
return func_node;
}
};
static FunctionParserRegister<ArrayAggregate> register_array_aggregate;

}
184 changes: 184 additions & 0 deletions cpp-ch/local-engine/Parser/scalar_function_parser/lambdafunction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <Parser/FunctionParser.h>
#include <Parser/TypeParser.h>
#include <Common/Exception.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include <Interpreters/ExpressionActionsSettings.h>
#include <Interpreters/ExpressionActions.h>
#include <Functions/FunctionsMiscellaneous.h>
#include <Common/CHUtil.h>

namespace DB::ErrorCodes
{
extern const int LOGICAL_ERROR;
}

namespace local_engine
{
/// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a lambda function node.
class LambdaFunction : public FunctionParser
{
public:
static constexpr auto name = "lambdafunction";
explicit LambdaFunction(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~LambdaFunction() override = default;

String getName() const override { return name; }
protected:
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for LambdaFunction");
}

DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
const substrait::Expression_ScalarFunction & substrait_func,
const String & ch_func_name,
DB::ActionsDAGPtr & actions_dag) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction");
}

const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const substrait::Expression_ScalarFunction & substrait_func,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAGPtr & actions_dag) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable");
}

const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override
{
auto lambda_actions_dag = std::make_shared<DB::ActionsDAG>();

const auto & substrait_lambda_body = substrait_func.arguments()[0].value();
const auto * lambda_body_node = parseExpression(lambda_actions_dag, substrait_lambda_body);
lambda_actions_dag->getOutputs().push_back(lambda_body_node);
lambda_actions_dag->removeUnusedActions(Names(1, lambda_body_node->result_name));

auto expression_actions_settings = DB::ExpressionActionsSettings::fromContext(getContext(), DB::CompileExpressions::yes);
auto lambda_actions = std::make_shared<DB::ExpressionActions>(lambda_actions_dag, expression_actions_settings);

DB::Names captured_column_names;
DB::Names required_column_names = lambda_actions->getRequiredColumns();
DB::ActionsDAG::NodeRawConstPtrs lambda_children;
auto lambda_function_args = collectLambdaArguments(substrait_func);
const auto & lambda_actions_inputs = lambda_actions_dag->getInputs();
for (const auto & required_column_name : required_column_names)
{
if (std::find_if(
lambda_function_args.begin(),
lambda_function_args.end(),
[&required_column_name](const DB::NameAndTypePair & name_type) { return name_type.name == required_column_name; })
== lambda_function_args.end())
{
auto it = std::find_if(
lambda_actions_inputs.begin(),
lambda_actions_inputs.end(),
[&required_column_name](const auto & node) { return node->result_name == required_column_name; });
if (it == lambda_actions_inputs.end())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Required column not found: {}", required_column_name);
}
lambda_children.push_back(*it);
captured_column_names.push_back(required_column_name);
}
}

LOG_DEBUG(getLogger("LambdaFunction"), "lambda actions dag:\n{}", lambda_actions_dag->dumpDAG());
auto function_capture = std::make_shared<DB::FunctionCaptureOverloadResolver>(
lambda_actions,
captured_column_names,
lambda_function_args,
//TypeParser::parseType(substrait_func.output_type()),
lambda_body_node->result_type,
lambda_body_node->result_name);

const auto * result = &actions_dag->addFunction(function_capture, lambda_children, lambda_body_node->result_name);
LOG_DEBUG(getLogger("LambdaFunction"), "actions dag:\n{}", actions_dag->dumpDAG());
return result;
}

DB::NamesAndTypesList collectLambdaArguments(const substrait::Expression_ScalarFunction & substrait_func) const
{
DB::NamesAndTypesList lambda_arguments;

for (const auto & arg : substrait_func.arguments())
{
if (arg.value().has_scalar_function()
&& plan_parser->getFunctionSignatureName(arg.value().scalar_function().function_reference()) == "namedlambdavariable")
{
auto [_, col_name_field] = parseLiteral(arg.value().scalar_function().arguments()[0].value().literal());
String col_name = col_name_field.get<String>();
auto type = TypeParser::parseType(arg.value().scalar_function().output_type());
lambda_arguments.emplace_back(col_name, type);
}
}
return lambda_arguments;
}
};

static FunctionParserRegister<LambdaFunction> register_lambda_function;


class NamedLambdaVariable : public FunctionParser
{
public:
static constexpr auto name = "namedlambdavariable";
explicit NamedLambdaVariable(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~NamedLambdaVariable() override = default;

String getName() const override { return name; }
protected:
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "getCHFunctionName is not implemented for NamedLambdaVariable");
}

DB::ActionsDAG::NodeRawConstPtrs parseFunctionArguments(
const substrait::Expression_ScalarFunction & substrait_func,
const String & ch_func_name,
DB::ActionsDAGPtr & actions_dag) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for NamedLambdaVariable");
}

const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const substrait::Expression_ScalarFunction & substrait_func,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAGPtr & actions_dag) const override
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "convertNodeTypeIfNeeded is not implemented for NamedLambdaVariable");
}

const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAGPtr & actions_dag) const override
{
auto [_, col_name_field] = parseLiteral(substrait_func.arguments()[0].value().literal());
String col_name = col_name_field.get<String>();

auto type = TypeParser::parseType(substrait_func.output_type());

const auto * col_node = &(actions_dag->addInput(col_name, type));
return col_node;
}
};

static FunctionParserRegister<NamedLambdaVariable> register_named_lambda_variable;

}

0 comments on commit d8ea7ee

Please sign in to comment.