From dce89f3bb0071407dfa821d135ad86b3ae531a02 Mon Sep 17 00:00:00 2001 From: Daniel Hunte Date: Fri, 10 Jan 2025 09:45:42 -0800 Subject: [PATCH] feat(fuzzer): Support multiple joins in the join node "toSql" methods for reference query runners (#11801) Summary: Currently, the hash join and nested loop join "toSql" methods for all reference query runners only support a single join. This change extends it to support multiple joins, only needing the join node of the last join in the tree. It traverses up the tree and recursively builds the sql query. Differential Revision: D66977480 --- velox/core/PlanNode.h | 2 + velox/exec/fuzzer/CMakeLists.txt | 9 +- velox/exec/fuzzer/DuckQueryRunner.cpp | 160 ++------------ velox/exec/fuzzer/DuckQueryRunner.h | 23 +- velox/exec/fuzzer/JoinFuzzer.cpp | 4 +- velox/exec/fuzzer/PrestoQueryRunner.cpp | 229 +++---------------- velox/exec/fuzzer/PrestoQueryRunner.h | 28 +-- velox/exec/fuzzer/ReferenceQueryRunner.cpp | 242 +++++++++++++++++++++ velox/exec/fuzzer/ReferenceQueryRunner.h | 49 ++++- velox/exec/tests/PrestoQueryRunnerTest.cpp | 118 ++++++++++ 10 files changed, 475 insertions(+), 389 deletions(-) create mode 100644 velox/exec/fuzzer/ReferenceQueryRunner.cpp diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index aac5583e3e7e1..f9970f3f22675 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -324,6 +324,8 @@ class ValuesNode : public PlanNode { const size_t repeatTimes_; }; +using ValuesNodePtr = std::shared_ptr; + class ArrowStreamNode : public PlanNode { public: ArrowStreamNode( diff --git a/velox/exec/fuzzer/CMakeLists.txt b/velox/exec/fuzzer/CMakeLists.txt index 856373b54fb48..260dc5353cd0f 100644 --- a/velox/exec/fuzzer/CMakeLists.txt +++ b/velox/exec/fuzzer/CMakeLists.txt @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_library(velox_fuzzer_util DuckQueryRunner.cpp PrestoQueryRunner.cpp - FuzzerUtil.cpp ToSQLUtil.cpp) +add_library( + velox_fuzzer_util + ReferenceQueryRunner.cpp + DuckQueryRunner.cpp + PrestoQueryRunner.cpp + FuzzerUtil.cpp + ToSQLUtil.cpp) target_link_libraries( velox_fuzzer_util diff --git a/velox/exec/fuzzer/DuckQueryRunner.cpp b/velox/exec/fuzzer/DuckQueryRunner.cpp index d6d606f6497e1..7758263e5223a 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.cpp +++ b/velox/exec/fuzzer/DuckQueryRunner.cpp @@ -13,6 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +#include +#include +#include + #include "velox/exec/fuzzer/DuckQueryRunner.h" #include "velox/exec/fuzzer/ToSQLUtil.h" #include "velox/exec/tests/utils/QueryAssertions.h" @@ -104,21 +109,22 @@ DuckQueryRunner::aggregationFunctionDataSpecs() const { std::multiset> DuckQueryRunner::execute( const std::string& sql, - const std::vector& input, - const RowTypePtr& resultType) { + const core::PlanNodePtr& plan) { DuckDbQueryRunner queryRunner; - queryRunner.createTable("tmp", input); - return queryRunner.execute(sql, resultType); + std::unordered_map> inputMap = + getAllTables(plan); + for (const auto& [tableName, input] : inputMap) { + queryRunner.createTable(tableName, input); + } + return queryRunner.execute(sql, plan->outputType()); } std::multiset> DuckQueryRunner::execute( const std::string& sql, - const std::vector& probeInput, - const std::vector& buildInput, + const std::vector& input, const RowTypePtr& resultType) { DuckDbQueryRunner queryRunner; - queryRunner.createTable("t", probeInput); - queryRunner.createTable("u", buildInput); + queryRunner.createTable("tmp", input); return queryRunner.execute(sql, resultType); } @@ -164,6 +170,11 @@ std::optional DuckQueryRunner::toSql( return toSql(joinNode); } + if (const auto valuesNode = + std::dynamic_pointer_cast(plan)) { + return toSql(valuesNode); + } + VELOX_NYI(); } @@ -340,137 +351,4 @@ std::optional DuckQueryRunner::toSql( return sql.str(); } - -std::optional DuckQueryRunner::toSql( - const std::shared_ptr& joinNode) { - const auto& joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); - } - return out.str(); - }; - - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; - - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); - } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); - } - return out.str(); - }; - - const auto& outputNames = joinNode->outputType()->names(); - - std::stringstream sql; - if (joinNode->isLeftSemiProjectJoin()) { - sql << "SELECT " - << folly::join(", ", outputNames.begin(), --outputNames.end()); - } else { - sql << "SELECT " << folly::join(", ", outputNames); - } - - switch (joinNode->joinType()) { - case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kLeftSemiFilter: - // Multiple columns returned by a scalar subquery is not supported in - // DuckDB. A scalar subquery expression is a subquery that returns one - // result row from exactly one column for every input row. - if (joinNode->leftKeys().size() > 1) { - return std::nullopt; - } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ")"; - break; - case core::JoinType::kLeftSemiProject: - if (joinNode->isNullAware()) { - sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ") FROM t"; - } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; - } - break; - case core::JoinType::kAnti: - if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ")"; - } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ")"; - } - break; - default: - VELOX_UNREACHABLE( - "Unknown join type: {}", static_cast(joinNode->joinType())); - } - - return sql.str(); -} - -std::optional DuckQueryRunner::toSql( - const std::shared_ptr& joinNode) { - std::stringstream sql; - sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); - - // Nested loop join without filter. - VELOX_CHECK( - joinNode->joinCondition() == nullptr, - "This code path should be called only for nested loop join without filter"); - const std::string joinCondition{"(1 = 1)"}; - switch (joinNode->joinType()) { - case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; - break; - case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; - break; - case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; - break; - default: - VELOX_UNREACHABLE( - "Unknown join type: {}", static_cast(joinNode->joinType())); - } - - return sql.str(); -} } // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/DuckQueryRunner.h b/velox/exec/fuzzer/DuckQueryRunner.h index 4fa826af04884..3336ffa1668ca 100644 --- a/velox/exec/fuzzer/DuckQueryRunner.h +++ b/velox/exec/fuzzer/DuckQueryRunner.h @@ -15,6 +15,10 @@ */ #pragma once +#include +#include +#include + #include "velox/exec/fuzzer/ReferenceQueryRunner.h" namespace facebook::velox::exec::test { @@ -46,20 +50,21 @@ class DuckQueryRunner : public ReferenceQueryRunner { /// Assumes that source of AggregationNode or Window Node is 'tmp' table. std::optional toSql(const core::PlanNodePtr& plan) override; - /// Creates 'tmp' table with 'input' data and runs 'sql' query. Returns - /// results according to 'resultType' schema. + /// Executes SQL query returned by the 'toSql' method based on the plan. std::multiset> execute( const std::string& sql, - const std::vector& input, - const RowTypePtr& resultType) override; + const core::PlanNodePtr& plan) override; + /// Creates 'tmp' table with 'input' data and runs 'sql' query. Returns + /// results according to 'resultType' schema. std::multiset> execute( const std::string& sql, - const std::vector& probeInput, - const std::vector& buildInput, + const std::vector& input, const RowTypePtr& resultType) override; private: + using ReferenceQueryRunner::toSql; + std::optional toSql( const std::shared_ptr& aggregationNode); @@ -72,12 +77,6 @@ class DuckQueryRunner : public ReferenceQueryRunner { std::optional toSql( const std::shared_ptr& rowNumberNode); - std::optional toSql( - const std::shared_ptr& joinNode); - - std::optional toSql( - const std::shared_ptr& joinNode); - std::unordered_set aggregateFunctionNames_; }; diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 1860eca9df0bb..21e4a7de81fcc 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -680,10 +680,8 @@ std::optional JoinFuzzer::computeReferenceResults( } if (auto sql = referenceQueryRunner_->toSql(plan)) { - return referenceQueryRunner_->execute( - sql.value(), probeInput, buildInput, plan->outputType()); + return referenceQueryRunner_->execute(*sql, plan); } - LOG(INFO) << "Query not supported by the reference DB"; return std::nullopt; } diff --git a/velox/exec/fuzzer/PrestoQueryRunner.cpp b/velox/exec/fuzzer/PrestoQueryRunner.cpp index c8bba9cdb64df..913b0215b4a6f 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.cpp +++ b/velox/exec/fuzzer/PrestoQueryRunner.cpp @@ -14,13 +14,13 @@ * limitations under the License. */ -#include "velox/exec/fuzzer/PrestoQueryRunner.h" #include // @manual #include #include +#include + #include "velox/common/base/Fs.h" #include "velox/common/encode/Base64.h" -#include "velox/common/file/FileSystems.h" #include "velox/connectors/hive/HiveDataSink.h" #include "velox/connectors/hive/TableHandle.h" #include "velox/core/Expressions.h" @@ -28,6 +28,7 @@ #include "velox/dwio/common/WriterFactory.h" #include "velox/dwio/dwrf/writer/Writer.h" #include "velox/exec/fuzzer/FuzzerUtil.h" +#include "velox/exec/fuzzer/PrestoQueryRunner.h" #include "velox/exec/fuzzer/ToSQLUtil.h" #include "velox/exec/tests/utils/QueryAssertions.h" #include "velox/functions/prestosql/types/IPAddressType.h" @@ -36,8 +37,6 @@ #include "velox/serializers/PrestoSerializer.h" #include "velox/type/parser/TypeParser.h" -#include - using namespace facebook::velox; namespace facebook::velox::exec::test { @@ -221,20 +220,6 @@ std::string toWindowCallSql( return sql.str(); } -bool isSupportedDwrfType(const TypePtr& type) { - if (type->isDate() || type->isIntervalDayTime() || type->isUnKnown()) { - return false; - } - - for (auto i = 0; i < type->size(); ++i) { - if (!isSupportedDwrfType(type->childAt(i))) { - return false; - } - } - - return true; -} - } // namespace const std::vector& PrestoQueryRunner::supportedScalarTypes() const { @@ -554,152 +539,10 @@ std::optional PrestoQueryRunner::toSql( return sql.str(); } -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& joinNode) { - if (!isSupportedDwrfType(joinNode->sources()[0]->outputType())) { - return std::nullopt; - } - - if (!isSupportedDwrfType(joinNode->sources()[1]->outputType())) { - return std::nullopt; - } - - const auto joinKeysToSql = [](auto keys) { - std::stringstream out; - for (auto i = 0; i < keys.size(); ++i) { - if (i > 0) { - out << ", "; - } - out << keys[i]->name(); - } - return out.str(); - }; - - const auto filterToSql = [](core::TypedExprPtr filter) { - auto call = std::dynamic_pointer_cast(filter); - return toCallSql(call); - }; - - const auto& joinConditionAsSql = [&](auto joinNode) { - std::stringstream out; - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { - if (i > 0) { - out << " AND "; - } - out << joinNode->leftKeys()[i]->name() << " = " - << joinNode->rightKeys()[i]->name(); - } - if (joinNode->filter()) { - out << " AND " << filterToSql(joinNode->filter()); - } - return out.str(); - }; - - const auto& outputNames = joinNode->outputType()->names(); - - std::stringstream sql; - if (joinNode->isLeftSemiProjectJoin()) { - sql << "SELECT " - << folly::join(", ", outputNames.begin(), --outputNames.end()); - } else { - sql << "SELECT " << folly::join(", ", outputNames); - } - - switch (joinNode->joinType()) { - case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); - break; - case core::JoinType::kLeftSemiFilter: - // Multiple columns returned by a scalar subquery is not supported in - // Presto. A scalar subquery expression is a subquery that returns one - // result row from exactly one column for every input row. - if (joinNode->leftKeys().size() > 1) { - return std::nullopt; - } - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ")"; - break; - case core::JoinType::kLeftSemiProject: - if (joinNode->isNullAware()) { - sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ") FROM t"; - } else { - sql << ", EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ") FROM t"; - } - break; - case core::JoinType::kAnti: - if (joinNode->isNullAware()) { - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) - << " FROM u"; - if (joinNode->filter()) { - sql << " WHERE " << filterToSql(joinNode->filter()); - } - sql << ")"; - } else { - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " - << joinConditionAsSql(joinNode); - sql << ")"; - } - break; - default: - VELOX_UNREACHABLE( - "Unknown join type: {}", static_cast(joinNode->joinType())); - } - return sql.str(); -} - -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& joinNode) { - std::stringstream sql; - sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); - - // Nested loop join without filter. - VELOX_CHECK( - joinNode->joinCondition() == nullptr, - "This code path should be called only for nested loop join without filter"); - const std::string joinCondition{"(1 = 1)"}; - switch (joinNode->joinType()) { - case core::JoinType::kInner: - sql << " FROM t INNER JOIN u ON " << joinCondition; - break; - case core::JoinType::kLeft: - sql << " FROM t LEFT JOIN u ON " << joinCondition; - break; - case core::JoinType::kFull: - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; - break; - default: - VELOX_UNREACHABLE( - "Unknown join type: {}", static_cast(joinNode->joinType())); - } - - return sql.str(); -} - -std::optional PrestoQueryRunner::toSql( - const std::shared_ptr& valuesNode) { - if (!isSupportedDwrfType(valuesNode->outputType())) { - return std::nullopt; - } - return "tmp"; +std::multiset> PrestoQueryRunner::execute( + const std::string& sql, + const core::PlanNodePtr& plan) { + return exec::test::materialize(executeAndReturnVector(sql, plan)); } std::multiset> PrestoQueryRunner::execute( @@ -709,15 +552,6 @@ std::multiset> PrestoQueryRunner::execute( return exec::test::materialize(executeVector(sql, input, resultType)); } -std::multiset> PrestoQueryRunner::execute( - const std::string& sql, - const std::vector& probeInput, - const std::vector& buildInput, - const RowTypePtr& resultType) { - return exec::test::materialize( - executeVector(sql, probeInput, buildInput, resultType)); -} - std::string PrestoQueryRunner::createTable( const std::string& name, const TypePtr& type) { @@ -749,40 +583,31 @@ std::string PrestoQueryRunner::createTable( return tableDirectoryPath; } -std::vector PrestoQueryRunner::executeVector( +std::vector PrestoQueryRunner::executeAndReturnVector( const std::string& sql, - const std::vector& probeInput, - const std::vector& buildInput, - const velox::RowTypePtr& resultType) { - auto probeType = asRowType(probeInput[0]->type()); - if (probeType->size() == 0) { - auto rowVector = makeNullRows(probeInput, "x", pool()); - return executeVector(sql, {rowVector}, buildInput, resultType); - } - - auto buildType = asRowType(buildInput[0]->type()); - if (probeType->size() == 0) { - auto rowVector = makeNullRows(buildInput, "y", pool()); - return executeVector(sql, probeInput, {rowVector}, resultType); + const core::PlanNodePtr& plan) { + std::unordered_map> inputMap = + getAllTables(plan); + for (const auto& [tableName, input] : inputMap) { + auto inputType = asRowType(input[0]->type()); + if (inputType->size() == 0) { + inputMap[tableName] = { + makeNullRows(input, fmt::format("{}x", tableName), pool())}; + } } - auto probeTableDirectoryPath = createTable("t", probeInput[0]->type()); - auto buildTableDirectoryPath = createTable("u", buildInput[0]->type()); - - // Create a new file in table's directory with fuzzer-generated data. - auto probeFilePath = fs::path(probeTableDirectoryPath) - .append("probe.dwrf") - .string() - .substr(strlen("file:")); + auto writerPool = aggregatePool()->addAggregateChild("writer"); + for (const auto& [tableName, input] : inputMap) { + auto tableDirectoryPath = createTable(tableName, input[0]->type()); - auto buildFilePath = fs::path(buildTableDirectoryPath) - .append("build.dwrf") - .string() - .substr(strlen("file:")); + // Create a new file in table's directory with fuzzer-generated data. + auto filePath = fs::path(tableDirectoryPath) + .append(fmt::format("{}.dwrf", tableName)) + .string() + .substr(strlen("file:")); - auto writerPool = aggregatePool()->addAggregateChild("writer"); - writeToFile(probeFilePath, probeInput, writerPool.get()); - writeToFile(buildFilePath, buildInput, writerPool.get()); + writeToFile(filePath, input, writerPool.get()); + } // Run the query. return execute(sql); diff --git a/velox/exec/fuzzer/PrestoQueryRunner.h b/velox/exec/fuzzer/PrestoQueryRunner.h index a72cae913e101..31d1604cacfcc 100644 --- a/velox/exec/fuzzer/PrestoQueryRunner.h +++ b/velox/exec/fuzzer/PrestoQueryRunner.h @@ -18,7 +18,6 @@ #include #include -#include "velox/common/memory/Memory.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/vector/ComplexVector.h" @@ -83,11 +82,11 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::vector& input, const velox::RowTypePtr& resultType) override; + /// Executes SQL query returned by the 'toSql' method based on the plan. + /// Returns std::nullopt if the plan is not supported. std::multiset> execute( const std::string& sql, - const std::vector& probeInput, - const std::vector& buildInput, - const RowTypePtr& resultType) override; + const core::PlanNodePtr& plan) override; /// Executes Presto SQL query and returns the results. Tables referenced by /// the query must already exist. @@ -105,17 +104,13 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { const std::vector& input, const RowTypePtr& resultType) override; - std::vector executeVector( - const std::string& sql, - const std::vector& probeInput, - const std::vector& buildInput, - const RowTypePtr& resultType) override; - std::shared_ptr queryRunnerContext() { return queryRunnerContext_; } private: + using ReferenceQueryRunner::toSql; + memory::MemoryPool* pool() { return pool_.get(); } @@ -136,14 +131,11 @@ class PrestoQueryRunner : public velox::exec::test::ReferenceQueryRunner { std::optional toSql( const std::shared_ptr& tableWriteNode); - std::optional toSql( - const std::shared_ptr& joinNode); - - std::optional toSql( - const std::shared_ptr& joinNode); - - std::optional toSql( - const std::shared_ptr& valuesNode); + /// Executes SQL query returned by the 'toSql' method based on the plan. + /// Returns std::nullopt if the plan is not supported. + std::vector executeAndReturnVector( + const std::string& sql, + const core::PlanNodePtr& plan); std::string startQuery( const std::string& sql, diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.cpp b/velox/exec/fuzzer/ReferenceQueryRunner.cpp new file mode 100644 index 0000000000000..6dcf7540e5ef6 --- /dev/null +++ b/velox/exec/fuzzer/ReferenceQueryRunner.cpp @@ -0,0 +1,242 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "velox/core/PlanNode.h" +#include "velox/exec/fuzzer/ReferenceQueryRunner.h" +#include "velox/exec/fuzzer/ToSQLUtil.h" + +namespace facebook::velox::exec::test { + +namespace { + +std::string joinKeysToSql( + const std::vector& keys) { + std::vector keyNames; + keyNames.reserve(keys.size()); + for (const core::FieldAccessTypedExprPtr& key : keys) { + keyNames.push_back(key->name()); + } + return folly::join(", ", keyNames); +} + +std::string filterToSql(const core::TypedExprPtr& filter) { + auto call = std::dynamic_pointer_cast(filter); + return toCallSql(call); +} + +std::string joinConditionAsSql(const core::AbstractJoinNode& joinNode) { + std::stringstream out; + for (auto i = 0; i < joinNode.leftKeys().size(); ++i) { + if (i > 0) { + out << " AND "; + } + out << joinNode.leftKeys()[i]->name() << " = " + << joinNode.rightKeys()[i]->name(); + } + if (joinNode.filter()) { + if (!joinNode.leftKeys().empty()) { + out << " AND "; + } + out << filterToSql(joinNode.filter()); + } + return out.str(); +} + +} // namespace + +bool ReferenceQueryRunner::isSupportedDwrfType(const TypePtr& type) { + if (type->isDate() || type->isIntervalDayTime() || type->isUnKnown()) { + return false; + } + + for (auto i = 0; i < type->size(); ++i) { + if (!isSupportedDwrfType(type->childAt(i))) { + return false; + } + } + + return true; +} + +std::unordered_map> +ReferenceQueryRunner::getAllTables(const core::PlanNodePtr& plan) { + std::unordered_map> result; + if (const auto valuesNode = + std::dynamic_pointer_cast(plan)) { + result.insert({getTableName(valuesNode), valuesNode->values()}); + } else { + for (const auto& source : plan->sources()) { + auto tablesAndNames = getAllTables(source); + result.insert(tablesAndNames.begin(), tablesAndNames.end()); + } + } + return result; +} + +std::optional ReferenceQueryRunner::joinSourceToSql( + const core::PlanNodePtr& planNode) { + const std::optional subQuery = toSql(planNode); + if (subQuery) { + return subQuery->find(" ") != std::string::npos + ? fmt::format("({})", *subQuery) + : *subQuery; + } + return std::nullopt; +} + +std::optional ReferenceQueryRunner::toSql( + const core::ValuesNodePtr& valuesNode) { + if (!isSupportedDwrfType(valuesNode->outputType())) { + return std::nullopt; + } + return getTableName(valuesNode); +} + +std::optional ReferenceQueryRunner::toSql( + const std::shared_ptr& joinNode) { + if (!isSupportedDwrfType(joinNode->sources()[0]->outputType()) || + !isSupportedDwrfType(joinNode->sources()[1]->outputType())) { + return std::nullopt; + } + + std::optional probeTableName = + joinSourceToSql(joinNode->sources()[0]); + std::optional buildTableName = + joinSourceToSql(joinNode->sources()[1]); + if (!probeTableName || !buildTableName) { + return std::nullopt; + } + + const auto& outputNames = joinNode->outputType()->names(); + + std::stringstream sql; + if (joinNode->isLeftSemiProjectJoin()) { + sql << "SELECT " + << folly::join(", ", outputNames.begin(), --outputNames.end()); + } else { + sql << "SELECT " << folly::join(", ", outputNames); + } + + switch (joinNode->joinType()) { + case core::JoinType::kInner: + sql << " FROM " << *probeTableName << " INNER JOIN " << *buildTableName + << " ON " << joinConditionAsSql(*joinNode); + break; + case core::JoinType::kLeft: + sql << " FROM " << *probeTableName << " LEFT JOIN " << *buildTableName + << " ON " << joinConditionAsSql(*joinNode); + break; + case core::JoinType::kFull: + sql << " FROM " << *probeTableName << " FULL OUTER JOIN " + << *buildTableName << " ON " << joinConditionAsSql(*joinNode); + break; + case core::JoinType::kLeftSemiFilter: + // Multiple columns returned by a scalar subquery is not supported. A + // scalar subquery expression is a subquery that returns one result row + // from exactly one column for every input row. + if (joinNode->leftKeys().size() > 1) { + return std::nullopt; + } + sql << " FROM " << *probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << *buildTableName; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ")"; + break; + case core::JoinType::kLeftSemiProject: + if (joinNode->isNullAware()) { + sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << *buildTableName; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ") FROM " << *probeTableName; + } else { + sql << ", EXISTS (SELECT * FROM " << *buildTableName << " WHERE " + << joinConditionAsSql(*joinNode); + sql << ") FROM " << *probeTableName; + } + break; + case core::JoinType::kAnti: + if (joinNode->isNullAware()) { + sql << " FROM " << *probeTableName << " WHERE " + << joinKeysToSql(joinNode->leftKeys()) << " NOT IN (SELECT " + << joinKeysToSql(joinNode->rightKeys()) << " FROM " + << *buildTableName; + if (joinNode->filter()) { + sql << " WHERE " << filterToSql(joinNode->filter()); + } + sql << ")"; + } else { + sql << " FROM " << *probeTableName + << " WHERE NOT EXISTS (SELECT * FROM " << *buildTableName + << " WHERE " << joinConditionAsSql(*joinNode); + sql << ")"; + } + break; + default: + VELOX_UNREACHABLE( + "Unknown join type: {}", static_cast(joinNode->joinType())); + } + return sql.str(); +} + +std::optional ReferenceQueryRunner::toSql( + const std::shared_ptr& joinNode) { + std::optional probeTableName = + joinSourceToSql(joinNode->sources()[0]); + std::optional buildTableName = + joinSourceToSql(joinNode->sources()[1]); + if (!probeTableName || !buildTableName) { + return std::nullopt; + } + + std::stringstream sql; + sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); + + // Nested loop join without filter. + VELOX_CHECK_NULL( + joinNode->joinCondition(), + "This code path should be called only for nested loop join without filter"); + const std::string joinCondition{"(1 = 1)"}; + switch (joinNode->joinType()) { + case core::JoinType::kInner: + sql << " FROM " << *probeTableName << " INNER JOIN " << *buildTableName + << " ON " << joinCondition; + break; + case core::JoinType::kLeft: + sql << " FROM " << *probeTableName << " LEFT JOIN " << *buildTableName + << " ON " << joinCondition; + break; + case core::JoinType::kFull: + sql << " FROM " << *probeTableName << " FULL OUTER JOIN " + << *buildTableName << " ON " << joinCondition; + break; + default: + VELOX_UNREACHABLE( + "Unknown join type: {}", static_cast(joinNode->joinType())); + } + return sql.str(); +} + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/fuzzer/ReferenceQueryRunner.h b/velox/exec/fuzzer/ReferenceQueryRunner.h index 5d0c24afdc246..3a73791b1eaf3 100644 --- a/velox/exec/fuzzer/ReferenceQueryRunner.h +++ b/velox/exec/fuzzer/ReferenceQueryRunner.h @@ -15,7 +15,10 @@ */ #pragma once +#include #include +#include + #include "velox/core/PlanNode.h" #include "velox/expression/FunctionSignature.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -54,6 +57,18 @@ class ReferenceQueryRunner { /// reference database. virtual std::optional toSql(const core::PlanNodePtr& plan) = 0; + /// Same as the above toSql but for values nodes. + virtual std::optional toSql( + const core::ValuesNodePtr& valuesNode); + + /// Same as the above toSql but for hash join nodes. + virtual std::optional toSql( + const std::shared_ptr& joinNode); + + /// Same as the above toSql but for nested loop join nodes. + virtual std::optional toSql( + const std::shared_ptr& joinNode); + /// Returns whether a constant expression is supported by the reference /// database. virtual bool isConstantExprSupported(const core::TypedExprPtr& /*expr*/) { @@ -66,6 +81,13 @@ class ReferenceQueryRunner { return true; } + /// Executes SQL query returned by the 'toSql' method based on the plan. + virtual std::multiset> execute( + const std::string& sql, + const core::PlanNodePtr& plan) { + VELOX_UNSUPPORTED(); + } + /// Executes SQL query returned by the 'toSql' method using 'input' data. /// Converts results using 'resultType' schema. virtual std::multiset> execute( @@ -80,7 +102,9 @@ class ReferenceQueryRunner { const std::string& sql, const std::vector& probeInput, const std::vector& buildInput, - const RowTypePtr& resultType) = 0; + const RowTypePtr& resultType) { + VELOX_UNSUPPORTED(); + } /// Returns true if 'executeVector' can be called to get results as Velox /// Vector. @@ -97,15 +121,6 @@ class ReferenceQueryRunner { VELOX_UNSUPPORTED(); } - /// Similar to above but for join node with 'probeInput' and 'buildInput'. - virtual std::vector executeVector( - const std::string& sql, - const std::vector& probeInput, - const std::vector& buildInput, - const RowTypePtr& resultType) { - VELOX_UNSUPPORTED(); - } - virtual std::vector execute(const std::string& sql) { VELOX_UNSUPPORTED(); } @@ -121,8 +136,20 @@ class ReferenceQueryRunner { return aggregatePool_; } + bool isSupportedDwrfType(const TypePtr& type); + + /// Returns the name of the values node table in the form t_. + std::string getTableName(const core::ValuesNodePtr& valuesNode) { + return fmt::format("t_{}", valuesNode->id()); + } + + // Traverses all nodes in the plan and returns all tables and their names. + std::unordered_map> + getAllTables(const core::PlanNodePtr& plan); + private: memory::MemoryPool* aggregatePool_; -}; + std::optional joinSourceToSql(const core::PlanNodePtr& planNode); +}; } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/PrestoQueryRunnerTest.cpp b/velox/exec/tests/PrestoQueryRunnerTest.cpp index 25b231dc6c7c1..14447f5eb3967 100644 --- a/velox/exec/tests/PrestoQueryRunnerTest.cpp +++ b/velox/exec/tests/PrestoQueryRunnerTest.cpp @@ -255,4 +255,122 @@ TEST_F(PrestoQueryRunnerTest, toSql) { } } +TEST_F(PrestoQueryRunnerTest, toSqlJoins) { + auto aggregatePool = rootPool_->addAggregateChild("toSqlJoins"); + auto queryRunner = std::make_unique( + aggregatePool.get(), + "http://unused", + "hive", + static_cast(1000)); + + auto t = makeRowVector( + {"t0", "t1", "t2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto u = makeRowVector( + {"u0", "u1", "u2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto v = makeRowVector( + {"v0", "v1", "v2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + auto w = makeRowVector( + {"w0", "w1", "w2"}, + { + makeFlatVector({}), + makeFlatVector({}), + makeFlatVector({}), + }); + + // Single join. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, t1 FROM t_0 INNER JOIN t_1 ON t0 = u0"); + } + + // Two joins with a filter. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0"}, + core::JoinType::kLeftSemiFilter) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + "v1 > 0", + {"t0", "v1"}, + core::JoinType::kInner) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, v1" + " FROM (SELECT t0 FROM t_0 WHERE t0 IN (SELECT u0 FROM t_1))" + " INNER JOIN t_3 ON t0 = v0 AND (cast(v1 as BIGINT) > BIGINT '0')"); + } + + // Three joins. + { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({t}) + .hashJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({u}).planNode(), + /*filter=*/"", + {"t0", "t1"}, + core::JoinType::kLeft) + .hashJoin( + {"t0"}, + {"v0"}, + PlanBuilder(planNodeIdGenerator).values({v}).planNode(), + /*filter=*/"", + {"t0", "v1"}, + core::JoinType::kInner) + .hashJoin( + {"t0", "v1"}, + {"w0", "w1"}, + PlanBuilder(planNodeIdGenerator).values({w}).planNode(), + /*filter=*/"", + {"t0", "w1"}, + core::JoinType::kFull) + .planNode(); + EXPECT_EQ( + *queryRunner->toSql(plan), + "SELECT t0, w1" + " FROM (SELECT t0, v1 FROM (SELECT t0, t1 FROM t_0 LEFT JOIN t_1 ON t0 = u0)" + " INNER JOIN t_3 ON t0 = v0)" + " FULL OUTER JOIN t_5 ON t0 = w0 AND v1 = w1"); + } +} + } // namespace facebook::velox::exec::test