Skip to content

Commit

Permalink
convert nan to null from stddev
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed May 30, 2024
1 parent d35d1dc commit 42f3c58
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
AAAAAAAADOCAAAAA|-|Little, national services will buy young molecules. In part video-taped activities join now|-|TN|-|1|-|24.0|-|NaN|-|NaN|-|1|-|11.0|-|NaN|-|NaN|-|1|-|49.0|-|NaN|-|NaN
AAAAAAAAEBOBAAAA|-|Special words should tell by a follower|-|TN|-|1|-|66.0|-|NaN|-|NaN|-|1|-|38.0|-|NaN|-|NaN|-|1|-|56.0|-|NaN|-|NaN
AAAAAAAADOCAAAAA|-|Little, national services will buy young molecules. In part video-taped activities join now|-|TN|-|1|-|24.0|-|null|-|null|-|1|-|11.0|-|null|-|null|-|1|-|49.0|-|null|-|null
AAAAAAAAEBOBAAAA|-|Special words should tell by a follower|-|TN|-|1|-|66.0|-|null|-|null|-|1|-|38.0|-|null|-|null|-|1|-|56.0|-|null|-|null
Original file line number Diff line number Diff line change
Expand Up @@ -481,5 +481,20 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {

spark.sql(table_drop_sql)
}

test("GLUTEN-5904 NaN values from stddev") {
val sql1 =
"""
|select a, stddev(b/c) from (select * from values (1,2, 1), (1,3,0) as data(a,b,c))
|group by a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql1, true, { _ => })
val sql2 =
"""
|select a, stddev(b) from (select * from values (1,2, 1) as data(a,b,c)) group by a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql2, true, { _ => })

}
}
// scalastyle:off line.size.limit
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Sum, sum, sum)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Avg, avg, avg)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Min, min, min)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(Max, max, max)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(StdDev, stddev, stddev_samp)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(StdDevSamp, stddev_samp, stddev_samp)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(StdDevPop, stddev_pop, stddev_pop)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(BitAnd, bit_and, groupBitAnd)
REGISTER_COMMON_AGGREGATE_FUNCTION_PARSER(BitOr, bit_or, groupBitOr)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Parser/AggregateFunctionParser.h>
#include <DataTypes/DataTypeNullable.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>

namespace local_engine
{
/// For stddev
struct StddevNameStruct
{
static constexpr auto spark_name = "stddev";
static constexpr auto ch_name = "stddev";
};

struct StddevSampNameStruct
{
static constexpr auto spark_name = "stddev_samp";
static constexpr auto ch_name = "stddev_samp";
};
template <typename NameStruct>
class AggregateFunctionParserStddev final : public AggregateFunctionParser
{
public:
AggregateFunctionParserStddev(SerializedPlanParser * plan_parser_) : AggregateFunctionParser(plan_parser_) { }
~AggregateFunctionParserStddev() override = default;
String getName() const override { return NameStruct::spark_name; }
static constexpr auto name = NameStruct::spark_name;
String getCHFunctionName(const CommonFunctionInfo &) const override { return NameStruct::ch_name; }
String getCHFunctionName(DB::DataTypes &) const override { return NameStruct::ch_name; }
const DB::ActionsDAG::Node * convertNodeTypeIfNeeded(
const CommonFunctionInfo & func_info,
const DB::ActionsDAG::Node * func_node,
DB::ActionsDAGPtr & actions_dag,
bool with_nullability) const override
{
/// result is nullable.
/// if result is NaN, convert it to NULL.
auto is_nan_func_node = toFunctionNode(actions_dag, "isNaN", getUniqueName("isNaN"), {func_node});
auto null_type = DB::makeNullable(func_node->result_type);
auto nullable_col = null_type->createColumn();
nullable_col->insertDefault();
const auto * null_node
= &actions_dag->addColumn(DB::ColumnWithTypeAndName(std::move(nullable_col), null_type, getUniqueName("null")));
DB::ActionsDAG::NodeRawConstPtrs convert_nan_func_args = {is_nan_func_node, null_node, func_node};

func_node = toFunctionNode(actions_dag, "if", func_node->result_name, convert_nan_func_args);
actions_dag->addOrReplaceInOutputs(*func_node);
return func_node;
}
};

static const AggregateFunctionParserRegister<AggregateFunctionParserStddev<StddevNameStruct>> registerer_stddev;
static const AggregateFunctionParserRegister<AggregateFunctionParserStddev<StddevSampNameStruct>> registerer_stddev_samp;
}

0 comments on commit 42f3c58

Please sign in to comment.