diff --git a/velox/functions/prestosql/ArrayIntersectExcept.cpp b/velox/functions/prestosql/ArrayIntersectExcept.cpp index 9826996b7d729..c13c812914dc3 100644 --- a/velox/functions/prestosql/ArrayIntersectExcept.cpp +++ b/velox/functions/prestosql/ArrayIntersectExcept.cpp @@ -111,11 +111,6 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction { /// If the rhs values passed to either array_intersect() or array_except() /// are constant (array literals) we create a set before instantiating the /// object and pass as a constructor parameter (constantSet). - /// - /// Smallest array optimization: - /// - /// If the rhs values passed to array_intersect() are not constant we create - /// sets from whichever side has the smallest sum of lengths in the batch. ArrayIntersectExceptFunction() = default; @@ -133,30 +128,6 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction { BaseVector* right = args[1].get(); exec::LocalDecodedVector leftHolder(context, *left, rows); - exec::LocalDecodedVector rightHolder(context, *right, rows); - - if (isIntersect && !constantSet_.has_value()) { - // Swap left and right if needed so that the right array has the smaller - // number of elements since the right will be made into a hash set. - vector_size_t leftSize = 0; - vector_size_t rightSize = 0; - const ArrayVector* leftArrayVector = - leftHolder.get()->base()->as(); - const ArrayVector* rightArrayVector = - rightHolder.get()->base()->as(); - rows.applyToSelected([&](vector_size_t row) { - vector_size_t leftidx = leftHolder.get()->index(row); - leftSize += leftArrayVector->sizeAt(leftidx); - - vector_size_t rightidx = rightHolder.get()->index(row); - rightSize += rightArrayVector->sizeAt(rightidx); - }); - if (leftSize < rightSize) { - std::swap(left, right); - std::swap(leftHolder, rightHolder); - } - } - auto decodedLeftArray = leftHolder.get(); auto baseLeftArray = decodedLeftArray->base()->as(); @@ -221,9 +192,9 @@ class ArrayIntersectExceptFunction : public exec::VectorFunction { // (check outputSet). bool addValue = false; if constexpr (isIntersect) { - addValue = rightSet.set.contains(val); + addValue = rightSet.set.count(val) > 0; } else { - addValue = !rightSet.set.contains(val); + addValue = rightSet.set.count(val) == 0; } if (addValue) { auto it = outputSet.set.insert(val); diff --git a/velox/functions/prestosql/tests/ArrayIntersectTest.cpp b/velox/functions/prestosql/tests/ArrayIntersectTest.cpp index 1ca25c876d233..8260f9fbdf9c2 100644 --- a/velox/functions/prestosql/tests/ArrayIntersectTest.cpp +++ b/velox/functions/prestosql/tests/ArrayIntersectTest.cpp @@ -271,12 +271,7 @@ TEST_F(ArrayIntersectTest, constant) { {}, {1, -2, 4}, }); - // Test wiith right hand side being constant and larger than left hand side. - testExpr( - expected, - "array_intersect(C0, ARRAY[1,4,-2,10,11,12,13,14,15])", - {array1}); - + testExpr(expected, "array_intersect(C0, ARRAY[1,4,-2])", {array1}); testExpr(expected, "array_intersect(ARRAY[1,-2,4], C0)", {array1}); testExpr( expected, "array_intersect(ARRAY[1,1,-2,1,-2,4,1,4,4], C0)", {array1}); @@ -313,10 +308,10 @@ TEST_F(ArrayIntersectTest, deterministic) { testExpr(expectedC0C1, "array_intersect(C0, C1)", {c0, c1}); // C1 then C0. - // Since C1 has more elements, it should be swapped with C0 and - // the order of the result should be based on C1. - auto expectedC1C0 = expectedC0C1; - + auto expectedC1C0 = makeNullableArrayVector({ + {1, 4, -2}, + {1, 4, -2}, + }); testExpr(expectedC1C0, "array_intersect(C1, C0)", {c0, c1}); testExpr(expectedC1C0, "array_intersect(ARRAY[1,4,-2], C0)", {c0}); } @@ -331,6 +326,6 @@ TEST_F(ArrayIntersectTest, dictionaryEncodedElementsInConstant) { auto expected = makeArrayVector({{1, 3}, {2}, {}}); testExpr( expected, - "array_intersect(testing_dictionary_array_elements(ARRAY [2, 2, 3, 1, 2, 2]), c0)", + "array_intersect(c0, testing_dictionary_array_elements(ARRAY [2, 2, 3, 1, 2, 2]))", {array}); }