Skip to content

Commit

Permalink
high order array functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Jun 28, 2024
1 parent 32808dd commit 719cd31
Show file tree
Hide file tree
Showing 7 changed files with 436 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,24 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
CHGenerateExecTransformer(generator, requiredChildOutput, outer, generatorOutput, child)
}

/** Transform array filter to Substrait. */
override def genArrayFilterTransformer(
substraitExprName: String,
argument: ExpressionTransformer,
function: ExpressionTransformer,
expr: ArrayFilter): ExpressionTransformer = {
GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr)
}

/** Transform array transform to Substrait. */
override def genArrayTransformTransformer(
substraitExprName: String,
argument: ExpressionTransformer,
function: ExpressionTransformer,
expr: ArrayTransform): ExpressionTransformer = {
GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr)
}

override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,4 +713,21 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
}

}

test("array functions with lambda") {
withTable("tb_array") {
sql("create table tb_array(ids array<int>) using parquet")
sql("""
|insert into tb_array values (array(1,5,2,null, 3)), (array(1,1,3,2)), (null), (array())
|""".stripMargin)
val transform_sql = "select transform(ids, x -> x + 1) from tb_array"
runQueryAndCompare(transform_sql)(checkGlutenOperatorMatch[ProjectExecTransformer])

val filter_sql = "select filter(ids, x -> x % 2 == 1) from tb_array";
runQueryAndCompare(filter_sql)(checkGlutenOperatorMatch[ProjectExecTransformer])

val aggregate_sql = "select ids, aggregate(ids, 3, (acc, x) -> acc + x) from tb_array";
runQueryAndCompare(aggregate_sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
}
}
}
10 changes: 10 additions & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,16 @@ NamesAndTypesList SerializedPlanParser::blockToNameAndTypeList(const Block & hea
return types;
}

std::optional<String> SerializedPlanParser::getFunctionSignatureName(UInt32 function_ref) const
{
auto it = function_mapping.find(std::to_string(function_ref));
if (it == function_mapping.end())
return {};
auto function_signature = it->second;
auto pos = function_signature.find(':');
return function_signature.substr(0, pos);
}

std::string
SerializedPlanParser::getFunctionName(const std::string & function_signature, const substrait::Expression_ScalarFunction & function)
{
Expand Down
4 changes: 3 additions & 1 deletion cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,12 @@ class SerializedPlanParser
RelMetricPtr getMetric() { return metrics.empty() ? nullptr : metrics.at(0); }

static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function);
std::optional<std::string> getFunctionSignatureName(UInt32 function_ref) const;

IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set<String> & columns);
IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header);

static std::pair<DataTypePtr, Field> parseLiteral(const substrait::Expression_Literal & literal);

static ContextMutablePtr global_context;
static Context::ConfigurationPtr config;
Expand Down Expand Up @@ -384,7 +387,6 @@ class SerializedPlanParser
// remove nullable after isNotNull
void removeNullableForRequiredColumns(const std::set<String> & require_columns, const ActionsDAGPtr & actions_dag) const;
std::string getUniqueName(const std::string & name) { return name + "_" + std::to_string(name_no++); }
static std::pair<DataTypePtr, Field> parseLiteral(const substrait::Expression_Literal & literal);
void wrapNullable(
const std::vector<String> & columns, ActionsDAGPtr actions_dag, std::map<std::string, std::string> & nullable_measure_names);
static std::pair<DB::DataTypePtr, DB::Field> convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <Parser/FunctionParser.h>
#include <Common/Exception.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include <Common/CHUtil.h>
#include <DataTypes/DataTypeFunction.h>
#include <DataTypes/DataTypeNullable.h>
#include <Core/Types.h>
#include <Parser/TypeParser.h>
#include <Parser/scalar_function_parser/lambdaFunction.h>

namespace DB::ErrorCodes
{
extern const int LOGICAL_ERROR;
}

namespace local_engine
{
class ArrayFilter : public FunctionParser
{
public:
static constexpr auto name = "filter";
explicit ArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayFilter() override = default;

String getName() const override { return name; }

String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
return "arrayFilter";
}

const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func,
DB::ActionsDAGPtr & actions_dag) const
{
auto ch_func_name = getCHFunctionName(substrait_func);
auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag);
assert(parsed_args.size() == 2);
if (collectLambdaArguments(*plan_parser, substrait_func.arguments()[1].value().scalar_function()).size() == 1)
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]});

/// filter with index argument.
const auto * range_end_node = toFunctionNode(actions_dag, "length", {toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})});
range_end_node = ActionsDAGUtil::convertNodeType(
actions_dag, range_end_node, "Nullable(Int32)", range_end_node->result_name);
const auto * index_array_node = toFunctionNode(
actions_dag,
"range",
{addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeInt32>(), 0), range_end_node});
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node});
}
};
static FunctionParserRegister<ArrayFilter> register_array_filter;

class ArrayTransform : public FunctionParser
{
public:
static constexpr auto name = "transform";
explicit ArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayTransform() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
return "arrayMap";
}

const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func,
DB::ActionsDAGPtr & actions_dag) const
{
auto ch_func_name = getCHFunctionName(substrait_func);
auto lambda_args = collectLambdaArguments(*plan_parser, substrait_func.arguments()[1].value().scalar_function());
auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag);
assert(parsed_args.size() == 2);
if (lambda_args.size() == 1)
{
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]});
}

/// transform with index argument.
const auto * range_end_node = toFunctionNode(actions_dag, "length", {toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]})});
range_end_node = ActionsDAGUtil::convertNodeType(
actions_dag, range_end_node, "Nullable(Int32)", range_end_node->result_name);
const auto * index_array_node = toFunctionNode(
actions_dag,
"range",
{addColumnToActionsDAG(actions_dag, std::make_shared<DataTypeInt32>(), 0), range_end_node});
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node});
}
};
static FunctionParserRegister<ArrayTransform> register_array_map;

class ArrayAggregate : public FunctionParser
{
public:
static constexpr auto name = "aggregate";
explicit ArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayAggregate() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
return "arrayFold";
}
const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func,
DB::ActionsDAGPtr & actions_dag) const
{
auto ch_func_name = getCHFunctionName(substrait_func);
auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag);
assert(parsed_args.size() == 3);
const auto * function_type = typeid_cast<const DataTypeFunction *>(parsed_args[2]->result_type.get());
if (!function_type)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third argument of aggregate function must be a lambda function");
if (!parsed_args[1]->result_type->equals(*(function_type->getReturnType())))
{
parsed_args[1] = ActionsDAGUtil::convertNodeType(
actions_dag,
parsed_args[1],
function_type->getReturnType()->getName(),
parsed_args[1]->result_name);
}

/// arrayFold cannot accept nullable(array)
const auto * array_col_node = parsed_args[0];
if (parsed_args[0]->result_type->isNullable())
{
array_col_node = toFunctionNode(actions_dag, "assumeNotNull", {parsed_args[0]});
}
const auto * func_node = toFunctionNode(actions_dag, ch_func_name, {parsed_args[2], array_col_node, parsed_args[1]});
/// For null array, result is null.
/// TODO: make a new version of arrayFold that can handle nullable array.
const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {parsed_args[0]});
const auto * null_node = addColumnToActionsDAG(actions_dag, DB::makeNullable(func_node->result_type), DB::Null());
return toFunctionNode(actions_dag, "if", {is_null_node, null_node, func_node});
}
};
static FunctionParserRegister<ArrayAggregate> register_array_aggregate;

}
Loading

0 comments on commit 719cd31

Please sign in to comment.