diff --git a/velox/functions/prestosql/InPredicate.cpp b/velox/functions/prestosql/InPredicate.cpp index 6074a5af6a7b..30bbc958af9b 100644 --- a/velox/functions/prestosql/InPredicate.cpp +++ b/velox/functions/prestosql/InPredicate.cpp @@ -47,18 +47,18 @@ class GenericInPredicate : public exec::VectorFunction { return; } - const auto valueBaseRow = value->index(row); - if (valueBase->containsNullAt(valueBaseRow)) { - boolResult->setNull(row, true); - return; - } - const auto arrayRow = inList->index(row); const auto offset = inListBaseArray->offsetAt(arrayRow); const auto size = inListBaseArray->sizeAt(arrayRow); VELOX_USER_CHECK_GT(size, 0, "IN list must not be empty"); + const auto valueBaseRow = value->index(row); + if (valueBase->containsNullAt(valueBaseRow)) { + boolResult->setNull(row, true); + return; + } + bool hasNull = false; for (auto i = 0; i < size; ++i) { if (inListElements->containsNullAt(offset + i)) { diff --git a/velox/functions/prestosql/tests/InPredicateTest.cpp b/velox/functions/prestosql/tests/InPredicateTest.cpp index be82d5f21a04..88608a605b85 100644 --- a/velox/functions/prestosql/tests/InPredicateTest.cpp +++ b/velox/functions/prestosql/tests/InPredicateTest.cpp @@ -975,5 +975,56 @@ TEST_F(InPredicateTest, nonConstantInList) { assertEqualVectors(expected, result, rows); } +TEST_F(InPredicateTest, nonConstantComplexInList) { + auto data = makeRowVector({ + makeArrayVectorFromJson({ + "null", + "[1, null, 3]", + "[1, null, 3]", + "[1, 2, 3]", + }), + makeArrayVector( + {0, 1, 1, 1}, + makeArrayVectorFromJson({ + "[1, 2, 3]", + "[1, 2, 3]", + }), + {1}), + }); + + auto expected = makeNullableFlatVector({ + std::nullopt, // Input is null + std::nullopt, // in-list is null + std::nullopt, // in-list is empty + true, + }); + + auto in = std::make_shared( + BOOLEAN(), + std::vector{ + field(ARRAY(INTEGER()), "c0"), + field(ARRAY(ARRAY(INTEGER())), "c1"), + }, + "in"); + + auto tryIn = std::make_shared( + BOOLEAN(), std::vector{in}, "try"); + + // Evaluate "in" on all rows. Expect an error. + VELOX_ASSERT_THROW(evaluate(in, data), "IN list must not be empty"); + + // Evaluate "try(in)" on all rows. + auto result = evaluate(tryIn, data); + assertEqualVectors(expected, result); + + // Evaluate "in" on a subset of rows that should not generate an error. + SelectivityVector rows(data->size()); + rows.setValid(2, false); + rows.updateBounds(); + + result = evaluate(in, data, rows); + assertEqualVectors(expected, result, rows); +} + } // namespace } // namespace facebook::velox::functions