From 7d5a0517e8951e5cbefb23f96f51a92d1bd4c3e6 Mon Sep 17 00:00:00 2001 From: duanmeng Date: Wed, 16 Oct 2024 19:43:25 +0800 Subject: [PATCH] Add hash join replayer --- velox/exec/HashBuild.cpp | 1 + velox/exec/HashProbe.cpp | 3 + velox/exec/Operator.cpp | 5 + velox/exec/QueryTraceTraits.h | 8 + velox/tool/trace/CMakeLists.txt | 3 +- velox/tool/trace/HashJoinReplayer.cpp | 45 +++ velox/tool/trace/HashJoinReplayer.h | 47 ++++ velox/tool/trace/OperatorReplayerBase.cpp | 2 +- velox/tool/trace/OperatorReplayerBase.h | 4 + velox/tool/trace/QueryReplayer.cpp | 8 + velox/tool/trace/tests/CMakeLists.txt | 2 +- .../tool/trace/tests/HashJoinReplayerTest.cpp | 256 ++++++++++++++++++ 12 files changed, 381 insertions(+), 3 deletions(-) create mode 100644 velox/tool/trace/HashJoinReplayer.cpp create mode 100644 velox/tool/trace/HashJoinReplayer.h create mode 100644 velox/tool/trace/tests/HashJoinReplayerTest.cpp diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index 498f4569c627e..c86f444afbbd9 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -312,6 +312,7 @@ void HashBuild::removeInputRowsForAntiJoinFilter() { } void HashBuild::addInput(RowVectorPtr input) { + traceInput(input); checkRunning(); ensureInputFits(input); diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index 3b9dc93f1c786..3a4849e006cc1 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -600,6 +600,9 @@ void HashProbe::decodeAndDetectNonNullKeys() { } void HashProbe::addInput(RowVectorPtr input) { + if (!isSpillInput()) { + traceInput(input); + } if (skipInput_) { VELOX_CHECK_NULL(input_); return; diff --git a/velox/exec/Operator.cpp b/velox/exec/Operator.cpp index 3b65536ba04cd..18b0a22f609e8 100644 --- a/velox/exec/Operator.cpp +++ b/velox/exec/Operator.cpp @@ -121,6 +121,11 @@ void Operator::maybeSetTracer() { return; } + if (trace::QueryTraceTraits::kSupportedOperatorTypes.count(operatorType()) == + 0) { + VELOX_UNSUPPORTED("{} does not support tracing", operatorType()); + } + auto& tracedOpMap = operatorCtx_->driverCtx()->tracedOperatorMap; if (const auto iter = tracedOpMap.find(operatorId()); iter != tracedOpMap.end()) { diff --git a/velox/exec/QueryTraceTraits.h b/velox/exec/QueryTraceTraits.h index ad817115b4902..a11e85be65789 100644 --- a/velox/exec/QueryTraceTraits.h +++ b/velox/exec/QueryTraceTraits.h @@ -31,5 +31,13 @@ struct QueryTraceTraits { static inline const std::string kQueryMetaFileName = "query_meta.json"; static inline const std::string kDataSummaryFileName = "data_summary.json"; static inline const std::string kDataFileName = "trace.data"; + + static inline const std::unordered_set kSupportedOperatorTypes{ + "TableWrite", + "Aggregation", + "PartialAggregation", + "PartitionedOutput", + "HashBuild", + "HashProbe"}; }; } // namespace facebook::velox::exec::trace diff --git a/velox/tool/trace/CMakeLists.txt b/velox/tool/trace/CMakeLists.txt index 3218b4ee30b0b..98c0f6d524227 100644 --- a/velox/tool/trace/CMakeLists.txt +++ b/velox/tool/trace/CMakeLists.txt @@ -17,7 +17,8 @@ velox_add_library( AggregationReplayer.cpp OperatorReplayerBase.cpp PartitionedOutputReplayer.cpp - TableWriterReplayer.cpp) + TableWriterReplayer.cpp + HashJoinReplayer.cpp) velox_link_libraries( velox_query_trace_replayer_base diff --git a/velox/tool/trace/HashJoinReplayer.cpp b/velox/tool/trace/HashJoinReplayer.cpp new file mode 100644 index 0000000000000..9e871bfd5fbab --- /dev/null +++ b/velox/tool/trace/HashJoinReplayer.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/tool/trace/HashJoinReplayer.h" +#include "velox/exec/QueryTraceUtil.h" +#include "velox/exec/tests/utils/PlanBuilder.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; + +namespace facebook::velox::tool::trace { +core::PlanNodePtr HashJoinReplayer::createPlanNode( + const core::PlanNode* node, + const core::PlanNodeId& nodeId, + const core::PlanNodePtr& source) const { + const auto* hashJoinNode = dynamic_cast(node); + return std::make_shared( + nodeId, + hashJoinNode->joinType(), + hashJoinNode->isNullAware(), + hashJoinNode->leftKeys(), + hashJoinNode->rightKeys(), + hashJoinNode->filter(), + source, + PlanBuilder(planNodeIdGenerator_) + .traceScan( + nodeDir_, exec::trace::getDataType(planFragment_, nodeId_, 1)) + .planNode(), + hashJoinNode->outputType()); +} +} // namespace facebook::velox::tool::trace diff --git a/velox/tool/trace/HashJoinReplayer.h b/velox/tool/trace/HashJoinReplayer.h new file mode 100644 index 0000000000000..e77811759398c --- /dev/null +++ b/velox/tool/trace/HashJoinReplayer.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/core/PlanNode.h" +#include "velox/tool/trace/OperatorReplayerBase.h" + +#include + +namespace facebook::velox::tool::trace { +/// The replayer to replay the traced 'HashJoin' operator. +class HashJoinReplayer : public OperatorReplayerBase { + public: + HashJoinReplayer( + const std::string& rootDir, + const std::string& taskId, + const std::string& nodeId, + const int32_t pipelineId, + const std::string& operatorType) + : OperatorReplayerBase( + rootDir, + taskId, + nodeId, + pipelineId, + operatorType) {} + + private: + core::PlanNodePtr createPlanNode( + const core::PlanNode* node, + const core::PlanNodeId& nodeId, + const core::PlanNodePtr& source) const override; +}; +} // namespace facebook::velox::tool::trace diff --git a/velox/tool/trace/OperatorReplayerBase.cpp b/velox/tool/trace/OperatorReplayerBase.cpp index b54ad323a4471..ff0af99de061b 100644 --- a/velox/tool/trace/OperatorReplayerBase.cpp +++ b/velox/tool/trace/OperatorReplayerBase.cpp @@ -68,7 +68,7 @@ core::PlanNodePtr OperatorReplayerBase::createPlan() const { const auto* replayNode = core::PlanNode::findFirstNode( planFragment_.get(), [this](const core::PlanNode* node) { return node->id() == nodeId_; }); - return exec::test::PlanBuilder() + return exec::test::PlanBuilder(planNodeIdGenerator_) .traceScan(nodeDir_, exec::trace::getDataType(planFragment_, nodeId_)) .addNode(replayNodeFactory(replayNode)) .planNode(); diff --git a/velox/tool/trace/OperatorReplayerBase.h b/velox/tool/trace/OperatorReplayerBase.h index 863a8f5f80f69..62a217d6ef2e1 100644 --- a/velox/tool/trace/OperatorReplayerBase.h +++ b/velox/tool/trace/OperatorReplayerBase.h @@ -18,6 +18,7 @@ #include "velox/common/file/FileSystems.h" #include "velox/core/PlanNode.h" +#include "velox/exec/tests/utils/PlanBuilder.h" namespace facebook::velox::exec { class Task; @@ -59,6 +60,9 @@ class OperatorReplayerBase { const std::string taskDir_; const std::string nodeDir_; + const std::shared_ptr planNodeIdGenerator_ = + std::make_shared(); + std::unordered_map queryConfigs_; std::unordered_map> connectorConfigs_; diff --git a/velox/tool/trace/QueryReplayer.cpp b/velox/tool/trace/QueryReplayer.cpp index a0680d30a387a..234054da8ce5b 100644 --- a/velox/tool/trace/QueryReplayer.cpp +++ b/velox/tool/trace/QueryReplayer.cpp @@ -40,6 +40,7 @@ #include "velox/parse/ExpressionsParser.h" #include "velox/parse/TypeResolver.h" #include "velox/tool/trace/AggregationReplayer.h" +#include "velox/tool/trace/HashJoinReplayer.h" #include "velox/tool/trace/OperatorReplayerBase.h" #include "velox/tool/trace/PartitionedOutputReplayer.h" #include "velox/tool/trace/TableWriterReplayer.h" @@ -148,6 +149,13 @@ std::unique_ptr createReplayer() { FLAGS_node_id, FLAGS_pipeline_id, FLAGS_operator_type); + } else if (FLAGS_operator_type == "HashJoin") { + replayer = std::make_unique( + FLAGS_root_dir, + FLAGS_task_id, + FLAGS_node_id, + FLAGS_pipeline_id, + FLAGS_operator_type); } else { VELOX_UNSUPPORTED("Unsupported operator type: {}", FLAGS_operator_type); } diff --git a/velox/tool/trace/tests/CMakeLists.txt b/velox/tool/trace/tests/CMakeLists.txt index 03d1705eebe76..42197fff233fc 100644 --- a/velox/tool/trace/tests/CMakeLists.txt +++ b/velox/tool/trace/tests/CMakeLists.txt @@ -15,7 +15,7 @@ add_executable( velox_tool_trace_test AggregationReplayerTest.cpp PartitionedOutputReplayerTest.cpp - TableWriterReplayerTest.cpp) + TableWriterReplayerTest.cpp HashJoinReplayerTest.cpp) add_test( NAME velox_tool_trace_test diff --git a/velox/tool/trace/tests/HashJoinReplayerTest.cpp b/velox/tool/trace/tests/HashJoinReplayerTest.cpp new file mode 100644 index 0000000000000..5552bc7cca0f0 --- /dev/null +++ b/velox/tool/trace/tests/HashJoinReplayerTest.cpp @@ -0,0 +1,256 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include + +#include "folly/experimental/EventCount.h" +#include "velox/common/file/FileSystems.h" +#include "velox/common/hyperloglog/SparseHll.h" +#include "velox/common/testutil/TestValue.h" +#include "velox/dwio/dwrf/writer/Writer.h" +#include "velox/exec/PartitionFunction.h" +#include "velox/exec/QueryDataReader.h" +#include "velox/exec/QueryTraceUtil.h" +#include "velox/exec/TableWriter.h" +#include "velox/exec/tests/utils/ArbitratorTestUtil.h" +#include "velox/exec/tests/utils/AssertQueryBuilder.h" +#include "velox/exec/tests/utils/HiveConnectorTestBase.h" +#include "velox/exec/tests/utils/PlanBuilder.h" +#include "velox/exec/tests/utils/TempDirectoryPath.h" +#include "velox/serializers/PrestoSerializer.h" +#include "velox/tool/trace/HashJoinReplayer.h" + +#include + +#include "velox/vector/fuzzer/VectorFuzzer.h" +#include "velox/vector/tests/utils/VectorTestBase.h" + +using namespace facebook::velox; +using namespace facebook::velox::core; +using namespace facebook::velox::common; +using namespace facebook::velox::exec; +using namespace facebook::velox::exec::test; +using namespace facebook::velox::connector; +using namespace facebook::velox::connector::hive; +using namespace facebook::velox::dwio::common; +using namespace facebook::velox::common::testutil; +using namespace facebook::velox::common::hll; + +namespace facebook::velox::tool::trace::test { +class HashJoinReplayerTest : public HiveConnectorTestBase { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + HiveConnectorTestBase::SetUpTestCase(); + filesystems::registerLocalFileSystem(); + if (!isRegisteredVectorSerde()) { + serializer::presto::PrestoVectorSerde::registerVectorSerde(); + } + Type::registerSerDe(); + common::Filter::registerSerDe(); + connector::hive::HiveTableHandle::registerSerDe(); + connector::hive::LocationHandle::registerSerDe(); + connector::hive::HiveColumnHandle::registerSerDe(); + connector::hive::HiveInsertTableHandle::registerSerDe(); + connector::hive::HiveConnectorSplit::registerSerDe(); + core::PlanNode::registerSerDe(); + core::ITypedExpr::registerSerDe(); + registerPartitionFunctionSerDe(); + } + + struct PlanWithSplits { + core::PlanNodePtr plan; + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + std::unordered_map> splits; + + explicit PlanWithSplits( + const core::PlanNodePtr& _plan, + const core::PlanNodeId& _probeScanId = "", + const core::PlanNodeId& _buildScanId = "", + const std::unordered_map< + core::PlanNodeId, + std::vector>& _splits = {}) + : plan(_plan), + probeScanId(_probeScanId), + buildScanId(_buildScanId), + splits(_splits) {} + }; + + int32_t randInt(int32_t min, int32_t max) { + return boost::random::uniform_int_distribution(min, max)(rng_); + } + + RowTypePtr concat(const RowTypePtr& a, const RowTypePtr& b) { + std::vector names = a->names(); + std::vector types = a->children(); + + for (auto i = 0; i < b->size(); ++i) { + names.push_back(b->nameOf(i)); + types.push_back(b->childAt(i)); + } + + return ROW(std::move(names), std::move(types)); + } + + std::vector + makeVectors(int32_t count, int32_t rowsPerVector, const RowTypePtr& rowType) { + return HiveConnectorTestBase::makeVectors(rowType, count, rowsPerVector); + } + + std::vector makeSplits( + const std::vector& inputs, + const std::string& path, + memory::MemoryPool* writerPool) { + std::vector splits; + for (auto i = 0; i < 4; ++i) { + const std::string filePath = fmt::format("{}/{}", path, i); + writeToFile(filePath, inputs); + splits.emplace_back(makeHiveConnectorSplit(filePath)); + } + + return splits; + } + + PlanWithSplits createPlan( + const std::string& tableDir, + core::JoinType joinType, + const std::vector& probeKeys, + const std::vector& buildKeys, + const std::vector& probeInput, + const std::vector& buildInput, + const std::vector& outputColumns) { + auto planNodeIdGenerator = std::make_shared(); + const std::vector probeSplits = + makeSplits(probeInput, fmt::format("{}/probe", tableDir), pool()); + const std::vector buildSplits = + makeSplits(buildInput, fmt::format("{}/build", tableDir), pool()); + core::PlanNodeId probeScanId; + core::PlanNodeId buildScanId; + auto plan = PlanBuilder(planNodeIdGenerator) + .tableScan(probeType_) + .capturePlanNodeId(probeScanId) + .hashJoin( + probeKeys, + buildKeys, + PlanBuilder(planNodeIdGenerator) + .tableScan(buildType_) + .capturePlanNodeId(buildScanId) + .planNode(), + /*filter=*/"", + outputColumns, + joinType, + false) + .capturePlanNodeId(traceNodeId_) + .planNode(); + return PlanWithSplits{ + plan, + probeScanId, + buildScanId, + {{probeScanId, probeSplits}, {buildScanId, buildSplits}}}; + } + + core::PlanNodeId traceNodeId_; + std::mt19937 rng_; + const RowTypePtr probeType_{ + ROW({"t0", "t1", "t2", "t3"}, {BIGINT(), INTEGER(), SMALLINT(), REAL()})}; + + const RowTypePtr buildType_{ + ROW({"u0", "u1", "u2", "u3"}, {BIGINT(), INTEGER(), SMALLINT(), REAL()})}; +}; + +TEST_F(HashJoinReplayerTest, test) { + static const std::vector kJoinTypes = { + core::JoinType::kInner, + core::JoinType::kLeft, + core::JoinType::kRight, + core::JoinType::kFull, + core::JoinType::kLeftSemiFilter, + core::JoinType::kLeftSemiProject, + core::JoinType::kAnti}; + const auto testDir = TempDirectoryPath::create(); + const auto tableDir = fmt::format("{}/{}", testDir->getPath(), "table"); + std::vector probeKeys{"t0", "t1"}; + std::vector buildKeys{"u0", "u1"}; + const auto probeInput = makeVectors(randInt(2, 5), 100, probeType_); + const auto buildInput = makeVectors(randInt(2, 5), 100, buildType_); + + for (const auto joinType : kJoinTypes) { + auto outputColumns = + (core::isLeftSemiProjectJoin(joinType) || + core::isLeftSemiFilterJoin(joinType) || core::isAntiJoin(joinType)) + ? asRowType(probeInput[0]->type())->names() + : concat( + asRowType(probeInput[0]->type()), + asRowType(buildInput[0]->type())) + ->names(); + + if (core::isLeftSemiProjectJoin(joinType) || + core::isRightSemiProjectJoin(joinType)) { + outputColumns.emplace_back("match"); + } + + const auto planWithSplits = createPlan( + tableDir, + joinType, + probeKeys, + buildKeys, + probeInput, + buildInput, + outputColumns); + AssertQueryBuilder builder(planWithSplits.plan); + for (const auto& [planNodeId, nodeSplits] : planWithSplits.splits) { + builder.splits(planNodeId, nodeSplits); + } + const auto result = builder.copyResults(pool()); + + const auto traceRoot = + fmt::format("{}/{}/traceRoot/", testDir->getPath(), joinType); + std::shared_ptr task; + auto tracePlanWithSplits = createPlan( + tableDir, + joinType, + probeKeys, + buildKeys, + probeInput, + buildInput, + outputColumns); + AssertQueryBuilder traceBuilder(tracePlanWithSplits.plan); + traceBuilder.config(core::QueryConfig::kQueryTraceEnabled, true) + .config(core::QueryConfig::kQueryTraceDir, traceRoot) + .config(core::QueryConfig::kQueryTraceMaxBytes, 100UL << 30) + .config(core::QueryConfig::kQueryTraceTaskRegExp, ".*") + .config(core::QueryConfig::kQueryTraceNodeIds, traceNodeId_); + for (const auto& [planNodeId, nodeSplits] : tracePlanWithSplits.splits) { + traceBuilder.splits(planNodeId, nodeSplits); + } + auto traceResult = traceBuilder.copyResults(pool(), task); + + assertEqualResults({result}, {traceResult}); + + const auto taskId = task->taskId(); + const auto replayingResult = + HashJoinReplayer(traceRoot, task->taskId(), traceNodeId_, 0, "HashJoin") + .run(); + assertEqualResults({result}, {replayingResult}); + } +} +} // namespace facebook::velox::tool::trace::test