-
Notifications
You must be signed in to change notification settings - Fork 447
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into test_simd_re
- Loading branch information
Showing
9 changed files
with
358 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
158 changes: 158 additions & 0 deletions
158
cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <AggregateFunctions/AggregateFunctionAvg.h> | ||
#include <AggregateFunctions/AggregateFunctionFactory.h> | ||
#include <AggregateFunctions/FactoryHelpers.h> | ||
#include <AggregateFunctions/Helpers.h> | ||
#include <DataTypes/DataTypeTuple.h> | ||
|
||
#include <algorithm> | ||
|
||
#include <Common/CHUtil.h> | ||
#include <Common/GlutenDecimalUtils.h> | ||
|
||
namespace DB | ||
{ | ||
struct Settings; | ||
|
||
namespace ErrorCodes | ||
{ | ||
|
||
} | ||
} | ||
|
||
namespace local_engine | ||
{ | ||
using namespace DB; | ||
|
||
|
||
DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type) | ||
{ | ||
const UInt32 precision_value = std::min<size_t>(getDecimalPrecision(*arg_type) + 4, DecimalUtils::max_precision<Decimal128>); | ||
const auto scale_value = std::min(getDecimalScale(*arg_type) + 4, precision_value); | ||
return createDecimal<DataTypeDecimal>(precision_value, scale_value); | ||
} | ||
|
||
template <typename T> | ||
requires is_decimal<T> | ||
class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T> | ||
{ | ||
public: | ||
using Base = AggregateFunctionAvg<T>; | ||
|
||
explicit AggregateFunctionSparkAvg(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_) | ||
: Base(argument_types_, createResultType(argument_types_, num_scale_, round_scale_), num_scale_) | ||
, num_scale(num_scale_) | ||
, round_scale(round_scale_) | ||
{ | ||
} | ||
|
||
DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_) | ||
{ | ||
const DataTypePtr & data_type = argument_types_[0]; | ||
const UInt32 precision_value = std::min<size_t>(getDecimalPrecision(*data_type) + 4, DecimalUtils::max_precision<Decimal128>); | ||
const auto scale_value = std::min(num_scale_ + 4, precision_value); | ||
return createDecimal<DataTypeDecimal>(precision_value, scale_value); | ||
} | ||
|
||
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override | ||
{ | ||
const DataTypePtr & result_type = this->getResultType(); | ||
auto result_scale = getDecimalScale(*result_type); | ||
WhichDataType which(result_type); | ||
if (which.isDecimal32()) | ||
{ | ||
assert_cast<ColumnDecimal<Decimal32> &>(to).getData().push_back( | ||
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); | ||
} | ||
else if (which.isDecimal64()) | ||
{ | ||
assert_cast<ColumnDecimal<Decimal64> &>(to).getData().push_back( | ||
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); | ||
} | ||
else if (which.isDecimal128()) | ||
{ | ||
assert_cast<ColumnDecimal<Decimal128> &>(to).getData().push_back( | ||
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); | ||
} | ||
else | ||
{ | ||
assert_cast<ColumnDecimal<Decimal256> &>(to).getData().push_back( | ||
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); | ||
} | ||
} | ||
|
||
String getName() const override { return "sparkAvg"; } | ||
|
||
private: | ||
Int128 NO_SANITIZE_UNDEFINED | ||
divideDecimalAndUInt(AvgFraction<AvgFieldType<T>, UInt64> avg, UInt32 num_scale, UInt32 result_scale, UInt32 round_scale) const | ||
{ | ||
auto value = avg.numerator.value; | ||
if (result_scale > num_scale) | ||
{ | ||
auto diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - num_scale); | ||
value = value * diff; | ||
} | ||
else if (result_scale < num_scale) | ||
{ | ||
auto diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(num_scale - result_scale); | ||
value = value / diff; | ||
} | ||
|
||
auto result = value / avg.denominator; | ||
|
||
if (round_scale > result_scale) | ||
return result; | ||
|
||
auto round_diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - round_scale); | ||
return (result + round_diff / 2) / round_diff * round_diff; | ||
} | ||
|
||
private: | ||
UInt32 num_scale; | ||
UInt32 round_scale; | ||
}; | ||
|
||
AggregateFunctionPtr | ||
createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings) | ||
{ | ||
assertNoParameters(name, parameters); | ||
assertUnary(name, argument_types); | ||
|
||
AggregateFunctionPtr res; | ||
const DataTypePtr & data_type = argument_types[0]; | ||
if (!isDecimal(data_type)) | ||
throw Exception( | ||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name); | ||
|
||
bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get<bool>(); | ||
const UInt32 p1 = DB::getDecimalPrecision(*data_type); | ||
const UInt32 s1 = DB::getDecimalScale(*data_type); | ||
auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL; | ||
auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, p2, s2, allowPrecisionLoss); | ||
|
||
res.reset(createWithDecimalType<AggregateFunctionSparkAvg>(*data_type, argument_types, getDecimalScale(*data_type), round_scale)); | ||
return res; | ||
} | ||
|
||
void registerAggregateFunctionSparkAvg(AggregateFunctionFactory & factory) | ||
{ | ||
factory.registerFunction("sparkAvg", createAggregateFunctionSparkAvg); | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/* | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
|
||
namespace local_engine | ||
{ | ||
|
||
class GlutenDecimalUtils | ||
{ | ||
public: | ||
static constexpr size_t MAX_PRECISION = 38; | ||
static constexpr size_t MAX_SCALE = 38; | ||
static constexpr auto system_Default = std::tuple(MAX_PRECISION, 18); | ||
static constexpr auto user_Default = std::tuple(10, 0); | ||
static constexpr size_t MINIMUM_ADJUSTED_SCALE = 6; | ||
|
||
// The decimal types compatible with other numeric types | ||
static constexpr auto BOOLEAN_DECIMAL = std::tuple(1, 0); | ||
static constexpr auto BYTE_DECIMAL = std::tuple(3, 0); | ||
static constexpr auto SHORT_DECIMAL = std::tuple(5, 0); | ||
static constexpr auto INT_DECIMAL = std::tuple(10, 0); | ||
static constexpr auto LONG_DECIMAL = std::tuple(20, 0); | ||
static constexpr auto FLOAT_DECIMAL = std::tuple(14, 7); | ||
static constexpr auto DOUBLE_DECIMAL = std::tuple(30, 15); | ||
static constexpr auto BIGINT_DECIMAL = std::tuple(MAX_PRECISION, 0); | ||
|
||
static std::tuple<size_t, size_t> adjustPrecisionScale(size_t precision, size_t scale) | ||
{ | ||
if (precision <= MAX_PRECISION) | ||
{ | ||
// Adjustment only needed when we exceed max precision | ||
return std::tuple(precision, scale); | ||
} | ||
else if (scale < 0) | ||
{ | ||
// Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision | ||
// loss since we would cause a loss of digits in the integer part. | ||
// In this case, we are likely to meet an overflow. | ||
return std::tuple(GlutenDecimalUtils::MAX_PRECISION, scale); | ||
} | ||
else | ||
{ | ||
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. | ||
auto intDigits = precision - scale; | ||
// If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise | ||
// preserve at least MINIMUM_ADJUSTED_SCALE fractional digits | ||
auto minScaleValue = std::min(scale, GlutenDecimalUtils::MINIMUM_ADJUSTED_SCALE); | ||
// The resulting scale is the maximum between what is available without causing a loss of | ||
// digits for the integer part of the decimal and the minimum guaranteed scale, which is | ||
// computed above | ||
auto adjustedScale = std::max(GlutenDecimalUtils::MAX_PRECISION - intDigits, minScaleValue); | ||
|
||
return std::tuple(GlutenDecimalUtils::MAX_PRECISION, adjustedScale); | ||
} | ||
} | ||
|
||
static std::tuple<size_t, size_t> dividePrecisionScale(size_t p1, size_t s1, size_t p2, size_t s2, bool allowPrecisionLoss) | ||
{ | ||
if (allowPrecisionLoss) | ||
{ | ||
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) | ||
// Scale: max(6, s1 + p2 + 1) | ||
const size_t intDig = p1 - s1 + s2; | ||
const size_t scale = std::max(MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1); | ||
const size_t precision = intDig + scale; | ||
return adjustPrecisionScale(precision, scale); | ||
} | ||
else | ||
{ | ||
auto intDig = std::min(MAX_SCALE, p1 - s1 + s2); | ||
auto decDig = std::min(MAX_SCALE, std::max(static_cast<size_t>(6), s1 + p2 + 1)); | ||
auto diff = (intDig + decDig) - MAX_SCALE; | ||
if (diff > 0) | ||
{ | ||
decDig -= diff / 2 + 1; | ||
intDig = MAX_SCALE - decDig; | ||
} | ||
return std::tuple(intDig + decDig, decDig); | ||
} | ||
} | ||
|
||
static std::tuple<size_t, size_t> widerDecimalType(const size_t p1, const size_t s1, const size_t p2, const size_t s2) | ||
{ | ||
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) | ||
auto scale = std::max(s1, s2); | ||
auto range = std::max(p1 - s1, p2 - s2); | ||
return std::tuple(range + scale, scale); | ||
} | ||
|
||
}; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.