diff --git a/velox/functions/prestosql/InPredicate.cpp b/velox/functions/prestosql/InPredicate.cpp index 86d7fdeacb20..0819cd0d69f1 100644 --- a/velox/functions/prestosql/InPredicate.cpp +++ b/velox/functions/prestosql/InPredicate.cpp @@ -20,36 +20,39 @@ namespace facebook::velox::functions { namespace { +// This implements InPredicate using a set over VectorValues (pairs of +// BaseVector, index). Can be used in place of Filters for Types not supported +// by Filters or when custom comparisons are needed. // Returns NULL if // - input value is NULL // - in-list is NULL or empty // - input value doesn't have an exact match, but has an indeterminate match in // the in-list. E.g., 'array[null] in (array[1])' or 'array[1] in // (array[null])'. -class ComplexTypeInPredicate : public exec::VectorFunction { +class VectorSetInPredicate : public exec::VectorFunction { public: - struct ComplexValue { + struct VectorValue { BaseVector* vector; vector_size_t index; }; - struct ComplexValueHash { - size_t operator()(ComplexValue value) const { + struct VectorValueHash { + size_t operator()(VectorValue value) const { return value.vector->hashValueAt(value.index); } }; - struct ComplexValueEqualTo { - bool operator()(ComplexValue left, ComplexValue right) const { + struct VectorValueEqualTo { + bool operator()(VectorValue left, VectorValue right) const { return left.vector->equalValueAt(right.vector, left.index, right.index); } }; - using ComplexSet = - folly::F14FastSet; + using VectorSet = + folly::F14FastSet; - ComplexTypeInPredicate( - ComplexSet uniqueValues, + VectorSetInPredicate( + VectorSet uniqueValues, bool hasNull, VectorPtr originalValues) : uniqueValues_{std::move(uniqueValues)}, @@ -58,7 +61,7 @@ class ComplexTypeInPredicate : public exec::VectorFunction { static std::shared_ptr create(const VectorPtr& values, vector_size_t offset, vector_size_t size) { - ComplexSet uniqueValues; + VectorSet uniqueValues; bool hasNull = false; for (auto i = offset; i < offset + size; i++) { @@ -68,7 +71,7 @@ class ComplexTypeInPredicate : public exec::VectorFunction { uniqueValues.insert({values.get(), i}); } - return std::make_shared( + return std::make_shared( std::move(uniqueValues), hasNull, values); } @@ -126,7 +129,7 @@ class ComplexTypeInPredicate : public exec::VectorFunction { // Set of unique values to check against. This set doesn't include any value // that is null or contains null. - const ComplexSet uniqueValues_; + const VectorSet uniqueValues_; // Boolean indicating whether one of the value was null or contained null. const bool hasNull_; @@ -339,10 +342,15 @@ class InPredicate : public exec::VectorFunction { } const auto& elements = arrayVector->elements(); + const auto& elementType = elements->type(); + + if (elementType->providesCustomComparison()) { + return VectorSetInPredicate::create(elements, offset, size); + } std::pair, bool> filter; - switch (inListType->childAt(0)->kind()) { + switch (elementType->kind()) { case TypeKind::HUGEINT: filter = createHugeintValuesFilter(elements, offset, size); break; @@ -384,7 +392,7 @@ class InPredicate : public exec::VectorFunction { case TypeKind::MAP: [[fallthrough]]; case TypeKind::ROW: - return ComplexTypeInPredicate::create(elements, offset, size); + return VectorSetInPredicate::create(elements, offset, size); default: VELOX_UNSUPPORTED( "Unsupported in-list type for IN predicate: {}", diff --git a/velox/functions/prestosql/tests/InPredicateTest.cpp b/velox/functions/prestosql/tests/InPredicateTest.cpp index b08558148513..ca8dfb71db53 100644 --- a/velox/functions/prestosql/tests/InPredicateTest.cpp +++ b/velox/functions/prestosql/tests/InPredicateTest.cpp @@ -14,7 +14,10 @@ * limitations under the License. */ #include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/lib/DateTimeFormatter.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/tz/TimeZoneMap.h" using namespace facebook::velox::test; using namespace facebook::velox::functions::test; @@ -25,32 +28,32 @@ namespace { class InPredicateTest : public FunctionBaseTest { protected: template - std::string getInList( + ArrayVectorPtr getInList( std::vector> input, - const TypePtr& type = CppToType::create()) { + const TypePtr& type) { FlatVectorPtr flatVec = makeNullableFlatVector(input, type); - std::string inList; - auto len = flatVec->size(); - auto toString = [&](vector_size_t idx) { - if (type->isDecimal()) { - if (flatVec->isNullAt(idx)) { - return std::string("null"); - } - return fmt::format( - "cast({} as {})", flatVec->toString(idx), type->toString()); - } - return flatVec->toString(idx); - }; - for (auto i = 0; i < len - 1; i++) { - inList += fmt::format("{}, ", toString(i)); - } - inList += toString(len - 1); - return inList; + return makeArrayVector({0, flatVec->size()}, flatVec); + } + + core::TypedExprPtr makeInExpression( + const std::string& probe, + const ArrayVectorPtr& inList, + const TypePtr& type) { + return std::make_shared( + BOOLEAN(), + std::vector{ + std::make_shared(type, probe), + std::make_shared(inList)}, + "in"); } template - void testValues(const TypePtr type = CppToType::create()) { + void testValues( + const TypePtr type = CppToType::create(), + std::function valueAt = [](auto row) { + return row % 17; + }) { if (type->isDecimal()) { this->options_.parseDecimalAsDouble = false; } @@ -58,17 +61,17 @@ class InPredicateTest : public FunctionBaseTest { memory::memoryManager()->addLeafPool()}; const vector_size_t size = 1'000; - auto inList = getInList({1, 3, 5}, type); + auto inList = getInList({valueAt(1), valueAt(3), valueAt(5)}, type); auto vector = makeFlatVector( - size, [](auto row) { return row % 17; }, nullptr, type); + size, [&](auto row) { return valueAt(row); }, nullptr, type); auto vectorWithNulls = makeFlatVector( - size, [](auto row) { return row % 17; }, nullEvery(7), type); + size, [&](auto row) { return valueAt(row); }, nullEvery(7), type); auto rowVector = makeRowVector({vector, vectorWithNulls}); // no nulls auto result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); auto expected = makeFlatVector(size, [](auto row) { auto n = row % 17; return n == 1 || n == 3 || n == 5; @@ -78,7 +81,7 @@ class InPredicateTest : public FunctionBaseTest { // some nulls result = evaluate>( - fmt::format("c1 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); expected = makeFlatVector( size, [](auto row) { @@ -91,9 +94,10 @@ class InPredicateTest : public FunctionBaseTest { // null values in the in-list // The results can be either true or null, but not false. - inList = getInList({1, 3, std::nullopt, 5}, type); + inList = + getInList({valueAt(1), valueAt(3), std::nullopt, valueAt(5)}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); expected = makeFlatVector( size, [](auto /* row */) { return true; }, @@ -105,7 +109,7 @@ class InPredicateTest : public FunctionBaseTest { assertEqualVectors(expected, result); result = evaluate>( - fmt::format("c1 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); expected = makeFlatVector( size, [](auto /* row */) { return true; }, @@ -116,9 +120,9 @@ class InPredicateTest : public FunctionBaseTest { assertEqualVectors(expected, result); - inList = getInList({2, std::nullopt}, type); + inList = getInList({valueAt(2), std::nullopt}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); expected = makeFlatVector( size, [](auto /* row */) { return true; }, @@ -130,7 +134,7 @@ class InPredicateTest : public FunctionBaseTest { assertEqualVectors(expected, result); result = evaluate>( - fmt::format("c1 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); expected = makeFlatVector( size, [](auto /* row */) { return true; }, @@ -173,9 +177,9 @@ class InPredicateTest : public FunctionBaseTest { rowVector = makeRowVector({dict}); - inList = getInList({2, 5, 9}, type); + inList = getInList({valueAt(2), valueAt(5), valueAt(9)}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); assertEqualVectors(expected, result); // an in list with nulls only is always null. @@ -186,12 +190,16 @@ class InPredicateTest : public FunctionBaseTest { } template - void testConstantValues(const TypePtr type = CppToType::create()) { + void testConstantValues( + const TypePtr type = CppToType::create(), + std::function valueAt = [](auto row) { + return row % 17; + }) { const vector_size_t size = 1'000; auto rowVector = makeRowVector( - {makeConstant(static_cast(123), size, type), + {makeConstant(valueAt(123), size, type), BaseVector::createNullConstant(type, size, pool())}); - auto inList = getInList({1, 3, 5}, type); + auto inList = getInList({valueAt(1), valueAt(3), valueAt(5)}, type); auto constTrue = makeConstant(true, size); auto constFalse = makeConstant(false, size); @@ -199,24 +207,24 @@ class InPredicateTest : public FunctionBaseTest { // a miss auto result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); assertEqualVectors(constFalse, result); // null result = evaluate>( - fmt::format("c1 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); assertEqualVectors(constNull, result); // a hit - inList = getInList({1, 123, 5}, type); + inList = getInList({valueAt(1), valueAt(123), valueAt(5)}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c0", inList, type), rowVector); assertEqualVectors(constTrue, result); // a miss that is a null - inList = getInList({1, std::nullopt, 5}, type); + inList = getInList({valueAt(1), std::nullopt, valueAt(5)}, type); result = evaluate>( - fmt::format("c0 IN ({})", inList), rowVector); + makeInExpression("c1", inList, type), rowVector); assertEqualVectors(constNull, result); } @@ -1120,5 +1128,16 @@ TEST_F(InPredicateTest, nans) { testNaNs(); } +TEST_F(InPredicateTest, TimestampWithTimeZone) { + // The millis ranges from 0-17, but after every 17th row we increment the time + // zone ID, so that no two rows have the same millis and time zone. However, + // by the semantics of TimestampWithTimeZone's comparison, it's the same 17 + // values repeated. + auto valueAt = [](auto row) { return pack(row % 17, row / 17); }; + + testValues(TIMESTAMP_WITH_TIME_ZONE(), valueAt); + testConstantValues(TIMESTAMP_WITH_TIME_ZONE(), valueAt); +} + } // namespace } // namespace facebook::velox::functions