From c49a91117a478574b560efeb8ce511d3a647fec6 Mon Sep 17 00:00:00 2001 From: loneylee Date: Fri, 21 Jun 2024 16:40:31 +0800 Subject: [PATCH] Support aggreate avg return decimal --- .../GlutenClickHouseDecimalSuite.scala | 4 +- cpp-ch/clickhouse.version | 6 +- .../AggregateFunctionSparkAvg.cpp | 145 ++++++++++++++++++ cpp-ch/local-engine/Common/CHUtil.cpp | 9 +- cpp-ch/local-engine/Common/CHUtil.h | 4 +- .../local-engine/Common/GlutenDecimalUtils.h | 108 +++++++++++++ cpp-ch/local-engine/Parser/RelParser.cpp | 23 ++- .../org/apache/gluten/GlutenConfig.scala | 8 +- 8 files changed, 292 insertions(+), 15 deletions(-) create mode 100644 cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp create mode 100644 cpp-ch/local-engine/Common/GlutenDecimalUtils.h diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala index 0884871010812..3195f907e579a 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala @@ -67,9 +67,9 @@ class GlutenClickHouseDecimalSuite private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply( (DecimalType.apply(9, 4), Seq()), // 1: ch decimal avg is float - (DecimalType.apply(18, 8), Seq(1)), + (DecimalType.apply(18, 8), Seq()), // 1: ch decimal avg is float, 3/10: all value is null and compare with limit - (DecimalType.apply(38, 19), Seq(1, 3, 10)) + (DecimalType.apply(38, 19), Seq(3, 10)) ) private def createDecimalTables(dataType: DecimalType): Unit = { diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version index 1e3ac8d88ea91..6671ce4111631 100644 --- a/cpp-ch/clickhouse.version +++ b/cpp-ch/clickhouse.version @@ -1,3 +1,3 @@ -CH_ORG=Kyligence -CH_BRANCH=rebase_ch/20240620 -CH_COMMIT=f9c3886a767 +CH_ORG=loneylee +CH_BRANCH=6176 +CH_COMMIT=39ff090345f5b05ab6f61bfcd8244647b08b01d9 diff --git a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp new file mode 100644 index 0000000000000..809dfb100744f --- /dev/null +++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace DB +{ +struct Settings; + +namespace ErrorCodes +{ + +} +} + +namespace local_engine +{ +using namespace DB; + + +DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type) +{ + const UInt32 precision_value = std::min(getDecimalPrecision(*arg_type) + 4, DecimalUtils::max_precision); + const auto scale_value = std::min(getDecimalScale(*arg_type) + 4, precision_value); + return createDecimal(precision_value, scale_value); +} + +template +requires is_decimal +class AggregateFunctionSparkAvg final : public AggregateFunctionAvg +{ +public: + using Base = AggregateFunctionAvg; + + explicit AggregateFunctionSparkAvg(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_) + : Base(argument_types_, createResultType(argument_types_, num_scale_, round_scale_), num_scale_) + , num_scale(num_scale_) + , round_scale(round_scale_) + { + // auto calculate = ; + // + // + // const UInt32 p1 = DB::getDecimalPrecision(*data_type); + // const UInt32 s1 = DB::getDecimalScale(*data_type); + // auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL; + // auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale( + // p1, s1, p2, s2, ); + } + + DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_) + { + const DataTypePtr & data_type = argument_types_[0]; + const UInt32 precision_value = std::min(getDecimalPrecision(*data_type) + 4, DecimalUtils::max_precision); + const auto scale_value = std::min(num_scale_ + 4, precision_value); + + // if (scale_value == round_scale_) + // return createDecimal(precision_value + 1, scale_value + 1); + + return createDecimal(precision_value, scale_value); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override + { + const DataTypePtr & result_type = this->getResultType(); + auto result_scale = getDecimalScale(*result_type); + WhichDataType which(result_type); + if (which.isDecimal32()) + { + assert_cast &>(to).getData().push_back( + this->data(place).divideDecimalAndUInt(num_scale, result_scale, round_scale)); + } + else if (which.isDecimal64()) + { + assert_cast &>(to).getData().push_back( + this->data(place).divideDecimalAndUInt(num_scale, result_scale, round_scale)); + } + else if (which.isDecimal128()) + { + assert_cast &>(to).getData().push_back( + this->data(place).divideDecimalAndUInt(num_scale, result_scale, round_scale)); + } + else + { + assert_cast &>(to).getData().push_back( + this->data(place).divideDecimalAndUInt(num_scale, result_scale, round_scale)); + } + } + + String getName() const override { return "sparkAvg"; } + +private: + UInt32 num_scale; + UInt32 round_scale; +}; + +AggregateFunctionPtr +createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings) +{ + assertNoParameters(name, parameters); + assertUnary(name, argument_types); + + AggregateFunctionPtr res; + const DataTypePtr & data_type = argument_types[0]; + if (!isDecimal(data_type)) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name); + + bool allowPrecisionLoss = settings->get(BackendInitializerUtil::DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get(); + const UInt32 p1 = DB::getDecimalPrecision(*data_type); + const UInt32 s1 = DB::getDecimalScale(*data_type); + auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL; + auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, p2, s2, allowPrecisionLoss); + + res.reset(createWithDecimalType(*data_type, argument_types, getDecimalScale(*data_type), round_scale)); + return res; +} + +void registerAggregateFunctionSparkAvg(AggregateFunctionFactory & factory) +{ + factory.registerFunction("sparkAvg", createAggregateFunctionSparkAvg); +} + +} diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 937beae99a6b0..7773f5a5de714 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -624,6 +624,7 @@ void BackendInitializerUtil::initSettings(std::map & b /// Initialize default setting. settings.set("date_time_input_format", "best_effort"); settings.set("mergetree.merge_after_insert", true); + settings.set(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS, true); for (const auto & [key, value] : backend_conf_map) { @@ -663,6 +664,11 @@ void BackendInitializerUtil::initSettings(std::map & b settings.set("session_timezone", time_zone_val); LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", "session_timezone", time_zone_val); } + else if (key == DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + { + settings.set(key, toField(key, value)); + LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", key, value); + } } /// Finally apply some fixed kvs to settings. @@ -786,6 +792,7 @@ void BackendInitializerUtil::updateNewSettings(const DB::ContextMutablePtr & con extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCombinatorFactory &); extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &); +extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &); extern void registerFunctions(FunctionFactory &); void registerAllFunctions() @@ -795,7 +802,7 @@ void registerAllFunctions() DB::registerAggregateFunctions(); auto & agg_factory = AggregateFunctionFactory::instance(); registerAggregateFunctionsBloomFilter(agg_factory); - + registerAggregateFunctionSparkAvg(agg_factory); { /// register aggregate function combinators from local_engine auto & factory = AggregateFunctionCombinatorFactory::instance(); diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 50de9461f4dee..fe9845111e308 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -35,7 +35,8 @@ class QueryPlan; namespace local_engine { -static const std::unordered_set BOOL_VALUE_SETTINGS{"mergetree.merge_after_insert"}; +static const std::unordered_set BOOL_VALUE_SETTINGS{ + "mergetree.merge_after_insert", "spark.sql.decimalOperations.allowPrecisionLoss"}; static const std::unordered_set LONG_VALUE_SETTINGS{ "optimize.maxfilesize", "optimize.minFileSize", "mergetree.max_num_part_per_merge_task"}; @@ -169,6 +170,7 @@ class BackendInitializerUtil inline static const std::string S3A_PREFIX = "fs.s3a."; inline static const std::string SPARK_DELTA_PREFIX = "spark.databricks.delta."; inline static const std::string SPARK_SESSION_TIME_ZONE = "spark.sql.session.timeZone"; + inline static const std::string DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = "spark.sql.decimalOperations.allowPrecisionLoss"; inline static const String GLUTEN_TASK_OFFHEAP = "spark.gluten.memory.task.offHeap.size.in.bytes"; inline static const String CH_TASK_MEMORY = "off_heap_per_task"; diff --git a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h new file mode 100644 index 0000000000000..32af66ec590e0 --- /dev/null +++ b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h @@ -0,0 +1,108 @@ +/* +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + + +namespace local_engine +{ + +class GlutenDecimalUtils +{ +public: + static constexpr size_t MAX_PRECISION = 38; + static constexpr size_t MAX_SCALE = 38; + static constexpr auto system_Default = std::tuple(MAX_PRECISION, 18); + static constexpr auto user_Default = std::tuple(10, 0); + static constexpr size_t MINIMUM_ADJUSTED_SCALE = 6; + + // The decimal types compatible with other numeric types + static constexpr auto BOOLEAN_DECIMAL = std::tuple(1, 0); + static constexpr auto BYTE_DECIMAL = std::tuple(3, 0); + static constexpr auto SHORT_DECIMAL = std::tuple(5, 0); + static constexpr auto INT_DECIMAL = std::tuple(10, 0); + static constexpr auto LONG_DECIMAL = std::tuple(20, 0); + static constexpr auto FLOAT_DECIMAL = std::tuple(14, 7); + static constexpr auto DOUBLE_DECIMAL = std::tuple(30, 15); + static constexpr auto BIGINT_DECIMAL = std::tuple(MAX_PRECISION, 0); + + static std::tuple adjustPrecisionScale(size_t precision, size_t scale) + { + if (precision <= MAX_PRECISION) + { + // Adjustment only needed when we exceed max precision + return std::tuple(precision, scale); + } + else if (scale < 0) + { + // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision + // loss since we would cause a loss of digits in the integer part. + // In this case, we are likely to meet an overflow. + return std::tuple(GlutenDecimalUtils::MAX_PRECISION, scale); + } + else + { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + auto intDigits = precision - scale; + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + auto minScaleValue = std::min(scale, GlutenDecimalUtils::MINIMUM_ADJUSTED_SCALE); + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above + auto adjustedScale = std::max(GlutenDecimalUtils::MAX_PRECISION - intDigits, minScaleValue); + + return std::tuple(GlutenDecimalUtils::MAX_PRECISION, adjustedScale); + } + } + + static std::tuple dividePrecisionScale(size_t p1, size_t s1, size_t p2, size_t s2, bool allowPrecisionLoss) + { + if (allowPrecisionLoss) + { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + const size_t intDig = p1 - s1 + s2; + const size_t scale = std::max(MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1); + const size_t precision = intDig + scale; + return adjustPrecisionScale(precision, scale); + } + else + { + auto intDig = std::min(MAX_SCALE, p1 - s1 + s2); + auto decDig = std::min(MAX_SCALE, std::max(static_cast(6), s1 + p2 + 1)); + auto diff = (intDig + decDig) - MAX_SCALE; + if (diff > 0) + { + decDig -= diff / 2 + 1; + intDig = MAX_SCALE - decDig; + } + return std::tuple(intDig + decDig, decDig); + } + } + + static std::tuple widerDecimalType(const size_t p1, const size_t s1, const size_t p2, const size_t s2) + { + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + auto scale = std::max(s1, s2); + auto range = std::max(p1 - s1, p2 - s2); + return std::tuple(range + scale, scale); + } + +}; + +} diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp b/cpp-ch/local-engine/Parser/RelParser.cpp index 7fc8078271093..282339c4d641f 100644 --- a/cpp-ch/local-engine/Parser/RelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParser.cpp @@ -15,12 +15,16 @@ * limitations under the License. */ #include "RelParser.h" + #include +#include + #include +#include #include -#include -#include #include +#include + namespace DB { @@ -38,7 +42,20 @@ AggregateFunctionPtr RelParser::getAggregateFunction( { auto & factory = AggregateFunctionFactory::instance(); auto action = NullsAction::EMPTY; - return factory.get(name, action, arg_types, parameters, properties); + + String function_name = name; + if (name == "avg" && isDecimal(removeNullable(arg_types[0]))) + function_name = "sparkAvg"; + else if (name == "avgPartialMerge") + { + if (auto agg_func = typeid_cast(arg_types[0].get()); + !agg_func->getArgumentsDataTypes().empty() && isDecimal(removeNullable(agg_func->getArgumentsDataTypes()[0]))) + { + function_name = "sparkAvgPartialMerge"; + } + } + + return factory.get(function_name, action, arg_types, parameters, properties); } std::optional RelParser::parseSignatureFunctionName(UInt32 function_ref) diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 148e8cdc067c0..4b4e29e7d0fbf 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -718,7 +718,9 @@ object GlutenConfig { GLUTEN_OFFHEAP_SIZE_IN_BYTES_KEY, GLUTEN_TASK_OFFHEAP_SIZE_IN_BYTES_KEY, - GLUTEN_OFFHEAP_ENABLED + GLUTEN_OFFHEAP_ENABLED, + SESSION_LOCAL_TIMEZONE.key, + DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key ) nativeConfMap.putAll(conf.filter(e => keys.contains(e._1)).asJava) @@ -735,10 +737,6 @@ object GlutenConfig { .filter(_._1.startsWith(SPARK_ABFS_ACCOUNT_KEY)) .foreach(entry => nativeConfMap.put(entry._1, entry._2)) - conf - .filter(_._1.startsWith(SQLConf.SESSION_LOCAL_TIMEZONE.key)) - .foreach(entry => nativeConfMap.put(entry._1, entry._2)) - // return nativeConfMap }