Skip to content

Commit

Permalink
[GLUTEN-6176][CH] Support aggreate avg return decimal (#6177)
Browse files Browse the repository at this point in the history
* Support aggreate avg return decimal

* update version

* fix rebase

* add ut
  • Loading branch information
loneylee authored Jun 25, 2024
1 parent 1fbdbc4 commit 4c52976
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class GlutenClickHouseDecimalSuite
private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply(
(DecimalType.apply(9, 4), Seq()),
// 1: ch decimal avg is float
(DecimalType.apply(18, 8), Seq(1)),
(DecimalType.apply(18, 8), Seq()),
// 1: ch decimal avg is float, 3/10: all value is null and compare with limit
(DecimalType.apply(38, 19), Seq(1, 3, 10))
(DecimalType.apply(38, 19), Seq(3, 10))
)

private def createDecimalTables(dataType: DecimalType): Unit = {
Expand Down Expand Up @@ -337,7 +337,6 @@ class GlutenClickHouseDecimalSuite
allowPrecisionLoss =>
Range
.inclusive(1, 22)
.filter(_ != 17) // Ignore Q17 which include avg
.foreach {
sql_num =>
{
Expand Down
158 changes: 158 additions & 0 deletions cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <AggregateFunctions/AggregateFunctionAvg.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include <DataTypes/DataTypeTuple.h>

#include <algorithm>

#include <Common/CHUtil.h>
#include <Common/GlutenDecimalUtils.h>

namespace DB
{
struct Settings;

namespace ErrorCodes
{

}
}

namespace local_engine
{
using namespace DB;


DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type)
{
const UInt32 precision_value = std::min<size_t>(getDecimalPrecision(*arg_type) + 4, DecimalUtils::max_precision<Decimal128>);
const auto scale_value = std::min(getDecimalScale(*arg_type) + 4, precision_value);
return createDecimal<DataTypeDecimal>(precision_value, scale_value);
}

template <typename T>
requires is_decimal<T>
class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
{
public:
using Base = AggregateFunctionAvg<T>;

explicit AggregateFunctionSparkAvg(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_)
: Base(argument_types_, createResultType(argument_types_, num_scale_, round_scale_), num_scale_)
, num_scale(num_scale_)
, round_scale(round_scale_)
{
}

DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_)
{
const DataTypePtr & data_type = argument_types_[0];
const UInt32 precision_value = std::min<size_t>(getDecimalPrecision(*data_type) + 4, DecimalUtils::max_precision<Decimal128>);
const auto scale_value = std::min(num_scale_ + 4, precision_value);
return createDecimal<DataTypeDecimal>(precision_value, scale_value);
}

void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
const DataTypePtr & result_type = this->getResultType();
auto result_scale = getDecimalScale(*result_type);
WhichDataType which(result_type);
if (which.isDecimal32())
{
assert_cast<ColumnDecimal<Decimal32> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else if (which.isDecimal64())
{
assert_cast<ColumnDecimal<Decimal64> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else if (which.isDecimal128())
{
assert_cast<ColumnDecimal<Decimal128> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else
{
assert_cast<ColumnDecimal<Decimal256> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
}

String getName() const override { return "sparkAvg"; }

private:
Int128 NO_SANITIZE_UNDEFINED
divideDecimalAndUInt(AvgFraction<AvgFieldType<T>, UInt64> avg, UInt32 num_scale, UInt32 result_scale, UInt32 round_scale) const
{
auto value = avg.numerator.value;
if (result_scale > num_scale)
{
auto diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - num_scale);
value = value * diff;
}
else if (result_scale < num_scale)
{
auto diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(num_scale - result_scale);
value = value / diff;
}

auto result = value / avg.denominator;

if (round_scale > result_scale)
return result;

auto round_diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - round_scale);
return (result + round_diff / 2) / round_diff * round_diff;
}

private:
UInt32 num_scale;
UInt32 round_scale;
};

AggregateFunctionPtr
createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);

AggregateFunctionPtr res;
const DataTypePtr & data_type = argument_types[0];
if (!isDecimal(data_type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name);

bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get<bool>();
const UInt32 p1 = DB::getDecimalPrecision(*data_type);
const UInt32 s1 = DB::getDecimalScale(*data_type);
auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, p2, s2, allowPrecisionLoss);

res.reset(createWithDecimalType<AggregateFunctionSparkAvg>(*data_type, argument_types, getDecimalScale(*data_type), round_scale));
return res;
}

void registerAggregateFunctionSparkAvg(AggregateFunctionFactory & factory)
{
factory.registerFunction("sparkAvg", createAggregateFunctionSparkAvg);
}

}
9 changes: 8 additions & 1 deletion cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ void BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
settings.set("date_time_input_format", "best_effort");
settings.set(MERGETREE_MERGE_AFTER_INSERT, true);
settings.set(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, false);
settings.set(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS, true);

for (const auto & [key, value] : backend_conf_map)
{
Expand Down Expand Up @@ -665,6 +666,11 @@ void BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
settings.set("session_timezone", time_zone_val);
LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", "session_timezone", time_zone_val);
}
else if (key == DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
{
settings.set(key, toField(key, value));
LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", key, value);
}
}

/// Finally apply some fixed kvs to settings.
Expand Down Expand Up @@ -788,6 +794,7 @@ void BackendInitializerUtil::updateNewSettings(const DB::ContextMutablePtr & con

extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCombinatorFactory &);
extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &);
extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &);
extern void registerFunctions(FunctionFactory &);

void registerAllFunctions()
Expand All @@ -797,7 +804,7 @@ void registerAllFunctions()
DB::registerAggregateFunctions();
auto & agg_factory = AggregateFunctionFactory::instance();
registerAggregateFunctionsBloomFilter(agg_factory);

registerAggregateFunctionSparkAvg(agg_factory);
{
/// register aggregate function combinators from local_engine
auto & factory = AggregateFunctionCombinatorFactory::instance();
Expand Down
5 changes: 4 additions & 1 deletion cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ namespace local_engine
{
static const String MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE = "mergetree.insert_without_local_storage";
static const String MERGETREE_MERGE_AFTER_INSERT = "mergetree.merge_after_insert";
static const std::unordered_set<String> BOOL_VALUE_SETTINGS{MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE};
static const std::string DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = "spark.sql.decimalOperations.allowPrecisionLoss";

static const std::unordered_set<String> BOOL_VALUE_SETTINGS{
MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, DECIMAL_OPERATIONS_ALLOW_PREC_LOSS};
static const std::unordered_set<String> LONG_VALUE_SETTINGS{
"optimize.maxfilesize", "optimize.minFileSize", "mergetree.max_num_part_per_merge_task"};

Expand Down
108 changes: 108 additions & 0 deletions cpp-ch/local-engine/Common/GlutenDecimalUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once


namespace local_engine
{

class GlutenDecimalUtils
{
public:
static constexpr size_t MAX_PRECISION = 38;
static constexpr size_t MAX_SCALE = 38;
static constexpr auto system_Default = std::tuple(MAX_PRECISION, 18);
static constexpr auto user_Default = std::tuple(10, 0);
static constexpr size_t MINIMUM_ADJUSTED_SCALE = 6;

// The decimal types compatible with other numeric types
static constexpr auto BOOLEAN_DECIMAL = std::tuple(1, 0);
static constexpr auto BYTE_DECIMAL = std::tuple(3, 0);
static constexpr auto SHORT_DECIMAL = std::tuple(5, 0);
static constexpr auto INT_DECIMAL = std::tuple(10, 0);
static constexpr auto LONG_DECIMAL = std::tuple(20, 0);
static constexpr auto FLOAT_DECIMAL = std::tuple(14, 7);
static constexpr auto DOUBLE_DECIMAL = std::tuple(30, 15);
static constexpr auto BIGINT_DECIMAL = std::tuple(MAX_PRECISION, 0);

static std::tuple<size_t, size_t> adjustPrecisionScale(size_t precision, size_t scale)
{
if (precision <= MAX_PRECISION)
{
// Adjustment only needed when we exceed max precision
return std::tuple(precision, scale);
}
else if (scale < 0)
{
// Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
// loss since we would cause a loss of digits in the integer part.
// In this case, we are likely to meet an overflow.
return std::tuple(GlutenDecimalUtils::MAX_PRECISION, scale);
}
else
{
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
auto intDigits = precision - scale;
// If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
// preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
auto minScaleValue = std::min(scale, GlutenDecimalUtils::MINIMUM_ADJUSTED_SCALE);
// The resulting scale is the maximum between what is available without causing a loss of
// digits for the integer part of the decimal and the minimum guaranteed scale, which is
// computed above
auto adjustedScale = std::max(GlutenDecimalUtils::MAX_PRECISION - intDigits, minScaleValue);

return std::tuple(GlutenDecimalUtils::MAX_PRECISION, adjustedScale);
}
}

static std::tuple<size_t, size_t> dividePrecisionScale(size_t p1, size_t s1, size_t p2, size_t s2, bool allowPrecisionLoss)
{
if (allowPrecisionLoss)
{
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
const size_t intDig = p1 - s1 + s2;
const size_t scale = std::max(MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1);
const size_t precision = intDig + scale;
return adjustPrecisionScale(precision, scale);
}
else
{
auto intDig = std::min(MAX_SCALE, p1 - s1 + s2);
auto decDig = std::min(MAX_SCALE, std::max(static_cast<size_t>(6), s1 + p2 + 1));
auto diff = (intDig + decDig) - MAX_SCALE;
if (diff > 0)
{
decDig -= diff / 2 + 1;
intDig = MAX_SCALE - decDig;
}
return std::tuple(intDig + decDig, decDig);
}
}

static std::tuple<size_t, size_t> widerDecimalType(const size_t p1, const size_t s1, const size_t p2, const size_t s2)
{
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
auto scale = std::max(s1, s2);
auto range = std::max(p1 - s1, p2 - s2);
return std::tuple(range + scale, scale);
}

};

}
23 changes: 20 additions & 3 deletions cpp-ch/local-engine/Parser/RelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
* limitations under the License.
*/
#include "RelParser.h"

#include <string>
#include <google/protobuf/wrappers.pb.h>

#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/IDataType.h>
#include <Common/Exception.h>
#include <google/protobuf/wrappers.pb.h>
#include <Poco/StringTokenizer.h>
#include <Common/Exception.h>


namespace DB
{
Expand All @@ -38,7 +42,20 @@ AggregateFunctionPtr RelParser::getAggregateFunction(
{
auto & factory = AggregateFunctionFactory::instance();
auto action = NullsAction::EMPTY;
return factory.get(name, action, arg_types, parameters, properties);

String function_name = name;
if (name == "avg" && isDecimal(removeNullable(arg_types[0])))
function_name = "sparkAvg";
else if (name == "avgPartialMerge")
{
if (auto agg_func = typeid_cast<const DataTypeAggregateFunction *>(arg_types[0].get());
!agg_func->getArgumentsDataTypes().empty() && isDecimal(removeNullable(agg_func->getArgumentsDataTypes()[0])))
{
function_name = "sparkAvgPartialMerge";
}
}

return factory.get(function_name, action, arg_types, parameters, properties);
}

std::optional<String> RelParser::parseSignatureFunctionName(UInt32 function_ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,9 @@ object GlutenConfig {

GLUTEN_OFFHEAP_SIZE_IN_BYTES_KEY,
GLUTEN_TASK_OFFHEAP_SIZE_IN_BYTES_KEY,
GLUTEN_OFFHEAP_ENABLED
GLUTEN_OFFHEAP_ENABLED,
SESSION_LOCAL_TIMEZONE.key,
DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key
)
nativeConfMap.putAll(conf.filter(e => keys.contains(e._1)).asJava)

Expand All @@ -735,10 +737,6 @@ object GlutenConfig {
.filter(_._1.startsWith(SPARK_ABFS_ACCOUNT_KEY))
.foreach(entry => nativeConfMap.put(entry._1, entry._2))

conf
.filter(_._1.startsWith(SQLConf.SESSION_LOCAL_TIMEZONE.key))
.foreach(entry => nativeConfMap.put(entry._1, entry._2))

// return
nativeConfMap
}
Expand Down

0 comments on commit 4c52976

Please sign in to comment.