diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index eec0ad874c5d..748bd5a7f7f6 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2551,5 +2551,17 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) spark.sql("drop table test_tbl_5096") } + + test("GLUTEN-5896: Bug fix greatest diff") { + val tbl_create_sql = + "create table test_tbl_5896(id bigint, x1 int, x2 int, x3 int) using parquet" + val tbl_insert_sql = + "insert into test_tbl_5896 values(1, 12, NULL, 13), (2, NULL, NULL, NULL), (3, 11, NULL, NULL), (4, 10, 9, 8)" + val select_sql = "select id, greatest(x1, x2, x3) from test_tbl_5896" + spark.sql(tbl_create_sql) + spark.sql(tbl_insert_sql) + compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) + spark.sql("drop table test_tbl_5896") + } } // scalastyle:on line.size.limit diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp new file mode 100644 index 000000000000..9577d65ec5f7 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ +class SparkFunctionGreatest : public DB::FunctionLeastGreatestGeneric +{ +public: + static constexpr auto name = "sparkGreatest"; + static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } + SparkFunctionGreatest() = default; + ~SparkFunctionGreatest() override = default; + bool useDefaultImplementationForNulls() const override { return false; } + +private: + DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override + { + if (types.empty()) + throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", name); + return makeNullable(getLeastSupertype(types)); + } + + DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override + { + size_t num_arguments = arguments.size(); + DB::Columns converted_columns(num_arguments); + for (size_t arg = 0; arg < num_arguments; ++arg) + converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst(); + auto result_column = result_type->createColumn(); + result_column->reserve(input_rows_count); + for (size_t row_num = 0; row_num < input_rows_count; ++row_num) + { + size_t best_arg = 0; + for (size_t arg = 1; arg < num_arguments; ++arg) + { + auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1); + if (cmp_result > 0) + best_arg = arg; + } + result_column->insertFrom(*converted_columns[best_arg], row_num); + } + return result_column; + } +}; + +REGISTER_FUNCTION(SparkGreatest) +{ + factory.registerFunction(); +} +} diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index a636ebb9352f..73448b0690c2 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -104,7 +104,7 @@ static const std::map SCALAR_FUNCTIONS {"hypot", "hypot"}, {"sign", "sign"}, {"radians", "radians"}, - {"greatest", "greatest"}, + {"greatest", "sparkGreatest"}, {"least", "least"}, {"shiftleft", "bitShiftLeft"}, {"shiftright", "bitShiftRight"},