Skip to content

Commit

Permalink
support sort_array (#6323)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc authored Jul 5, 2024
1 parent ff0b473 commit 995145e
Show file tree
Hide file tree
Showing 10 changed files with 454 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,15 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr)
}

/** Transform array sort to Substrait. */
override def genArraySortTransformer(
substraitExprName: String,
argument: ExpressionTransformer,
function: ExpressionTransformer,
expr: ArraySort): ExpressionTransformer = {
GenericExpressionTransformer(substraitExprName, Seq(argument, function), expr)
}

override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = generate

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = generate
Expand Down
223 changes: 180 additions & 43 deletions cpp-ch/local-engine/Functions/SparkFunctionArraySort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,75 +14,212 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Functions/SparkFunctionArraySort.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionFactory.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnFunction.h>
#include <Columns/ColumnNullable.h>
#include <Common/Exception.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeFunction.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
#include <base/sort.h>

namespace DB
namespace DB::ErrorCodes
{
extern const int TOO_FEW_ARGUMENTS_FOR_FUNCTION;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int TYPE_MISMATCH;
extern const int ILLEGAL_COLUMN;
}

namespace ErrorCodes
/// The usage of `arraySort` in CH is different from Spark's `sort_array` function.
/// We need to implement a custom function to sort arrays.
namespace local_engine
{
extern const int LOGICAL_ERROR;
}

namespace
struct LambdaLess
{
const DB::IColumn & column;
DB::DataTypePtr type;
const DB::ColumnFunction & lambda;
explicit LambdaLess(const DB::IColumn & column_, DB::DataTypePtr type_, const DB::ColumnFunction & lambda_)
: column(column_), type(type_), lambda(lambda_) {}

/// May not efficient
bool operator()(size_t lhs, size_t rhs) const
{
/// The column name seems not matter.
auto left_value_col = DB::ColumnWithTypeAndName(oneRowColumn(lhs), type, "left");
auto right_value_col = DB::ColumnWithTypeAndName(oneRowColumn(rhs), type, "right");
auto cloned_lambda = lambda.cloneResized(1);
auto * lambda_ = typeid_cast<DB::ColumnFunction *>(cloned_lambda.get());
lambda_->appendArguments({std::move(left_value_col), std::move(right_value_col)});
auto compare_res_col = lambda_->reduce();
DB::Field field;
compare_res_col.column->get(0, field);
return field.get<Int32>() < 0;
}
private:
ALWAYS_INLINE DB::ColumnPtr oneRowColumn(size_t i) const
{
auto res = column.cloneEmpty();
res->insertFrom(column, i);
return std::move(res);
}
};

template <bool positive>
struct Less
{
const IColumn & column;
const DB::IColumn & column;

explicit Less(const IColumn & column_) : column(column_) { }
explicit Less(const DB::IColumn & column_) : column(column_) { }

bool operator()(size_t lhs, size_t rhs) const
{
if constexpr (positive)
/*
Note: We use nan_direction_hint=-1 for ascending sort to make NULL the least value.
However, NaN is also considered the least value,
which results in different sorting results compared to Spark since Spark treats NaN as the greatest value.
For now, we are temporarily ignoring this issue because cases with NaN are rare,
and aligning with Spark would require tricky modifications to the CH underlying code.
*/
return column.compareAt(lhs, rhs, column, -1) < 0;
else
return column.compareAt(lhs, rhs, column, -1) > 0;
return column.compareAt(lhs, rhs, column, 1) < 0;
}
};

}

template <bool positive>
ColumnPtr SparkArraySortImpl<positive>::execute(
const ColumnArray & array,
ColumnPtr mapped,
const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]])
class FunctionSparkArraySort : public DB::IFunction
{
const ColumnArray::Offsets & offsets = array.getOffsets();
public:
static constexpr auto name = "arraySortSpark";
static DB::FunctionPtr create(DB::ContextPtr /*context*/) { return std::make_shared<FunctionSparkArraySort>(); }

size_t size = offsets.size();
size_t nested_size = array.getData().size();
IColumn::Permutation permutation(nested_size);
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool isSuitableForShortCircuitArgumentsExecution(const DB::DataTypesWithConstInfo &) const override { return true; }

for (size_t i = 0; i < nested_size; ++i)
permutation[i] = i;
bool useDefaultImplementationForNulls() const override { return false; }
bool useDefaultImplementationForLowCardinalityColumns() const override { return false; }

ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < size; ++i)
void getLambdaArgumentTypes(DB::DataTypes & arguments) const override
{
auto next_offset = offsets[i];
::sort(&permutation[current_offset], &permutation[next_offset], Less<positive>(*mapped));
current_offset = next_offset;
if (arguments.size() < 2)
throw DB::Exception(DB::ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION, "Function {} requires as arguments a lambda function and an array", getName());

if (arguments.size() > 1)
{
const auto * lambda_function_type = DB::checkAndGetDataType<DB::DataTypeFunction>(arguments[0].get());
if (!lambda_function_type || lambda_function_type->getArgumentTypes().size() != 2)
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument of function {} must be a lambda function with 2 arguments, found {} instead.",
getName(), arguments[0]->getName());
auto array_nesteed_type = DB::checkAndGetDataType<DB::DataTypeArray>(arguments.back().get())->getNestedType();
DB::DataTypes lambda_args = {array_nesteed_type, array_nesteed_type};
arguments[0] = std::make_shared<DB::DataTypeFunction>(lambda_args);
}
}

return ColumnArray::create(array.getData().permute(permutation, 0), array.getOffsetsPtr());
}
DB::DataTypePtr getReturnTypeImpl(const DB::ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() > 1)
{
const auto * lambda_function_type = checkAndGetDataType<DB::DataTypeFunction>(arguments[0].type.get());
if (!lambda_function_type)
throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be a function", getName());
}

return arguments.back().type;
}

DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr &, size_t input_rows_count) const override
{
auto array_col = arguments.back().column;
auto array_type = arguments.back().type;
DB::ColumnPtr null_map = nullptr;
if (const auto * null_col = typeid_cast<const DB::ColumnNullable *>(array_col.get()))
{
null_map = null_col->getNullMapColumnPtr();
array_col = null_col->getNestedColumnPtr();
array_type = typeid_cast<const DB::DataTypeNullable *>(array_type.get())->getNestedType();
}

const auto * array_col_concrete = DB::checkAndGetColumn<DB::ColumnArray>(array_col.get());
if (!array_col_concrete)
{
const auto * aray_col_concrete_const = DB::checkAndGetColumnConst<DB::ColumnArray>(array_col.get());
if (!aray_col_concrete_const)
{
throw DB::Exception(DB::ErrorCodes::ILLEGAL_COLUMN, "Expected array column, found {}", array_col->getName());
}
array_col = DB::recursiveRemoveLowCardinality(aray_col_concrete_const->convertToFullColumn());
array_col_concrete = DB::checkAndGetColumn<DB::ColumnArray>(array_col.get());
}
auto array_nested_type = DB::checkAndGetDataType<DB::DataTypeArray>(array_type.get())->getNestedType();

DB::ColumnPtr sorted_array_col = nullptr;
if (arguments.size() > 1)
sorted_array_col = executeWithLambda(*array_col_concrete, array_nested_type, *checkAndGetColumn<DB::ColumnFunction>(arguments[0].column.get()));
else
sorted_array_col = executeWithoutLambda(*array_col_concrete);

if (null_map)
{
sorted_array_col = DB::ColumnNullable::create(sorted_array_col, null_map);
}
return sorted_array_col;
}
private:
static DB::ColumnPtr executeWithLambda(const DB::ColumnArray & array_col, DB::DataTypePtr array_nested_type, const DB::ColumnFunction & lambda)
{
const auto & offsets = array_col.getOffsets();
auto rows = array_col.size();

size_t nested_size = array_col.getData().size();
DB::IColumn::Permutation permutation(nested_size);
for (size_t i = 0; i < nested_size; ++i)
permutation[i] = i;

DB::ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < rows; ++i)
{
auto next_offset = offsets[i];
::sort(&permutation[current_offset],
&permutation[next_offset],
LambdaLess(array_col.getData(),
array_nested_type,
lambda));
current_offset = next_offset;
}
auto res = DB::ColumnArray::create(array_col.getData().permute(permutation, 0), array_col.getOffsetsPtr());
return res;
}

static DB::ColumnPtr executeWithoutLambda(const DB::ColumnArray & array_col)
{
const auto & offsets = array_col.getOffsets();
auto rows = array_col.size();

size_t nested_size = array_col.getData().size();
DB::IColumn::Permutation permutation(nested_size);
for (size_t i = 0; i < nested_size; ++i)
permutation[i] = i;

DB::ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < rows; ++i)
{
auto next_offset = offsets[i];
::sort(&permutation[current_offset],
&permutation[next_offset],
Less(array_col.getData()));
current_offset = next_offset;
}
auto res = DB::ColumnArray::create(array_col.getData().permute(permutation, 0), array_col.getOffsetsPtr());
return res;
}

String getName() const override
{
return name;
}

};

REGISTER_FUNCTION(ArraySortSpark)
{
factory.registerFunction<SparkFunctionArraySort>();
factory.registerFunction<SparkFunctionArrayReverseSort>();
factory.registerFunction<FunctionSparkArraySort>();
}

}
88 changes: 88 additions & 0 deletions cpp-ch/local-engine/Functions/SparkFunctionSortArray.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Functions/SparkFunctionSortArray.h>
#include <Functions/FunctionFactory.h>

namespace DB
{

namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}

namespace
{

template <bool positive>
struct Less
{
const IColumn & column;

explicit Less(const IColumn & column_) : column(column_) { }

bool operator()(size_t lhs, size_t rhs) const
{
if constexpr (positive)
/*
Note: We use nan_direction_hint=-1 for ascending sort to make NULL the least value.
However, NaN is also considered the least value,
which results in different sorting results compared to Spark since Spark treats NaN as the greatest value.
For now, we are temporarily ignoring this issue because cases with NaN are rare,
and aligning with Spark would require tricky modifications to the CH underlying code.
*/
return column.compareAt(lhs, rhs, column, -1) < 0;
else
return column.compareAt(lhs, rhs, column, -1) > 0;
}
};

}

template <bool positive>
ColumnPtr SparkSortArrayImpl<positive>::execute(
const ColumnArray & array,
ColumnPtr mapped,
const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]])
{
const ColumnArray::Offsets & offsets = array.getOffsets();

size_t size = offsets.size();
size_t nested_size = array.getData().size();
IColumn::Permutation permutation(nested_size);

for (size_t i = 0; i < nested_size; ++i)
permutation[i] = i;

ColumnArray::Offset current_offset = 0;
for (size_t i = 0; i < size; ++i)
{
auto next_offset = offsets[i];
::sort(&permutation[current_offset], &permutation[next_offset], Less<positive>(*mapped));
current_offset = next_offset;
}

return ColumnArray::create(array.getData().permute(permutation, 0), array.getOffsetsPtr());
}

REGISTER_FUNCTION(SortArraySpark)
{
factory.registerFunction<SparkFunctionSortArray>();
factory.registerFunction<SparkFunctionReverseSortArray>();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace ErrorCodes
/** Sort arrays, by values of its elements, or by values of corresponding elements of calculated expression (known as "schwartzsort").
*/
template <bool positive>
struct SparkArraySortImpl
struct SparkSortArrayImpl
{
static bool needBoolean() { return false; }
static bool needExpression() { return false; }
Expand Down Expand Up @@ -67,16 +67,16 @@ struct SparkArraySortImpl
const ColumnWithTypeAndName * fixed_arguments [[maybe_unused]] = nullptr);
};

struct NameArraySort
struct NameSortArray
{
static constexpr auto name = "arraySortSpark";
static constexpr auto name = "sortArraySpark";
};
struct NameArrayReverseSort
struct NameReverseSortArray
{
static constexpr auto name = "arrayReverseSortSpark";
static constexpr auto name = "reverseSortArraySpark";
};

using SparkFunctionArraySort = FunctionArrayMapped<SparkArraySortImpl<true>, NameArraySort>;
using SparkFunctionArrayReverseSort = FunctionArrayMapped<SparkArraySortImpl<false>, NameArrayReverseSort>;
using SparkFunctionSortArray = FunctionArrayMapped<SparkSortArrayImpl<true>, NameSortArray>;
using SparkFunctionReverseSortArray = FunctionArrayMapped<SparkSortArrayImpl<false>, NameReverseSortArray>;

}
Loading

0 comments on commit 995145e

Please sign in to comment.