Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-6176][CH] Support aggreate avg return decimal #6177

Merged
merged 4 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class GlutenClickHouseDecimalSuite
private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply(
(DecimalType.apply(9, 4), Seq()),
// 1: ch decimal avg is float
(DecimalType.apply(18, 8), Seq(1)),
(DecimalType.apply(18, 8), Seq()),
// 1: ch decimal avg is float, 3/10: all value is null and compare with limit
(DecimalType.apply(38, 19), Seq(1, 3, 10))
(DecimalType.apply(38, 19), Seq(3, 10))
)

private def createDecimalTables(dataType: DecimalType): Unit = {
Expand Down Expand Up @@ -337,7 +337,6 @@ class GlutenClickHouseDecimalSuite
allowPrecisionLoss =>
Range
.inclusive(1, 22)
.filter(_ != 17) // Ignore Q17 which include avg
.foreach {
sql_num =>
{
Expand Down
158 changes: 158 additions & 0 deletions cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <AggregateFunctions/AggregateFunctionAvg.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <AggregateFunctions/Helpers.h>
#include <DataTypes/DataTypeTuple.h>

#include <algorithm>

#include <Common/CHUtil.h>
#include <Common/GlutenDecimalUtils.h>

namespace DB
{
struct Settings;

namespace ErrorCodes
{

}
}

namespace local_engine
{
using namespace DB;


DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type)
{
const UInt32 precision_value = std::min<size_t>(getDecimalPrecision(*arg_type) + 4, DecimalUtils::max_precision<Decimal128>);
const auto scale_value = std::min(getDecimalScale(*arg_type) + 4, precision_value);
return createDecimal<DataTypeDecimal>(precision_value, scale_value);
}

template <typename T>
requires is_decimal<T>
class AggregateFunctionSparkAvg final : public AggregateFunctionAvg<T>
{
public:
using Base = AggregateFunctionAvg<T>;

explicit AggregateFunctionSparkAvg(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_)
: Base(argument_types_, createResultType(argument_types_, num_scale_, round_scale_), num_scale_)
, num_scale(num_scale_)
, round_scale(round_scale_)
{
}

DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_)
{
const DataTypePtr & data_type = argument_types_[0];
const UInt32 precision_value = std::min<size_t>(getDecimalPrecision(*data_type) + 4, DecimalUtils::max_precision<Decimal128>);
const auto scale_value = std::min(num_scale_ + 4, precision_value);
return createDecimal<DataTypeDecimal>(precision_value, scale_value);
}

void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
const DataTypePtr & result_type = this->getResultType();
auto result_scale = getDecimalScale(*result_type);
WhichDataType which(result_type);
if (which.isDecimal32())
{
assert_cast<ColumnDecimal<Decimal32> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else if (which.isDecimal64())
{
assert_cast<ColumnDecimal<Decimal64> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else if (which.isDecimal128())
{
assert_cast<ColumnDecimal<Decimal128> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
else
{
assert_cast<ColumnDecimal<Decimal256> &>(to).getData().push_back(
divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale));
}
}

String getName() const override { return "sparkAvg"; }

private:
Int128 NO_SANITIZE_UNDEFINED
divideDecimalAndUInt(AvgFraction<AvgFieldType<T>, UInt64> avg, UInt32 num_scale, UInt32 result_scale, UInt32 round_scale) const
{
auto value = avg.numerator.value;
if (result_scale > num_scale)
{
auto diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - num_scale);
value = value * diff;
}
else if (result_scale < num_scale)
{
auto diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(num_scale - result_scale);
value = value / diff;
}

auto result = value / avg.denominator;

if (round_scale > result_scale)
return result;

auto round_diff = DecimalUtils::scaleMultiplier<AvgFieldType<T>>(result_scale - round_scale);
return (result + round_diff / 2) / round_diff * round_diff;
}

private:
UInt32 num_scale;
UInt32 round_scale;
};

AggregateFunctionPtr
createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings)
{
assertNoParameters(name, parameters);
assertUnary(name, argument_types);

AggregateFunctionPtr res;
const DataTypePtr & data_type = argument_types[0];
if (!isDecimal(data_type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name);

bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get<bool>();
const UInt32 p1 = DB::getDecimalPrecision(*data_type);
const UInt32 s1 = DB::getDecimalScale(*data_type);
auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL;
auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, p2, s2, allowPrecisionLoss);

res.reset(createWithDecimalType<AggregateFunctionSparkAvg>(*data_type, argument_types, getDecimalScale(*data_type), round_scale));
return res;
}

void registerAggregateFunctionSparkAvg(AggregateFunctionFactory & factory)
{
factory.registerFunction("sparkAvg", createAggregateFunctionSparkAvg);
}

}
9 changes: 8 additions & 1 deletion cpp-ch/local-engine/Common/CHUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ void BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
settings.set("date_time_input_format", "best_effort");
settings.set(MERGETREE_MERGE_AFTER_INSERT, true);
settings.set(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, false);
settings.set(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS, true);

for (const auto & [key, value] : backend_conf_map)
{
Expand Down Expand Up @@ -664,6 +665,11 @@ void BackendInitializerUtil::initSettings(std::map<std::string, std::string> & b
settings.set("session_timezone", time_zone_val);
LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", "session_timezone", time_zone_val);
}
else if (key == DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)
{
settings.set(key, toField(key, value));
LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", key, value);
}
}

/// Finally apply some fixed kvs to settings.
Expand Down Expand Up @@ -787,6 +793,7 @@ void BackendInitializerUtil::updateNewSettings(const DB::ContextMutablePtr & con

extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCombinatorFactory &);
extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &);
extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &);
extern void registerFunctions(FunctionFactory &);

void registerAllFunctions()
Expand All @@ -796,7 +803,7 @@ void registerAllFunctions()
DB::registerAggregateFunctions();
auto & agg_factory = AggregateFunctionFactory::instance();
registerAggregateFunctionsBloomFilter(agg_factory);

registerAggregateFunctionSparkAvg(agg_factory);
{
/// register aggregate function combinators from local_engine
auto & factory = AggregateFunctionCombinatorFactory::instance();
Expand Down
5 changes: 4 additions & 1 deletion cpp-ch/local-engine/Common/CHUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ namespace local_engine
{
static const String MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE = "mergetree.insert_without_local_storage";
static const String MERGETREE_MERGE_AFTER_INSERT = "mergetree.merge_after_insert";
static const std::unordered_set<String> BOOL_VALUE_SETTINGS{MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE};
static const std::string DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = "spark.sql.decimalOperations.allowPrecisionLoss";

static const std::unordered_set<String> BOOL_VALUE_SETTINGS{
MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, DECIMAL_OPERATIONS_ALLOW_PREC_LOSS};
static const std::unordered_set<String> LONG_VALUE_SETTINGS{
"optimize.maxfilesize", "optimize.minFileSize", "mergetree.max_num_part_per_merge_task"};

Expand Down
108 changes: 108 additions & 0 deletions cpp-ch/local-engine/Common/GlutenDecimalUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once


namespace local_engine
{

class GlutenDecimalUtils
{
public:
static constexpr size_t MAX_PRECISION = 38;
static constexpr size_t MAX_SCALE = 38;
static constexpr auto system_Default = std::tuple(MAX_PRECISION, 18);
static constexpr auto user_Default = std::tuple(10, 0);
static constexpr size_t MINIMUM_ADJUSTED_SCALE = 6;

// The decimal types compatible with other numeric types
static constexpr auto BOOLEAN_DECIMAL = std::tuple(1, 0);
static constexpr auto BYTE_DECIMAL = std::tuple(3, 0);
static constexpr auto SHORT_DECIMAL = std::tuple(5, 0);
static constexpr auto INT_DECIMAL = std::tuple(10, 0);
static constexpr auto LONG_DECIMAL = std::tuple(20, 0);
static constexpr auto FLOAT_DECIMAL = std::tuple(14, 7);
static constexpr auto DOUBLE_DECIMAL = std::tuple(30, 15);
static constexpr auto BIGINT_DECIMAL = std::tuple(MAX_PRECISION, 0);

static std::tuple<size_t, size_t> adjustPrecisionScale(size_t precision, size_t scale)
{
if (precision <= MAX_PRECISION)
{
// Adjustment only needed when we exceed max precision
return std::tuple(precision, scale);
}
else if (scale < 0)
{
// Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision
// loss since we would cause a loss of digits in the integer part.
// In this case, we are likely to meet an overflow.
return std::tuple(GlutenDecimalUtils::MAX_PRECISION, scale);
}
else
{
// Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION.
auto intDigits = precision - scale;
// If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise
// preserve at least MINIMUM_ADJUSTED_SCALE fractional digits
auto minScaleValue = std::min(scale, GlutenDecimalUtils::MINIMUM_ADJUSTED_SCALE);
// The resulting scale is the maximum between what is available without causing a loss of
// digits for the integer part of the decimal and the minimum guaranteed scale, which is
// computed above
auto adjustedScale = std::max(GlutenDecimalUtils::MAX_PRECISION - intDigits, minScaleValue);

return std::tuple(GlutenDecimalUtils::MAX_PRECISION, adjustedScale);
}
}

static std::tuple<size_t, size_t> dividePrecisionScale(size_t p1, size_t s1, size_t p2, size_t s2, bool allowPrecisionLoss)
{
if (allowPrecisionLoss)
{
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
const size_t intDig = p1 - s1 + s2;
const size_t scale = std::max(MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1);
const size_t precision = intDig + scale;
return adjustPrecisionScale(precision, scale);
}
else
{
auto intDig = std::min(MAX_SCALE, p1 - s1 + s2);
auto decDig = std::min(MAX_SCALE, std::max(static_cast<size_t>(6), s1 + p2 + 1));
auto diff = (intDig + decDig) - MAX_SCALE;
if (diff > 0)
{
decDig -= diff / 2 + 1;
intDig = MAX_SCALE - decDig;
}
return std::tuple(intDig + decDig, decDig);
}
}

static std::tuple<size_t, size_t> widerDecimalType(const size_t p1, const size_t s1, const size_t p2, const size_t s2)
{
// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
auto scale = std::max(s1, s2);
auto range = std::max(p1 - s1, p2 - s2);
return std::tuple(range + scale, scale);
}

};

}
23 changes: 20 additions & 3 deletions cpp-ch/local-engine/Parser/RelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
* limitations under the License.
*/
#include "RelParser.h"

#include <string>
#include <google/protobuf/wrappers.pb.h>

#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/IDataType.h>
#include <Common/Exception.h>
#include <google/protobuf/wrappers.pb.h>
#include <Poco/StringTokenizer.h>
#include <Common/Exception.h>


namespace DB
{
Expand All @@ -38,7 +42,20 @@ AggregateFunctionPtr RelParser::getAggregateFunction(
{
auto & factory = AggregateFunctionFactory::instance();
auto action = NullsAction::EMPTY;
return factory.get(name, action, arg_types, parameters, properties);

String function_name = name;
if (name == "avg" && isDecimal(removeNullable(arg_types[0])))
function_name = "sparkAvg";
else if (name == "avgPartialMerge")
{
if (auto agg_func = typeid_cast<const DataTypeAggregateFunction *>(arg_types[0].get());
!agg_func->getArgumentsDataTypes().empty() && isDecimal(removeNullable(agg_func->getArgumentsDataTypes()[0])))
{
function_name = "sparkAvgPartialMerge";
}
}

return factory.get(function_name, action, arg_types, parameters, properties);
}

std::optional<String> RelParser::parseSignatureFunctionName(UInt32 function_ref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,9 @@ object GlutenConfig {

GLUTEN_OFFHEAP_SIZE_IN_BYTES_KEY,
GLUTEN_TASK_OFFHEAP_SIZE_IN_BYTES_KEY,
GLUTEN_OFFHEAP_ENABLED
GLUTEN_OFFHEAP_ENABLED,
SESSION_LOCAL_TIMEZONE.key,
DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key
)
nativeConfMap.putAll(conf.filter(e => keys.contains(e._1)).asJava)

Expand All @@ -735,10 +737,6 @@ object GlutenConfig {
.filter(_._1.startsWith(SPARK_ABFS_ACCOUNT_KEY))
.foreach(entry => nativeConfMap.put(entry._1, entry._2))

conf
.filter(_._1.startsWith(SQLConf.SESSION_LOCAL_TIMEZONE.key))
.foreach(entry => nativeConfMap.put(entry._1, entry._2))

// return
nativeConfMap
}
Expand Down
Loading