Skip to content

Commit

Permalink
Revert "[GLUTEN-4249][CH]Improve cast (apache#4250)" (apache#4307)
Browse files Browse the repository at this point in the history
This reverts commit 026e67f.
  • Loading branch information
baibaichen authored Jan 7, 2024
1 parent d8248ea commit 6518d76
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
23 changes: 23 additions & 0 deletions cpp-ch/local-engine/Functions/SparkFunctionFloor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,29 @@ static void checkAndSetNullable(T & t, UInt8 & null_flag)
== 0b0111111111110000000000000000000000000000000000000000000000000000);

null_flag = is_nan | is_inf;

/* Equivalent code:
if (null_flag)
t = 0;
*/
if constexpr (std::is_same_v<T, float>)
{
UInt32 * uint_data = reinterpret_cast<UInt32 *>(&t);
*uint_data &= ~(-null_flag);
}
else
{
UInt64 * uint_data = reinterpret_cast<UInt64 *>(&t);
*uint_data &= ~(-null_flag);
}
}

DECLARE_AVX2_SPECIFIC_CODE(

inline void checkFloat32AndSetNullables(Float32 * data, UInt8 * null_map, size_t size) {
const __m256 inf = _mm256_set1_ps(INFINITY);
const __m256 neg_inf = _mm256_set1_ps(-INFINITY);
const __m256 zero = _mm256_set1_ps(0.0f);

size_t i = 0;
for (; i + 7 < size; i += 8)
Expand All @@ -58,6 +74,9 @@ DECLARE_AVX2_SPECIFIC_CODE(
__m256 is_neg_inf = _mm256_cmp_ps(values, neg_inf, _CMP_EQ_OQ);
__m256 is_nan = _mm256_cmp_ps(values, values, _CMP_NEQ_UQ);
__m256 is_null = _mm256_or_ps(_mm256_or_ps(is_inf, is_neg_inf), is_nan);
__m256 new_values = _mm256_blendv_ps(values, zero, is_null);

_mm256_storeu_ps(&data[i], new_values);

UInt32 mask = static_cast<UInt32>(_mm256_movemask_ps(is_null));
for (size_t j = 0; j < 8; ++j)
Expand All @@ -75,6 +94,7 @@ DECLARE_AVX2_SPECIFIC_CODE(
inline void checkFloat64AndSetNullables(Float64 * data, UInt8 * null_map, size_t size) {
const __m256d inf = _mm256_set1_pd(INFINITY);
const __m256d neg_inf = _mm256_set1_pd(-INFINITY);
const __m256d zero = _mm256_set1_pd(0.0);

size_t i = 0;
for (; i + 3 < size; i += 4)
Expand All @@ -85,6 +105,9 @@ DECLARE_AVX2_SPECIFIC_CODE(
__m256d is_neg_inf = _mm256_cmp_pd(values, neg_inf, _CMP_EQ_OQ);
__m256d is_nan = _mm256_cmp_pd(values, values, _CMP_NEQ_UQ);
__m256d is_null = _mm256_or_pd(_mm256_or_pd(is_inf, is_neg_inf), is_nan);
__m256d new_values = _mm256_blendv_pd(values, zero, is_null);

_mm256_storeu_pd(&data[i], new_values);

UInt32 mask = static_cast<UInt32>(_mm256_movemask_pd(is_null));
for (size_t j = 0; j < 4; ++j)
Expand Down
25 changes: 19 additions & 6 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,12 +1020,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type) && !converted_decimal_args)
{
auto result_type = TypeParser::parseType(rel.scalar_function().output_type());
bool castNullableFloatToInt = false;
if (function_node->result_type->isNullable() && result_type->isNullable()
&& isFloat(DB::removeNullable(function_node->result_type))
&& isInt(DB::removeNullable(result_type)))
castNullableFloatToInt = true;
if (isDecimalOrNullableDecimal(result_type) || castNullableFloatToInt)
if (isDecimalOrNullableDecimal(result_type))
{
result_node = ActionsDAGUtil::convertNodeType(
actions_dag,
Expand Down Expand Up @@ -1647,6 +1642,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr act
args.emplace_back(parseExpression(actions_dag, input));

const auto & substrait_type = rel.cast().type();
auto to_ch_type = TypeParser::parseType(substrait_type);
const ActionsDAG::Node * function_node = nullptr;
if (DB::isString(DB::removeNullable(args.back()->result_type)) && substrait_type.has_date())
{
Expand All @@ -1665,6 +1661,23 @@ const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr act
// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
function_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", args);
}
else if (DB::isFloat(DB::removeNullable(args[0]->result_type)) && DB::isNativeInteger(DB::removeNullable(to_ch_type)))
{
/// It looks like by design in CH that forbids cast NaN/Inf to integer.
auto zero_node = addColumn(actions_dag, args[0]->result_type, 0.0);
const auto * if_not_finite_node = toFunctionNode(actions_dag, "ifNotFinite", {args[0], zero_node});
const auto * final_arg_node = if_not_finite_node;
if (args[0]->result_type->isNullable())
{
DB::Field null_field;
const auto * null_value = addColumn(actions_dag, args[0]->result_type, null_field);
const auto * is_null_node = toFunctionNode(actions_dag, "isNull", {args[0]});
const auto * if_node = toFunctionNode(actions_dag, "if", {is_null_node, null_value, if_not_finite_node});
final_arg_node = if_node;
}
function_node = toFunctionNode(
actions_dag, "CAST", {final_arg_node, addColumn(actions_dag, std::make_shared<DataTypeString>(), to_ch_type->getName())});
}
else
{
DataTypePtr ch_type = TypeParser::parseType(substrait_type);
Expand Down

0 comments on commit 6518d76

Please sign in to comment.