From 758c3729687e83366aa531f7ee4b19ea67a5ef80 Mon Sep 17 00:00:00 2001 From: yan ma Date: Fri, 26 Apr 2024 18:27:25 +0800 Subject: [PATCH 1/2] [GLUTEN-4652] Fix min_by/max_by result mismatch --- .../gluten/utils/VeloxIntermediateData.scala | 5 +- .../VeloxAggregateFunctionsSuite.scala | 18 ++++++ cpp/velox/CMakeLists.txt | 1 + .../functions/RegistrationAllFunctions.cc | 13 +++- .../functions/RowConstructorWithAllNull.cc | 63 +++++++++++++++++++ .../functions/RowConstructorWithAllNull.h | 44 +++++++++++++ .../operators/functions/RowFunctionWithNull.h | 23 +++++-- 7 files changed, 159 insertions(+), 8 deletions(-) create mode 100644 cpp/velox/operators/functions/RowConstructorWithAllNull.cc create mode 100644 cpp/velox/operators/functions/RowConstructorWithAllNull.h diff --git a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala index e6a8bf2c8f20..8f95afe8d0e8 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala @@ -159,7 +159,10 @@ object VeloxIntermediateData { * row_constructor_with_null. */ def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc match { - case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => "row_constructor" + case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => + "row_constructor" + case _: MaxMinBy => + "row_constructor_with_all_null" case _ => "row_constructor_with_null" } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala index 394c4e01651e..70fff52b84d6 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxAggregateFunctionsSuite.scala @@ -27,6 +27,8 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu override protected val resourcePath: String = "/tpch-data-parquet-velox" override protected val fileFormat: String = "parquet" + import testImplicits._ + override def beforeAll(): Unit = { super.beforeAll() createTPCHNotNullTables() @@ -188,6 +190,22 @@ abstract class VeloxAggregateFunctionsSuite extends VeloxWholeStageTransformerSu } } + test("min_by/max_by") { + withTempPath { + path => + Seq((5: Integer, 6: Integer), (null: Integer, 11: Integer), (null: Integer, 5: Integer)) + .toDF("a", "b") + .write + .parquet(path.getCanonicalPath) + spark.read + .parquet(path.getCanonicalPath) + .createOrReplaceTempView("test") + runQueryAndCompare("select min_by(a, b), max_by(a, b) from test") { + checkGlutenOperatorMatch[HashAggregateExecTransformer] + } + } + } + test("groupby") { val df = runQueryAndCompare( "select l_orderkey, sum(l_partkey) as sum from lineitem " + diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt index 55f45881b108..146b094c5cbb 100644 --- a/cpp/velox/CMakeLists.txt +++ b/cpp/velox/CMakeLists.txt @@ -306,6 +306,7 @@ set(VELOX_SRCS memory/VeloxMemoryManager.cc operators/functions/RegistrationAllFunctions.cc operators/functions/RowConstructorWithNull.cc + operators/functions/RowConstructorWithAllNull.cc operators/functions/SparkTokenizer.cc operators/serializer/VeloxColumnarToRowConverter.cc operators/serializer/VeloxColumnarBatchSerializer.cc diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index 2d2e820f1d03..c77fa47e5bff 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -16,6 +16,7 @@ */ #include "operators/functions/RegistrationAllFunctions.h" #include "operators/functions/Arithmetic.h" +#include "operators/functions/RowConstructorWithAllNull.h" #include "operators/functions/RowConstructorWithNull.h" #include "operators/functions/RowFunctionWithNull.h" @@ -47,11 +48,19 @@ void registerFunctionOverwrite() { velox::exec::registerVectorFunction( "row_constructor_with_null", std::vector>{}, - std::make_unique(), - RowFunctionWithNull::metadata()); + std::make_unique>(), + RowFunctionWithNull::metadata()); velox::exec::registerFunctionCallToSpecialForm( RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull, std::make_unique()); + velox::exec::registerVectorFunction( + "row_constructor_with_all_null", + std::vector>{}, + std::make_unique>(), + RowFunctionWithNull::metadata()); + velox::exec::registerFunctionCallToSpecialForm( + RowConstructorWithAllNullCallToSpecialForm::kRowConstructorWithAllNull, + std::make_unique()); velox::functions::sparksql::registerBitwiseFunctions("spark_"); } } // namespace diff --git a/cpp/velox/operators/functions/RowConstructorWithAllNull.cc b/cpp/velox/operators/functions/RowConstructorWithAllNull.cc new file mode 100644 index 000000000000..9b9da2d8bf66 --- /dev/null +++ b/cpp/velox/operators/functions/RowConstructorWithAllNull.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "RowConstructorWithAllNull.h" +#include "velox/expression/VectorFunction.h" + +namespace gluten { +facebook::velox::TypePtr RowConstructorWithAllNullCallToSpecialForm::resolveType( + const std::vector& argTypes) { + auto numInput = argTypes.size(); + std::vector names(numInput); + std::vector types(numInput); + for (auto i = 0; i < numInput; i++) { + types[i] = argTypes[i]; + names[i] = fmt::format("c{}", i + 1); + } + return facebook::velox::ROW(std::move(names), std::move(types)); +} + +facebook::velox::exec::ExprPtr RowConstructorWithAllNullCallToSpecialForm::constructSpecialForm( + const std::string& name, + const facebook::velox::TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage, + const facebook::velox::core::QueryConfig& config) { + auto [function, metadata] = facebook::velox::exec::vectorFunctionFactories().withRLock( + [&config, &name](auto& functionMap) -> std::pair< + std::shared_ptr, + facebook::velox::exec::VectorFunctionMetadata> { + auto functionIterator = functionMap.find(name); + if (functionIterator != functionMap.end()) { + return {functionIterator->second.factory(name, {}, config), functionIterator->second.metadata}; + } else { + VELOX_FAIL("Function {} is not registered.", name); + } + }); + + return std::make_shared( + type, std::move(compiledChildren), function, metadata, name, trackCpuUsage); +} + +facebook::velox::exec::ExprPtr RowConstructorWithAllNullCallToSpecialForm::constructSpecialForm( + const facebook::velox::TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage, + const facebook::velox::core::QueryConfig& config) { + return constructSpecialForm(kRowConstructorWithAllNull, type, std::move(compiledChildren), trackCpuUsage, config); +} +} // namespace gluten diff --git a/cpp/velox/operators/functions/RowConstructorWithAllNull.h b/cpp/velox/operators/functions/RowConstructorWithAllNull.h new file mode 100644 index 000000000000..186498d8b30c --- /dev/null +++ b/cpp/velox/operators/functions/RowConstructorWithAllNull.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/expression/FunctionCallToSpecialForm.h" +#include "velox/expression/SpecialForm.h" + +namespace gluten { +class RowConstructorWithAllNullCallToSpecialForm : public facebook::velox::exec::FunctionCallToSpecialForm { + public: + facebook::velox::TypePtr resolveType(const std::vector& argTypes) override; + + facebook::velox::exec::ExprPtr constructSpecialForm( + const facebook::velox::TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage, + const facebook::velox::core::QueryConfig& config) override; + + static constexpr const char* kRowConstructorWithAllNull = "row_constructor_with_all_null"; + + protected: + facebook::velox::exec::ExprPtr constructSpecialForm( + const std::string& name, + const facebook::velox::TypePtr& type, + std::vector&& compiledChildren, + bool trackCpuUsage, + const facebook::velox::core::QueryConfig& config); +}; +} // namespace gluten diff --git a/cpp/velox/operators/functions/RowFunctionWithNull.h b/cpp/velox/operators/functions/RowFunctionWithNull.h index 9ed6bc27792a..9ee1828785f2 100644 --- a/cpp/velox/operators/functions/RowFunctionWithNull.h +++ b/cpp/velox/operators/functions/RowFunctionWithNull.h @@ -23,8 +23,10 @@ namespace gluten { /** - * A customized RowFunction to set struct as null when one of its argument is null. + * @tparam allNull If true, set struct as null when all of arguments are all, else will + * set it null when one of its arguments is null. */ +template class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction { public: void apply( @@ -42,15 +44,26 @@ class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction { rows.applyToSelected([&](facebook::velox::vector_size_t i) { facebook::velox::bits::clearNull(nullsPtr, i); if (!facebook::velox::bits::isBitNull(nullsPtr, i)) { + int argsNullCnt = 0; for (size_t c = 0; c < argsCopy.size(); c++) { auto arg = argsCopy[c].get(); if (arg->mayHaveNulls() && arg->isNullAt(i)) { - // If any argument of the struct is null, set the struct as null. - facebook::velox::bits::setNull(nullsPtr, i, true); - cntNull++; - break; + // For row_constructor_with_null, if any argument of the struct is null, + // set the struct as null. + if (!allNull) { + facebook::velox::bits::setNull(nullsPtr, i, true); + cntNull++; + break; + } else { + argsNullCnt++; + } } } + // For row_constructor_with_all_null, set the struct to be null when all arguments are all + if (allNull && argsNullCnt == argsCopy.size()) { + facebook::velox::bits::setNull(nullsPtr, i, true); + cntNull++; + } } }); From bd5ebc7b123598698257c93c0a37451553a364b2 Mon Sep 17 00:00:00 2001 From: yan ma Date: Mon, 29 Apr 2024 20:03:53 +0800 Subject: [PATCH 2/2] address comments --- .../gluten/utils/VeloxIntermediateData.scala | 3 + cpp/velox/CMakeLists.txt | 1 - .../functions/RowConstructorWithAllNull.cc | 63 ------------------- .../functions/RowConstructorWithAllNull.h | 17 ++--- .../operators/functions/RowFunctionWithNull.h | 10 +-- 5 files changed, 14 insertions(+), 80 deletions(-) delete mode 100644 cpp/velox/operators/functions/RowConstructorWithAllNull.cc diff --git a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala index 8f95afe8d0e8..a00bcae1ce70 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/utils/VeloxIntermediateData.scala @@ -161,6 +161,9 @@ object VeloxIntermediateData { def getRowConstructFuncName(aggFunc: AggregateFunction): String = aggFunc match { case _: Average | _: Sum if aggFunc.dataType.isInstanceOf[DecimalType] => "row_constructor" + // For agg function min_by/max_by, it needs to keep rows with null value but non-null + // comparison, such as . So we set the struct to null when all of the arguments + // are null case _: MaxMinBy => "row_constructor_with_all_null" case _ => "row_constructor_with_null" diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt index 146b094c5cbb..55f45881b108 100644 --- a/cpp/velox/CMakeLists.txt +++ b/cpp/velox/CMakeLists.txt @@ -306,7 +306,6 @@ set(VELOX_SRCS memory/VeloxMemoryManager.cc operators/functions/RegistrationAllFunctions.cc operators/functions/RowConstructorWithNull.cc - operators/functions/RowConstructorWithAllNull.cc operators/functions/SparkTokenizer.cc operators/serializer/VeloxColumnarToRowConverter.cc operators/serializer/VeloxColumnarBatchSerializer.cc diff --git a/cpp/velox/operators/functions/RowConstructorWithAllNull.cc b/cpp/velox/operators/functions/RowConstructorWithAllNull.cc deleted file mode 100644 index 9b9da2d8bf66..000000000000 --- a/cpp/velox/operators/functions/RowConstructorWithAllNull.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "RowConstructorWithAllNull.h" -#include "velox/expression/VectorFunction.h" - -namespace gluten { -facebook::velox::TypePtr RowConstructorWithAllNullCallToSpecialForm::resolveType( - const std::vector& argTypes) { - auto numInput = argTypes.size(); - std::vector names(numInput); - std::vector types(numInput); - for (auto i = 0; i < numInput; i++) { - types[i] = argTypes[i]; - names[i] = fmt::format("c{}", i + 1); - } - return facebook::velox::ROW(std::move(names), std::move(types)); -} - -facebook::velox::exec::ExprPtr RowConstructorWithAllNullCallToSpecialForm::constructSpecialForm( - const std::string& name, - const facebook::velox::TypePtr& type, - std::vector&& compiledChildren, - bool trackCpuUsage, - const facebook::velox::core::QueryConfig& config) { - auto [function, metadata] = facebook::velox::exec::vectorFunctionFactories().withRLock( - [&config, &name](auto& functionMap) -> std::pair< - std::shared_ptr, - facebook::velox::exec::VectorFunctionMetadata> { - auto functionIterator = functionMap.find(name); - if (functionIterator != functionMap.end()) { - return {functionIterator->second.factory(name, {}, config), functionIterator->second.metadata}; - } else { - VELOX_FAIL("Function {} is not registered.", name); - } - }); - - return std::make_shared( - type, std::move(compiledChildren), function, metadata, name, trackCpuUsage); -} - -facebook::velox::exec::ExprPtr RowConstructorWithAllNullCallToSpecialForm::constructSpecialForm( - const facebook::velox::TypePtr& type, - std::vector&& compiledChildren, - bool trackCpuUsage, - const facebook::velox::core::QueryConfig& config) { - return constructSpecialForm(kRowConstructorWithAllNull, type, std::move(compiledChildren), trackCpuUsage, config); -} -} // namespace gluten diff --git a/cpp/velox/operators/functions/RowConstructorWithAllNull.h b/cpp/velox/operators/functions/RowConstructorWithAllNull.h index 186498d8b30c..dfc79e1a977b 100644 --- a/cpp/velox/operators/functions/RowConstructorWithAllNull.h +++ b/cpp/velox/operators/functions/RowConstructorWithAllNull.h @@ -17,20 +17,11 @@ #pragma once -#include "velox/expression/FunctionCallToSpecialForm.h" -#include "velox/expression/SpecialForm.h" +#include "RowConstructorWithNull.h" namespace gluten { -class RowConstructorWithAllNullCallToSpecialForm : public facebook::velox::exec::FunctionCallToSpecialForm { +class RowConstructorWithAllNullCallToSpecialForm : public RowConstructorWithNullCallToSpecialForm { public: - facebook::velox::TypePtr resolveType(const std::vector& argTypes) override; - - facebook::velox::exec::ExprPtr constructSpecialForm( - const facebook::velox::TypePtr& type, - std::vector&& compiledChildren, - bool trackCpuUsage, - const facebook::velox::core::QueryConfig& config) override; - static constexpr const char* kRowConstructorWithAllNull = "row_constructor_with_all_null"; protected: @@ -39,6 +30,8 @@ class RowConstructorWithAllNullCallToSpecialForm : public facebook::velox::exec: const facebook::velox::TypePtr& type, std::vector&& compiledChildren, bool trackCpuUsage, - const facebook::velox::core::QueryConfig& config); + const facebook::velox::core::QueryConfig& config) { + return constructSpecialForm(kRowConstructorWithAllNull, type, std::move(compiledChildren), trackCpuUsage, config); + } }; } // namespace gluten diff --git a/cpp/velox/operators/functions/RowFunctionWithNull.h b/cpp/velox/operators/functions/RowFunctionWithNull.h index 9ee1828785f2..4131fb472ddd 100644 --- a/cpp/velox/operators/functions/RowFunctionWithNull.h +++ b/cpp/velox/operators/functions/RowFunctionWithNull.h @@ -50,7 +50,7 @@ class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction { if (arg->mayHaveNulls() && arg->isNullAt(i)) { // For row_constructor_with_null, if any argument of the struct is null, // set the struct as null. - if (!allNull) { + if constexpr (!allNull) { facebook::velox::bits::setNull(nullsPtr, i, true); cntNull++; break; @@ -60,9 +60,11 @@ class RowFunctionWithNull final : public facebook::velox::exec::VectorFunction { } } // For row_constructor_with_all_null, set the struct to be null when all arguments are all - if (allNull && argsNullCnt == argsCopy.size()) { - facebook::velox::bits::setNull(nullsPtr, i, true); - cntNull++; + if constexpr (allNull) { + if (argsNullCnt == argsCopy.size()) { + facebook::velox::bits::setNull(nullsPtr, i, true); + cntNull++; + } } } });