Skip to content

Commit

Permalink
[GLUTEN-6156][CH]Fix least diff (#6155)
Browse files Browse the repository at this point in the history
What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)

(Fixes: #6156)

How was this patch tested?
test by ut
  • Loading branch information
KevinyhZou authored Jun 26, 2024
1 parent 774c668 commit 10a663c
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2575,12 +2575,12 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
spark.sql("drop table test_tbl_5096")
}

test("GLUTEN-5896: Bug fix greatest diff") {
test("GLUTEN-5896: Bug fix greatest/least diff") {
val tbl_create_sql =
"create table test_tbl_5896(id bigint, x1 int, x2 int, x3 int) using parquet"
val tbl_insert_sql =
"insert into test_tbl_5896 values(1, 12, NULL, 13), (2, NULL, NULL, NULL), (3, 11, NULL, NULL), (4, 10, 9, 8)"
val select_sql = "select id, greatest(x1, x2, x3) from test_tbl_5896"
val select_sql = "select id, greatest(x1, x2, x3), least(x1, x2, x3) from test_tbl_5896"
spark.sql(tbl_create_sql)
spark.sql(tbl_insert_sql)
compareResultsAgainstVanillaSpark(select_sql, true, { _ => })
Expand Down
77 changes: 77 additions & 0 deletions cpp-ch/local-engine/Functions/FunctionGreatestLeast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Functions/LeastGreatestGeneric.h>
#include <DataTypes/getLeastSupertype.h>
#include <DataTypes/DataTypeNullable.h>

namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}
namespace local_engine
{
template <DB::LeastGreatest kind>
class FunctionGreatestestLeast : public DB::FunctionLeastGreatestGeneric<kind>
{
public:
bool useDefaultImplementationForNulls() const override { return false; }
virtual String getName() const = 0;

private:
DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override
{
if (types.empty())
throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", getName());
return makeNullable(getLeastSupertype(types));
}

DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override
{
size_t num_arguments = arguments.size();
DB::Columns converted_columns(num_arguments);
for (size_t arg = 0; arg < num_arguments; ++arg)
converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst();
auto result_column = result_type->createColumn();
result_column->reserve(input_rows_count);
for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
{
size_t best_arg = 0;
for (size_t arg = 1; arg < num_arguments; ++arg)
{
if constexpr (kind == DB::LeastGreatest::Greatest)
{
auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1);
if (cmp_result > 0)
best_arg = arg;
}
else
{
auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], 1);
if (cmp_result < 0)
best_arg = arg;
}
}
result_column->insertFrom(*converted_columns[best_arg], row_num);
}
return result_column;
}
};

}
47 changes: 5 additions & 42 deletions cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,58 +14,21 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Functions/LeastGreatestGeneric.h>
#include <DataTypes/getLeastSupertype.h>
#include <DataTypes/DataTypeNullable.h>

namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
}
#include <Functions/FunctionGreatestLeast.h>

namespace local_engine
{
class SparkFunctionGreatest : public DB::FunctionLeastGreatestGeneric<DB::LeastGreatest::Greatest>
class SparkFunctionGreatest : public FunctionGreatestestLeast<DB::LeastGreatest::Greatest>
{
public:
static constexpr auto name = "sparkGreatest";
static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<SparkFunctionGreatest>(); }
SparkFunctionGreatest() = default;
~SparkFunctionGreatest() override = default;
bool useDefaultImplementationForNulls() const override { return false; }

private:
DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override
{
if (types.empty())
throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", name);
return makeNullable(getLeastSupertype(types));
}

DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override
String getName() const override
{
size_t num_arguments = arguments.size();
DB::Columns converted_columns(num_arguments);
for (size_t arg = 0; arg < num_arguments; ++arg)
converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst();
auto result_column = result_type->createColumn();
result_column->reserve(input_rows_count);
for (size_t row_num = 0; row_num < input_rows_count; ++row_num)
{
size_t best_arg = 0;
for (size_t arg = 1; arg < num_arguments; ++arg)
{
auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1);
if (cmp_result > 0)
best_arg = arg;
}
result_column->insertFrom(*converted_columns[best_arg], row_num);
}
return result_column;
}
return name;
}
};

REGISTER_FUNCTION(SparkGreatest)
Expand Down
38 changes: 38 additions & 0 deletions cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <Functions/FunctionGreatestLeast.h>

namespace local_engine
{
class SparkFunctionLeast : public FunctionGreatestestLeast<DB::LeastGreatest::Least>
{
public:
static constexpr auto name = "sparkLeast";
static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared<SparkFunctionLeast>(); }
SparkFunctionLeast() = default;
~SparkFunctionLeast() override = default;
String getName() const override
{
return name;
}
};

REGISTER_FUNCTION(SparkLeast)
{
factory.registerFunction<SparkFunctionLeast>();
}
}
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/SerializedPlanParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ static const std::map<std::string, std::string> SCALAR_FUNCTIONS
{"sign", "sign"},
{"radians", "radians"},
{"greatest", "sparkGreatest"},
{"least", "least"},
{"least", "sparkLeast"},
{"shiftleft", "bitShiftLeft"},
{"shiftright", "bitShiftRight"},
{"check_overflow", "checkDecimalOverflowSpark"},
Expand Down

0 comments on commit 10a663c

Please sign in to comment.