Skip to content

Commit

Permalink
[GLUTEN-5896][CH]Fix greatest diff #5920
Browse files Browse the repository at this point in the history
What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)

(Fixes: #5896)

How was this patch tested?
TEST BY UT
  • Loading branch information
KevinyhZou authored May 31, 2024
1 parent e870de8 commit c94cde4
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2551,5 +2551,17 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
spark.sql("drop table test_tbl_5096")
}

test("GLUTEN-5896: Bug fix greatest diff") {
val tbl_create_sql =
"create table test_tbl_5896(id bigint, x1 int, x2 int, x3 int) using parquet"
val tbl_insert_sql =
"insert into test_tbl_5896 values(1, 12, NULL, 13), (2, NULL, NULL, NULL), (3, 11, NULL, NULL), (4, 10, 9, 8)"
val select_sql = "select id, greatest(x1, x2, x3) from test_tbl_5896"
spark.sql(tbl_create_sql)
spark.sql(tbl_insert_sql)
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
spark.sql("drop table test_tbl_5896")
}
}
// scalastyle:on line.size.limit
75 changes: 75 additions & 0 deletions cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Functions/LeastGreatestGeneric.h>
#include <DataTypes/getLeastSupertype.h>
#include <DataTypes/DataTypeNullable.h>

namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}

namespace local_engine
{
class SparkFunctionGreatest : public DB::FunctionLeastGreatestGeneric<DB::LeastGreatest::Greatest>
{
public:
static constexpr auto name = "sparkGreatest";
static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<SparkFunctionGreatest>(); }
SparkFunctionGreatest() = default;
~SparkFunctionGreatest() override = default;
bool useDefaultImplementationForNulls() const override { return false; }

private:
DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override
{
if (types.empty())
throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", name);
return makeNullable(getLeastSupertype(types));
}

DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override
{
size_t num_arguments = arguments.size();
DB::Columns converted_columns(num_arguments);
for (size_t arg = 0; arg < num_arguments; ++arg)
converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst();
auto result_column = result_type->createColumn();
result_column->reserve(input_rows_count);
for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
{
size_t best_arg = 0;
for (size_t arg = 1; arg < num_arguments; ++arg)
{
auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1);
if (cmp_result > 0)
best_arg = arg;
}
result_column->insertFrom(*converted_columns[best_arg], row_num);
}
return result_column;
}
};

REGISTER_FUNCTION(SparkGreatest)
{
factory.registerFunction<SparkFunctionGreatest>();
}
}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ static const std::map<std::string, std::string> SCALAR_FUNCTIONS
{"hypot", "hypot"},
{"sign", "sign"},
{"radians", "radians"},
{"greatest", "greatest"},
{"greatest", "sparkGreatest"},
{"least", "least"},
{"shiftleft", "bitShiftLeft"},
{"shiftright", "bitShiftRight"},
Expand Down

0 comments on commit c94cde4

Please sign in to comment.