diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index 60ad9d69c329..a788cc8e2ae5 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -428,7 +428,8 @@ void HashProbe::asyncWaitForHashTable() { } } else if ( (isInnerJoin(joinType_) || isLeftSemiFilterJoin(joinType_) || - isRightSemiFilterJoin(joinType_) || isRightSemiProjectJoin(joinType_)) && + isRightSemiFilterJoin(joinType_) || + (isRightSemiProjectJoin(joinType_) && !nullAware_)) && table_->hashMode() != BaseHashTable::HashMode::kHash && !isSpillInput() && !hasMoreSpillData()) { // Find out whether there are any upstream operators that can accept dynamic @@ -443,13 +444,9 @@ void HashProbe::asyncWaitForHashTable() { const auto channels = operatorCtx_->driverCtx()->driver->canPushdownFilters( this, keyChannels_); - // Null aware Right Semi Project join needs to know whether there are any - // nulls on the probe side. Hence, cannot filter these out. - const auto nullAllowed = isRightSemiProjectJoin(joinType_) && nullAware_; - for (auto i = 0; i < keyChannels_.size(); ++i) { if (channels.find(keyChannels_[i]) != channels.end()) { - if (auto filter = buildHashers[i]->getFilter(nullAllowed)) { + if (auto filter = buildHashers[i]->getFilter(/*nullAllowed=*/false)) { dynamicFilters_.emplace(keyChannels_[i], std::move(filter)); } } diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index a0b92c178717..0d423a98f214 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -3294,62 +3294,80 @@ TEST_P(MultiThreadedHashJoinTest, noSpillLevelLimit) { .run(); } -// Verify that dynamic filter pushed down from null-aware right semi project -// join into table scan doesn't filter out nulls. +// Verify that dynamic filter pushed down is turned off for null-aware right +// semi project join. TEST_F(HashJoinTest, nullAwareRightSemiProjectOverScan) { - auto probe = makeRowVector( + std::vector probes; + std::vector builds; + // Matches present: + probes.push_back(makeRowVector( {"t0"}, { makeNullableFlatVector({1, std::nullopt, 2}), - }); + })); + builds.push_back(makeRowVector( + {"u0"}, + { + makeNullableFlatVector({1, 2, 3, std::nullopt}), + })); - auto build = makeRowVector( + // No matches present: + probes.push_back(makeRowVector( + {"t0"}, + { + makeFlatVector({5, 6}), + })); + builds.push_back(makeRowVector( {"u0"}, { makeNullableFlatVector({1, 2, 3, std::nullopt}), - }); + })); - std::shared_ptr probeFile = TempFilePath::create(); - writeToFile(probeFile->getPath(), {probe}); + for (int i = 0; i < probes.size(); i++) { + RowVectorPtr& probe = probes[i]; + RowVectorPtr& build = builds[i]; + std::shared_ptr probeFile = TempFilePath::create(); + writeToFile(probeFile->getPath(), {probe}); - std::shared_ptr buildFile = TempFilePath::create(); - writeToFile(buildFile->getPath(), {build}); + std::shared_ptr buildFile = TempFilePath::create(); + writeToFile(buildFile->getPath(), {build}); - createDuckDbTable("t", {probe}); - createDuckDbTable("u", {build}); + createDuckDbTable("t", {probe}); + createDuckDbTable("u", {build}); - core::PlanNodeId probeScanId; - core::PlanNodeId buildScanId; - auto planNodeIdGenerator = std::make_shared(); - auto plan = PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(probe->type())) - .capturePlanNodeId(probeScanId) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .tableScan(asRowType(build->type())) - .capturePlanNodeId(buildScanId) - .planNode(), - "", - {"u0", "match"}, - core::JoinType::kRightSemiProject, - true /*nullAware*/) - .planNode(); + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(probe->type())) + .capturePlanNodeId(probeScanId) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .tableScan(asRowType(build->type())) + .capturePlanNodeId(buildScanId) + .planNode(), + "", + {"u0", "match"}, + core::JoinType::kRightSemiProject, + true /*nullAware*/) + .planNode(); - SplitInput splitInput = { - {probeScanId, - {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}}, - {buildScanId, - {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, - }; + SplitInput splitInput = { + {probeScanId, + {exec::Split(makeHiveConnectorSplit(probeFile->getPath()))}}, + {buildScanId, + {exec::Split(makeHiveConnectorSplit(buildFile->getPath()))}}, + }; - HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) - .planNode(plan) - .inputSplits(splitInput) - .checkSpillStats(false) - .referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u") - .run(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .planNode(plan) + .inputSplits(splitInput) + .checkSpillStats(false) + .referenceQuery("SELECT u0, u0 IN (SELECT t0 FROM t) FROM u") + .run(); + } } TEST_F(HashJoinTest, duplicateJoinKeys) {