From 7f51251288a5c183de6e9b6b3ccd1fc6066a6265 Mon Sep 17 00:00:00 2001 From: loneylee Date: Thu, 12 Sep 2024 19:01:11 +0800 Subject: [PATCH] fix modulo --- .../SparkFunctionDecimalBinaryArithmetic.cpp | 49 ++++++++++++++----- .../scalar_function_parser/arithmetic.cpp | 2 +- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp index 27a8e3fbefa32..7e9d25919a411 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp @@ -383,19 +383,40 @@ struct SparkDecimalBinaryOperation const size_t & max_scale) { if constexpr (CalculateWith256) - return calculateImpl(l, r, scale_left, scale_right, res, resultDataType, max_scale); + return calculateImpl( + static_cast(l), + static_cast(r), + static_cast(scale_left), + static_cast(scale_right), + res, + resultDataType, + max_scale); else if (is_division) - return calculateImpl(l, r, scale_left, scale_right, res, resultDataType, max_scale); + return calculateImpl( + static_cast(l), + static_cast(r), + static_cast(scale_left), + static_cast(scale_right), + res, + resultDataType, + max_scale); else - return calculateImpl(l, r, scale_left, scale_right, res, resultDataType, max_scale); + return calculateImpl( + static_cast(l), + static_cast(r), + static_cast(scale_left), + static_cast(scale_right), + res, + resultDataType, + max_scale); } - template + template static NO_SANITIZE_UNDEFINED bool calculateImpl( - LeftNativeType l, - RightNativeType r, - NativeResultType scale_left, - NativeResultType scale_right, + CalcType l, + CalcType r, + CalcType scale_left, + CalcType scale_right, NativeResultType & res, const ResultDataType & resultDataType, const size_t & max_scale) @@ -412,13 +433,19 @@ struct SparkDecimalBinaryOperation auto scale_diff = max_scale - result_scale; chassert(scale_diff >= 0); if (scale_diff) - c_res = c_res / DecimalUtils::scaleMultiplier(scale_diff); + { + auto scaled_diff = DecimalUtils::scaleMultiplier(scale_diff); + DecimalDivideImpl::apply(c_res, scaled_diff, c_res); + } auto max_value = intExp10OfSize(resultDataType.getPrecision()); // check overflow - if (c_res <= -max_value || c_res >= max_value) - return false; + if constexpr (std::is_same_v || is_division) + { + if (c_res <= -max_value || c_res >= max_value) + return false; + } res = static_cast(c_res); diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp index dae74fe4c32bd..d96c09e44f11f 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/arithmetic.cpp @@ -326,7 +326,7 @@ class FunctionParserModulo final : public FunctionParserBinaryArithmetic return toFunctionNode(actions_dag, function_name, {left_arg, right_arg, type_node}); } - return toFunctionNode(actions_dag, "sparkDivide", {left_arg, right_arg}); + return toFunctionNode(actions_dag, "modulo", {left_arg, right_arg}); } };