From db2c5e30a7dc5b0fcd58d932084f061d753f2736 Mon Sep 17 00:00:00 2001 From: glutenperfbot Date: Fri, 15 Mar 2024 00:06:18 +0000 Subject: [PATCH] Rebase velox (2024_03_15) Signed-off-by: glutenperfbot --- velox/docs/functions/spark/array.rst | 9 + velox/functions/sparksql/ArrayUnionFunction.h | 89 ++++++++ velox/functions/sparksql/Register.cpp | 23 ++ .../sparksql/tests/ArrayUnionTest.cpp | 208 ++++++++++++++++++ velox/functions/sparksql/tests/CMakeLists.txt | 1 + 5 files changed, 330 insertions(+) create mode 100644 velox/functions/sparksql/ArrayUnionFunction.h create mode 100644 velox/functions/sparksql/tests/ArrayUnionTest.cpp diff --git a/velox/docs/functions/spark/array.rst b/velox/docs/functions/spark/array.rst index f80e13923dc05..280945fabe9ca 100644 --- a/velox/docs/functions/spark/array.rst +++ b/velox/docs/functions/spark/array.rst @@ -74,6 +74,15 @@ Array Functions SELECT array_sort(ARRAY [NULL, 1, NULL]); -- [1, NULL, NULL] SELECT array_sort(ARRAY [NULL, 2, 1]); -- [1, 2, NULL] +.. spark:function:: array_union(array(E), array(E1)) -> array(E2) + + Returns an array of the elements in the union of array1 and array2, without duplicates. :: + + SELECT array_union(array(1, 2, 3), array(1, 3, 5)); -- [1, 2, 3, 5] + SELECT array_union(array(1, 3, 5), array(1, 2, 3)); -- [1, 3, 5, 2] + SELECT array_union(array(1, 2, 3), array(1, 3, 5, null)); -- [1, 2, 3, 5, null] + SELECT array_union(array(1, 2, NaN), array(1, 3, NaN)); -- [1, 2, NaN, 3] + .. spark:function:: concat(array(E), array(E1), ..., array(En)) -> array(E, E1, ..., En) Returns the concatenation of array(E), array(E1), ..., array(En). :: diff --git a/velox/functions/sparksql/ArrayUnionFunction.h b/velox/functions/sparksql/ArrayUnionFunction.h new file mode 100644 index 0000000000000..c8a4f21f5af27 --- /dev/null +++ b/velox/functions/sparksql/ArrayUnionFunction.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace facebook::velox::functions::sparksql { + +/// This class implements the array union function. +/// +/// DEFINITION: +/// array_union(x, y) → array +/// Returns an array of the elements in the union of x and y, without +/// duplicates. +template +struct ArrayUnionFunction { + VELOX_DEFINE_FUNCTION_TYPES(T) + + // Fast path for primitives. + template + void call(Out& out, const In& inputArray1, const In& inputArray2) { + folly::F14FastSet elementSet; + bool nullAdded = false; + bool nanAdded = false; + auto addItems = [&](auto& inputArray) { + for (const auto& item : inputArray) { + if (item.has_value()) { + if constexpr ( + std::is_same_v>> || + std::is_same_v>>) { + bool isNaN = std::isnan(item.value()); + if ((isNaN && !nanAdded) || + (!isNaN && elementSet.insert(item.value()).second)) { + auto& newItem = out.add_item(); + newItem = item.value(); + } + if (!nanAdded && isNaN) { + nanAdded = true; + } + } else if (elementSet.insert(item.value()).second) { + auto& newItem = out.add_item(); + newItem = item.value(); + } + } else if (!nullAdded) { + nullAdded = true; + out.add_null(); + } + } + }; + addItems(inputArray1); + addItems(inputArray2); + } + + void call( + out_type>>& out, + const arg_type>>& inputArray1, + const arg_type>>& inputArray2) { + folly::F14FastSet elementSet; + bool nullAdded = false; + auto addItems = [&](auto& inputArray) { + for (const auto& item : inputArray) { + if (item.has_value()) { + if (elementSet.insert(item.value()).second) { + auto& newItem = out.add_item(); + newItem.copy_from(item.value()); + } + } else if (!nullAdded) { + nullAdded = true; + out.add_null(); + } + } + }; + addItems(inputArray1); + addItems(inputArray2); + } +}; +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 12c65fa3b4f0d..f0ac791fcf835 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -27,6 +27,7 @@ #include "velox/functions/prestosql/StringFunctions.h" #include "velox/functions/sparksql/ArrayMinMaxFunction.h" #include "velox/functions/sparksql/ArraySort.h" +#include "velox/functions/sparksql/ArrayUnionFunction.h" #include "velox/functions/sparksql/Bitwise.h" #include "velox/functions/sparksql/DateTimeFunctions.h" #include "velox/functions/sparksql/Hash.h" @@ -122,6 +123,12 @@ inline void registerArrayMinMaxFunctions(const std::string& prefix) { } } // namespace +template +inline void registerArrayUnionFunctions(const std::string& prefix) { + registerFunction, Array, Array>( + {prefix + "array_union"}); +} + void registerFunctions(const std::string& prefix) { registerAllSpecialFormGeneralFunctions(); @@ -357,8 +364,24 @@ void registerFunctions(const std::string& prefix) { registerFunction( {prefix + "monotonically_increasing_id"}); +<<<<<<< HEAD registerFunction>({prefix + "uuid"}); +======= + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions(prefix); + registerArrayUnionFunctions>(prefix); +>>>>>>> Fix array_union on NaN (7086) } } // namespace sparksql diff --git a/velox/functions/sparksql/tests/ArrayUnionTest.cpp b/velox/functions/sparksql/tests/ArrayUnionTest.cpp new file mode 100644 index 0000000000000..e75719bc5a156 --- /dev/null +++ b/velox/functions/sparksql/tests/ArrayUnionTest.cpp @@ -0,0 +1,208 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +using namespace facebook::velox; +using namespace facebook::velox::test; + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class ArrayUnionTest : public SparkFunctionBaseTest { + protected: + void testExpression( + const std::string& expression, + const std::vector& input, + const VectorPtr& expected) { + auto result = evaluate(expression, makeRowVector(input)); + assertEqualVectors(expected, result); + } + + template + void testFloatArray() { + const auto array1 = makeArrayVector( + {{1.99, 2.78, 3.98, 4.01}, + {3.89, 4.99, 5.13}, + {7.13, 8.91, std::numeric_limits::quiet_NaN()}, + {10.02, 20.01, std::numeric_limits::quiet_NaN()}}); + const auto array2 = makeArrayVector( + {{2.78, 4.01, 5.99}, + {3.89, 4.99, 5.13}, + {7.13, 8.91, std::numeric_limits::quiet_NaN()}, + {40.99, 50.12}}); + + VectorPtr expected; + expected = makeArrayVector({ + {1.99, 2.78, 3.98, 4.01, 5.99}, + {3.89, 4.99, 5.13}, + {7.13, 8.91, std::numeric_limits::quiet_NaN()}, + {10.02, 20.01, std::numeric_limits::quiet_NaN(), 40.99, 50.12}, + }); + testExpression("array_union(c0, c1)", {array1, array2}, expected); + + expected = makeArrayVector({ + {2.78, 4.01, 5.99, 1.99, 3.98}, + {3.89, 4.99, 5.13}, + {7.13, 8.91, std::numeric_limits::quiet_NaN()}, + {40.99, 50.12, 10.02, 20.01, std::numeric_limits::quiet_NaN()}, + }); + testExpression("array_union(c0, c1)", {array2, array1}, expected); + } +}; + +// Union two integer arrays. +TEST_F(ArrayUnionTest, intArray) { + const auto array1 = makeArrayVector( + {{1, 2, 3, 4}, {3, 4, 5}, {7, 8, 9}, {10, 20, 30}}); + const auto array2 = + makeArrayVector({{2, 4, 5}, {3, 4, 5}, {}, {40, 50}}); + VectorPtr expected; + + expected = makeArrayVector({ + {1, 2, 3, 4, 5}, + {3, 4, 5}, + {7, 8, 9}, + {10, 20, 30, 40, 50}, + }); + testExpression("array_union(c0, c1)", {array1, array2}, expected); + + expected = makeArrayVector({ + {2, 4, 5, 1, 3}, + {3, 4, 5}, + {7, 8, 9}, + {40, 50, 10, 20, 30}, + }); + testExpression("array_union(c0, c1)", {array2, array1}, expected); +} + +// Union two float or double arrays. +TEST_F(ArrayUnionTest, floatArray) { + testFloatArray(); + testFloatArray(); +} + +// Union two string arrays. +TEST_F(ArrayUnionTest, stringArray) { + const auto array1 = + makeArrayVector({{"foo", "bar"}, {"foo", "baz"}}); + const auto array2 = + makeArrayVector({{"foo", "bar"}, {"bar", "baz"}}); + VectorPtr expected; + + expected = makeArrayVector({ + {"foo", "bar"}, + {"foo", "baz", "bar"}, + }); + testExpression("array_union(c0, c1)", {array1, array2}, expected); +} + +// Union two integer arrays with null. +TEST_F(ArrayUnionTest, nullArray) { + const auto array1 = makeNullableArrayVector({ + {{1, std::nullopt, 3, 4}}, + {7, 8, 9}, + {{10, std::nullopt, std::nullopt}}, + }); + const auto array2 = makeNullableArrayVector({ + {{std::nullopt, std::nullopt, 3, 5}}, + std::nullopt, + {{1, 10}}, + }); + VectorPtr expected; + + expected = makeNullableArrayVector({ + {{1, std::nullopt, 3, 4, 5}}, + std::nullopt, + {{10, std::nullopt, 1}}, + }); + testExpression("array_union(c0, c1)", {array1, array2}, expected); + + expected = makeNullableArrayVector({ + {{std::nullopt, 3, 5, 1, 4}}, + std::nullopt, + {{1, 10, std::nullopt}}, + }); + testExpression("array_union(c0, c1)", {array2, array1}, expected); +} + +// Union array vectors. +TEST_F(ArrayUnionTest, complexTypes) { + auto baseVector = makeArrayVector( + {{1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}, {6, 6}}); + + // Create arrays of array vector using above base vector. + // [[1, 1], [2, 2]] + // [[3, 3], [4, 4]] + // [[5, 5], [6, 6]] + auto arrayOfArrays1 = makeArrayVector({0, 2, 4}, baseVector); + // [[1, 1], [2, 2], [3, 3]] + // [[4, 4]] + // [[5, 5], [6, 6]] + auto arrayOfArrays2 = makeArrayVector({0, 3, 4}, baseVector); + + // [[1, 1], [2, 2], [3, 3]] + // [[3, 3], [4, 4]] + // [[5, 5], [6, 6]] + auto expected = makeArrayVector( + {0, 3, 5}, + makeArrayVector( + {{1, 1}, {2, 2}, {3, 3}, {3, 3}, {4, 4}, {5, 5}, {6, 6}})); + + testExpression( + "array_union(c0, c1)", {arrayOfArrays1, arrayOfArrays2}, expected); +} + +// Union double array vectors. +TEST_F(ArrayUnionTest, complexDoubleType) { + auto baseVector = makeArrayVector( + {{1.0, 1.0}, + {2.0, 2.0}, + {3.0, 3.0}, + {4.0, 4.0}, + {5.0, std::numeric_limits::quiet_NaN()}, + {6.0, 6.0}}); + + // Create arrays of array vector using above base vector. + // [[1.0, 1.0], [2.0, 2.0]] + // [[3.0, 3.0], [4.0, 4.0]] + // [[5.0, NaN], [6.0, 6.0]] + auto arrayOfArrays1 = makeArrayVector({0, 2, 4}, baseVector); + // [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]] + // [[4.0, 4.0]] + // [[5.0, NaN], [6.0, 6.0]] + auto arrayOfArrays2 = makeArrayVector({0, 3, 4}, baseVector); + + // [[1.0, 1.0], [2.0, 2.0], [3.0, 3.0]] + // [[3.0, 3.0], [4.0, 4.0]] + // [[5.0, NaN], [6.0, 6.0]] + auto expected = makeArrayVector( + {0, 3, 5}, + makeArrayVector( + {{1.0, 1.0}, + {2.0, 2.0}, + {3.0, 3.0}, + {3.0, 3.0}, + {4.0, 4.0}, + {5.0, std::numeric_limits::quiet_NaN()}, + {6.0, 6.0}})); + + testExpression( + "array_union(c0, c1)", {arrayOfArrays1, arrayOfArrays2}, expected); +} +} // namespace +} // namespace facebook::velox::functions::sparksql::test diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index 55823183a6047..3a9a5da0d1c55 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -18,6 +18,7 @@ add_executable( ArrayMaxTest.cpp ArrayMinTest.cpp ArraySortTest.cpp + ArrayUnionTest.cpp BitwiseTest.cpp ComparisonsTest.cpp DateTimeFunctionsTest.cpp