diff --git a/velox/functions/prestosql/aggregates/PrestoHasher.cpp b/velox/functions/prestosql/aggregates/PrestoHasher.cpp index 277f25d965d8c..66781ef2a45b5 100644 --- a/velox/functions/prestosql/aggregates/PrestoHasher.cpp +++ b/velox/functions/prestosql/aggregates/PrestoHasher.cpp @@ -110,6 +110,19 @@ safeXor(const int64_t& hash, const int64_t& a, const int64_t& b) { return hash + (a ^ b); } +void decodedSelectivityVector( + const SelectivityVector& rows, + std::shared_ptr& decodedVector, + SelectivityVector& decodedRows) { + VELOX_CHECK_GE(decodedRows.end(), decodedVector->size()); + decodedRows.intersect(rows); + + // Remove all nulls . + if (decodedVector->nulls(&rows)) { + decodedRows.deselectNulls(decodedVector->nulls(), 0, decodedVector->size()); + } +} + } // namespace template @@ -233,8 +246,11 @@ void PrestoHasher::hash( BufferPtr& hashes) { auto baseArray = vector_->base()->as(); auto indices = vector_->indices(); + auto decodedRows = SelectivityVector(rows.size(), true); + decodedSelectivityVector(rows, vector_, decodedRows); + auto elementRows = functions::toElementRows( - baseArray->elements()->size(), rows, baseArray, indices); + baseArray->elements()->size(), decodedRows, baseArray, indices); BufferPtr elementHashes = AlignedBuffer::allocate(elementRows.end(), baseArray->pool()); @@ -243,13 +259,13 @@ void PrestoHasher::hash( auto rawSizes = baseArray->rawSizes(); auto rawOffsets = baseArray->rawOffsets(); - auto rawNulls = baseArray->rawNulls(); auto rawElementHashes = elementHashes->as(); auto rawHashes = hashes->asMutable(); + auto decodedNulls = vector_->nulls(); rows.applyToSelected([&](auto row) { int64_t hash = 0; - if (!(rawNulls && bits::isBitNull(rawNulls, indices[row]))) { + if (!((decodedNulls && bits::isBitNull(decodedNulls, row)))) { auto size = rawSizes[indices[row]]; auto offset = rawOffsets[indices[row]]; @@ -269,8 +285,11 @@ void PrestoHasher::hash( auto indices = vector_->indices(); VELOX_CHECK_EQ(children_.size(), 2) + auto decodedRows = SelectivityVector(rows.size(), true); + decodedSelectivityVector(rows, vector_, decodedRows); + auto elementRows = functions::toElementRows( - baseMap->mapKeys()->size(), rows, baseMap, indices); + baseMap->mapKeys()->size(), decodedRows, baseMap, indices); BufferPtr keyHashes = AlignedBuffer::allocate(elementRows.end(), baseMap->pool()); @@ -286,11 +305,11 @@ void PrestoHasher::hash( auto rawSizes = baseMap->rawSizes(); auto rawOffsets = baseMap->rawOffsets(); - auto rawNulls = baseMap->rawNulls(); + auto decodedNulls = vector_->nulls(); rows.applyToSelected([&](auto row) { int64_t hash = 0; - if (!(rawNulls && bits::isBitNull(rawNulls, indices[row]))) { + if (!((decodedNulls && bits::isBitNull(decodedNulls, row)))) { auto size = rawSizes[indices[row]]; auto offset = rawOffsets[indices[row]]; diff --git a/velox/functions/prestosql/aggregates/tests/ChecksumAggregateTest.cpp b/velox/functions/prestosql/aggregates/tests/ChecksumAggregateTest.cpp index fb877881d921d..65c8a20a34d8f 100644 --- a/velox/functions/prestosql/aggregates/tests/ChecksumAggregateTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/ChecksumAggregateTest.cpp @@ -384,4 +384,42 @@ TEST_F(ChecksumAggregateTest, unknown) { assertChecksum(data, "vBwbUFiJq80="); } +TEST_F(ChecksumAggregateTest, complexVectorWithNulls) { + // Create a dictionary on a map vector with null rows. + auto baseMap = makeMapVectorFromJson({ + "{1: 10, 2: null, 3: 30}", + }); + + auto dictionarySize = baseMap->size() * 3; + // Set bad index for null value. + auto indexBuffer = makeIndices(dictionarySize, [baseMap](auto row) { + return row % 7 == 0 ? -1000 : row % baseMap->size(); + }); + auto nulls = makeNulls(dictionarySize, [](auto row) { return row % 7 == 0; }); + auto dictionary = + BaseVector::wrapInDictionary(nulls, indexBuffer, dictionarySize, baseMap); + + auto row = makeRowVector({dictionary}); + + assertChecksum(row, "r4PlPOShD0w="); + + // Create a dictionary on a array vector with null rows. + auto baseArray = makeArrayVectorFromJson({ + "[1, 2, null, 3, 4]", + }); + + dictionarySize = baseArray->size() * 3; + // Set bad index for null value. + indexBuffer = makeIndices(dictionarySize, [baseArray](auto row) { + return row % 7 == 0 ? -1000 : row % baseArray->size(); + }); + nulls = makeNulls(dictionarySize, [](auto row) { return row % 7 == 0; }); + dictionary = BaseVector::wrapInDictionary( + nulls, indexBuffer, dictionarySize, baseArray); + + row = makeRowVector({dictionary}); + + assertChecksum(row, "i5mk/hSs+AQ="); +} + } // namespace facebook::velox::aggregate::test