Skip to content

Commit

Permalink
[GLUTEN-6816][CH] support function zip_with with some minor refactors (
Browse files Browse the repository at this point in the history
…#7211)

What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)

(Fixes: #6816)

How was this patch tested?
GlutenDataFrameFunctionsSuite

"arrays zip_with function - for non-primitive types"
"arrays zip_with function - for non-primitive types"
"arrays zip_with function - invalid"
  • Loading branch information
taiyang-li authored Sep 13, 2024
1 parent ce09169 commit 325f6a4
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ object CHExpressionUtil {
SKEWNESS -> DefaultValidator(),
MAKE_YM_INTERVAL -> DefaultValidator(),
MAP_ZIP_WITH -> DefaultValidator(),
ZIP_WITH -> DefaultValidator(),
KURTOSIS -> DefaultValidator(),
REGR_R2 -> DefaultValidator(),
REGR_SLOPE -> DefaultValidator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@

namespace DB::ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
}

namespace local_engine
{
class ArrayFilter : public FunctionParser
class FunctionParserArrayFilter : public FunctionParser
{
public:
static constexpr auto name = "filter";
explicit ArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayFilter() override = default;
explicit FunctionParserArrayFilter(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserArrayFilter() override = default;

String getName() const override { return name; }

Expand Down Expand Up @@ -69,14 +70,14 @@ class ArrayFilter : public FunctionParser
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node});
}
};
static FunctionParserRegister<ArrayFilter> register_array_filter;
static FunctionParserRegister<FunctionParserArrayFilter> register_array_filter;

class ArrayTransform : public FunctionParser
class FunctionParserArrayTransform : public FunctionParser
{
public:
static constexpr auto name = "transform";
explicit ArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayTransform() override = default;
explicit FunctionParserArrayTransform(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserArrayTransform() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
Expand Down Expand Up @@ -115,14 +116,14 @@ class ArrayTransform : public FunctionParser
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0], index_array_node});
}
};
static FunctionParserRegister<ArrayTransform> register_array_map;
static FunctionParserRegister<FunctionParserArrayTransform> register_array_map;

class ArrayAggregate : public FunctionParser
class FunctionParserArrayAggregate : public FunctionParser
{
public:
static constexpr auto name = "aggregate";
explicit ArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArrayAggregate() override = default;
explicit FunctionParserArrayAggregate(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserArrayAggregate() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
Expand All @@ -134,9 +135,11 @@ class ArrayAggregate : public FunctionParser
auto ch_func_name = getCHFunctionName(substrait_func);
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
assert(parsed_args.size() == 3);

const auto * function_type = typeid_cast<const DataTypeFunction *>(parsed_args[2]->result_type.get());
if (!function_type)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The third argument of aggregate function must be a lambda function");
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The third argument of aggregate function must be a lambda function");

if (!parsed_args[1]->result_type->equals(*(function_type->getReturnType())))
{
parsed_args[1] = ActionsDAGUtil::convertNodeType(
Expand All @@ -160,14 +163,14 @@ class ArrayAggregate : public FunctionParser
return toFunctionNode(actions_dag, "if", {is_null_node, null_node, func_node});
}
};
static FunctionParserRegister<ArrayAggregate> register_array_aggregate;
static FunctionParserRegister<FunctionParserArrayAggregate> register_array_aggregate;

class ArraySort : public FunctionParser
class FunctionParserArraySort : public FunctionParser
{
public:
static constexpr auto name = "array_sort";
explicit ArraySort(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~ArraySort() override = default;
explicit FunctionParserArraySort(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserArraySort() override = default;
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override
{
Expand All @@ -180,7 +183,8 @@ class ArraySort : public FunctionParser
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);

if (parsed_args.size() != 2)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "array_sort function must have two arguments");
throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "array_sort function must have two arguments");

if (isDefaultCompare(substrait_func.arguments()[1].value().scalar_function()))
{
return toFunctionNode(actions_dag, ch_func_name, {parsed_args[0]});
Expand Down Expand Up @@ -304,6 +308,34 @@ class ArraySort : public FunctionParser
return is_if_both_null_else(lambda_body);
}
};
static FunctionParserRegister<ArraySort> register_array_sort;
static FunctionParserRegister<FunctionParserArraySort> register_array_sort;

class FunctionParserZipWith: public FunctionParser
{
public:
static constexpr auto name = "zip_with";
explicit FunctionParserZipWith(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserZipWith() override = default;
String getName() const override { return name; }

const DB::ActionsDAG::Node *
parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override
{
/// Parse spark zip_with(arr1, arr2, func) as CH arrayMap(func, arrayZipUnaligned(arr1, arr2))
auto parsed_args = parseFunctionArguments(substrait_func, actions_dag);
if (parsed_args.size() != 3)
throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "zip_with function must have three arguments");

auto lambda_args = collectLambdaArguments(*plan_parser, substrait_func.arguments()[2].value().scalar_function());
if (lambda_args.size() != 2)
throw DB::Exception(DB::ErrorCodes::SIZES_OF_COLUMNS_DOESNT_MATCH, "The lambda function in zip_with must have two arguments");

const auto * array_zip_unaligned = toFunctionNode(actions_dag, "arrayZipUnaligned", {parsed_args[0], parsed_args[1]});
const auto * array_map = toFunctionNode(actions_dag, "arrayMap", {parsed_args[2], array_zip_unaligned});
return convertNodeTypeIfNeeded(substrait_func, array_map, actions_dag);
}
};
static FunctionParserRegister<FunctionParserZipWith> register_zip_with;


}
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ DB::NamesAndTypesList collectLambdaArguments(const SerializedPlanParser & plan_p
}

/// Refer to `PlannerActionsVisitorImpl::visitLambda` for how to build a lambda function node.
class LambdaFunction : public FunctionParser
class FunctionParserLambda : public FunctionParser
{
public:
static constexpr auto name = "lambdafunction";
explicit LambdaFunction(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~LambdaFunction() override = default;
explicit FunctionParserLambda(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {}
~FunctionParserLambda() override = default;

String getName() const override { return name; }

Expand All @@ -79,7 +79,7 @@ class LambdaFunction : public FunctionParser
for (const auto * output_node : actions_dag.getOutputs())
{
parent_header.emplace_back(output_node->result_name, output_node->result_type);
}
}
ActionsDAG lambda_actions_dag{parent_header};

/// The first argument is the lambda function body, followings are the lambda arguments which is
Expand Down Expand Up @@ -157,7 +157,7 @@ class LambdaFunction : public FunctionParser
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "parseFunctionArguments is not implemented for LambdaFunction");
}

const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const substrait::Expression_ScalarFunction & substrait_func,
const DB::ActionsDAG::Node * func_node,
Expand All @@ -167,7 +167,7 @@ class LambdaFunction : public FunctionParser
}
};

static FunctionParserRegister<LambdaFunction> register_lambda_function;
static FunctionParserRegister<FunctionParserLambda> register_lambda_function;


class NamedLambdaVariable : public FunctionParser
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,13 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[GlutenDataFrameComplexTypeSuite]
enableSuite[GlutenDataFrameFunctionsSuite]
.exclude("map with arrays")
.exclude("bin")
.exclude("sequence")
.exclude("element_at function")
.exclude("flatten function")
.exclude("aggregate function - array for primitive type not containing null")
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
.exclude("transform keys function - primitive data types")
.exclude("transform values function - test empty")
.exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union")
.exclude("SPARK-24734: Fix containsNull of Concat for array type")
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
.exclude("flatten function")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,13 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[GlutenDataFrameComplexTypeSuite]
enableSuite[GlutenDataFrameFunctionsSuite]
.exclude("map with arrays")
.exclude("bin")
.exclude("sequence")
.exclude("element_at function")
.exclude("flatten function")
.exclude("aggregate function - array for primitive type not containing null")
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
.exclude("transform keys function - primitive data types")
.exclude("transform values function - test empty")
.exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union")
.exclude("SPARK-24734: Fix containsNull of Concat for array type")
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
.exclude("flatten function")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,13 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[GlutenDataFrameComplexTypeSuite]
enableSuite[GlutenDataFrameFunctionsSuite]
.exclude("map with arrays")
.exclude("bin")
.exclude("sequence")
.exclude("element_at function")
.exclude("flatten function")
.exclude("aggregate function - array for primitive type not containing null")
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
.exclude("transform keys function - primitive data types")
.exclude("transform values function - test empty")
.exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union")
.exclude("SPARK-24734: Fix containsNull of Concat for array type")
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,13 @@ class ClickHouseTestSettings extends BackendTestSettings {
enableSuite[GlutenDataFrameComplexTypeSuite]
enableSuite[GlutenDataFrameFunctionsSuite]
.exclude("map with arrays")
.exclude("bin")
.exclude("sequence")
.exclude("element_at function")
.exclude("flatten function")
.exclude("aggregate function - array for primitive type not containing null")
.exclude("aggregate function - array for primitive type containing null")
.exclude("aggregate function - array for non-primitive type")
.exclude("transform keys function - primitive data types")
.exclude("transform values function - test empty")
.exclude("SPARK-14393: values generated by non-deterministic functions shouldn't change after coalesce or union")
.exclude("SPARK-24734: Fix containsNull of Concat for array type")
.exclude("shuffle function - array for primitive type not containing null")
.exclude("shuffle function - array for primitive type containing null")
.exclude("shuffle function - array for non-primitive type")
enableSuite[GlutenDataFrameHintSuite]
enableSuite[GlutenDataFrameImplicitsSuite]
enableSuite[GlutenDataFrameJoinSuite].exclude(
Expand Down

0 comments on commit 325f6a4

Please sign in to comment.