forked from apache/incubator-gluten
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GLUTEN-1626][CH] make round func act in the same way as vanilla spark (
apache#2303) What changes were proposed in this pull request? Add a new ch roundHalfUp func to implement rounding away from zero (towards +inf/-inf) for float and make gluten round func mapping to it. No suiltable SSE instructions to realize a fast version func, so use std round func instead. Test round and roundHalfUp with local ch server, use select roundHalfUp(x) as round_half_up from system.numbers_mt limit 10000; select round(x) as round from system.numbers_mt limit 10000; performance difference is very little (Fixes: apache#1626) How was this patch tested? GlutenClickHouseTPCDSParquetSuite Q78 GlutenClickHouseTPCDSParquetSuite GLUTEN-1626: test 'roundHalfup' clickhouse server response
- Loading branch information
1 parent
d942ee1
commit b3b80d5
Showing
5 changed files
with
371 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#include "Functions/SparkFunctionRoundHalfUp.h" | ||
#include <Functions/FunctionFactory.h> | ||
|
||
|
||
namespace local_engine | ||
{ | ||
REGISTER_FUNCTION(RoundSpark) | ||
{ | ||
factory.registerFunction<FunctionRoundHalfUp>( | ||
FunctionDocumentation{ | ||
.description=R"( | ||
Similar to function round,except that in case when given number has equal distance to surrounding numbers, the function rounds away from zero(towards +inf/-inf). | ||
)", | ||
.examples{{"roundHalfUp", "SELECT roundHalfUp(3.165,2)", "3.17"}}, | ||
.categories{"Rounding"} | ||
}, FunctionFactory::CaseInsensitive); | ||
|
||
} | ||
} |
295 changes: 295 additions & 0 deletions
295
cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,295 @@ | ||
#pragma once | ||
|
||
#include <Functions/FunctionsRound.h> | ||
|
||
|
||
namespace local_engine | ||
{ | ||
using namespace DB; | ||
|
||
|
||
/// Implementation for round half up. Not vectorized. | ||
|
||
inline float roundHalfUp(float x) | ||
{ | ||
return roundf(x); | ||
|
||
UNREACHABLE(); | ||
} | ||
|
||
inline double roundHalfUp(double x) | ||
{ | ||
return round(x); | ||
|
||
UNREACHABLE(); | ||
} | ||
|
||
template <typename T> | ||
class BaseFloatRoundingHalfUpComputation | ||
{ | ||
public: | ||
using ScalarType = T; | ||
using VectorType = T; | ||
static const size_t data_count = 1; | ||
|
||
static VectorType load(const ScalarType * in) { return *in; } | ||
static VectorType load1(const ScalarType in) { return in; } | ||
static VectorType store(ScalarType * out, ScalarType val) { return *out = val;} | ||
static VectorType multiply(VectorType val, VectorType scale) { return val * scale; } | ||
static VectorType divide(VectorType val, VectorType scale) { return val / scale; } | ||
static VectorType apply(VectorType val){return roundHalfUp(val);} | ||
|
||
static VectorType prepare(size_t scale) | ||
{ | ||
return load1(scale); | ||
} | ||
}; | ||
|
||
|
||
|
||
/** Implementation of low-level round-off functions for floating-point values. | ||
*/ | ||
template <typename T, ScaleMode scale_mode> | ||
class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation<T> | ||
{ | ||
using Base = BaseFloatRoundingHalfUpComputation<T>; | ||
|
||
public: | ||
static inline void compute(const T * __restrict in, const typename Base::VectorType & scale, T * __restrict out) | ||
{ | ||
auto val = Base::load(in); | ||
|
||
if (scale_mode == ScaleMode::Positive) | ||
val = Base::multiply(val, scale); | ||
else if (scale_mode == ScaleMode::Negative) | ||
val = Base::divide(val, scale); | ||
|
||
val = Base::apply(val); | ||
|
||
if (scale_mode == ScaleMode::Positive) | ||
val = Base::divide(val, scale); | ||
else if (scale_mode == ScaleMode::Negative) | ||
val = Base::multiply(val, scale); | ||
|
||
Base::store(out, val); | ||
} | ||
}; | ||
|
||
|
||
/** Implementing high-level rounding functions. | ||
*/ | ||
template <typename T, ScaleMode scale_mode> | ||
struct FloatRoundingHalfUpImpl | ||
{ | ||
private: | ||
static_assert(!is_decimal<T>); | ||
|
||
using Op = FloatRoundingHalfUpComputation<T, scale_mode>; | ||
using Data = std::array<T, Op::data_count>; | ||
using ColumnType = ColumnVector<T>; | ||
using Container = typename ColumnType::Container; | ||
|
||
public: | ||
static NO_INLINE void apply(const Container & in, size_t scale, Container & out) | ||
{ | ||
auto mm_scale = Op::prepare(scale); | ||
|
||
const size_t data_count = std::tuple_size<Data>(); | ||
|
||
const T* end_in = in.data() + in.size(); | ||
const T* limit = in.data() + in.size() / data_count * data_count; | ||
|
||
const T* __restrict p_in = in.data(); | ||
T* __restrict p_out = out.data(); | ||
|
||
while (p_in < limit) | ||
{ | ||
Op::compute(p_in, mm_scale, p_out); | ||
p_in += data_count; | ||
p_out += data_count; | ||
} | ||
|
||
if (p_in < end_in) | ||
{ | ||
Data tmp_src{{}}; | ||
Data tmp_dst; | ||
|
||
size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in); | ||
|
||
memcpy(&tmp_src, p_in, tail_size_bytes); | ||
Op::compute(reinterpret_cast<T *>(&tmp_src), mm_scale, reinterpret_cast<T *>(&tmp_dst)); | ||
memcpy(p_out, &tmp_dst, tail_size_bytes); | ||
} | ||
} | ||
}; | ||
|
||
|
||
|
||
|
||
|
||
/** Select the appropriate processing algorithm depending on the scale. | ||
*/ | ||
template <typename T, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode> | ||
struct DispatcherRoundingHalfUp | ||
{ | ||
template <ScaleMode scale_mode> | ||
using FunctionRoundingImpl = std::conditional_t<std::is_floating_point_v<T>, | ||
FloatRoundingHalfUpImpl<T, scale_mode>, | ||
IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>; | ||
|
||
static ColumnPtr apply(const IColumn * col_general, Scale scale_arg) | ||
{ | ||
const auto * const col = checkAndGetColumn<ColumnVector<T>>(col_general); | ||
auto col_res = ColumnVector<T>::create(); | ||
|
||
typename ColumnVector<T>::Container & vec_res = col_res->getData(); | ||
vec_res.resize(col->getData().size()); | ||
|
||
if (!vec_res.empty()) | ||
{ | ||
if (scale_arg == 0) | ||
{ | ||
size_t scale = 1; | ||
FunctionRoundingImpl<ScaleMode::Zero>::apply(col->getData(), scale, vec_res); | ||
} | ||
else if (scale_arg > 0) | ||
{ | ||
size_t scale = intExp10(scale_arg); | ||
FunctionRoundingImpl<ScaleMode::Positive>::apply(col->getData(), scale, vec_res); | ||
} | ||
else | ||
{ | ||
size_t scale = intExp10(-scale_arg); | ||
FunctionRoundingImpl<ScaleMode::Negative>::apply(col->getData(), scale, vec_res); | ||
} | ||
} | ||
|
||
return col_res; | ||
} | ||
}; | ||
|
||
template <is_decimal T, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode> | ||
struct DispatcherRoundingHalfUp<T, rounding_mode, tie_breaking_mode> | ||
{ | ||
public: | ||
static ColumnPtr apply(const IColumn * col_general, Scale scale_arg) | ||
{ | ||
const auto * const col = checkAndGetColumn<ColumnDecimal<T>>(col_general); | ||
const typename ColumnDecimal<T>::Container & vec_src = col->getData(); | ||
|
||
auto col_res = ColumnDecimal<T>::create(vec_src.size(), col->getScale()); | ||
auto & vec_res = col_res->getData(); | ||
|
||
if (!vec_res.empty()) | ||
DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(col->getData(), col->getScale(), vec_res, scale_arg); | ||
|
||
return col_res; | ||
} | ||
}; | ||
|
||
/** A template for functions that round the value of an input parameter of type | ||
* (U)Int8/16/32/64, Float32/64 or Decimal32/64/128, and accept an additional optional parameter (default is 0). | ||
*/ | ||
template <typename Name, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode> | ||
class FunctionRoundingHalfUp : public IFunction | ||
{ | ||
public: | ||
static constexpr auto name = "roundHalfUp"; | ||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionRoundingHalfUp>(); } | ||
|
||
String getName() const override | ||
{ | ||
return name; | ||
} | ||
|
||
bool isVariadic() const override { return true; } | ||
size_t getNumberOfArguments() const override { return 0; } | ||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } | ||
|
||
/// Get result types by argument types. If the function does not apply to these arguments, throw an exception. | ||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override | ||
{ | ||
if ((arguments.empty()) || (arguments.size() > 2)) | ||
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, | ||
"Number of arguments for function {} doesn't match: passed {}, should be 1 or 2.", | ||
getName(), arguments.size()); | ||
|
||
for (const auto & type : arguments) | ||
if (!isNumber(type)) | ||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", | ||
arguments[0]->getName(), getName()); | ||
|
||
return arguments[0]; | ||
} | ||
|
||
static Scale getScaleArg(const ColumnsWithTypeAndName & arguments) | ||
{ | ||
if (arguments.size() == 2) | ||
{ | ||
const IColumn & scale_column = *arguments[1].column; | ||
if (!isColumnConst(scale_column)) | ||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must be constant"); | ||
|
||
Field scale_field = assert_cast<const ColumnConst &>(scale_column).getField(); | ||
if (scale_field.getType() != Field::Types::UInt64 | ||
&& scale_field.getType() != Field::Types::Int64) | ||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type"); | ||
|
||
Int64 scale64 = scale_field.get<Int64>(); | ||
if (scale64 > std::numeric_limits<Scale>::max() | ||
|| scale64 < std::numeric_limits<Scale>::min()) | ||
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale argument for rounding function is too large"); | ||
|
||
return scale64; | ||
} | ||
return 0; | ||
} | ||
|
||
bool useDefaultImplementationForConstants() const override { return true; } | ||
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } | ||
|
||
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override | ||
{ | ||
const ColumnWithTypeAndName & column = arguments[0]; | ||
Scale scale_arg = getScaleArg(arguments); | ||
|
||
ColumnPtr res; | ||
auto call = [&](const auto & types) -> bool | ||
{ | ||
using Types = std::decay_t<decltype(types)>; | ||
using DataType = typename Types::LeftType; | ||
|
||
if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>) | ||
{ | ||
using FieldType = typename DataType::FieldType; | ||
res = DispatcherRoundingHalfUp<FieldType, rounding_mode, tie_breaking_mode>::apply(column.column.get(), scale_arg); | ||
return true; | ||
} | ||
return false; | ||
}; | ||
|
||
if (!callOnIndexAndDataType<void>(column.type->getTypeId(), call)) | ||
{ | ||
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", column.name, getName()); | ||
} | ||
|
||
return res; | ||
} | ||
|
||
bool hasInformationAboutMonotonicity() const override | ||
{ | ||
return true; | ||
} | ||
|
||
Monotonicity getMonotonicityForRange(const IDataType &, const Field &, const Field &) const override | ||
{ | ||
return { .is_monotonic = true, .is_always_monotonic = true }; | ||
} | ||
}; | ||
|
||
|
||
struct NameRoundHalfUp { static constexpr auto name = "roundHalfUp"; }; | ||
|
||
using FunctionRoundHalfUp = FunctionRoundingHalfUp<NameRoundHalfUp, RoundingMode::Round, TieBreakingMode::Auto>; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters