diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 96662f969238..236d34d540b7 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -110,6 +110,17 @@ class JoinFuzzer { numGroups(_numGroups) {} }; + struct JoinData { + core::JoinType joinType; + bool nullAware; + std::vector probeKeys; + std::vector buildKeys; + core::PlanNodePtr probeInput; + std::vector buildInput; + std::vector outputColumns; + std::string filter; + }; + static core::PlanNodePtr tryFlipJoinSides(const core::HashJoinNode& joinNode); static core::PlanNodePtr tryFlipJoinSides( const core::MergeJoinNode& joinNode); @@ -160,6 +171,13 @@ class JoinFuzzer { const std::vector& outputColumns, const std::string& filter); + // Constructs a cascading multi-join plan with hash join nodes. + // joinDataList[0].probeInput should be a single values node made using the + // same planNodeIdGenerator. + JoinFuzzer::PlanWithSplits makeDefaultPlan( + std::shared_ptr planNodeIdGenerator, + const std::vector& joinDataList); + JoinFuzzer::PlanWithSplits makeMergeJoinPlan( core::JoinType joinType, const std::vector& probeKeys, @@ -169,6 +187,13 @@ class JoinFuzzer { const std::vector& outputColumns, const std::string& filter); + // Constructs a cascading multi-join plan with merge join nodes. + // joinDataList[0].probeInput should be a single values node made using the + // same planNodeIdGenerator. + JoinFuzzer::PlanWithSplits makeMergeJoinPlan( + std::shared_ptr planNodeIdGenerator, + const std::vector& joinDataList); + // Returns a PlanWithSplits for NestedLoopJoin with inputs from Values nodes. // If withFilter is true, uses the equality filter between probeKeys and // buildKeys as the join filter. Uses empty join filter otherwise. @@ -181,6 +206,13 @@ class JoinFuzzer { const std::vector& outputColumns, const std::string& filter); + // Constructs a cascading multi-join plan with nested loop join nodes. + // joinDataList[0].probeInput should be a single values node made using the + // same planNodeIdGenerator. + JoinFuzzer::PlanWithSplits makeNestedLoopJoinPlan( + std::shared_ptr planNodeIdGenerator, + const std::vector& joinDataList); + // Makes the default query plan with table scan as inputs for both probe and // build sides. JoinFuzzer::PlanWithSplits makeDefaultPlanWithTableScan( @@ -762,6 +794,26 @@ JoinFuzzer::PlanWithSplits JoinFuzzer::makeDefaultPlan( return PlanWithSplits{plan}; } +JoinFuzzer::PlanWithSplits makeDefaultPlan( + std::shared_ptr planNodeIdGenerator, + const std::vector& joinDataList) { + VELOX_CHECK_GT(joinDataList.size(), 0); + PlanBuilder plan = PlanBuilder( + /*initialPlanNode=*/joinDataList[0].probeInput, planNodeIdGenerator); + for (const JoinFuzzer::JoinData& joinData : joinDataList) { + plan.hashJoin( + joinData.probeKeys, + joinData.buildKeys, + /*build=*/ + PlanBuilder(planNodeIdGenerator).values(joinData.buildInput).planNode(), + joinData.filter, + joinData.outputColumns, + joinData.joinType, + joinData.nullAware); + } + return JoinFuzzer::PlanWithSplits{plan.planNode()}; +} + JoinFuzzer::PlanWithSplits JoinFuzzer::makeDefaultPlanWithTableScan( core::JoinType joinType, bool nullAware, @@ -896,6 +948,29 @@ JoinFuzzer::PlanWithSplits JoinFuzzer::makeMergeJoinPlan( .planNode()}; } +JoinFuzzer::PlanWithSplits makeMergeJoinPlan( + std::shared_ptr planNodeIdGenerator, + const std::vector& joinDataList) { + VELOX_CHECK_GT(joinDataList.size(), 0); + PlanBuilder plan = PlanBuilder( + /*initialPlanNode=*/joinDataList[0].probeInput, planNodeIdGenerator); + for (const JoinFuzzer::JoinData& joinData : joinDataList) { + plan.orderBy(joinData.probeKeys, false) + .mergeJoin( + joinData.probeKeys, + joinData.buildKeys, + /*build=*/ + PlanBuilder(planNodeIdGenerator) + .values(joinData.buildInput) + .orderBy(joinData.buildKeys, false) + .planNode(), + joinData.filter, + joinData.outputColumns, + joinData.joinType); + } + return JoinFuzzer::PlanWithSplits{plan.planNode()}; +} + JoinFuzzer::PlanWithSplits JoinFuzzer::makeNestedLoopJoinPlan( core::JoinType joinType, const std::vector& probeKeys, @@ -916,6 +991,25 @@ JoinFuzzer::PlanWithSplits JoinFuzzer::makeNestedLoopJoinPlan( .planNode()}; } +JoinFuzzer::PlanWithSplits makeNestedLoopJoinPlan( + std::shared_ptr planNodeIdGenerator, + const std::vector& joinDataList) { + VELOX_CHECK_GT(joinDataList.size(), 0); + PlanBuilder plan = PlanBuilder( + /*initialPlanNode=*/joinDataList[0].probeInput, planNodeIdGenerator); + for (const JoinFuzzer::JoinData& joinData : joinDataList) { + plan.orderBy(joinData.probeKeys, false) + .nestedLoopJoin( + /*right=*/PlanBuilder(planNodeIdGenerator) + .values(joinData.buildInput) + .planNode(), + joinData.filter, + joinData.outputColumns, + joinData.joinType); + } + return JoinFuzzer::PlanWithSplits{plan.planNode()}; +} + void JoinFuzzer::makeAlternativePlans( const core::PlanNodePtr& plan, const std::vector& probeInput,