Skip to content

Commit

Permalink
fix(hashjoin): Resolve dangling pointer access in listNullKeyRows for…
Browse files Browse the repository at this point in the history
… hash mode tables
  • Loading branch information
zhli1142015 committed Jan 22, 2025
1 parent f81a8c4 commit 9ea00a3
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
4 changes: 4 additions & 0 deletions velox/exec/HashTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ bool HashTable<ignoreNullKeys>::compareKeys(
const char* group,
HashLookup& lookup,
vector_size_t row) {
if (lookup.probeForNullKeyOnly) {
return true;
}
int32_t numKeys = lookup.hashers.size();
// The loop runs at least once. Allow for first comparison to fail
// before loop end check.
Expand Down Expand Up @@ -2025,6 +2028,7 @@ int32_t HashTable<false>::listNullKeyRows(
VELOX_CHECK_GT(nextOffset_, 0);
VELOX_CHECK_EQ(hashers_.size(), 1);
HashLookup lookup(hashers_);
lookup.probeForNullKeyOnly = true;
if (hashMode_ == HashMode::kHash) {
lookup.hashes.push_back(VectorHasher::kNullHash);
} else {
Expand Down
3 changes: 3 additions & 0 deletions velox/exec/HashTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ struct HashLookup {
/// If using valueIds, list of concatenated valueIds. 1:1 with 'hashes'.
/// Populated by groupProbe and joinProbe.
raw_vector<uint64_t> normalizedKeys;

/// If true, only probe for null keys. Used by listNullKeyRows.
bool probeForNullKeyOnly{false};
};

struct HashTableStats {
Expand Down
45 changes: 45 additions & 0 deletions velox/exec/tests/HashJoinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,51 @@ TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterAndNullKey) {
}
}

TEST_P(
MultiThreadedHashJoinTest,
hashModeNullAwareAntiJoinWithFilterAndNullKey) {
// Use float type keys to trigger hash mode table.
auto probeVectors = makeBatches(50, [&](int32_t /*unused*/) {
return makeRowVector(
{"t0", "t1"},
{
makeNullableFlatVector<float>({std::nullopt, 1, 2}),
makeFlatVector<int32_t>({1, 1, 2}),
});
});
auto buildVectors = makeBatches(5, [&](int32_t /*unused*/) {
return makeRowVector(
{"u0", "u1"},
{
makeNullableFlatVector<float>({std::nullopt, 2, 3}),
makeFlatVector<int32_t>({0, 2, 3}),
});
});

std::vector<std::string> filters({"u1 < t1", "u1 + t1 = 0"});
for (const std::string& filter : filters) {
const auto referenceSql = fmt::format(
"SELECT t.* FROM t WHERE t0 NOT IN (SELECT u0 FROM u WHERE {})",
filter);

auto testProbeVectors = probeVectors;
auto testBuildVectors = buildVectors;
HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get())
.numDrivers(numDrivers_)
.probeKeys({"t0"})
.probeVectors(std::move(testProbeVectors))
.buildKeys({"u0"})
.buildVectors(std::move(testBuildVectors))
.joinType(core::JoinType::kAnti)
.nullAware(true)
.joinFilter(filter)
.joinOutputLayout({"t0", "t1"})
.referenceQuery(referenceSql)
.checkSpillStats(false)
.run();
}
}

TEST_P(MultiThreadedHashJoinTest, nullAwareAntiJoinWithFilterOnNullableColumn) {
const std::string referenceSql =
"SELECT t.* FROM t WHERE t0 NOT IN (SELECT u0 FROM u WHERE t1 <> u1)";
Expand Down

0 comments on commit 9ea00a3

Please sign in to comment.