Skip to content

Commit

Permalink
Fix Presto checksum aggregate for complex types with null rows (faceb…
Browse files Browse the repository at this point in the history
…ookincubator#10765)

Summary:
Pull Request resolved: facebookincubator#10765

The Presto checksum aggregate function encounters crashes when calculating checksums on vectors containing dictionary-encoded complex types with null rows. The issue arises because the function fails to consider null values in the decoded vector, leading to attempts to access uninitialized indices, which causes the crashes. This PR addresses the issue by properly accounting for nulls in the vector, thereby resolving the bug.

Differential Revision: D61321414
  • Loading branch information
Krishna Pai authored and facebook-github-bot committed Aug 23, 2024
1 parent b228e09 commit 1611960
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 6 deletions.
31 changes: 25 additions & 6 deletions velox/functions/prestosql/aggregates/PrestoHasher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ safeXor(const int64_t& hash, const int64_t& a, const int64_t& b) {
return hash + (a ^ b);
}

void decodedSelectivityVector(
const SelectivityVector& rows,
std::shared_ptr<DecodedVector>& decodedVector,
SelectivityVector& decodedRows) {
VELOX_CHECK_GE(decodedRows.end(), decodedVector->size());
decodedRows.intersect(rows);

// Remove all nulls .
if (decodedVector->nulls(&rows)) {
decodedRows.deselectNulls(decodedVector->nulls(), 0, decodedVector->size());
}
}

} // namespace

template <TypeKind kind>
Expand Down Expand Up @@ -233,8 +246,11 @@ void PrestoHasher::hash<TypeKind::ARRAY>(
BufferPtr& hashes) {
auto baseArray = vector_->base()->as<ArrayVector>();
auto indices = vector_->indices();
auto decodedRows = SelectivityVector(rows.size(), true);
decodedSelectivityVector(rows, vector_, decodedRows);

auto elementRows = functions::toElementRows(
baseArray->elements()->size(), rows, baseArray, indices);
baseArray->elements()->size(), decodedRows, baseArray, indices);

BufferPtr elementHashes =
AlignedBuffer::allocate<int64_t>(elementRows.end(), baseArray->pool());
Expand All @@ -243,13 +259,13 @@ void PrestoHasher::hash<TypeKind::ARRAY>(

auto rawSizes = baseArray->rawSizes();
auto rawOffsets = baseArray->rawOffsets();
auto rawNulls = baseArray->rawNulls();
auto rawElementHashes = elementHashes->as<int64_t>();
auto rawHashes = hashes->asMutable<int64_t>();
auto decodedNulls = vector_->nulls();

rows.applyToSelected([&](auto row) {
int64_t hash = 0;
if (!(rawNulls && bits::isBitNull(rawNulls, indices[row]))) {
if (!((decodedNulls && bits::isBitNull(decodedNulls, row)))) {
auto size = rawSizes[indices[row]];
auto offset = rawOffsets[indices[row]];

Expand All @@ -269,8 +285,11 @@ void PrestoHasher::hash<TypeKind::MAP>(
auto indices = vector_->indices();
VELOX_CHECK_EQ(children_.size(), 2)

auto decodedRows = SelectivityVector(rows.size(), true);
decodedSelectivityVector(rows, vector_, decodedRows);

auto elementRows = functions::toElementRows(
baseMap->mapKeys()->size(), rows, baseMap, indices);
baseMap->mapKeys()->size(), decodedRows, baseMap, indices);
BufferPtr keyHashes =
AlignedBuffer::allocate<int64_t>(elementRows.end(), baseMap->pool());

Expand All @@ -286,11 +305,11 @@ void PrestoHasher::hash<TypeKind::MAP>(

auto rawSizes = baseMap->rawSizes();
auto rawOffsets = baseMap->rawOffsets();
auto rawNulls = baseMap->rawNulls();
auto decodedNulls = vector_->nulls();

rows.applyToSelected([&](auto row) {
int64_t hash = 0;
if (!(rawNulls && bits::isBitNull(rawNulls, indices[row]))) {
if (!((decodedNulls && bits::isBitNull(decodedNulls, row)))) {
auto size = rawSizes[indices[row]];
auto offset = rawOffsets[indices[row]];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,42 @@ TEST_F(ChecksumAggregateTest, unknown) {
assertChecksum(data, "vBwbUFiJq80=");
}

TEST_F(ChecksumAggregateTest, complexVectorWithNulls) {
// Create a dictionary on a map vector with null rows.
auto baseMap = makeMapVectorFromJson<int32_t, int64_t>({
"{1: 10, 2: null, 3: 30}",
});

auto dictionarySize = baseMap->size() * 3;
// Set bad index for null value.
auto indexBuffer = makeIndices(dictionarySize, [baseMap](auto row) {
return row % 7 == 0 ? -1000 : row % baseMap->size();
});
auto nulls = makeNulls(dictionarySize, [](auto row) { return row % 7 == 0; });
auto dictionary =
BaseVector::wrapInDictionary(nulls, indexBuffer, dictionarySize, baseMap);

auto row = makeRowVector({dictionary});

assertChecksum(row, "r4PlPOShD0w=");

// Create a dictionary on a array vector with null rows.
auto baseArray = makeArrayVectorFromJson<int64_t>({
"[1, 2, null, 3, 4]",
});

dictionarySize = baseArray->size() * 3;
// Set bad index for null value.
indexBuffer = makeIndices(dictionarySize, [baseArray](auto row) {
return row % 7 == 0 ? -1000 : row % baseArray->size();
});
nulls = makeNulls(dictionarySize, [](auto row) { return row % 7 == 0; });
dictionary = BaseVector::wrapInDictionary(
nulls, indexBuffer, dictionarySize, baseArray);

row = makeRowVector({dictionary});

assertChecksum(row, "i5mk/hSs+AQ=");
}

} // namespace facebook::velox::aggregate::test

0 comments on commit 1611960

Please sign in to comment.