Skip to content

Commit

Permalink
fix(fuzzer): Add filter parsing to toSql methods for hasJoinNode in R…
Browse files Browse the repository at this point in the history
…eferenceQueryRunners (facebookincubator#11566)

Summary:

This change updates both the DuckQueryRunner and PrestoQueryRunner to parse filters in their hasJoinNode toSql methods.

Differential Revision: D66021799
  • Loading branch information
Daniel Hunte authored and facebook-github-bot committed Nov 18, 2024
1 parent aedf91c commit 57faacd
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
19 changes: 12 additions & 7 deletions velox/exec/fuzzer/DuckQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ std::optional<std::string> DuckQueryRunner::toSql(
return out.str();
};

const auto& equiClausesToSql = [](auto joinNode) {
const auto& joinConditionAsSql = [](auto joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode->leftKeys().size(); ++i) {
if (i > 0) {
Expand All @@ -363,6 +363,11 @@ std::optional<std::string> DuckQueryRunner::toSql(
out << joinNode->leftKeys()[i]->name() << " = "
<< joinNode->rightKeys()[i]->name();
}
if (joinNode->filter()) {
auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(
joinNode->filter());
out << " AND " << toCallSql(call);
}
return out.str();
};

Expand All @@ -378,13 +383,13 @@ std::optional<std::string> DuckQueryRunner::toSql(

switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeftSemiFilter:
if (joinNode->leftKeys().size() > 1) {
Expand All @@ -399,8 +404,8 @@ std::optional<std::string> DuckQueryRunner::toSql(
sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t";
} else {
sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode)
<< ") FROM t";
sql << ", EXISTS (SELECT * FROM u WHERE "
<< joinConditionAsSql(joinNode) << ") FROM t";
}
break;
case core::JoinType::kAnti:
Expand All @@ -410,7 +415,7 @@ std::optional<std::string> DuckQueryRunner::toSql(
<< " FROM u)";
} else {
sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE "
<< equiClausesToSql(joinNode) << ")";
<< joinConditionAsSql(joinNode) << ")";
}
break;
default:
Expand Down
19 changes: 12 additions & 7 deletions velox/exec/fuzzer/PrestoQueryRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ std::optional<std::string> PrestoQueryRunner::toSql(
return out.str();
};

const auto equiClausesToSql = [](auto joinNode) {
const auto& joinConditionAsSql = [](auto joinNode) {
std::stringstream out;
for (auto i = 0; i < joinNode->leftKeys().size(); ++i) {
if (i > 0) {
Expand All @@ -578,6 +578,11 @@ std::optional<std::string> PrestoQueryRunner::toSql(
out << joinNode->leftKeys()[i]->name() << " = "
<< joinNode->rightKeys()[i]->name();
}
if (joinNode->filter()) {
auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(
joinNode->filter());
out << " AND " << toCallSql(call);
}
return out.str();
};

Expand All @@ -593,13 +598,13 @@ std::optional<std::string> PrestoQueryRunner::toSql(

switch (joinNode->joinType()) {
case core::JoinType::kInner:
sql << " FROM t INNER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeft:
sql << " FROM t LEFT JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kFull:
sql << " FROM t FULL OUTER JOIN u ON " << equiClausesToSql(joinNode);
sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode);
break;
case core::JoinType::kLeftSemiFilter:
if (joinNode->leftKeys().size() > 1) {
Expand All @@ -614,8 +619,8 @@ std::optional<std::string> PrestoQueryRunner::toSql(
sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT "
<< joinKeysToSql(joinNode->rightKeys()) << " FROM u) FROM t";
} else {
sql << ", EXISTS (SELECT * FROM u WHERE " << equiClausesToSql(joinNode)
<< ") FROM t";
sql << ", EXISTS (SELECT * FROM u WHERE "
<< joinConditionAsSql(joinNode) << ") FROM t";
}
break;
case core::JoinType::kAnti:
Expand All @@ -625,7 +630,7 @@ std::optional<std::string> PrestoQueryRunner::toSql(
<< " FROM u)";
} else {
sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE "
<< equiClausesToSql(joinNode) << ")";
<< joinConditionAsSql(joinNode) << ")";
}
break;
default:
Expand Down

0 comments on commit 57faacd

Please sign in to comment.