diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala index b401e065a107..3ed87caa2998 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSAbstractSuite.scala @@ -93,7 +93,6 @@ abstract class GlutenClickHouseTPCDSAbstractSuite extends WholeStageTransformerS "q49", // inconsistent results "q61", // inconsistent results "q67", // inconsistent results - "q78", // inconsistent results "q83", // decimal error "q90" // inconsistent results(decimal) ) diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala index ac878ffb1849..aff2c21bc371 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCDSParquetSuite.scala @@ -342,5 +342,61 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui |""".stripMargin compareResultsAgainstVanillaSpark(sql, true, _ => {}) } + + test("GLUTEN-1626: test 'roundHalfup'") { + val sql0 = + """ + |select cast(ss_wholesale_cost as Int) a, round(sum(ss_wholesale_cost),2), + |round(sum(ss_wholesale_cost+0.06),2), round(sum(ss_wholesale_cost-0.04),2) + |from store_sales + |group by a order by a + |""".stripMargin + compareResultsAgainstVanillaSpark(sql0, true, _ => {}) + + val sql1 = + """ + |select cast(ss_sales_price as Int) a, round(sum(ss_sales_price),2), + |round(sum(ss_sales_price+0.06),2), round(sum(ss_sales_price-0.04),2) + |from store_sales + |group by a order by a + |""".stripMargin + compareResultsAgainstVanillaSpark(sql1, true, _ => {}) + + val sql2 = + """ + |select cast(cs_wholesale_cost as Int) a, round(sum(cs_wholesale_cost),2), + |round(sum(cs_wholesale_cost+0.06),2), round(sum(cs_wholesale_cost-0.04),2) + |from catalog_sales + |group by a order by a + |""".stripMargin + compareResultsAgainstVanillaSpark(sql2, true, _ => {}) + + val sql3 = + """ + |select cast(cs_sales_price as Int) a, round(sum(cs_sales_price),2), + |round(sum(cs_sales_price+0.06),2), round(sum(cs_sales_price-0.04),2) + |from catalog_sales + |group by a order by a + |""".stripMargin + compareResultsAgainstVanillaSpark(sql3, true, _ => {}) + + val sql4 = + """ + |select cast(ws_wholesale_cost as Int) a, round(sum(ws_wholesale_cost),2), + |round(sum(ws_wholesale_cost+0.06),2), round(sum(ws_wholesale_cost-0.04),2) + |from web_sales + |group by a order by a + |""".stripMargin + compareResultsAgainstVanillaSpark(sql4, true, _ => {}) + + val sql5 = + """ + |select cast(ws_sales_price as Int) a, round(sum(ws_sales_price),2), + |round(sum(ws_sales_price+0.06),2), round(sum(ws_sales_price-0.04),2) + |from web_sales + |group by a order by a + |""".stripMargin + compareResultsAgainstVanillaSpark(sql5, true, _ => {}) + } } // scalastyle:on line.size.limit diff --git a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.cpp b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.cpp new file mode 100644 index 000000000000..bc2557dd3539 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.cpp @@ -0,0 +1,19 @@ +#include "Functions/SparkFunctionRoundHalfUp.h" +#include + + +namespace local_engine +{ +REGISTER_FUNCTION(RoundSpark) +{ + factory.registerFunction( + FunctionDocumentation{ + .description=R"( +Similar to function round,except that in case when given number has equal distance to surrounding numbers, the function rounds away from zero(towards +inf/-inf). + )", + .examples{{"roundHalfUp", "SELECT roundHalfUp(3.165,2)", "3.17"}}, + .categories{"Rounding"} + }, FunctionFactory::CaseInsensitive); + +} +} diff --git a/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h new file mode 100644 index 000000000000..7d70c6fd71bc --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionRoundHalfUp.h @@ -0,0 +1,295 @@ +#pragma once + +#include + + +namespace local_engine +{ +using namespace DB; + + +/// Implementation for round half up. Not vectorized. + +inline float roundHalfUp(float x) +{ + return roundf(x); + + UNREACHABLE(); +} + +inline double roundHalfUp(double x) +{ + return round(x); + + UNREACHABLE(); +} + +template +class BaseFloatRoundingHalfUpComputation +{ +public: + using ScalarType = T; + using VectorType = T; + static const size_t data_count = 1; + + static VectorType load(const ScalarType * in) { return *in; } + static VectorType load1(const ScalarType in) { return in; } + static VectorType store(ScalarType * out, ScalarType val) { return *out = val;} + static VectorType multiply(VectorType val, VectorType scale) { return val * scale; } + static VectorType divide(VectorType val, VectorType scale) { return val / scale; } + static VectorType apply(VectorType val){return roundHalfUp(val);} + + static VectorType prepare(size_t scale) + { + return load1(scale); + } +}; + + + +/** Implementation of low-level round-off functions for floating-point values. + */ +template +class FloatRoundingHalfUpComputation : public BaseFloatRoundingHalfUpComputation +{ + using Base = BaseFloatRoundingHalfUpComputation; + +public: + static inline void compute(const T * __restrict in, const typename Base::VectorType & scale, T * __restrict out) + { + auto val = Base::load(in); + + if (scale_mode == ScaleMode::Positive) + val = Base::multiply(val, scale); + else if (scale_mode == ScaleMode::Negative) + val = Base::divide(val, scale); + + val = Base::apply(val); + + if (scale_mode == ScaleMode::Positive) + val = Base::divide(val, scale); + else if (scale_mode == ScaleMode::Negative) + val = Base::multiply(val, scale); + + Base::store(out, val); + } +}; + + +/** Implementing high-level rounding functions. + */ +template +struct FloatRoundingHalfUpImpl +{ +private: + static_assert(!is_decimal); + + using Op = FloatRoundingHalfUpComputation; + using Data = std::array; + using ColumnType = ColumnVector; + using Container = typename ColumnType::Container; + +public: + static NO_INLINE void apply(const Container & in, size_t scale, Container & out) + { + auto mm_scale = Op::prepare(scale); + + const size_t data_count = std::tuple_size(); + + const T* end_in = in.data() + in.size(); + const T* limit = in.data() + in.size() / data_count * data_count; + + const T* __restrict p_in = in.data(); + T* __restrict p_out = out.data(); + + while (p_in < limit) + { + Op::compute(p_in, mm_scale, p_out); + p_in += data_count; + p_out += data_count; + } + + if (p_in < end_in) + { + Data tmp_src{{}}; + Data tmp_dst; + + size_t tail_size_bytes = (end_in - p_in) * sizeof(*p_in); + + memcpy(&tmp_src, p_in, tail_size_bytes); + Op::compute(reinterpret_cast(&tmp_src), mm_scale, reinterpret_cast(&tmp_dst)); + memcpy(p_out, &tmp_dst, tail_size_bytes); + } + } +}; + + + + + +/** Select the appropriate processing algorithm depending on the scale. + */ +template +struct DispatcherRoundingHalfUp +{ + template + using FunctionRoundingImpl = std::conditional_t, + FloatRoundingHalfUpImpl, + IntegerRoundingImpl>; + + static ColumnPtr apply(const IColumn * col_general, Scale scale_arg) + { + const auto * const col = checkAndGetColumn>(col_general); + auto col_res = ColumnVector::create(); + + typename ColumnVector::Container & vec_res = col_res->getData(); + vec_res.resize(col->getData().size()); + + if (!vec_res.empty()) + { + if (scale_arg == 0) + { + size_t scale = 1; + FunctionRoundingImpl::apply(col->getData(), scale, vec_res); + } + else if (scale_arg > 0) + { + size_t scale = intExp10(scale_arg); + FunctionRoundingImpl::apply(col->getData(), scale, vec_res); + } + else + { + size_t scale = intExp10(-scale_arg); + FunctionRoundingImpl::apply(col->getData(), scale, vec_res); + } + } + + return col_res; + } +}; + +template +struct DispatcherRoundingHalfUp +{ +public: + static ColumnPtr apply(const IColumn * col_general, Scale scale_arg) + { + const auto * const col = checkAndGetColumn>(col_general); + const typename ColumnDecimal::Container & vec_src = col->getData(); + + auto col_res = ColumnDecimal::create(vec_src.size(), col->getScale()); + auto & vec_res = col_res->getData(); + + if (!vec_res.empty()) + DecimalRoundingImpl::apply(col->getData(), col->getScale(), vec_res, scale_arg); + + return col_res; + } +}; + +/** A template for functions that round the value of an input parameter of type + * (U)Int8/16/32/64, Float32/64 or Decimal32/64/128, and accept an additional optional parameter (default is 0). + */ +template +class FunctionRoundingHalfUp : public IFunction +{ +public: + static constexpr auto name = "roundHalfUp"; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + String getName() const override + { + return name; + } + + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } + + /// Get result types by argument types. If the function does not apply to these arguments, throw an exception. + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if ((arguments.empty()) || (arguments.size() > 2)) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Number of arguments for function {} doesn't match: passed {}, should be 1 or 2.", + getName(), arguments.size()); + + for (const auto & type : arguments) + if (!isNumber(type)) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", + arguments[0]->getName(), getName()); + + return arguments[0]; + } + + static Scale getScaleArg(const ColumnsWithTypeAndName & arguments) + { + if (arguments.size() == 2) + { + const IColumn & scale_column = *arguments[1].column; + if (!isColumnConst(scale_column)) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must be constant"); + + Field scale_field = assert_cast(scale_column).getField(); + if (scale_field.getType() != Field::Types::UInt64 + && scale_field.getType() != Field::Types::Int64) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type"); + + Int64 scale64 = scale_field.get(); + if (scale64 > std::numeric_limits::max() + || scale64 < std::numeric_limits::min()) + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Scale argument for rounding function is too large"); + + return scale64; + } + return 0; + } + + bool useDefaultImplementationForConstants() const override { return true; } + ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override + { + const ColumnWithTypeAndName & column = arguments[0]; + Scale scale_arg = getScaleArg(arguments); + + ColumnPtr res; + auto call = [&](const auto & types) -> bool + { + using Types = std::decay_t; + using DataType = typename Types::LeftType; + + if constexpr (IsDataTypeNumber || IsDataTypeDecimal) + { + using FieldType = typename DataType::FieldType; + res = DispatcherRoundingHalfUp::apply(column.column.get(), scale_arg); + return true; + } + return false; + }; + + if (!callOnIndexAndDataType(column.type->getTypeId(), call)) + { + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of argument of function {}", column.name, getName()); + } + + return res; + } + + bool hasInformationAboutMonotonicity() const override + { + return true; + } + + Monotonicity getMonotonicityForRange(const IDataType &, const Field &, const Field &) const override + { + return { .is_monotonic = true, .is_always_monotonic = true }; + } +}; + + +struct NameRoundHalfUp { static constexpr auto name = "roundHalfUp"; }; + +using FunctionRoundHalfUp = FunctionRoundingHalfUp; + +} diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 131b2b5daec1..16ff004362ee 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -62,7 +62,7 @@ static const std::map SCALAR_FUNCTIONS {"abs", "abs"}, {"ceil", "ceil"}, {"floor", "floor"}, - {"round", "round"}, + {"round", "roundHalfUp"}, {"bround", "roundBankers"}, {"exp", "exp"}, {"power", "power"},