From 6518d76334e6b740016a724d30f3914ea98e9345 Mon Sep 17 00:00:00 2001 From: Chang chen Date: Sun, 7 Jan 2024 19:41:34 +0800 Subject: [PATCH] Revert "[GLUTEN-4249][CH]Improve cast (#4250)" (#4307) This reverts commit 026e67f2de2e6c8318d5dc060f47b36ee1501dc6. --- .../Functions/SparkFunctionFloor.h | 23 +++++++++++++++++ .../Parser/SerializedPlanParser.cpp | 25 ++++++++++++++----- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionFloor.h b/cpp-ch/local-engine/Functions/SparkFunctionFloor.h index 3ae19fe1813f..64792327590d 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionFloor.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionFloor.h @@ -41,6 +41,21 @@ static void checkAndSetNullable(T & t, UInt8 & null_flag) == 0b0111111111110000000000000000000000000000000000000000000000000000); null_flag = is_nan | is_inf; + + /* Equivalent code: + if (null_flag) + t = 0; + */ + if constexpr (std::is_same_v) + { + UInt32 * uint_data = reinterpret_cast(&t); + *uint_data &= ~(-null_flag); + } + else + { + UInt64 * uint_data = reinterpret_cast(&t); + *uint_data &= ~(-null_flag); + } } DECLARE_AVX2_SPECIFIC_CODE( @@ -48,6 +63,7 @@ DECLARE_AVX2_SPECIFIC_CODE( inline void checkFloat32AndSetNullables(Float32 * data, UInt8 * null_map, size_t size) { const __m256 inf = _mm256_set1_ps(INFINITY); const __m256 neg_inf = _mm256_set1_ps(-INFINITY); + const __m256 zero = _mm256_set1_ps(0.0f); size_t i = 0; for (; i + 7 < size; i += 8) @@ -58,6 +74,9 @@ DECLARE_AVX2_SPECIFIC_CODE( __m256 is_neg_inf = _mm256_cmp_ps(values, neg_inf, _CMP_EQ_OQ); __m256 is_nan = _mm256_cmp_ps(values, values, _CMP_NEQ_UQ); __m256 is_null = _mm256_or_ps(_mm256_or_ps(is_inf, is_neg_inf), is_nan); + __m256 new_values = _mm256_blendv_ps(values, zero, is_null); + + _mm256_storeu_ps(&data[i], new_values); UInt32 mask = static_cast(_mm256_movemask_ps(is_null)); for (size_t j = 0; j < 8; ++j) @@ -75,6 +94,7 @@ DECLARE_AVX2_SPECIFIC_CODE( inline void checkFloat64AndSetNullables(Float64 * data, UInt8 * null_map, size_t size) { const __m256d inf = _mm256_set1_pd(INFINITY); const __m256d neg_inf = _mm256_set1_pd(-INFINITY); + const __m256d zero = _mm256_set1_pd(0.0); size_t i = 0; for (; i + 3 < size; i += 4) @@ -85,6 +105,9 @@ DECLARE_AVX2_SPECIFIC_CODE( __m256d is_neg_inf = _mm256_cmp_pd(values, neg_inf, _CMP_EQ_OQ); __m256d is_nan = _mm256_cmp_pd(values, values, _CMP_NEQ_UQ); __m256d is_null = _mm256_or_pd(_mm256_or_pd(is_inf, is_neg_inf), is_nan); + __m256d new_values = _mm256_blendv_pd(values, zero, is_null); + + _mm256_storeu_pd(&data[i], new_values); UInt32 mask = static_cast(_mm256_movemask_pd(is_null)); for (size_t j = 0; j < 4; ++j) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index c427aaf6c83f..a6becc48b437 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -1020,12 +1020,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type) && !converted_decimal_args) { auto result_type = TypeParser::parseType(rel.scalar_function().output_type()); - bool castNullableFloatToInt = false; - if (function_node->result_type->isNullable() && result_type->isNullable() - && isFloat(DB::removeNullable(function_node->result_type)) - && isInt(DB::removeNullable(result_type))) - castNullableFloatToInt = true; - if (isDecimalOrNullableDecimal(result_type) || castNullableFloatToInt) + if (isDecimalOrNullableDecimal(result_type)) { result_node = ActionsDAGUtil::convertNodeType( actions_dag, @@ -1647,6 +1642,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr act args.emplace_back(parseExpression(actions_dag, input)); const auto & substrait_type = rel.cast().type(); + auto to_ch_type = TypeParser::parseType(substrait_type); const ActionsDAG::Node * function_node = nullptr; if (DB::isString(DB::removeNullable(args.back()->result_type)) && substrait_type.has_date()) { @@ -1665,6 +1661,23 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr act // Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x) function_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", args); } + else if (DB::isFloat(DB::removeNullable(args[0]->result_type)) && DB::isNativeInteger(DB::removeNullable(to_ch_type))) + { + /// It looks like by design in CH that forbids cast NaN/Inf to integer. + auto zero_node = addColumn(actions_dag, args[0]->result_type, 0.0); + const auto * if_not_finite_node = toFunctionNode(actions_dag, "ifNotFinite", {args[0], zero_node}); + const auto * final_arg_node = if_not_finite_node; + if (args[0]->result_type->isNullable()) + { + DB::Field null_field; + const auto * null_value = addColumn(actions_dag, args[0]->result_type, null_field); + const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {args[0]}); + const auto * if_node = toFunctionNode(actions_dag, "if", {is_null_node, null_value, if_not_finite_node}); + final_arg_node = if_node; + } + function_node = toFunctionNode( + actions_dag, "CAST", {final_arg_node, addColumn(actions_dag, std::make_shared(), to_ch_type->getName())}); + } else { DataTypePtr ch_type = TypeParser::parseType(substrait_type);