diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 118f8418609d..188995f11058 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2575,12 +2575,12 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr spark.sql("drop table test_tbl_5096") } - test("GLUTEN-5896: Bug fix greatest diff") { + test("GLUTEN-5896: Bug fix greatest/least diff") { val tbl_create_sql = "create table test_tbl_5896(id bigint, x1 int, x2 int, x3 int) using parquet" val tbl_insert_sql = "insert into test_tbl_5896 values(1, 12, NULL, 13), (2, NULL, NULL, NULL), (3, 11, NULL, NULL), (4, 10, 9, 8)" - val select_sql = "select id, greatest(x1, x2, x3) from test_tbl_5896" + val select_sql = "select id, greatest(x1, x2, x3), least(x1, x2, x3) from test_tbl_5896" spark.sql(tbl_create_sql) spark.sql(tbl_insert_sql) compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) diff --git a/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h new file mode 100644 index 000000000000..6930c1d75b79 --- /dev/null +++ b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} +namespace local_engine +{ +template +class FunctionGreatestestLeast : public DB::FunctionLeastGreatestGeneric +{ +public: + bool useDefaultImplementationForNulls() const override { return false; } + virtual String getName() const = 0; + +private: + DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override + { + if (types.empty()) + throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", getName()); + return makeNullable(getLeastSupertype(types)); + } + + DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override + { + size_t num_arguments = arguments.size(); + DB::Columns converted_columns(num_arguments); + for (size_t arg = 0; arg < num_arguments; ++arg) + converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst(); + auto result_column = result_type->createColumn(); + result_column->reserve(input_rows_count); + for (size_t row_num = 0; row_num < input_rows_count; ++row_num) + { + size_t best_arg = 0; + for (size_t arg = 1; arg < num_arguments; ++arg) + { + if constexpr (kind == DB::LeastGreatest::Greatest) + { + auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1); + if (cmp_result > 0) + best_arg = arg; + } + else + { + auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], 1); + if (cmp_result < 0) + best_arg = arg; + } + } + result_column->insertFrom(*converted_columns[best_arg], row_num); + } + return result_column; + } +}; + +} diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp index 9577d65ec5f7..920fe1b9c9cc 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp @@ -14,58 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; -} -} +#include namespace local_engine { -class SparkFunctionGreatest : public DB::FunctionLeastGreatestGeneric +class SparkFunctionGreatest : public FunctionGreatestestLeast { public: static constexpr auto name = "sparkGreatest"; static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } SparkFunctionGreatest() = default; ~SparkFunctionGreatest() override = default; - bool useDefaultImplementationForNulls() const override { return false; } - -private: - DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override - { - if (types.empty()) - throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", name); - return makeNullable(getLeastSupertype(types)); - } - - DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override + String getName() const override { - size_t num_arguments = arguments.size(); - DB::Columns converted_columns(num_arguments); - for (size_t arg = 0; arg < num_arguments; ++arg) - converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst(); - auto result_column = result_type->createColumn(); - result_column->reserve(input_rows_count); - for (size_t row_num = 0; row_num < input_rows_count; ++row_num) - { - size_t best_arg = 0; - for (size_t arg = 1; arg < num_arguments; ++arg) - { - auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1); - if (cmp_result > 0) - best_arg = arg; - } - result_column->insertFrom(*converted_columns[best_arg], row_num); - } - return result_column; - } + return name; + } }; REGISTER_FUNCTION(SparkGreatest) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp new file mode 100644 index 000000000000..70aafdf07209 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +namespace local_engine +{ +class SparkFunctionLeast : public FunctionGreatestestLeast +{ +public: + static constexpr auto name = "sparkLeast"; + static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } + SparkFunctionLeast() = default; + ~SparkFunctionLeast() override = default; + String getName() const override + { + return name; + } +}; + +REGISTER_FUNCTION(SparkLeast) +{ + factory.registerFunction(); +} +} diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 6ce92b558b73..184065836e65 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -105,7 +105,7 @@ static const std::map SCALAR_FUNCTIONS {"sign", "sign"}, {"radians", "radians"}, {"greatest", "sparkGreatest"}, - {"least", "least"}, + {"least", "sparkLeast"}, {"shiftleft", "bitShiftLeft"}, {"shiftright", "bitShiftRight"}, {"check_overflow", "checkDecimalOverflowSpark"},