From 3c05461c83e94033a49c5124719b67f4ea56cef5 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 2 Jul 2024 18:17:56 +0800 Subject: [PATCH] support sort_array --- .../clickhouse/CHSparkPlanExecApi.scala | 9 + .../Functions/SparkFunctionArraySort.cpp | 223 ++++++++++++++---- .../Functions/SparkFunctionSortArray.cpp | 88 +++++++ ...onArraySort.h => SparkFunctionSortArray.h} | 10 +- .../arrayHighOrderFunctions.cpp | 144 +++++++++++ .../scalar_function_parser/lambdaFunction.cpp | 1 + .../scalar_function_parser/sortArray.cpp | 4 +- .../gluten/backendsapi/SparkPlanExecApi.scala | 9 + .../expression/ExpressionConverter.scala | 13 + .../expression/ExpressionMappings.scala | 1 + .../gluten/expression/ExpressionNames.scala | 1 + 11 files changed, 453 insertions(+), 50 deletions(-) create mode 100644 cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp rename cpp-ch/local-engine/Functions/{SparkFunctionArraySort.h => SparkFunctionSortArray.h} (87%) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index c0dee707ef4f3..01c54861952c7 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -860,6 +860,15 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr) } + /** Transform array sort to Substrait. */ + override def genArraySortTransformer( + substraitExprName: String, + argument: ExpressionTransformer, + function: ExpressionTransformer, + expr: ArraySort): ExpressionTransformer = { + GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr) + } + override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp index 126b84eaaf95d..1371ec60e1796 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp @@ -14,75 +14,212 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -namespace DB +namespace DB::ErrorCodes { + extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int TYPE_MISMATCH; + extern const int ILLEGAL_COLUMN; +} -namespace ErrorCodes +/// The usage of `arraySort` in CH is different from Spark's `sort_array` function. +/// We need to implement a custom function to sort arrays. +namespace local_engine { - extern const int LOGICAL_ERROR; -} -namespace +struct LambdaLess { + const DB::IColumn & column; + DB::DataTypePtr type; + const DB::ColumnFunction & lambda; + explicit LambdaLess(const DB::IColumn & column_, DB::DataTypePtr type_, const DB::ColumnFunction & lambda_) + : column(column_), type(type_), lambda(lambda_) {} + + /// May not efficient + bool operator()(size_t lhs, size_t rhs) const + { + /// The column name seems not matter. + auto left_value_col = DB::ColumnWithTypeAndName(oneRowColumn(lhs), type, "left"); + auto right_value_col = DB::ColumnWithTypeAndName(oneRowColumn(rhs), type, "right"); + auto cloned_lambda = lambda.cloneResized(1); + auto * lambda_ = typeid_cast(cloned_lambda.get()); + lambda_->appendArguments({std::move(left_value_col), std::move(right_value_col)}); + auto compare_res_col = lambda_->reduce(); + DB::Field field; + compare_res_col.column->get(0, field); + return field.get() < 0; + } +private: + ALWAYS_INLINE DB::ColumnPtr oneRowColumn(size_t i) const + { + auto res = column.cloneEmpty(); + res->insertFrom(column, i); + return std::move(res); + } +}; -template struct Less { - const IColumn & column; + const DB::IColumn & column; - explicit Less(const IColumn & column_) : column(column_) { } + explicit Less(const DB::IColumn & column_) : column(column_) { } bool operator()(size_t lhs, size_t rhs) const { - if constexpr (positive) - /* - Note: We use nan_direction_hint=-1 for ascending sort to make NULL the least value. - However, NaN is also considered the least value, - which results in different sorting results compared to Spark since Spark treats NaN as the greatest value. - For now, we are temporarily ignoring this issue because cases with NaN are rare, - and aligning with Spark would require tricky modifications to the CH underlying code. - */ - return column.compareAt(lhs, rhs, column, -1) < 0; - else - return column.compareAt(lhs, rhs, column, -1) > 0; + return column.compareAt(lhs, rhs, column, 1) < 0; } }; -} - -template -ColumnPtr SparkArraySortImpl::execute( - const ColumnArray & array, - ColumnPtr mapped, - const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]]) +class FunctionSparkArraySort : public DB::IFunction { - const ColumnArray::Offsets & offsets = array.getOffsets(); +public: + static constexpr auto name = "arraySortSpark"; + static DB::FunctionPtr create(DB::ContextPtr /*context*/) { return std::make_shared(); } - size_t size = offsets.size(); - size_t nested_size = array.getData().size(); - IColumn::Permutation permutation(nested_size); + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo &) const override { return true; } - for (size_t i = 0; i < nested_size; ++i) - permutation[i] = i; + bool useDefaultImplementationForNulls() const override { return false; } + bool useDefaultImplementationForLowCardinalityColumns() const override { return false; } - ColumnArray::Offset current_offset = 0; - for (size_t i = 0; i < size; ++i) + void getLambdaArgumentTypes(DB::DataTypes & arguments) const override { - auto next_offset = offsets[i]; - ::sort(&permutation[current_offset], &permutation[next_offset], Less(*mapped)); - current_offset = next_offset; + if (arguments.size() < 2) + throw DB::Exception(DB::ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {} requires as arguments a lambda function and an array", getName()); + + if (arguments.size() > 1) + { + const auto * lambda_function_type = DB::checkAndGetDataType(arguments[0].get()); + if (!lambda_function_type || lambda_function_type->getArgumentTypes().size() != 2) + throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a lambda function with 2 arguments, found {} instead.", + getName(), arguments[0]->getName()); + auto array_nesteed_type = DB::checkAndGetDataType(arguments.back().get())->getNestedType(); + DB::DataTypes lambda_args = {array_nesteed_type, array_nesteed_type}; + arguments[0] = std::make_shared(lambda_args); + } } - return ColumnArray::create(array.getData().permute(permutation, 0), array.getOffsetsPtr()); -} + DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName & arguments) const override + { + if (arguments.size() > 1) + { + const auto * lambda_function_type = checkAndGetDataType(arguments[0].type.get()); + if (!lambda_function_type) + throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName()); + } + + return arguments.back().type; + } + + DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t input_rows_count) const override + { + auto array_col = arguments.back().column; + auto array_type = arguments.back().type; + DB::ColumnPtr null_map = nullptr; + if (const auto * null_col = typeid_cast(array_col.get())) + { + null_map = null_col->getNullMapColumnPtr(); + array_col = null_col->getNestedColumnPtr(); + array_type = typeid_cast(array_type.get())->getNestedType(); + } + + const auto * array_col_concrete = DB::checkAndGetColumn(array_col.get()); + if (!array_col_concrete) + { + const auto * aray_col_concrete_const = DB::checkAndGetColumnConst(array_col.get()); + if (!aray_col_concrete_const) + { + throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Expected array column, found {}", array_col->getName()); + } + array_col = DB::recursiveRemoveLowCardinality(aray_col_concrete_const->convertToFullColumn()); + array_col_concrete = DB::checkAndGetColumn(array_col.get()); + } + auto array_nested_type = DB::checkAndGetDataType(array_type.get())->getNestedType(); + + DB::ColumnPtr sorted_array_col = nullptr; + if (arguments.size() > 1) + sorted_array_col = executeWithLambda(*array_col_concrete, array_nested_type, *checkAndGetColumn(arguments[0].column.get())); + else + sorted_array_col = executeWithoutLambda(*array_col_concrete); + + if (null_map) + { + sorted_array_col = DB::ColumnNullable::create(sorted_array_col, null_map); + } + return sorted_array_col; + } +private: + static DB::ColumnPtr executeWithLambda(const DB::ColumnArray & array_col, DB::DataTypePtr array_nested_type, const DB::ColumnFunction & lambda) + { + const auto & offsets = array_col.getOffsets(); + auto rows = array_col.size(); + + size_t nested_size = array_col.getData().size(); + DB::IColumn::Permutation permutation(nested_size); + for (size_t i = 0; i < nested_size; ++i) + permutation[i] = i; + + DB::ColumnArray::Offset current_offset = 0; + for (size_t i = 0; i < rows; ++i) + { + auto next_offset = offsets[i]; + ::sort(&permutation[current_offset], + &permutation[next_offset], + LambdaLess(array_col.getData(), + array_nested_type, + lambda)); + current_offset = next_offset; + } + auto res = DB::ColumnArray::create(array_col.getData().permute(permutation, 0), array_col.getOffsetsPtr()); + return res; + } + + static DB::ColumnPtr executeWithoutLambda(const DB::ColumnArray & array_col) + { + const auto & offsets = array_col.getOffsets(); + auto rows = array_col.size(); + + size_t nested_size = array_col.getData().size(); + DB::IColumn::Permutation permutation(nested_size); + for (size_t i = 0; i < nested_size; ++i) + permutation[i] = i; + + DB::ColumnArray::Offset current_offset = 0; + for (size_t i = 0; i < rows; ++i) + { + auto next_offset = offsets[i]; + ::sort(&permutation[current_offset], + &permutation[next_offset], + Less(array_col.getData())); + current_offset = next_offset; + } + auto res = DB::ColumnArray::create(array_col.getData().permute(permutation, 0), array_col.getOffsetsPtr()); + return res; + } + + String getName() const override + { + return name; + } + +}; REGISTER_FUNCTION(ArraySortSpark) { - factory.registerFunction(); - factory.registerFunction(); + factory.registerFunction(); } - } diff --git a/cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp new file mode 100644 index 0000000000000..42b88fbce730d --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +namespace +{ + +template +struct Less +{ + const IColumn & column; + + explicit Less(const IColumn & column_) : column(column_) { } + + bool operator()(size_t lhs, size_t rhs) const + { + if constexpr (positive) + /* + Note: We use nan_direction_hint=-1 for ascending sort to make NULL the least value. + However, NaN is also considered the least value, + which results in different sorting results compared to Spark since Spark treats NaN as the greatest value. + For now, we are temporarily ignoring this issue because cases with NaN are rare, + and aligning with Spark would require tricky modifications to the CH underlying code. + */ + return column.compareAt(lhs, rhs, column, -1) < 0; + else + return column.compareAt(lhs, rhs, column, -1) > 0; + } +}; + +} + +template +ColumnPtr SparkSortArrayImpl::execute( + const ColumnArray & array, + ColumnPtr mapped, + const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]]) +{ + const ColumnArray::Offsets & offsets = array.getOffsets(); + + size_t size = offsets.size(); + size_t nested_size = array.getData().size(); + IColumn::Permutation permutation(nested_size); + + for (size_t i = 0; i < nested_size; ++i) + permutation[i] = i; + + ColumnArray::Offset current_offset = 0; + for (size_t i = 0; i < size; ++i) + { + auto next_offset = offsets[i]; + ::sort(&permutation[current_offset], &permutation[next_offset], Less(*mapped)); + current_offset = next_offset; + } + + return ColumnArray::create(array.getData().permute(permutation, 0), array.getOffsetsPtr()); +} + +REGISTER_FUNCTION(SortArraySpark) +{ + factory.registerFunction(); + factory.registerFunction(); +} + +} diff --git a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.h b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.h similarity index 87% rename from cpp-ch/local-engine/Functions/SparkFunctionArraySort.h rename to cpp-ch/local-engine/Functions/SparkFunctionSortArray.h index 9ce48f9c0baf5..22ad6a636b2e7 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionArraySort.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionSortArray.h @@ -32,7 +32,7 @@ namespace ErrorCodes /** Sort arrays, by values of its elements, or by values of corresponding elements of calculated expression (known as "schwartzsort"). */ template -struct SparkArraySortImpl +struct SparkSortArrayImpl { static bool needBoolean() { return false; } static bool needExpression() { return false; } @@ -69,14 +69,14 @@ struct SparkArraySortImpl struct NameArraySort { - static constexpr auto name = "arraySortSpark"; + static constexpr auto name = "sortArraySpark"; }; struct NameArrayReverseSort { - static constexpr auto name = "arrayReverseSortSpark"; + static constexpr auto name = "reverseSortArraySpark"; }; -using SparkFunctionArraySort = FunctionArrayMapped, NameArraySort>; -using SparkFunctionArrayReverseSort = FunctionArrayMapped, NameArrayReverseSort>; +using SparkFunctionSortArray = FunctionArrayMapped, NameArraySort>; +using SparkFunctionReverseSortArray = FunctionArrayMapped, NameArrayReverseSort>; } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp index 584bc0ef1e04f..3811880aea63f 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arrayHighOrderFunctions.cpp @@ -151,4 +151,148 @@ class ArrayAggregate : public FunctionParser }; static FunctionParserRegister register_array_aggregate; +class ArraySort : public FunctionParser +{ +public: + static constexpr auto name = "array_sort"; + explicit ArraySort(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~ArraySort() override = default; + String getName() const override { return name; } + String getCHFunctionName(const substrait::Expression_ScalarFunction & scalar_function) const override + { + return "arraySortSpark"; + } + const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, + DB::ActionsDAGPtr & actions_dag) const + { + auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + + if (parsed_args.size() != 2) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "array_sort function must have two arguments"); + if (isDefaultCompare(substrait_func.arguments()[1].value().scalar_function())) + { + return toFunctionNode(actions_dag, ch_func_name, {parsed_args[0]}); + } + + return toFunctionNode(actions_dag, ch_func_name, {parsed_args[1], parsed_args[0]}); + } +private: + + /// The default lambda compare function for array_sort, `array_sort(x)`. + bool isDefaultCompare(const substrait::Expression_ScalarFunction & scalar_function) const + { + String left_variable_name, right_variable_name; + auto names_types = collectLambdaArguments(*plan_parser, scalar_function); + { + auto it = names_types.begin(); + left_variable_name = it->name; + it++; + right_variable_name = it->name; + } + + auto is_function = [&](const substrait::Expression & expr, const String & function_name) { + return expr.has_scalar_function() + && *(plan_parser->getFunctionSignatureName(expr.scalar_function().function_reference())) == function_name; + }; + + auto is_variable = [&](const substrait::Expression & expr, const String & var) { + if (!is_function(expr, "namedlambdavariable")) + { + return false; + } + const auto var_expr = expr.scalar_function().arguments()[0].value(); + if (!var_expr.has_literal()) + return false; + auto [_, name] = plan_parser->parseLiteral(var_expr.literal()); + return var == name.get(); + }; + + auto is_int_value = [&](const substrait::Expression & expr, Int32 val) { + if (!expr.has_literal()) + return false; + auto [_, x] = plan_parser->parseLiteral(expr.literal()); + return val == x.get(); + }; + + auto is_variable_null = [&](const substrait::Expression & expr, const String & var) { + return is_function(expr, "is_null") && is_variable(expr.scalar_function().arguments(0).value(), var); + }; + + auto is_both_null = [&](const substrait::Expression & expr) { + return is_function(expr, "and") + && is_variable_null(expr.scalar_function().arguments(0).value(), left_variable_name) + && is_variable_null(expr.scalar_function().arguments(1).value(), right_variable_name); + }; + + auto is_left_greater_right = [&](const substrait::Expression & expr) { + if (!expr.has_if_then()) + return false; + + const auto & if_ = expr.if_then().ifs(0); + if (!is_function(if_.if_(), "gt")) + return false; + + const auto & less_args = if_.if_().scalar_function().arguments(); + return is_variable(less_args[0].value(), left_variable_name) + && is_variable(less_args[1].value(), right_variable_name) + && is_int_value(if_.then(), 1) + && is_int_value(expr.if_then().else_(), 0); + }; + + auto is_left_less_right = [&](const substrait::Expression & expr) { + if (!expr.has_if_then()) + return false; + + const auto & if_ = expr.if_then().ifs(0); + if (!is_function(if_.if_(), "lt")) + return false; + + const auto & less_args = if_.if_().scalar_function().arguments(); + return is_variable(less_args[0].value(), left_variable_name) + && is_variable(less_args[1].value(), right_variable_name) + && is_int_value(if_.then(), -1) + && is_left_greater_right(expr.if_then().else_()); + }; + + auto is_right_null_else = [&](const substrait::Expression & expr) { + if (!expr.has_if_then()) + return false; + + /// if right arg is null, return 1 + const auto & if_then = expr.if_then(); + return is_variable_null(if_then.ifs(0).if_(), right_variable_name) + && is_int_value(if_then.ifs(0).then(), -1) + && is_left_less_right(if_then.else_()); + + }; + + auto is_left_null_else = [&](const substrait::Expression & expr) { + if (!expr.has_if_then()) + return false; + + /// if left arg is null, return 1 + const auto & if_then = expr.if_then(); + return is_variable_null(if_then.ifs(0).if_(), left_variable_name) + && is_int_value(if_then.ifs(0).then(), 1) + && is_right_null_else(if_then.else_()); + }; + + auto is_if_both_null_else = [&](const substrait::Expression & expr) { + if (!expr.has_if_then()) + { + return false; + } + const auto & if_ = expr.if_then().ifs(0); + return is_both_null(if_.if_()) + && is_int_value(if_.then(), 0) + && is_left_null_else(expr.if_then().else_()); + }; + + const auto & lambda_body = scalar_function.arguments()[0].value(); + return is_if_both_null_else(lambda_body); + } +}; +static FunctionParserRegister register_array_sort; + } diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp index 57c076ed2670d..91452888c4b0d 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/lambdaFunction.cpp @@ -149,6 +149,7 @@ class LambdaFunction : public FunctionParser required_column_name, actions_dag->dumpDAG()); } + LOG_ERROR(getLogger("LambdaFunction"), "xxx capture column: {}", required_column_name); /// The nodes must be the ones in `actions_dag`, otherwise `ActionsDAG::evaluatePartialResult` will fail. Because nodes may have the /// same name but their addresses are different. lambda_children.push_back(parent_node_it->second); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp index 85416bd71864b..4fd2fd4f68004 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/sortArray.cpp @@ -52,8 +52,8 @@ class FunctionParserSortArray : public FunctionParser const auto * array_arg = parsed_args[0]; const auto * order_arg = parsed_args[1]; - const auto * sort_node = toFunctionNode(actions_dag, "arraySortSpark", {array_arg}); - const auto * reverse_sort_node = toFunctionNode(actions_dag, "arrayReverseSortSpark", {array_arg}); + const auto * sort_node = toFunctionNode(actions_dag, "sortArraySpark", {array_arg}); + const auto * reverse_sort_node = toFunctionNode(actions_dag, "reverseSortArraySpark", {array_arg}); const auto * result_node = toFunctionNode(actions_dag, "if", {order_arg, sort_node, reverse_sort_node}); return result_node; diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index ff7449e2d3404..a69d41d00c12d 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -258,6 +258,15 @@ trait SparkPlanExecApi { throw new GlutenNotSupportException("all_match is not supported") } + /** Transform array array_sort to Substrait. */ + def genArraySortTransformer( + substraitExprName: String, + argument: ExpressionTransformer, + function: ExpressionTransformer, + expr: ArraySort): ExpressionTransformer = { + throw new GlutenNotSupportException("array_sort(on array) is not supported") + } + /** Transform array exists to Substrait */ def genArrayExistsTransformer( substraitExprName: String, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index b5bcb6876e4d9..805ff94900fe9 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -556,6 +556,19 @@ object ExpressionConverter extends SQLConfHelper with Logging { expressionsMap), arrayTransform ) + case arraySort: ArraySort => + BackendsApiManager.getSparkPlanExecApiInstance.genArraySortTransformer( + substraitExprName, + replaceWithExpressionTransformerInternal( + arraySort.argument, + attributeSeq, + expressionsMap), + replaceWithExpressionTransformerInternal( + arraySort.function, + attributeSeq, + expressionsMap), + arraySort + ) case tryEval @ TryEval(a: Add) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( substraitExprName, diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index 806ec844de601..74fbea67b6116 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -248,6 +248,7 @@ object ExpressionMappings { Sig[ArrayFilter](FILTER), Sig[ArrayForAll](FORALL), Sig[ArrayExists](EXISTS), + Sig[ArraySort](ARRAY_SORT), Sig[Shuffle](SHUFFLE), Sig[ZipWith](ZIP_WITH), Sig[Flatten](FLATTEN), diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index 7060e297ea10e..52b4ae9f41470 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -261,6 +261,7 @@ object ExpressionNames { final val ARRAY_EXCEPT = "array_except" final val ARRAY_REPEAT = "array_repeat" final val ARRAY_REMOVE = "array_remove" + final val ARRAY_SORT = "array_sort" final val ARRAYS_ZIP = "arrays_zip" final val FILTER = "filter" final val FORALL = "forall"