Skip to content

Commit

Permalink
[GLUTEN-1626][CH] make round func act in the same way as vanilla spark (
Browse files Browse the repository at this point in the history
apache#2303)

What changes were proposed in this pull request?
Add a new ch roundHalfUp func to implement rounding away from zero (towards +inf/-inf) for float and make gluten round func mapping to it.
No suiltable SSE instructions to realize a fast version func, so use std round func instead.
Test round and roundHalfUp with local ch server,
use
select roundHalfUp(x) as round_half_up from system.numbers_mt limit 10000;
select round(x) as round from system.numbers_mt limit 10000;
performance difference is very little

(Fixes: apache#1626)

How was this patch tested?
GlutenClickHouseTPCDSParquetSuite Q78
GlutenClickHouseTPCDSParquetSuite GLUTEN-1626: test 'roundHalfup'
clickhouse server response
  • Loading branch information
lhuang09287750 authored Jul 14, 2023
1 parent d942ee1 commit b3b80d5
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ abstract class GlutenClickHouseTPCDSAbstractSuite extends WholeStageTransformerS
"q49", // inconsistent results
"q61", // inconsistent results
"q67", // inconsistent results
"q78", // inconsistent results
"q83", // decimal error
"q90" // inconsistent results(decimal)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,5 +342,61 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, _ => {})
}

test("GLUTEN-1626: test 'roundHalfup'") {
val sql0 =
"""
|select cast(ss_wholesale_cost as Int) a, round(sum(ss_wholesale_cost),2),
|round(sum(ss_wholesale_cost+0.06),2), round(sum(ss_wholesale_cost-0.04),2)
|from store_sales
|group by a order by a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql0, true, _ => {})

val sql1 =
"""
|select cast(ss_sales_price as Int) a, round(sum(ss_sales_price),2),
|round(sum(ss_sales_price+0.06),2), round(sum(ss_sales_price-0.04),2)
|from store_sales
|group by a order by a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql1, true, _ => {})

val sql2 =
"""
|select cast(cs_wholesale_cost as Int) a, round(sum(cs_wholesale_cost),2),
|round(sum(cs_wholesale_cost+0.06),2), round(sum(cs_wholesale_cost-0.04),2)
|from catalog_sales
|group by a order by a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql2, true, _ => {})

val sql3 =
"""
|select cast(cs_sales_price as Int) a, round(sum(cs_sales_price),2),
|round(sum(cs_sales_price+0.06),2), round(sum(cs_sales_price-0.04),2)
|from catalog_sales
|group by a order by a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql3, true, _ => {})

val sql4 =
"""
|select cast(ws_wholesale_cost as Int) a, round(sum(ws_wholesale_cost),2),
|round(sum(ws_wholesale_cost+0.06),2), round(sum(ws_wholesale_cost-0.04),2)
|from web_sales
|group by a order by a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql4, true, _ => {})

val sql5 =
"""
|select cast(ws_sales_price as Int) a, round(sum(ws_sales_price),2),
|round(sum(ws_sales_price+0.06),2), round(sum(ws_sales_price-0.04),2)
|from web_sales
|group by a order by a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql5, true, _ => {})
}
}
// scalastyle:on line.size.limit
19 changes: 19 additions & 0 deletions cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "Functions/SparkFunctionRoundHalfUp.h"
#include <Functions/FunctionFactory.h>


namespace local_engine
{
REGISTER_FUNCTION(RoundSpark)
{
factory.registerFunction<FunctionRoundHalfUp>(
FunctionDocumentation{
.description=R"(
Similar to function round,except that in case when given number has equal distance to surrounding numbers, the function rounds away from zero(towards +inf/-inf).
)",
.examples{{"roundHalfUp", "SELECT roundHalfUp(3.165,2)", "3.17"}},
.categories{"Rounding"}
}, FunctionFactory::CaseInsensitive);

}
}
295 changes: 295 additions & 0 deletions cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
#pragma once

#include <Functions/FunctionsRound.h>


namespace local_engine
{
using namespace DB;


/// Implementation for round half up. Not vectorized.

inline float roundHalfUp(float x)
{
return roundf(x);

UNREACHABLE();
}

inline double roundHalfUp(double x)
{
return round(x);

UNREACHABLE();
}

template <typename T>
class BaseFloatRoundingHalfUpComputation
{
public:
using ScalarType = T;
using VectorType = T;
static const size_t data_count = 1;

static VectorType load(const ScalarType * in) { return *in; }
static VectorType load1(const ScalarType in) { return in; }
static VectorType store(ScalarType * out, ScalarType val) { return *out = val;}
static VectorType multiply(VectorType val, VectorType scale) { return val * scale; }
static VectorType divide(VectorType val, VectorType scale) { return val / scale; }
static VectorType apply(VectorType val){return roundHalfUp(val);}

static VectorType prepare(size_t scale)
{
return load1(scale);
}
};



/** Implementation of low-level round-off functions for floating-point values.
*/
template <typename T, ScaleMode scale_mode>
class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation<T>
{
using Base = BaseFloatRoundingHalfUpComputation<T>;

public:
static inline void compute(const T * __restrict in, const typename Base::VectorType & scale, T * __restrict out)
{
auto val = Base::load(in);

if (scale_mode == ScaleMode::Positive)
val = Base::multiply(val, scale);
else if (scale_mode == ScaleMode::Negative)
val = Base::divide(val, scale);

val = Base::apply(val);

if (scale_mode == ScaleMode::Positive)
val = Base::divide(val, scale);
else if (scale_mode == ScaleMode::Negative)
val = Base::multiply(val, scale);

Base::store(out, val);
}
};


/** Implementing high-level rounding functions.
*/
template <typename T, ScaleMode scale_mode>
struct FloatRoundingHalfUpImpl
{
private:
static_assert(!is_decimal<T>);

using Op = FloatRoundingHalfUpComputation<T, scale_mode>;
using Data = std::array<T, Op::data_count>;
using ColumnType = ColumnVector<T>;
using Container = typename ColumnType::Container;

public:
static NO_INLINE void apply(const Container & in, size_t scale, Container & out)
{
auto mm_scale = Op::prepare(scale);

const size_t data_count = std::tuple_size<Data>();

const T* end_in = in.data() + in.size();
const T* limit = in.data() + in.size() / data_count * data_count;

const T* __restrict p_in = in.data();
T* __restrict p_out = out.data();

while (p_in < limit)
{
Op::compute(p_in, mm_scale, p_out);
p_in += data_count;
p_out += data_count;
}

if (p_in < end_in)
{
Data tmp_src{{}};
Data tmp_dst;

size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in);

memcpy(&tmp_src, p_in, tail_size_bytes);
Op::compute(reinterpret_cast<T *>(&tmp_src), mm_scale, reinterpret_cast<T *>(&tmp_dst));
memcpy(p_out, &tmp_dst, tail_size_bytes);
}
}
};





/** Select the appropriate processing algorithm depending on the scale.
*/
template <typename T, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode>
struct DispatcherRoundingHalfUp
{
template <ScaleMode scale_mode>
using FunctionRoundingImpl = std::conditional_t<std::is_floating_point_v<T>,
FloatRoundingHalfUpImpl<T, scale_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>;

static ColumnPtr apply(const IColumn * col_general, Scale scale_arg)
{
const auto * const col = checkAndGetColumn<ColumnVector<T>>(col_general);
auto col_res = ColumnVector<T>::create();

typename ColumnVector<T>::Container & vec_res = col_res->getData();
vec_res.resize(col->getData().size());

if (!vec_res.empty())
{
if (scale_arg == 0)
{
size_t scale = 1;
FunctionRoundingImpl<ScaleMode::Zero>::apply(col->getData(), scale, vec_res);
}
else if (scale_arg > 0)
{
size_t scale = intExp10(scale_arg);
FunctionRoundingImpl<ScaleMode::Positive>::apply(col->getData(), scale, vec_res);
}
else
{
size_t scale = intExp10(-scale_arg);
FunctionRoundingImpl<ScaleMode::Negative>::apply(col->getData(), scale, vec_res);
}
}

return col_res;
}
};

template <is_decimal T, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode>
struct DispatcherRoundingHalfUp<T, rounding_mode, tie_breaking_mode>
{
public:
static ColumnPtr apply(const IColumn * col_general, Scale scale_arg)
{
const auto * const col = checkAndGetColumn<ColumnDecimal<T>>(col_general);
const typename ColumnDecimal<T>::Container & vec_src = col->getData();

auto col_res = ColumnDecimal<T>::create(vec_src.size(), col->getScale());
auto & vec_res = col_res->getData();

if (!vec_res.empty())
DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(col->getData(), col->getScale(), vec_res, scale_arg);

return col_res;
}
};

/** A template for functions that round the value of an input parameter of type
* (U)Int8/16/32/64, Float32/64 or Decimal32/64/128, and accept an additional optional parameter (default is 0).
*/
template <typename Name, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode>
class FunctionRoundingHalfUp : public IFunction
{
public:
static constexpr auto name = "roundHalfUp";
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionRoundingHalfUp>(); }

String getName() const override
{
return name;
}

bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }

/// Get result types by argument types. If the function does not apply to these arguments, throw an exception.
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if ((arguments.empty()) || (arguments.size() > 2))
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 1 or 2.",
getName(), arguments.size());

for (const auto & type : arguments)
if (!isNumber(type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}",
arguments[0]->getName(), getName());

return arguments[0];
}

static Scale getScaleArg(const ColumnsWithTypeAndName & arguments)
{
if (arguments.size() == 2)
{
const IColumn & scale_column = *arguments[1].column;
if (!isColumnConst(scale_column))
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must be constant");

Field scale_field = assert_cast<const ColumnConst &>(scale_column).getField();
if (scale_field.getType() != Field::Types::UInt64
&& scale_field.getType() != Field::Types::Int64)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type");

Int64 scale64 = scale_field.get<Int64>();
if (scale64 > std::numeric_limits<Scale>::max()
|| scale64 < std::numeric_limits<Scale>::min())
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale argument for rounding function is too large");

return scale64;
}
return 0;
}

bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }

ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override
{
const ColumnWithTypeAndName & column = arguments[0];
Scale scale_arg = getScaleArg(arguments);

ColumnPtr res;
auto call = [&](const auto & types) -> bool
{
using Types = std::decay_t<decltype(types)>;
using DataType = typename Types::LeftType;

if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>)
{
using FieldType = typename DataType::FieldType;
res = DispatcherRoundingHalfUp<FieldType, rounding_mode, tie_breaking_mode>::apply(column.column.get(), scale_arg);
return true;
}
return false;
};

if (!callOnIndexAndDataType<void>(column.type->getTypeId(), call))
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", column.name, getName());
}

return res;
}

bool hasInformationAboutMonotonicity() const override
{
return true;
}

Monotonicity getMonotonicityForRange(const IDataType &, const Field &, const Field &) const override
{
return { .is_monotonic = true, .is_always_monotonic = true };
}
};


struct NameRoundHalfUp { static constexpr auto name = "roundHalfUp"; };

using FunctionRoundHalfUp = FunctionRoundingHalfUp<NameRoundHalfUp, RoundingMode::Round, TieBreakingMode::Auto>;

}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ static const std::map<std::string, std::string> SCALAR_FUNCTIONS
{"abs", "abs"},
{"ceil", "ceil"},
{"floor", "floor"},
{"round", "round"},
{"round", "roundHalfUp"},
{"bround", "roundBankers"},
{"exp", "exp"},
{"power", "power"},
Expand Down

0 comments on commit b3b80d5

Please sign in to comment.