Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-6816][CH] support function zip_with with some minor refactors #7211

Merged
merged 2 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ object CHExpressionUtil {
SKEWNESS -> DefaultValidator(),
MAKE_YM_INTERVAL -> DefaultValidator(),
MAP_ZIP_WITH -> DefaultValidator(),
ZIP_WITH -> DefaultValidator(),
KURTOSIS -> DefaultValidator(),
REGR_R2 -> DefaultValidator(),
REGR_SLOPE -> DefaultValidator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@

namespace DB::ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
}

namespace local_engine
{
class ArrayFilter : public FunctionParser
class FunctionParserArrayFilter : public FunctionParser
{
public:
static constexpr auto name = "filter";
explicit ArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayFilter() override = default;
explicit FunctionParserArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserArrayFilter() override = default;

String getName() const override { return name; }

Expand Down Expand Up @@ -69,14 +70,14 @@ class ArrayFilter : public FunctionParser
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node});
}
};
static FunctionParserRegister<ArrayFilter> register_array_filter;
static FunctionParserRegister<FunctionParserArrayFilter> register_array_filter;

class ArrayTransform : public FunctionParser
class FunctionParserArrayTransform : public FunctionParser
{
public:
static constexpr auto name = "transform";
explicit ArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayTransform() override = default;
explicit FunctionParserArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserArrayTransform() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
Expand Down Expand Up @@ -115,14 +116,14 @@ class ArrayTransform : public FunctionParser
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node});
}
};
static FunctionParserRegister<ArrayTransform> register_array_map;
static FunctionParserRegister<FunctionParserArrayTransform> register_array_map;

class ArrayAggregate : public FunctionParser
class FunctionParserArrayAggregate : public FunctionParser
{
public:
static constexpr auto name = "aggregate";
explicit ArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayAggregate() override = default;
explicit FunctionParserArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserArrayAggregate() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
Expand All @@ -134,9 +135,11 @@ class ArrayAggregate : public FunctionParser
auto ch_func_name = getCHFunctionName(substrait_func);
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
assert(parsed_args.size() == 3);

const auto * function_type = typeid_cast<const DataTypeFunction *>(parsed_args[2]->result_type.get());
if (!function_type)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third argument of aggregate function must be a lambda function");
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The third argument of aggregate function must be a lambda function");

if (!parsed_args[1]->result_type->equals(*(function_type->getReturnType())))
{
parsed_args[1] = ActionsDAGUtil::convertNodeType(
Expand All @@ -160,14 +163,14 @@ class ArrayAggregate : public FunctionParser
return toFunctionNode(actions_dag, "if", {is_null_node, null_node, func_node});
}
};
static FunctionParserRegister<ArrayAggregate> register_array_aggregate;
static FunctionParserRegister<FunctionParserArrayAggregate> register_array_aggregate;

class ArraySort : public FunctionParser
class FunctionParserArraySort : public FunctionParser
{
public:
static constexpr auto name = "array_sort";
explicit ArraySort(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArraySort() override = default;
explicit FunctionParserArraySort(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserArraySort() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
Expand All @@ -180,7 +183,8 @@ class ArraySort : public FunctionParser
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);

if (parsed_args.size() != 2)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "array_sort function must have two arguments");
throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "array_sort function must have two arguments");

if (isDefaultCompare(substrait_func.arguments()[1].value().scalar_function()))
{
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[0]});
Expand Down Expand Up @@ -304,6 +308,34 @@ class ArraySort : public FunctionParser
return is_if_both_null_else(lambda_body);
}
};
static FunctionParserRegister<ArraySort> register_array_sort;
static FunctionParserRegister<FunctionParserArraySort> register_array_sort;

class FunctionParserZipWith: public FunctionParser
{
public:
static constexpr auto name = "zip_with";
explicit FunctionParserZipWith(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserZipWith() override = default;
String getName() const override { return name; }

const DB::ActionsDAG::Node *
parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override
{
/// Parse spark zip_with(arr1, arr2, func) as CH arrayMap(func, arrayZipUnaligned(arr1, arr2))
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
if (parsed_args.size() != 3)
throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "zip_with function must have three arguments");

auto lambda_args = collectLambdaArguments(*plan_parser, substrait_func.arguments()[2].value().scalar_function());
if (lambda_args.size() != 2)
throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "The lambda function in zip_with must have two arguments");

const auto * array_zip_unaligned = toFunctionNode(actions_dag, "arrayZipUnaligned", {parsed_args[0], parsed_args[1]});
const auto * array_map = toFunctionNode(actions_dag, "arrayMap", {parsed_args[2], array_zip_unaligned});
return convertNodeTypeIfNeeded(substrait_func, array_map, actions_dag);
}
};
static FunctionParserRegister<FunctionParserZipWith> register_zip_with;


}
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & plan_p
}

/// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a lambda function node.
class LambdaFunction : public FunctionParser
class FunctionParserLambda : public FunctionParser
{
public:
static constexpr auto name = "lambdafunction";
explicit LambdaFunction(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~LambdaFunction() override = default;
explicit FunctionParserLambda(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserLambda() override = default;

String getName() const override { return name; }

Expand All @@ -79,7 +79,7 @@ class LambdaFunction : public FunctionParser
for (const auto * output_node : actions_dag.getOutputs())
{
parent_header.emplace_back(output_node->result_name, output_node->result_type);
}
}
ActionsDAG lambda_actions_dag{parent_header};

/// The first argument is the lambda function body, followings are the lambda arguments which is
Expand Down Expand Up @@ -157,7 +157,7 @@ class LambdaFunction : public FunctionParser
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction");
}

const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const substrait::Expression_ScalarFunction & substrait_func,
const DB::ActionsDAG::Node * func_node,
Expand All @@ -167,7 +167,7 @@ class LambdaFunction : public FunctionParser
}
};

static FunctionParserRegister<LambdaFunction> register_lambda_function;
static FunctionParserRegister<FunctionParserLambda> register_lambda_function;


class NamedLambdaVariable : public FunctionParser
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,13 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[GlutenDataFrameComplexTypeSuite]
enableSuite[GlutenDataFrameFunctionsSuite]
.exclude("map with arrays")
.exclude("bin")
.exclude("sequence")
.exclude("element_at function")
.exclude("flatten function")
.exclude("aggregate function - array for primitive type not containing null")
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
.exclude("transform keys function - primitive data types")
.exclude("transform values function - test empty")
.exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union")
.exclude("SPARK-24734: Fix containsNull of Concat for array type")
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
.exclude("flatten function")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,13 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[GlutenDataFrameComplexTypeSuite]
enableSuite[GlutenDataFrameFunctionsSuite]
.exclude("map with arrays")
.exclude("bin")
.exclude("sequence")
.exclude("element_at function")
.exclude("flatten function")
.exclude("aggregate function - array for primitive type not containing null")
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
.exclude("transform keys function - primitive data types")
.exclude("transform values function - test empty")
.exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union")
.exclude("SPARK-24734: Fix containsNull of Concat for array type")
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
.exclude("flatten function")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,13 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[GlutenDataFrameComplexTypeSuite]
enableSuite[GlutenDataFrameFunctionsSuite]
.exclude("map with arrays")
.exclude("bin")
.exclude("sequence")
.exclude("element_at function")
.exclude("flatten function")
.exclude("aggregate function - array for primitive type not containing null")
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
.exclude("transform keys function - primitive data types")
.exclude("transform values function - test empty")
.exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union")
.exclude("SPARK-24734: Fix containsNull of Concat for array type")
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,13 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[GlutenDataFrameComplexTypeSuite]
enableSuite[GlutenDataFrameFunctionsSuite]
.exclude("map with arrays")
.exclude("bin")
.exclude("sequence")
.exclude("element_at function")
.exclude("flatten function")
.exclude("aggregate function - array for primitive type not containing null")
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
.exclude("transform keys function - primitive data types")
.exclude("transform values function - test empty")
.exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union")
.exclude("SPARK-24734: Fix containsNull of Concat for array type")
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
Expand Down
Loading