From 8d77bebbeaf50529307ee82e211bcdbb3c926987 Mon Sep 17 00:00:00 2001 From: Ke Date: Thu, 24 Oct 2024 09:46:00 -0700 Subject: [PATCH] Add single parameter (array>) support for array_intersect (#11305) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/11305 Reviewed By: kevinwilfong Differential Revision: D64698174 Pulled By: kewang1024 fbshipit-source-id: 6a9cefb24565c004fad956d8ba265172723fb29c --- .../prestosql/ArrayIntersectExcept.cpp | 173 ++++++++++++- .../prestosql/tests/ArrayIntersectTest.cpp | 235 ++++++++++++++++++ 2 files changed, 403 insertions(+), 5 deletions(-) diff --git a/velox/functions/prestosql/ArrayIntersectExcept.cpp b/velox/functions/prestosql/ArrayIntersectExcept.cpp index 23648b450b19..36f80612822d 100644 --- a/velox/functions/prestosql/ArrayIntersectExcept.cpp +++ b/velox/functions/prestosql/ArrayIntersectExcept.cpp @@ -138,19 +138,146 @@ void generateSet( DecodedVector* decodeArrayElements( exec::LocalDecodedVector& arrayDecoder, exec::LocalDecodedVector& elementsDecoder, - const SelectivityVector& rows) { + const SelectivityVector& rows, + SelectivityVector* elementRows) { auto decodedVector = arrayDecoder.get(); auto baseArrayVector = arrayDecoder->base()->as(); // Decode and acquire array elements vector. auto elementsVector = baseArrayVector->elements(); - auto elementsSelectivityRows = toElementRows( + *elementRows = toElementRows( elementsVector->size(), rows, baseArrayVector, decodedVector->indices()); - elementsDecoder.get()->decode(*elementsVector, elementsSelectivityRows); + elementsDecoder.get()->decode(*elementsVector, *elementRows); auto decodedElementsVector = elementsDecoder.get(); return decodedElementsVector; } +DecodedVector* decodeArrayElements( + exec::LocalDecodedVector& arrayDecoder, + exec::LocalDecodedVector& elementsDecoder, + const SelectivityVector& rows) { + SelectivityVector elementRows; + return decodeArrayElements(arrayDecoder, elementsDecoder, rows, &elementRows); +} + +template +class ArraysIntersectSingleParam : public exec::VectorFunction { + public: + /// This class is used for array_intersect function with single parameter. + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + memory::MemoryPool* pool = context.pool(); + + exec::LocalDecodedVector outerArrayDecoder(context, *args[0], rows); + auto decodedOuterArray = outerArrayDecoder.get(); + auto outerArray = decodedOuterArray->base()->as(); + + exec::LocalDecodedVector innerArrayDecoder(context); + SelectivityVector innerRows; + auto decodedInnerArray = decodeArrayElements( + outerArrayDecoder, innerArrayDecoder, rows, &innerRows); + auto innerArray = decodedInnerArray->base()->as(); + + exec::LocalDecodedVector elementDecoder(context); + SelectivityVector elementRows; + auto decodedInnerElement = decodeArrayElements( + innerArrayDecoder, elementDecoder, innerRows, &elementRows); + + const auto elementCount = + countElements(innerRows, *decodedInnerArray); + const auto rowCount = args[0]->size(); + + // Allocate new vectors for indices, nulls, length and offsets. + BufferPtr newIndices = allocateIndices(elementCount, pool); + BufferPtr newElementNulls = + AlignedBuffer::allocate(elementCount, pool, bits::kNotNull); + BufferPtr newLengths = allocateSizes(rowCount, pool); + BufferPtr newOffsets = allocateOffsets(rowCount, pool); + BufferPtr newNulls = allocateNulls(rowCount, pool); + + // Pointers and cursors to the raw data. + auto rawNewIndices = newIndices->asMutable(); + auto rawNewElementNulls = newElementNulls->asMutable(); + auto rawNewOffsets = newOffsets->asMutable(); + auto rawNewLengths = newLengths->asMutable(); + auto rawNewNulls = newNulls->asMutable(); + auto indicesCursor = 0; + + rows.applyToSelected([&](vector_size_t row) { + rawNewOffsets[row] = indicesCursor; + std::optional finalNullIndex; + SetWithNull finalSet; + + auto idx = decodedOuterArray->index(row); + auto offset = outerArray->offsetAt(idx); + auto size = outerArray->sizeAt(idx); + bool setInitialized = false; + for (auto i = offset; i < (offset + size); ++i) { + auto innerIdx = decodedInnerArray->index(i); + auto innerOffset = innerArray->offsetAt(innerIdx); + auto innerSize = innerArray->sizeAt(innerIdx); + + // 1. prepare for next iteration + indicesCursor = rawNewOffsets[row]; + SetWithNull intermediateSet; + std::optional intermediateNullIndex; + + // 2. Null array + if (decodedInnerArray->isNullAt(i)) { + bits::setNull(rawNewNulls, row, true); + finalNullIndex = intermediateNullIndex; + rawNewLengths[row] = indicesCursor - rawNewOffsets[row]; + break; + } + + // 3. Regular array + for (auto j = innerOffset; j < (innerOffset + innerSize); ++j) { + // null element + if (decodedInnerElement->isNullAt(j)) { + if ((!setInitialized || finalSet.hasNull) && + !intermediateNullIndex.has_value()) { + intermediateSet.hasNull = true; + intermediateNullIndex = std::optional(indicesCursor++); + } + continue; + } + // regular element + if (!setInitialized || finalSet.count(decodedInnerElement, j)) { + auto success = intermediateSet.insert(decodedInnerElement, j); + if (success) { + rawNewIndices[indicesCursor++] = j; + } + } + } + setInitialized = true; + finalSet = intermediateSet; + finalNullIndex = intermediateNullIndex; + rawNewLengths[row] = indicesCursor - rawNewOffsets[row]; + } + + if (finalNullIndex.has_value()) { + bits::setNull(rawNewElementNulls, finalNullIndex.value(), true); + } + }); + + auto newElements = BaseVector::wrapInDictionary( + newElementNulls, newIndices, indicesCursor, innerArray->elements()); + auto resultArray = std::make_shared( + pool, + outputType, + std::move(newNulls), + rowCount, + newOffsets, + newLengths, + newElements); + context.moveOrCopyResult(resultArray, rows, result); + } +}; + // See documentation at https://prestodb.io/docs/current/functions/array.html template class ArrayIntersectExceptFunction : public exec::VectorFunction { @@ -414,7 +541,7 @@ class ArraysOverlapFunction : public exec::VectorFunction { void validateMatchingArrayTypes( const std::vector& inputArgs, const std::string& name, - vector_size_t expectedArgCount) { + size_t expectedArgCount) { VELOX_USER_CHECK_EQ( inputArgs.size(), expectedArgCount, @@ -504,10 +631,30 @@ std::shared_ptr createTypedArraysIntersectExcept( } } +template +std::shared_ptr createArraysIntersectSingleParam( + const TypePtr& elementType) { + if (elementType->providesCustomComparison()) { + return std::make_shared>(); + } else { + using T = std::conditional_t< + TypeTraits::isPrimitiveType, + typename TypeTraits::NativeType, + WrappedVectorEntry>; + return std::make_shared>(); + } +} + std::shared_ptr createArrayIntersect( const std::string& name, const std::vector& inputArgs, const core::QueryConfig& /*config*/) { + if (inputArgs.size() == 1) { + auto elementType = inputArgs.front().type->childAt(0)->childAt(0); + return VELOX_DYNAMIC_TYPE_DISPATCH( + createArraysIntersectSingleParam, elementType->kind(), elementType); + } + validateMatchingArrayTypes(inputArgs, name, 2); auto elementType = inputArgs.front().type->childAt(0); @@ -534,6 +681,22 @@ std::shared_ptr createArrayExcept( elementType); } +std::vector> +arrayIntersectSignatures() { + return std::vector>{ + exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("array(T)") + .argumentType("array(T)") + .argumentType("array(T)") + .build(), + exec::FunctionSignatureBuilder() + .typeVariable("T") + .returnType("array(T)") + .argumentType("array(array(T))") + .build()}; +} + std::vector> signatures( const std::string& returnType) { return std::vector>{ @@ -600,7 +763,7 @@ VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( udf_array_intersect, - signatures("array(T)"), + arrayIntersectSignatures(), createArrayIntersect); VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION( diff --git a/velox/functions/prestosql/tests/ArrayIntersectTest.cpp b/velox/functions/prestosql/tests/ArrayIntersectTest.cpp index 2a1e8e4168a5..0c862b648c11 100644 --- a/velox/functions/prestosql/tests/ArrayIntersectTest.cpp +++ b/velox/functions/prestosql/tests/ArrayIntersectTest.cpp @@ -15,6 +15,7 @@ */ #include +#include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/vector/tests/TestingDictionaryArrayElementsFunction.h" @@ -105,6 +106,47 @@ class ArrayIntersectTest : public FunctionBaseTest { testExpr(expected, "array_intersect(C0, C1)", {array1, array2}); } + template + void testIntNestedArray() { + auto array1 = makeNestedArrayVectorFromJson({ + "[[1, 2, 3], [1, null]]", + "[[], [1], [2, 1], []]", + "[[1], [], [2, 1]]", + "[[1, 2, 3, null]]", + "[[], [1, 2, 3, null]]", + "[[1, null], [null]]", + "[[null], [null]]", + "[]", + }); + auto expected1 = makeNullableArrayVector({ + {1}, + {}, + {}, + {1, 2, 3, std::nullopt}, + {}, + {std::nullopt}, + {std::nullopt}, + {}, + }); + testExpr(expected1, "array_intersect(C0)", {array1}); + + auto array2 = makeNestedArrayVectorFromJson({ + "[null, [1, 2, 3], [1, null]]", + "[[], [1], [2, 1], [], null]", + "[[1], [], null, [2, 1]]", + "[[1, 2, 3, null], null]", + "[[null], [null], null]", + }); + auto expected2 = makeNullableArrayVector({ + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + std::nullopt, + }); + testExpr(expected2, "array_intersect(C0)", {array2}); + } + template void testFloatingPoint() { auto array1 = makeNullableArrayVector({ @@ -138,6 +180,37 @@ class ArrayIntersectTest : public FunctionBaseTest { testExpr(expected, "array_intersect(C0, C1)", {array1, array2}); testExpr(expected, "array_intersect(C1, C0)", {array1, array2}); } + + template + void testFloatingPointNestedArray() { + using innerArrayType = std::vector>; + using outerArrayType = + std::vector>>>; + + innerArrayType a1{1.0001, -2.0, 3.03, std::nullopt, 4.00004}; + innerArrayType a2{1.0, -2.0, 4.0}; + + innerArrayType b1{std::numeric_limits::min(), std::nullopt}; + innerArrayType b2{std::numeric_limits::min()}; + + innerArrayType c1{ + std::numeric_limits::infinity(), std::numeric_limits::max()}; + innerArrayType c2{ + std::numeric_limits::infinity(), std::numeric_limits::max()}; + + outerArrayType row1{{a1}, {a2}}; + outerArrayType row2{{b1}, {b2}}; + outerArrayType row3{{c1}, {c2}}; + outerArrayType row4{{a1}, {{}}}; + auto arrayVector = + makeNullableNestedArrayVector({{row1}, {row2}, {row3}, {row4}}); + auto expected = makeNullableArrayVector( + {{-2.0}, + {std::numeric_limits::min()}, + {std::numeric_limits::infinity(), std::numeric_limits::max()}, + {}}); + testExpr(expected, "array_intersect(C0)", {arrayVector}); + } }; } // namespace @@ -149,11 +222,23 @@ TEST_F(ArrayIntersectTest, intArrays) { testInt(); } +TEST_F(ArrayIntersectTest, intNestedArrays) { + testIntNestedArray(); + testIntNestedArray(); + testIntNestedArray(); + testIntNestedArray(); +} + TEST_F(ArrayIntersectTest, floatArrays) { testFloatingPoint(); testFloatingPoint(); } +TEST_F(ArrayIntersectTest, floatNestedArrays) { + testFloatingPointNestedArray(); + testFloatingPointNestedArray(); +} + TEST_F(ArrayIntersectTest, boolArrays) { auto array1 = makeNullableArrayVector( {{true, false}, @@ -189,6 +274,32 @@ TEST_F(ArrayIntersectTest, boolArrays) { testExpr(expected, "array_intersect(C1, C0)", {array1, array2}); } +TEST_F(ArrayIntersectTest, boolNestedArrays) { + using innerArrayType = std::vector>; + using outerArrayType = + std::vector>>>; + + innerArrayType a1{true, true, std::nullopt}; + innerArrayType a2{true, false, std::nullopt}; + + innerArrayType b1{std::nullopt, std::nullopt, true}; + innerArrayType b2{true, false, std::nullopt}; + innerArrayType b3{true, true, std::nullopt}; + + innerArrayType c1{true, true, std::nullopt}; + innerArrayType c2{false, false, false}; + + outerArrayType row1{{a1}, {a2}}; + outerArrayType row2{{b1}, {b2}, {b3}}; + outerArrayType row3{{c1}, {c2}}; + outerArrayType row4{{a1}, {{}}}; + auto arrayVector = + makeNullableNestedArrayVector({{row1}, {row2}, {row3}, {row4}}); + auto expected = makeNullableArrayVector( + {{true, std::nullopt}, {true, std::nullopt}, {}, {}}); + testExpr(expected, "array_intersect(C0)", {arrayVector}); +} + // Test inline strings. TEST_F(ArrayIntersectTest, strArrays) { using S = StringView; @@ -215,6 +326,32 @@ TEST_F(ArrayIntersectTest, strArrays) { testExpr(expected, "array_intersect(C1, C0)", {array1, array2}); } +TEST_F(ArrayIntersectTest, strNestedArrays) { + using innerArrayType = std::vector>; + using outerArrayType = + std::vector>>>; + + innerArrayType a1{"a", "a", std::nullopt}; + innerArrayType a2{"a", "b", std::nullopt}; + + innerArrayType b1{std::nullopt, std::nullopt, "c"}; + innerArrayType b2{"c", "d", std::nullopt}; + innerArrayType b3{"ce", "d", "c", std::nullopt}; + + innerArrayType c1{"a", "a", std::nullopt}; + innerArrayType c2{"ba", "ab", "b"}; + + outerArrayType row1{{a1}, {a2}}; + outerArrayType row2{{b1}, {b2}, {b3}}; + outerArrayType row3{{c1}, {c2}}; + outerArrayType row4{{a1}, {{}}}; + auto arrayVector = makeNullableNestedArrayVector( + {{row1}, {row2}, {row3}, {row4}}); + auto expected = makeNullableArrayVector( + {{"a", std::nullopt}, {"c", std::nullopt}, {}, {}}); + testExpr(expected, "array_intersect(C0)", {arrayVector}); +} + // Test non-inline (> 12 length) strings. TEST_F(ArrayIntersectTest, longStrArrays) { using S = StringView; @@ -246,6 +383,50 @@ TEST_F(ArrayIntersectTest, longStrArrays) { testExpr(expected, "array_intersect(C1, C0)", {array1, array2}); } +TEST_F(ArrayIntersectTest, longStrNestedArrays) { + using innerArrayType = std::vector>; + using outerArrayType = + std::vector>>>; + + innerArrayType a1{ + "red shiny car ahead", "blue clear sky above", std::nullopt}; + innerArrayType a2{ + "red shiny car ahead", "I see blue clear sky above", std::nullopt}; + + innerArrayType b1{ + std::nullopt, + std::nullopt, + "blue clear sky above", + "orange beautiful sunset"}; + innerArrayType b2{ + "orange beautiful sunset", "blue clear sky above", std::nullopt}; + innerArrayType b3{ + "orange beautiful sunset", + "blue clear sky above", + "green plants make us happy", + std::nullopt}; + + innerArrayType c1{ + "orange beautiful sunset", "blue clear sky above", std::nullopt}; + innerArrayType c2{ + "a orange beautiful sunset", + "a blue clear sky above", + }; + + outerArrayType row1{{a1}, {a2}}; + outerArrayType row2{{b1}, {b2}, {b3}}; + outerArrayType row3{{c1}, {c2}}; + outerArrayType row4{{a1}, {{}}}; + auto arrayVector = makeNullableNestedArrayVector( + {{row1}, {row2}, {row3}, {row4}}); + auto expected = makeNullableArrayVector( + {{"red shiny car ahead", std::nullopt}, + {"orange beautiful sunset", "blue clear sky above", std::nullopt}, + {}, + {}}); + testExpr(expected, "array_intersect(C0)", {arrayVector}); +} + TEST_F(ArrayIntersectTest, varbinary) { auto left = makeNullableArrayVector( {{"a"_sv, "b"_sv, "c"_sv}}, ARRAY(VARBINARY())); @@ -281,6 +462,23 @@ TEST_F(ArrayIntersectTest, complexTypeArray) { testExpr(expected, "array_intersect(c0, c1)", {left, right}); } +TEST_F(ArrayIntersectTest, complexTypeArrayNestedArrays) { + auto inputData = makeNestedArrayVectorFromJson( + {"[null, [1, 2, 3], [null, null]]", + "[[1, 2, 3]]", + "[[1], [2], []]", + "[[1]]", + "[[1, null, 3]]", + "[[1, null, 3], [1, 2]]", + "[[1, null, 3]]", + "[[1, null, 3, null]]"}); + + auto input = makeArrayVector({0, 2, 4, 6}, inputData); + auto expected = makeNestedArrayVectorFromJson( + {"[[1, 2, 3]]", "[[1]]", "[[1, null, 3]]", "[]"}); + testExpr(expected, "array_intersect(c0)", {input}); +} + TEST_F(ArrayIntersectTest, complexTypeMap) { std::vector> a{{"blue", 1}, {"red", 2}}; std::vector> b{{"green", std::nullopt}}; @@ -299,6 +497,22 @@ TEST_F(ArrayIntersectTest, complexTypeMap) { testExpr(expected, "array_intersect(c0, c1)", {left, right}); } +TEST_F(ArrayIntersectTest, complexTypeMapNestedArrays) { + std::vector> a{{"blue", 1}, {"red", 2}}; + std::vector> b{{"green", std::nullopt}}; + std::vector> c{{"yellow", 4}, {"purple", 5}}; + std::vector>>> inputData{ + {b, a}, {a, b}, {b}, {}, {c, a}, {a}}; + std::vector>>> expectedData{ + {a, b}, {}, {a}}; + + auto input = makeArrayVector( + {0, 2, 4}, makeArrayOfMapVector(inputData)); + auto expected = makeArrayOfMapVector(expectedData); + + testExpr(expected, "array_intersect(c0)", {input}); +} + TEST_F(ArrayIntersectTest, complexTypeRow) { RowTypePtr rowType = ROW({INTEGER(), VARCHAR()}); @@ -317,6 +531,27 @@ TEST_F(ArrayIntersectTest, complexTypeRow) { testExpr(expected, "array_intersect(c0, c1)", {left, right}); } +TEST_F(ArrayIntersectTest, complexTypeRowNestedArrays) { + RowTypePtr rowType = ROW({INTEGER(), VARCHAR()}); + using ArrayOfRow = std::vector>>; + + std::vector data = { + {{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}}, + {{{1, "red"}}}, + {{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}}, + {{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}}, + {{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}}, + {{{1, "red"}}, std::nullopt, std::nullopt}}; + auto input = makeArrayVector({0, 2, 4}, makeArrayOfRowVector(data, rowType)); + + std::vector expectedData = { + {{{1, "red"}}}, + {{{1, "red"}}, {{2, "blue"}}, {{3, "green"}}}, + {{{1, "red"}}}}; + auto expected = makeArrayOfRowVector(expectedData, rowType); + testExpr(expected, "array_intersect(c0)", {input}); +} + // When one of the arrays is constant. TEST_F(ArrayIntersectTest, constant) { auto array1 = makeNullableArrayVector({