From c9018cdd884c5e93a171ba1561888e719026898c Mon Sep 17 00:00:00 2001 From: Chang chen Date: Wed, 8 May 2024 13:55:59 +0800 Subject: [PATCH] [GLUTEN-5620][CORE] Simplify Decimal process logic (#5621) * rescaleCastForDecimal refactor * refactor isPromoteCast * Simplify Decimal process logic and re-implement FunctionParserDivide, so divide.cpp is deleted. * remove SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs * rename noCheckOverflow to dontTransformCheckOverflow * update per comments * fix warning * fix style warning * fix typo --- .../backendsapi/clickhouse/CHBackend.scala | 1 + cpp-ch/local-engine/Common/CHUtil.cpp | 63 ++- cpp-ch/local-engine/Common/CHUtil.h | 16 +- .../Parser/SerializedPlanParser.cpp | 82 +--- .../Parser/SerializedPlanParser.h | 12 +- .../scalar_function_parser/arithmetic.cpp | 399 ++++++++++++++++++ .../Parser/scalar_function_parser/divide.cpp | 68 --- .../backendsapi/BackendSettingsApi.scala | 8 + .../expression/ExpressionConverter.scala | 134 +++--- .../gluten/utils/DecimalArithmeticUtil.scala | 83 ++-- 10 files changed, 593 insertions(+), 273 deletions(-) create mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp delete mode 100644 cpp-ch/local-engine/Parser/scalar_function_parser/divide.cpp diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 276ce11fb4dc..da6c60d8aea1 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -257,6 +257,7 @@ object CHBackendSettings extends BackendSettingsApi with Logging { override def needOutputSchemaForPlan(): Boolean = true override def allowDecimalArithmetic: Boolean = !SQLConf.get.decimalOperationsAllowPrecisionLoss + override def transformCheckOverflow: Boolean = false override def requiredInputFilePaths(): Boolean = true diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 9704b3041cd9..9e2ce6304718 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -14,14 +14,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include "CHUtil.h" #include #include #include +#include #include #include #include #include -#include #include #include #include @@ -30,14 +32,17 @@ #include #include #include -#include +#include #include #include #include +#include #include #include #include +#include #include +#include #include #include #include @@ -51,8 +56,11 @@ #include #include #include +#include +#include #include #include +#include #include #include #include @@ -63,20 +71,12 @@ #include #include -#include -#include - -#include "CHUtil.h" -#include "Disks/registerGlutenDisks.h" - -#include -#include - namespace DB { namespace ErrorCodes { extern const int BAD_ARGUMENTS; +extern const int UNKNOWN_TYPE; } } @@ -311,16 +311,48 @@ size_t PODArrayUtil::adjustMemoryEfficientSize(size_t n) std::string PlanUtil::explainPlan(DB::QueryPlan & plan) { - std::string plan_str; - DB::QueryPlan::ExplainPlanOptions buf_opt{ + constexpr DB::QueryPlan::ExplainPlanOptions buf_opt{ .header = true, .actions = true, .indexes = true, }; DB::WriteBufferFromOwnString buf; plan.explainPlan(buf, buf_opt); - plan_str = buf.str(); - return plan_str; + + return buf.str(); +} + +void PlanUtil::checkOuputType(const DB::QueryPlan & plan) +{ + // QueryPlan::checkInitialized is a private method, so we assume plan is initialized, otherwise there is a core dump here. + // It's okay, because it's impossible for us not to initialize where we call this method. + const auto & step = *plan.getRootNode()->step; + if (!step.hasOutputStream()) + return; + if (!step.getOutputStream().header) + return; + for (const auto & elem : step.getOutputStream().header) + { + const DB::DataTypePtr & ch_type = elem.type; + const auto ch_type_without_nullable = DB::removeNullable(ch_type); + const DB::WhichDataType which(ch_type_without_nullable); + if (which.isDateTime64()) + { + const auto * ch_type_datetime64 = checkAndGetDataType(ch_type_without_nullable.get()); + if (ch_type_datetime64->getScale() != 6) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); + } + else if (which.isDecimal()) + { + if (which.isDecimal256()) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); + + const auto scale = getDecimalScale(*ch_type_without_nullable); + const auto precision = getDecimalPrecision(*ch_type_without_nullable); + if (scale == 0 && precision == 0) + throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support converting from {}", ch_type->getName()); + } + } } NestedColumnExtractHelper::NestedColumnExtractHelper(const DB::Block & block_, bool case_insentive_) @@ -713,7 +745,6 @@ void registerAllFunctions() auto & factory = AggregateFunctionCombinatorFactory::instance(); registerAggregateFunctionCombinatorPartialMerge(factory); } - } void registerGlutenDisks() diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 574cdbe4c8d7..edbd91c50d22 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -16,19 +16,21 @@ */ #pragma once #include -#include #include #include #include -#include #include #include #include #include -#include #include #include -#include + +namespace DB +{ +class QueryPipeline; +class QueryPlan; +} namespace local_engine { @@ -96,10 +98,10 @@ class NestedColumnExtractHelper const DB::ColumnWithTypeAndName * findColumn(const DB::Block & block, const std::string & name) const; }; -class PlanUtil +namespace PlanUtil { -public: - static std::string explainPlan(DB::QueryPlan & plan); +std::string explainPlan(DB::QueryPlan & plan); +void checkOuputType(const DB::QueryPlan & plan); }; class ActionsDAGUtil diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 543489c2e08f..82acba37f7d8 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -867,8 +867,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( auto pos = function_signature.find(':'); auto func_name = function_signature.substr(0, pos); - auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this); - if (func_parser) + if (auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this)) { LOG_DEBUG( &Poco::Logger::get("SerializedPlanParser"), @@ -971,13 +970,12 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( args = std::move(new_args); } - bool converted_decimal_args = convertBinaryArithmeticFunDecimalArgs(actions_dag, args, scalar_function); auto function_builder = FunctionFactory::instance().get(ch_func_name, context); std::string args_name = join(args, ','); result_name = ch_func_name + "(" + args_name + ")"; const auto * function_node = &actions_dag->addFunction(function_builder, args, result_name); result_node = function_node; - if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type) && !converted_decimal_args) + if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type)) { auto result_type = TypeParser::parseType(rel.scalar_function().output_type()); if (isDecimalOrNullableDecimal(result_type)) @@ -1014,76 +1012,6 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( return result_node; } -bool SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs( - ActionsDAGPtr actions_dag, - ActionsDAG::NodeRawConstPtrs & args, - const substrait::Expression_ScalarFunction & arithmeticFun) -{ - auto function_signature = function_mapping.at(std::to_string(arithmeticFun.function_reference())); - auto pos = function_signature.find(':'); - auto func_name = function_signature.substr(0, pos); - - if (func_name == "divide" || func_name == "multiply" || func_name == "plus" || func_name == "minus") - { - /// for divide/plus/minus, we need to convert first arg to result precision and scale - /// for multiply, we need to convert first arg to result precision, but keep scale - auto arg1_type = removeNullable(args[0]->result_type); - auto arg2_type = removeNullable(args[1]->result_type); - if (isDecimal(arg1_type) && isDecimal(arg2_type)) - { - UInt32 p1 = getDecimalPrecision(*arg1_type); - UInt32 s1 = getDecimalScale(*arg1_type); - UInt32 p2 = getDecimalPrecision(*arg2_type); - UInt32 s2 = getDecimalScale(*arg2_type); - - UInt32 precision; - UInt32 scale; - - if (func_name == "plus" || func_name == "minus") - { - scale = s1; - precision = scale + std::max(p1 - s1, p2 - s2) + 1; - } - else if (func_name == "divide") - { - scale = std::max(static_cast(6), s1 + p2 + 1); - precision = p1 - s1 + s2 + scale; - } - else // multiply - { - scale = s1; - precision = p1 + p2 + 1; - } - - UInt32 maxPrecision = DataTypeDecimal256::maxPrecision(); - UInt32 maxScale = DataTypeDecimal128::maxPrecision(); - precision = std::min(precision, maxPrecision); - scale = std::min(scale, maxScale); - - ActionsDAG::NodeRawConstPtrs new_args; - new_args.reserve(args.size()); - - ActionsDAG::NodeRawConstPtrs cast_args; - cast_args.reserve(2); - cast_args.emplace_back(args[0]); - DataTypePtr ch_type = createDecimal(precision, scale); - ch_type = wrapNullableType(arithmeticFun.output_type().decimal().nullability(), ch_type); - String type_name = ch_type->getName(); - DataTypePtr str_type = std::make_shared(); - const ActionsDAG::Node * type_node = &actions_dag->addColumn( - ColumnWithTypeAndName(str_type->createColumnConst(1, type_name), str_type, getUniqueName(type_name))); - cast_args.emplace_back(type_node); - const ActionsDAG::Node * cast_node = toFunctionNode(actions_dag, "CAST", cast_args); - actions_dag->addOrReplaceInOutputs(*cast_node); - new_args.emplace_back(cast_node); - new_args.emplace_back(args[1]); - args = std::move(new_args); - return true; - } - } - return false; -} - void SerializedPlanParser::parseFunctionArguments( ActionsDAGPtr & actions_dag, ActionsDAG::NodeRawConstPtrs & parsed_args, @@ -1835,11 +1763,15 @@ QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) auto res = parse(std::move(plan_ptr)); +#ifndef NDEBUG + PlanUtil::checkOuputType(*res); +#endif + auto * logger = &Poco::Logger::get("SerializedPlanParser"); if (logger->debug()) { auto out = PlanUtil::explainPlan(*res); - LOG_DEBUG(logger, "clickhouse plan:\n{}", out); + LOG_ERROR(logger, "clickhouse plan:\n{}", out); } return res; } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 5bf7da25d32c..a636ebb9352f 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include #include @@ -25,14 +24,10 @@ #include #include #include -#include #include #include -#include #include -#include #include -#include #include #include #include @@ -301,9 +296,6 @@ class SerializedPlanParser static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); - bool convertBinaryArithmeticFunDecimalArgs( - ActionsDAGPtr actions_dag, ActionsDAG::NodeRawConstPtrs & args, const substrait::Expression_ScalarFunction & arithmeticFun); - IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); static ContextMutablePtr global_context; @@ -383,7 +375,6 @@ class SerializedPlanParser void wrapNullable( const std::vector & columns, ActionsDAGPtr actions_dag, std::map & nullable_measure_names); static std::pair convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field); - const ActionsDAG::Node * addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field); int name_no = 0; std::unordered_map function_mapping; @@ -395,6 +386,9 @@ class SerializedPlanParser // for parse rel node, collect steps from a rel node std::vector temp_step_collection; std::vector metrics; + +public: + const ActionsDAG::Node * addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field); }; struct SparkBuffer diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp new file mode 100644 index 000000000000..ec056da45e07 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +namespace DB::ErrorCodes +{ +extern const int BAD_ARGUMENTS; +extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +namespace local_engine +{ + +class DecimalType +{ + static constexpr Int32 spark_max_precision = 38; + static constexpr Int32 spark_max_scale = 38; + static constexpr Int32 minimum_adjusted_scale = 6; + + static constexpr Int32 chickhouse_max_precision = DB::DataTypeDecimal256::maxPrecision(); + static constexpr Int32 chickhouse_max_scale = DB::DataTypeDecimal128::maxPrecision(); + +public: + Int32 precision; + Int32 scale; + +private: + static DecimalType bounded_to_spark(const Int32 precision, const Int32 scale) + { + return DecimalType(std::min(precision, spark_max_precision), std::min(scale, spark_max_scale)); + } + static DecimalType bounded_to_click_house(const Int32 precision, const Int32 scale) + { + return DecimalType(std::min(precision, chickhouse_max_precision), std::min(scale, chickhouse_max_scale)); + } + static void check_negative_scale(const Int32 scale) + { + /// only support spark.sql.legacy.allowNegativeScaleOfDecimal == false + if (scale < 0) + throw Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Negative scale is not supported"); + } + + static DecimalType adjust_precision_scale(const Int32 precision, const Int32 scale) + { + check_negative_scale(scale); + assert(precision >= scale); + + if (precision <= spark_max_precision) + { + // Adjustment only needed when we exceed max precision + return DecimalType(precision, scale); + } + else if (scale < 0) + { + // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision + // loss since we would cause a loss of digits in the integer part. + // In this case, we are likely to meet an overflow. + return DecimalType(spark_max_precision, scale); + } + else + { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + const int intDigits = precision - scale; + + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + const int minScaleValue = std::min(scale, minimum_adjusted_scale); + + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above + const int adjusted_scale = std::max(spark_max_precision - intDigits, minScaleValue); + return DecimalType(spark_max_precision, adjusted_scale); + } + } + +public: + /// The formula follows Hive which is based on the SQL standard and MS SQL: + /// https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + /// https://msdn.microsoft.com/en-us/library/ms190476.aspx + /// Result Precision: max(s1, s2) + max(p1-s1, p2-s2) + 1 + /// Result Scale: max(s1, s2) + /// +, - + static DecimalType + resultAddSubstractDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2, bool allowPrecisionLoss = true) + { + const Int32 scale = std::max(s1, s2); + const Int32 precision = std::max(p1 - s1, p2 - s2) + scale + 1; + + if (allowPrecisionLoss) + return adjust_precision_scale(precision, scale); + else + return bounded_to_spark(precision, scale); + } + + /// The formula follows Hive which is based on the SQL standard and MS SQL: + /// https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + /// https://msdn.microsoft.com/en-us/library/ms190476.aspx + /// Result Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + /// Result Scale: max(6, s1 + p2 + 1) + static DecimalType + resultDivideDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2, bool allowPrecisionLoss = true) + { + if (allowPrecisionLoss) + { + const Int32 Int32Dig = p1 - s1 + s2; + const Int32 scale = std::max(minimum_adjusted_scale, s1 + p2 + 1); + const Int32 prec = Int32Dig + scale; + return adjust_precision_scale(prec, scale); + } + else + { + Int32 Int32Dig = std::min(spark_max_scale, p1 - s1 + s2); + Int32 decDig = std::min(spark_max_scale, std::max(minimum_adjusted_scale, s1 + p2 + 1)); + Int32 diff = (Int32Dig + decDig) - spark_max_scale; + + if (diff > 0) + { + decDig -= diff / 2 + 1; + Int32Dig = spark_max_scale - decDig; + } + + return bounded_to_spark(Int32Dig + decDig, decDig); + } + } + + /// The formula follows Hive which is based on the SQL standard and MS SQL: + /// https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + /// https://msdn.microsoft.com/en-us/library/ms190476.aspx + /// Result Precision: p1 + p2 + 1 + /// Result Scale: s1 + s2 + static DecimalType + resultMultiplyDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2, bool allowPrecisionLoss = true) + { + const Int32 scale = s1 + s2; + const Int32 precision = p1 + p2 + 1; + + if (allowPrecisionLoss) + return adjust_precision_scale(precision, scale); + else + return bounded_to_spark(precision, scale); + } + + static DecimalType evalAddSubstractDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) + { + const Int32 scale = s1; + const Int32 precision = scale + std::max(p1 - s1, p2 - s2) + 1; + return bounded_to_click_house(precision, scale); + } + + static DecimalType evalDividetDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) + { + const Int32 scale = std::max(minimum_adjusted_scale, s1 + p2 + 1); + const Int32 precision = p1 - s1 + s2 + scale; + return bounded_to_click_house(precision, scale); + } + + static DecimalType evalMultiplyDecimalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) + { + const Int32 scale = s1; + const Int32 precision = p1 + p2 + 1; + return bounded_to_click_house(precision, scale); + } +}; + +class FunctionParserBinaryArithmetic : public FunctionParser +{ +protected: + ActionsDAG::NodeRawConstPtrs convertBinaryArithmeticFunDecimalArgs( + ActionsDAGPtr & actions_dag, + const ActionsDAG::NodeRawConstPtrs & args, + const DecimalType & eval_type, + const substrait::Expression_ScalarFunction & arithmeticFun) const + { + const Int32 precision = eval_type.precision; + const Int32 scale = eval_type.scale; + + ActionsDAG::NodeRawConstPtrs new_args; + new_args.reserve(args.size()); + + ActionsDAG::NodeRawConstPtrs cast_args; + cast_args.reserve(2); + cast_args.emplace_back(args[0]); + DataTypePtr ch_type = createDecimal(precision, scale); + ch_type = wrapNullableType(arithmeticFun.output_type().decimal().nullability(), ch_type); + const String type_name = ch_type->getName(); + const DataTypePtr str_type = std::make_shared(); + const ActionsDAG::Node * type_node + = &actions_dag->addColumn(ColumnWithTypeAndName(str_type->createColumnConst(1, type_name), str_type, getUniqueName(type_name))); + cast_args.emplace_back(type_node); + const ActionsDAG::Node * cast_node = toFunctionNode(actions_dag, "CAST", cast_args); + actions_dag->addOrReplaceInOutputs(*cast_node); + new_args.emplace_back(cast_node); + new_args.emplace_back(args[1]); + return new_args; + } + + DecimalType getDecimalType(const DataTypePtr & left, const DataTypePtr & right, const bool resultType) const + { + assert(isDecimal(left) && isDecimal(right)); + const Int32 p1 = getDecimalPrecision(*left); + const Int32 s1 = getDecimalScale(*left); + const Int32 p2 = getDecimalPrecision(*right); + const Int32 s2 = getDecimalScale(*right); + return resultType ? internalResultType(p1, s1, p2, s2) : internalEvalType(p1, s1, p2, s2); + } + + virtual DecimalType internalResultType(Int32 p1, Int32 s1, Int32 p2, Int32 s2) const = 0; + virtual DecimalType internalEvalType(Int32 p1, Int32 s1, Int32 p2, Int32 s2) const = 0; + + const ActionsDAG::Node * + checkDecimalOverflow(ActionsDAGPtr & actions_dag, const ActionsDAG::Node * func_node, Int32 precision, Int32 scale) const + { + const DB::ActionsDAG::NodeRawConstPtrs overflow_args + = {func_node, + plan_parser->addColumn(actions_dag, std::make_shared(), precision), + plan_parser->addColumn(actions_dag, std::make_shared(), scale)}; + return toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", overflow_args); + } + const DB::ActionsDAG::Node * convertNodeTypeIfNeeded( + const substrait::Expression_ScalarFunction & substrait_func, + const DB::ActionsDAG::Node * func_node, + DB::ActionsDAGPtr & actions_dag) const override + { + const auto & substrait_type = substrait_func.output_type(); + if (const auto result_type = TypeParser::parseType(substrait_type); isDecimalOrNullableDecimal(result_type)) + { + const auto a = removeNullable(result_type); + const auto b = removeNullable(func_node->result_type); + if (a->equals(*b)) + return func_node; + + // as stated in isTypeMatched, currently we don't change nullability of the result type + const std::string type_name = func_node->result_type->isNullable() ? wrapNullableType(true, result_type)->getName() + : removeNullable(result_type)->getName(); + return ActionsDAGUtil::convertNodeType(actions_dag, func_node, type_name, func_node->result_name, DB::CastType::accurateOrNull); + } + return FunctionParser::convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } + + virtual const DB::ActionsDAG::Node * + createFunctionNode(DB::ActionsDAGPtr & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & args) const + { + return toFunctionNode(actions_dag, func_name, args); + } + +public: + explicit FunctionParserBinaryArithmetic(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } + const ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, ActionsDAGPtr & actions_dag) const override + { + const auto ch_func_name = getCHFunctionName(substrait_func); + auto parsed_args = parseFunctionArguments(substrait_func, ch_func_name, actions_dag); + + if (parsed_args.size() != 2) + throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); + + const auto left_type = DB::removeNullable(parsed_args[0]->result_type); + const auto right_type = DB::removeNullable(parsed_args[1]->result_type); + const bool converted = isDecimal(left_type) && isDecimal(right_type); + + if (converted) + { + const DecimalType evalType = getDecimalType(left_type, right_type, false); + parsed_args = convertBinaryArithmeticFunDecimalArgs(actions_dag, parsed_args, evalType, substrait_func); + } + + const auto * func_node = createFunctionNode(actions_dag, ch_func_name, parsed_args); + + if (converted) + { + const auto parsed_outputType = removeNullable(TypeParser::parseType(substrait_func.output_type())); + assert(isDecimal(parsed_outputType)); + const Int32 parsed_precision = getDecimalPrecision(*parsed_outputType); + const Int32 parsed_scale = getDecimalScale(*parsed_outputType); + +#ifndef NDEBUG + const auto [precision, scale] = getDecimalType(left_type, right_type, true); + // assert(parsed_precision == precision); + // assert(parsed_scale == scale); +#endif + func_node = checkDecimalOverflow(actions_dag, func_node, parsed_precision, parsed_scale); + } + return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag); + } +}; + +class FunctionParserPlus final : public FunctionParserBinaryArithmetic +{ +public: + explicit FunctionParserPlus(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { } + + static constexpr auto name = "add"; + String getName() const override { return name; } + +protected: + DecimalType internalResultType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override + { + return DecimalType::resultAddSubstractDecimalType(p1, s1, p2, s2); + } + DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override + { + return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2); + } +}; + +class FunctionParserMinus final : public FunctionParserBinaryArithmetic +{ +public: + explicit FunctionParserMinus(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { } + + static constexpr auto name = "subtract"; + String getName() const override { return name; } + +protected: + DecimalType internalResultType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override + { + return DecimalType::resultAddSubstractDecimalType(p1, s1, p2, s2); + } + DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override + { + return DecimalType::evalAddSubstractDecimalType(p1, s1, p2, s2); + } +}; + +class FunctionParserMultiply final : public FunctionParserBinaryArithmetic +{ +public: + explicit FunctionParserMultiply(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { } + static constexpr auto name = "multiply"; + String getName() const override { return name; } + +protected: + DecimalType internalResultType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override + { + return DecimalType::resultMultiplyDecimalType(p1, s1, p2, s2); + } + DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override + { + return DecimalType::evalMultiplyDecimalType(p1, s1, p2, s2); + } +}; + +class FunctionParserDivide final : public FunctionParserBinaryArithmetic +{ +public: + explicit FunctionParserDivide(SerializedPlanParser * plan_parser_) : FunctionParserBinaryArithmetic(plan_parser_) { } + static constexpr auto name = "divide"; + String getName() const override { return name; } + +protected: + DecimalType internalResultType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override + { + return DecimalType::resultDivideDecimalType(p1, s1, p2, s2); + } + DecimalType internalEvalType(const Int32 p1, const Int32 s1, const Int32 p2, const Int32 s2) const override + { + return DecimalType::evalDividetDecimalType(p1, s1, p2, s2); + } + + const DB::ActionsDAG::Node * createFunctionNode( + DB::ActionsDAGPtr & actions_dag, const String & func_name, const DB::ActionsDAG::NodeRawConstPtrs & new_args) const override + { + assert(func_name == name); + const auto * left_arg = new_args[0]; + const auto * right_arg = new_args[1]; + + if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type))) + return toFunctionNode(actions_dag, "sparkDivideDecimal", {left_arg, right_arg}); + else + return toFunctionNode(actions_dag, "sparkDivide", {left_arg, right_arg}); + } +}; + +static FunctionParserRegister register_plus; +static FunctionParserRegister register_minus; +static FunctionParserRegister register_mltiply; +static FunctionParserRegister register_divide; + +} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/divide.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/divide.cpp deleted file mode 100644 index 5c1eb358b31d..000000000000 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/divide.cpp +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int BAD_ARGUMENTS; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; -} -} - -namespace local_engine -{ - -class FunctionParserDivide : public FunctionParser -{ -public: - explicit FunctionParserDivide(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) { } - ~FunctionParserDivide() override = default; - - static constexpr auto name = "divide"; - - String getName() const override { return name; } - - const ActionsDAG::Node * parse( - const substrait::Expression_ScalarFunction & substrait_func, - ActionsDAGPtr & actions_dag) const override - { - /// Parse divide(left, right) as if (right == 0) null else left / right - auto parsed_args = parseFunctionArguments(substrait_func, "", actions_dag); - if (parsed_args.size() != 2) - throw Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} requires exactly two arguments", getName()); - - ActionsDAG::NodeRawConstPtrs new_args{parsed_args[0], parsed_args[1]}; - plan_parser->convertBinaryArithmeticFunDecimalArgs(actions_dag, new_args, substrait_func); - - const auto * left_arg = new_args[0]; - const auto * right_arg = new_args[1]; - - if (isDecimal(removeNullable(left_arg->result_type)) || isDecimal(removeNullable(right_arg->result_type))) - return toFunctionNode(actions_dag, "sparkDivideDecimal", {left_arg, right_arg}); - else - return toFunctionNode(actions_dag, "sparkDivide", {left_arg, right_arg}); - } -}; - -static FunctionParserRegister register_divide; -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index fa795b84b64d..ddf62201daeb 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -98,6 +98,14 @@ trait BackendSettingsApi { def allowDecimalArithmetic: Boolean = true + /** + * After https://github.com/apache/spark/pull/36698, every arithmetic should report the accurate + * result decimal type and implement `CheckOverflow` by itself.

Regardless of whether there + * is 36698 or not, this option is used to indicate whether to transform `CheckOverflow`. `false` + * means the backend will implement `CheckOverflow` by default and no need to transform it. + */ + def transformCheckOverflow: Boolean = true + def rescaleDecimalIntegralExpression(): Boolean = false def shuffleSupportedCodec(): Set[String] diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 7815cbf69ebd..562ae294e2c0 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -119,7 +119,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { case i: StaticInvoke => val objectName = i.staticObject.getName.stripSuffix("$") if (objectName.endsWith("UrlCodec")) { - val child = i.arguments(0) + val child = i.arguments.head i.functionName match { case "decode" => return GenericExpressionTransformer( @@ -138,20 +138,8 @@ object ExpressionConverter extends SQLConfHelper with Logging { case _ => } - TestStats.addExpressionClassName(expr.getClass.getName) - // Check whether Gluten supports this expression - val substraitExprNameOpt = expressionsMap.get(expr.getClass) - if (substraitExprNameOpt.isEmpty) { - throw new GlutenNotSupportException( - s"Not supported to map spark function name" + - s" to substrait function name: $expr, class name: ${expr.getClass.getSimpleName}.") - } - val substraitExprName = substraitExprNameOpt.get + val substraitExprName: String = getAndCheckSubstraitName(expr, expressionsMap) - // Check whether each backend supports this expression - if (!BackendsApiManager.getValidatorApiInstance.doExprValidate(substraitExprName, expr)) { - throw new GlutenNotSupportException(s"Not supported: $expr.") - } expr match { case extendedExpr if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains( @@ -162,7 +150,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { case c: CreateArray => val children = c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)) - CreateArrayTransformer(substraitExprName, children, true, c) + CreateArrayTransformer(substraitExprName, children, useStringTypeWhenEmpty = true, c) case g: GetArrayItem => GetArrayItemTransformer( substraitExprName, @@ -319,7 +307,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { i.hset, i.child.dataType, i) - case s: org.apache.spark.sql.execution.ScalarSubquery => + case s: ScalarSubquery => ScalarSubqueryTransformer(s.plan, s.exprId, s) case c: Cast => // Add trim node, as necessary. @@ -463,7 +451,6 @@ object ExpressionConverter extends SQLConfHelper with Logging { expressionsMap), arguments = lambdaFunction.arguments.map( replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), - hidden = false, original = lambdaFunction ) case j: JsonTuple => @@ -477,11 +464,6 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(l.right, attributeSeq, expressionsMap), l ) - case c: CheckOverflow => - CheckOverflowTransformer( - substraitExprName, - replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap), - c) case m: MakeDecimal => MakeDecimalTransformer( substraitExprName, @@ -510,42 +492,71 @@ object ExpressionConverter extends SQLConfHelper with Logging { expr.children.map( replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), expr) - case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) => - // PrecisionLoss=true: velox support / ch not support - // PrecisionLoss=false: velox not support / ch support - // TODO ch support PrecisionLoss=true - if (!BackendsApiManager.getSettings.allowDecimalArithmetic) { - throw new GlutenNotSupportException( - s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " + - s"${conf.decimalOperationsAllowPrecisionLoss} mode") - } - val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) { - DecimalArithmeticUtil.rescaleLiteral(b) - } else { - b - } - val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal( - DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left), - DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right)) - val leftChild = replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap) - val rightChild = - replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap) - val resultType = DecimalArithmeticUtil.getResultTypeForOperation( - DecimalArithmeticUtil.getOperationType(b), - DecimalArithmeticUtil - .getResultType(leftChild) - .getOrElse(left.dataType.asInstanceOf[DecimalType]), - DecimalArithmeticUtil - .getResultType(rightChild) - .getOrElse(right.dataType.asInstanceOf[DecimalType]) - ) + case CheckOverflow(b: BinaryArithmetic, decimalType, _) + if !BackendsApiManager.getSettings.transformCheckOverflow && + DecimalArithmeticUtil.isDecimalArithmetic(b) => + DecimalArithmeticUtil.checkAllowDecimalArithmetic() + val leftChild = + replaceWithExpressionTransformerInternal(b.left, attributeSeq, expressionsMap) + val rightChild = + replaceWithExpressionTransformerInternal(b.right, attributeSeq, expressionsMap) DecimalArithmeticExpressionTransformer( - substraitExprName, + getAndCheckSubstraitName(b, expressionsMap), leftChild, rightChild, - resultType, + decimalType, b) + + case c: CheckOverflow => + CheckOverflowTransformer( + substraitExprName, + replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap), + c) + + case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) => + DecimalArithmeticUtil.checkAllowDecimalArithmetic() + if (!BackendsApiManager.getSettings.transformCheckOverflow) { + val leftChild = + replaceWithExpressionTransformerInternal(b.left, attributeSeq, expressionsMap) + val rightChild = + replaceWithExpressionTransformerInternal(b.right, attributeSeq, expressionsMap) + DecimalArithmeticExpressionTransformer( + substraitExprName, + leftChild, + rightChild, + b.dataType.asInstanceOf[DecimalType], + b) + } else { + val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) { + DecimalArithmeticUtil.rescaleLiteral(b) + } else { + b + } + val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal( + DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left), + DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right)) + val leftChild = + replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap) + val rightChild = + replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap) + + val resultType = DecimalArithmeticUtil.getResultTypeForOperation( + DecimalArithmeticUtil.getOperationType(b), + DecimalArithmeticUtil + .getResultType(leftChild) + .getOrElse(left.dataType.asInstanceOf[DecimalType]), + DecimalArithmeticUtil + .getResultType(rightChild) + .getOrElse(right.dataType.asInstanceOf[DecimalType]) + ) + DecimalArithmeticExpressionTransformer( + substraitExprName, + leftChild, + rightChild, + resultType, + b) + } case n: NaNvl => BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer( substraitExprName, @@ -651,6 +662,23 @@ object ExpressionConverter extends SQLConfHelper with Logging { } } + private def getAndCheckSubstraitName(expr: Expression, expressionsMap: Map[Class[_], String]) = { + TestStats.addExpressionClassName(expr.getClass.getName) + // Check whether Gluten supports this expression + val substraitExprNameOpt = expressionsMap.get(expr.getClass) + if (substraitExprNameOpt.isEmpty) { + throw new GlutenNotSupportException( + s"Not supported to map spark function name" + + s" to substrait function name: $expr, class name: ${expr.getClass.getSimpleName}.") + } + val substraitExprName = substraitExprNameOpt.get + // Check whether each backend supports this expression + if (!BackendsApiManager.getValidatorApiInstance.doExprValidate(substraitExprName, expr)) { + throw new GlutenNotSupportException(s"Not supported: $expr.") + } + substraitExprName + } + /** * Transform BroadcastExchangeExec to ColumnarBroadcastExchangeExec in DynamicPruningExpression. * diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala index 621dcc061ec7..ff63a1726393 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/DecimalArithmeticUtil.scala @@ -19,11 +19,15 @@ package org.apache.gluten.utils import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.expression.{CheckOverflowTransformer, ChildTransformer, DecimalArithmeticExpressionTransformer, ExpressionTransformer} +import org.apache.gluten.expression.ExpressionConverter.conf import org.apache.spark.sql.catalyst.analysis.DecimalPrecision import org.apache.spark.sql.catalyst.expressions.{Add, BinaryArithmetic, Cast, Divide, Expression, Literal, Multiply, Pmod, PromotePrecision, Remainder, Subtract} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ByteType, Decimal, DecimalType, IntegerType, LongType, ShortType} +import scala.annotation.tailrec + object DecimalArithmeticUtil { object OperationType extends Enumeration { @@ -31,7 +35,7 @@ object DecimalArithmeticUtil { val ADD, SUBTRACT, MULTIPLY, DIVIDE, MOD = Value } - val MIN_ADJUSTED_SCALE = 6 + private val MIN_ADJUSTED_SCALE = 6 val MAX_PRECISION = 38 // Returns the result decimal type of a decimal arithmetic computing. @@ -67,7 +71,7 @@ object DecimalArithmeticUtil { } // Returns the adjusted decimal type when the precision is larger the maximum. - def adjustScaleIfNeeded(precision: Int, scale: Int): DecimalType = { + private def adjustScaleIfNeeded(precision: Int, scale: Int): DecimalType = { var typePrecision = precision var typeScale = scale if (precision > MAX_PRECISION) { @@ -159,56 +163,33 @@ object DecimalArithmeticUtil { } // Returns whether the input expression is a combination of PromotePrecision(Cast as DecimalType). - private def isPromoteCast(expr: Expression): Boolean = { - expr match { - case precision: PromotePrecision => - precision.child match { - case cast: Cast if cast.dataType.isInstanceOf[DecimalType] => true - case _ => false - } - case _ => false - } + private def isPromoteCast(expr: Expression): Boolean = expr match { + case PromotePrecision(Cast(_, _: DecimalType, _, _)) => true + case _ => false } def rescaleCastForDecimal(left: Expression, right: Expression): (Expression, Expression) = { - if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) { - return (left, right) + + def doScale(e1: Expression, e2: Expression): (Expression, Expression) = { + val newE2 = rescaleCastForOneSide(e2) + val isWiderType = checkIsWiderType( + e1.dataType.asInstanceOf[DecimalType], + newE2.dataType.asInstanceOf[DecimalType], + e2.dataType.asInstanceOf[DecimalType]) + if (isWiderType) (e1, newE2) else (e1, e2) } - // Decimal * cast int. - if (!isPromoteCast(left)) { + + if (!BackendsApiManager.getSettings.rescaleDecimalIntegralExpression()) { + (left, right) + } else if (!isPromoteCast(left) && isPromoteCastIntegral(right)) { // Have removed PromotePrecision(Cast(DecimalType)). - if (isPromoteCastIntegral(right)) { - val newRight = rescaleCastForOneSide(right) - val isWiderType = checkIsWiderType( - left.dataType.asInstanceOf[DecimalType], - newRight.dataType.asInstanceOf[DecimalType], - right.dataType.asInstanceOf[DecimalType]) - if (isWiderType) { - (left, newRight) - } else { - (left, right) - } - } else { - (left, right) - } + // Decimal * cast int. + doScale(left, right) + } else if (!isPromoteCast(right) && isPromoteCastIntegral(left)) { // Cast int * decimal. - } else if (!isPromoteCast(right)) { - if (isPromoteCastIntegral(left)) { - val newLeft = rescaleCastForOneSide(left) - val isWiderType = checkIsWiderType( - newLeft.dataType.asInstanceOf[DecimalType], - right.dataType.asInstanceOf[DecimalType], - left.dataType.asInstanceOf[DecimalType]) - if (isWiderType) { - (newLeft, right) - } else { - (left, right) - } - } else { - (left, right) - } + val (r, l) = doScale(right, left) + (l, r) } else { - // Cast int * cast int. Usually user defined cast. (left, right) } } @@ -235,6 +216,7 @@ object DecimalArithmeticUtil { } } + @tailrec def getResultType(transformer: ExpressionTransformer): Option[DecimalType] = { transformer match { case ChildTransformer(child) => @@ -289,4 +271,15 @@ object DecimalArithmeticUtil { val widerType = DecimalPrecision.widerDecimalType(left, right) widerType.equals(wider) } + + def checkAllowDecimalArithmetic(): Unit = { + // PrecisionLoss=true: velox support / ch not support + // PrecisionLoss=false: velox not support / ch support + // TODO ch support PrecisionLoss=true + if (!BackendsApiManager.getSettings.allowDecimalArithmetic) { + throw new GlutenNotSupportException( + s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " + + s"${conf.decimalOperationsAllowPrecisionLoss} mode") + } + } }