diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp index 11b1f994e6557..808aba02a96d4 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.cpp @@ -133,31 +133,46 @@ struct SparkDecimalBinaryOperation const ColVecRight * col_right_vec = checkAndGetColumn(col_right.get()); size_t rows = col_left->size(); - if constexpr (Mode == OpMode::Effect) + size_t max_scale = getMaxScaled(left_type.getScale(), right_type.getScale(), result_type.getScale()); + + bool calculate_with_i256 = false; + if constexpr (Mode != OpMode::Effect) { - return executeDecimalImpl, false>( - left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type); + if (shouldPromoteTo256(left_type, right_type, result_type)) + calculate_with_i256 = true; + + if (is_division && max_scale - left_type.getScale() + max_scale > ResultDataType::maxPrecision()) + calculate_with_i256 = true; } - if (shouldPromoteTo256(left_type, right_type, result_type)) + auto p1 = left_type.getPrecision(); + auto p2 = right_type.getPrecision(); + if (DataTypeDecimal::maxPrecision() < p1 + max_scale - left_type.getScale() + || DataTypeDecimal::maxPrecision() < p2 + max_scale - right_type.getScale()) + calculate_with_i256 = true; + + if (calculate_with_i256) { - return executeDecimalImpl( + /// Use Int256 for calculation + return executeDecimalImpl( left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type); } - - size_t max_scale = getMaxScaled(left_type.getScale(), right_type.getScale(), result_type.getScale()); - if (is_division && max_scale - left_type.getScale() + max_scale > ResultDataType::maxPrecision()) + else if constexpr (is_division) { - return executeDecimalImpl( + /// Use Int128 for calculation + return executeDecimalImpl( + left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type); + } + else + { + /// Use ResultNativeType for calculation + return executeDecimalImpl>( left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type); } - - return executeDecimalImpl, false>( - left_type, right_type, col_left_const, col_right_const, col_left_vec, col_right_vec, rows, result_type); } private: - template + template static ColumnPtr executeDecimalImpl( const LeftDataType & left_type, const RightDataType & right_type, @@ -195,17 +210,15 @@ struct SparkDecimalBinaryOperation return DecimalUtils::scaleMultiplier(max_scale - right_type.getScale()); }(); - bool calculate_with_256 = false; - if constexpr (calculate_with_256_) - calculate_with_256 = true; - else + ScaledNativeType unscale_result = [&] { - auto p1 = left_type.getPrecision(); - auto p2 = right_type.getPrecision(); - if (DataTypeDecimal::maxPrecision() < p1 + max_scale - left_type.getScale() - || DataTypeDecimal::maxPrecision() < p2 + max_scale - right_type.getScale()) - calculate_with_256 = true; - } + auto result_scale = result_type.getScale(); + auto diff = max_scale - result_scale; + chassert(diff >= 0); + return DecimalUtils::scaleMultiplier(diff); + }(); + + ScaledNativeType max_value = intExp10OfSize(result_type.getPrecision()); auto res_vec = ColVecResult::create(rows, result_type.getScale()); auto & res_vec_data = res_vec->getData(); @@ -214,22 +227,7 @@ struct SparkDecimalBinaryOperation if (col_left_vec && col_right_vec) { - if (calculate_with_256) - { - process( - col_left_vec->getData().data(), - col_right_vec->getData().data(), - res_vec_data, - res_nullmap_data, - rows, - scale_left, - scale_right, - max_scale, - result_type); - } - else - { - process( + process( col_left_vec->getData().data(), col_right_vec->getData().data(), res_vec_data, @@ -237,69 +235,36 @@ struct SparkDecimalBinaryOperation rows, scale_left, scale_right, - max_scale, - result_type); - } + unscale_result, + max_value); } else if (col_left_const && col_right_vec) { LeftFieldType left_value = col_left_const->getValue(); - if (calculate_with_256) - { - process( - &left_value, - col_right_vec->getData().data(), - res_vec_data, - res_nullmap_data, - rows, - scale_left, - scale_right, - max_scale, - result_type); - } - else - { - process( - &left_value, - col_right_vec->getData().data(), - res_vec_data, - res_nullmap_data, - rows, - scale_left, - scale_right, - max_scale, - result_type); - } + process( + &left_value, + col_right_vec->getData().data(), + res_vec_data, + res_nullmap_data, + rows, + scale_left, + scale_right, + unscale_result, + max_value); } else if (col_left_vec && col_right_const) { RightFieldType right_value = col_right_const->getValue(); - if (calculate_with_256) - { - process( - col_left_vec->getData().data(), - &right_value, - res_vec_data, - res_nullmap_data, - rows, - scale_left, - scale_right, - max_scale, - result_type); - } - else - { - process( - col_left_vec->getData().data(), - &right_value, - res_vec_data, - res_nullmap_data, - rows, - scale_left, - scale_right, - max_scale, - result_type); - } + process( + col_left_vec->getData().data(), + &right_value, + res_vec_data, + res_nullmap_data, + rows, + scale_left, + scale_right, + unscale_result, + max_value); } else throw Exception( @@ -312,150 +277,98 @@ struct SparkDecimalBinaryOperation return ColumnNullable::create(std::move(res_vec), std::move(res_null_map)); } - template < - OpCase op_case, - bool calculate_with_256, - typename LeftFieldType, - typename RightFieldType, - typename ResultDataType, - typename ScaledNativeType> - static void NO_INLINE process( - const LeftFieldType * __restrict left_data, // maybe scalar or vector - const RightFieldType * __restrict right_data, // maybe scalar or vector - PaddedPODArray & __restrict res_vec_data, // should be vector - NullMap & res_nullmap_data, - size_t rows, - const ScaledNativeType & scale_left, - const ScaledNativeType & scale_right, - size_t max_scale, - const ResultDataType & result_type) - { - using ResultNativeType = NativeType; - - if constexpr (op_case == OpCase::Vector) + template < + OpCase op_case, + typename LeftFieldType, + typename RightFieldType, + typename ResultFieldType, + typename ScaledNativeType> + static void NO_INLINE process( + const LeftFieldType * __restrict left_data, // maybe scalar or vector + const RightFieldType * __restrict right_data, // maybe scalar or vector + PaddedPODArray & __restrict res_vec_data, // should be vector + NullMap & res_nullmap_data, + size_t rows, + const ScaledNativeType & scale_left, + const ScaledNativeType & scale_right, + const ScaledNativeType & unscale_result, + const ScaledNativeType & max_value) { - for (size_t i = 0; i < rows; ++i) + using ResultNativeType = NativeType; + + if constexpr (op_case == OpCase::Vector) { - ResultNativeType res; - if (calculate( + for (size_t i = 0; i < rows; ++i) + res_nullmap_data[i] = !calculate( static_cast(unwrap(left_data, i)), static_cast(unwrap(right_data, i)), scale_left, scale_right, - max_scale, - result_type, - res)) - res_vec_data[i] = res; - else - res_nullmap_data[i] = 1; + unscale_result, + max_value, + res_vec_data[i].value); } - } - else if constexpr (op_case == OpCase::LeftConstant) - { - ScaledNativeType scaled_left - = applyScaled(static_cast(unwrap(left_data, 0)), scale_left); - - for (size_t i = 0; i < rows; ++i) + else if constexpr (op_case == OpCase::LeftConstant) { - ResultNativeType res; - if (calculate( + ScaledNativeType scaled_left + = applyScaled(static_cast(unwrap(left_data, 0)), scale_left); + + for (size_t i = 0; i < rows; ++i) + res_nullmap_data[i] = !calculate( scaled_left, static_cast(unwrap(right_data, i)), - static_cast(0), + static_cast(1), scale_right, - max_scale, - result_type, - res)) - res_vec_data[i] = res; - else - res_nullmap_data[i] = 1; + unscale_result, + max_value, + res_vec_data[i].value); } - } - else if constexpr (op_case == OpCase::RightConstant) - { - ScaledNativeType scaled_right - = applyScaled(static_cast(unwrap(right_data, 0)), scale_right); - - for (size_t i = 0; i < rows; ++i) + else if constexpr (op_case == OpCase::RightConstant) { - ResultNativeType res; - if (calculate( + ScaledNativeType scaled_right + = applyScaled(static_cast(unwrap(right_data, 0)), scale_right); + + for (size_t i = 0; i < rows; ++i) + res_nullmap_data[i] = !calculate( static_cast(unwrap(left_data, i)), scaled_right, scale_left, - static_cast(0), - max_scale, - result_type, - res)) - res_vec_data[i] = res; - else - res_nullmap_data[i] = 1; + static_cast(1), + unscale_result, + max_value, + res_vec_data[i].value); } - } } template < - bool calculate_with_256, typename ScaledNativeType, - typename ResultNativeType, - typename ResultDataType> - static NO_SANITIZE_UNDEFINED bool calculate( + typename ResultNativeType> + static ALWAYS_INLINE bool calculate( const ScaledNativeType & left, const ScaledNativeType & right, const ScaledNativeType & scale_left, const ScaledNativeType & scale_right, - size_t max_scale, - const ResultDataType & result_type, + const ScaledNativeType & unscale_result, + const ScaledNativeType & max_value, ResultNativeType & res) { - if constexpr (calculate_with_256) - return calculateImpl(left, right, scale_left, scale_right, max_scale, result_type, res); - else if constexpr (is_division) - return calculateImpl(left, right, scale_left, scale_right, max_scale, result_type, res); - else - return calculateImpl(left, right, scale_left, scale_right, max_scale, result_type, res); - } + auto scaled_left = scale_left > 1 ? applyScaled(left, scale_left) : left; + auto scaled_right = scale_right > 1 ? applyScaled(right, scale_right) : right; - template < - typename CalculateType, - typename ScaledNativeType, - typename ResultNativeType, - typename ResultDataType> - static NO_SANITIZE_UNDEFINED bool calculateImpl( - const ScaledNativeType & left, - const ScaledNativeType & right, - const ScaledNativeType & scale_left, - const ScaledNativeType & scale_right, - size_t max_scale, - const ResultDataType & result_type, - ResultNativeType & res) - { - CalculateType scaled_left = applyScaled(static_cast(left), static_cast(scale_left)); - CalculateType scaled_right = applyScaled(static_cast(right), static_cast(scale_right)); - CalculateType c_res = 0; - auto success = Operation::template apply(scaled_left, scaled_right, c_res); + ScaledNativeType c_res = 0; + auto success = Operation::template apply(scaled_left, scaled_right, c_res); if (!success) return false; - auto result_scale = result_type.getScale(); - auto scale_diff = max_scale - result_scale; - chassert(scale_diff >= 0); - if (scale_diff) - { - auto scaled_diff = DecimalUtils::scaleMultiplier(scale_diff); - DecimalDivideImpl::apply(c_res, scaled_diff, c_res); - } - - // check overflow - if constexpr (std::is_same_v || is_division) - { - auto max_value = intExp10OfSize(result_type.getPrecision()); - if (c_res <= -max_value || c_res >= max_value) - return false; - } + if (unscale_result > 1) + c_res = applyUnscaled(c_res, unscale_result); res = static_cast(c_res); - return true; + + if constexpr (std::is_same_v || is_division) + return c_res > -max_value && c_res < max_value; + else + return true; } /// Unwrap underlying native type from decimal type @@ -468,13 +381,25 @@ struct SparkDecimalBinaryOperation return elem[i].value; } + + template + static ALWAYS_INLINE T applyScaled(T n, T scale) + { + chassert(scale != 0); + + T res; + DecimalMultiplyImpl::apply(n, scale, res); + return res; + } + template - static T applyScaled(T n, T scale) + static ALWAYS_INLINE T applyUnscaled(T n, T scale) { - if (scale > 1) - return common::mulIgnoreOverflow(n, scale); + chassert(scale != 0); - return n; + T res; + DecimalDivideImpl::apply(n, scale, res); + return res; } }; @@ -652,19 +577,14 @@ class SparkFunctionDecimalBinaryArithmetic final : public IFunction const RightDataType & right_type, const ResultDataType & result_type) { - // std::cout << "left_type:" << left_type.getName() << " right_type:" << right_type.getName() - // << " result_type:" << result_type.getName() << std::endl; - auto & b = static_cast &>(builder); DataTypePtr calculate_type = std::make_shared>(); - // std::cout << "calculate_type_bytes:" << sizeof(calculate_type) << std::endl; auto * left = nativeCast(b, arguments[0], calculate_type); auto * right = nativeCast(b, arguments[1], calculate_type); size_t max_scale = SparkDecimalBinaryOperation::getMaxScaled( left_type.getScale(), right_type.getScale(), result_type.getScale()); - // std::cout << "max_scale:" << max_scale << std::endl; CalculateType scale_left = [&] { @@ -677,7 +597,6 @@ class SparkFunctionDecimalBinaryArithmetic final : public IFunction else return DecimalUtils::scaleMultiplier(diff); }(); - // std::cout << "scale_left:" << toString(Field{scale_left}) << std::endl; CalculateType scale_right = [&] { @@ -686,7 +605,6 @@ class SparkFunctionDecimalBinaryArithmetic final : public IFunction else return DecimalUtils::scaleMultiplier(max_scale - right_type.getScale()); }(); - // std::cout << "scale_right:" << toString(Field{scale_right}) << std::endl; auto * scaled_left = b.CreateMul(left, getNativeConstant(b, scale_left)); auto * scaled_right = b.CreateMul(right, getNativeConstant(b, scale_right)); @@ -718,7 +636,6 @@ class SparkFunctionDecimalBinaryArithmetic final : public IFunction auto result_scale = result_type.getScale(); auto scale_diff = max_scale - result_scale; - // std::cout << "result_scale:" << result_scale << " scale_diff:" << scale_diff << std::endl; auto * unscaled_result = scaled_result; if (scale_diff) { diff --git a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h index 9957ebdd607ee..e62abda1508e1 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionDecimalBinaryArithmetic.h @@ -42,6 +42,11 @@ static bool canCastLower(const Int128 & a, const Int128 & b) return a.items[1] == 0 && b.items[1] == 0; } +static bool canCastLower(const UInt128 & a, const UInt128 & b) +{ + return a.items[1] == 0 && b.items[1] == 0; +} + static const Int256 & toInt256(const NewInt256 & value) { return *reinterpret_cast(&value); @@ -58,7 +63,8 @@ struct DecimalPlusImpl template static bool apply(T a, T b, T & r) { - return !common::addOverflow(a, b, r); + r = a + b; + return true; } template <> @@ -67,13 +73,16 @@ struct DecimalPlusImpl if (canCastLower(a, b)) { UInt64 low_result; - if (common::addOverflow(static_cast(a), static_cast(b), low_result)) - return !common::addOverflow(a, b, r); - - r = static_cast(low_result); - return true; + if (!common::addOverflow(static_cast(a), static_cast(b), low_result)) + { + r = static_cast(low_result); + chassert(r == a + b); + return true; + } } - return !common::addOverflow(a, b, r); + + r = a + b; + return true; } template <> @@ -82,16 +91,17 @@ struct DecimalPlusImpl if (canCastLower(a, b)) { UInt128 low_result; - if (common::addOverflow(static_cast(a), static_cast(b), low_result)) - return !common::addOverflow(a, b, r); - - r = static_cast(low_result); - return true; + if (!common::addOverflow(static_cast(a), static_cast(b), low_result)) + { + r = static_cast(low_result); + chassert(r == a + b); + return true; + } } - return !common::addOverflow(a, b, r); - // r = toInt256(toNewInt256(a) + toNewInt256(b)); - // return true; + r = toInt256(toNewInt256(a) + toNewInt256(b)); + chassert(r == a + b); + return true; } #if USE_EMBEDDED_COMPILER @@ -110,7 +120,8 @@ struct DecimalMinusImpl template static bool apply(T a, T b, T & r) { - return !common::subOverflow(a, b, r); + r = a - b; + return true; } template <> @@ -118,15 +129,17 @@ struct DecimalMinusImpl { if (canCastLower(a, b)) { - UInt64 low_result; - if (common::subOverflow(static_cast(a), static_cast(b), low_result)) - return !common::subOverflow(a, b, r); - - r = static_cast(low_result); - return true; + Int64 low_result; + if (!common::subOverflow(static_cast(a), static_cast(b), low_result)) + { + r = static_cast(low_result); + chassert(r == a - b); + return true; + } } - return !common::subOverflow(a, b, r); + r = a - b; + return true; } template <> @@ -134,17 +147,18 @@ struct DecimalMinusImpl { if (canCastLower(a, b)) { - UInt128 low_result; - if (common::subOverflow(static_cast(a), static_cast(b), low_result)) - return !common::subOverflow(a, b, r); - - r = static_cast(low_result); - return true; + Int128 low_result; + if (!common::subOverflow(static_cast(a), static_cast(b), low_result)) + { + r = static_cast(low_result); + chassert(r == a - b); + return true; + } } - return !common::subOverflow(a, b, r); - // r = toInt256(toNewInt256(a) - toNewInt256(b)); - // return true; + r = toInt256(toNewInt256(a) - toNewInt256(b)); + chassert(r == a - b); + return true; } @@ -165,30 +179,34 @@ struct DecimalMultiplyImpl template static bool apply(T a, T b, T & c) { - return !common::mulOverflow(a, b, c); + c = a * b; + return true; } - template + template <> static bool apply(Int128 a, Int128 b, Int128 & r) { if (canCastLower(a, b)) { UInt64 low_result = 0; - if (common::mulOverflow(static_cast(a), static_cast(b), low_result)) - return !common::mulOverflow(a, b, r); - - r = static_cast(low_result); - return true; + if (!common::mulOverflow(static_cast(a), static_cast(b), low_result)) + { + r = static_cast(low_result); + chassert(r == a * b); + return true; + } } - return !common::mulOverflow(a, b, r); + r = a * b; + return true; } template <> static bool apply(Int256 a, Int256 b, Int256 & r) { - // r = toInt256(toNewInt256(a) * toNewInt256(b)); - r = a * b; + /// Notice that we can't use common::mulOverflow here because it doesn't support checking overflow on Int128 multiplication. + r = toInt256(toNewInt256(a) * toNewInt256(b)); + chassert(r == a * b); return true; } @@ -222,7 +240,9 @@ struct DecimalDivideImpl if (canCastLower(a, b)) { + /// We must cast to UInt64 to avoid overflow in the division. r = static_cast(static_cast(a) / static_cast(b)); + chassert(r == a / b); return true; } @@ -230,6 +250,25 @@ struct DecimalDivideImpl return true; } + template <> + static bool apply(UInt128 a, UInt128 b, UInt128 & r) + { + if (b == 0) + return false; + + if (canCastLower(a, b)) + { + /// We must cast to UInt64 to avoid overflow in the division. + r = static_cast(static_cast(a) / static_cast(b)); + chassert(r == a / b); + return true; + } + + r = a / b; + return true; + } + + template <> static bool apply(Int256 a, Int256 b, Int256 & r) { @@ -238,16 +277,16 @@ struct DecimalDivideImpl if (canCastLower(a, b)) { - UInt128 low_result = 0; - UInt128 low_a = static_cast(a); - UInt128 low_b = static_cast(b); - apply(low_a, low_b, low_result); + /// We must cast to UInt128 to avoid overflow in the division. + UInt128 low_result; + apply(static_cast(a), static_cast(b), low_result); r = static_cast(low_result); + chassert(r == a / b); return true; } - r = a / b; - // r = toInt256(toNewInt256(a) / toNewInt256(b)); + r = toInt256(toNewInt256(a) / toNewInt256(b)); + chassert(r == a / b); return true; } @@ -275,6 +314,44 @@ struct DecimalModuloImpl return true; } + template <> + static bool apply(Int128 a, Int128 b, Int128 & r) + { + if (b == 0) + return false; + + if (canCastLower(a, b)) + { + /// We must cast to UInt64 to avoid overflow in the division. + r = static_cast(static_cast(a) % static_cast(b)); + chassert(r == a % b); + return true; + } + + r = a % b; + return true; + } + + + template <> + static bool apply(Int256 a, Int256 b, Int256 & r) + { + if (b == 0) + return false; + + if (canCastLower(a, b)) + { + /// We must cast to UInt128 to avoid overflow in the division. + r = static_cast(static_cast(a) % static_cast(b)); + chassert(r == a % b); + return true; + } + + r = toInt256(toNewInt256(a) % toNewInt256(b)); + chassert(r == a % b); + return true; + } + #if USE_EMBEDDED_COMPILER static constexpr bool compilable = true;