diff --git a/velox/exec/fuzzer/ExprTransformer.h b/velox/exec/fuzzer/ExprTransformer.h new file mode 100644 index 000000000000..40bff211461a --- /dev/null +++ b/velox/exec/fuzzer/ExprTransformer.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/core/Expressions.h" + +namespace facebook::velox::exec::test { + +class ExprTransformer { + public: + virtual ~ExprTransformer() = default; + + /// Transforms the given expression into a new expression. This should be + /// called during the expression generation in expression fuzzer. + virtual core::TypedExprPtr transform(core::TypedExprPtr) const = 0; + + /// Returns the additional number of levels of nesting introduced by the + /// transformation. + virtual int32_t extraLevelOfNesting() const = 0; +}; + +} // namespace facebook::velox::exec::test diff --git a/velox/expression/fuzzer/ExpressionFuzzer.cpp b/velox/expression/fuzzer/ExpressionFuzzer.cpp index f310707384b9..45ed0b927ded 100644 --- a/velox/expression/fuzzer/ExpressionFuzzer.cpp +++ b/velox/expression/fuzzer/ExpressionFuzzer.cpp @@ -1004,23 +1004,39 @@ core::TypedExprPtr ExpressionFuzzer::generateExpression( chosenFunctionName = templateList[chosenExprIndex]; } - if (chosenFunctionName == "cast") { - expression = generateCastExpression(returnType); - } else if (chosenFunctionName == "row_constructor") { - // Avoid generating deeply nested types that is rarely used in practice. - if (levelOfNesting(returnType) < 3) { - expression = generateRowConstructorExpression(returnType); - } - } else if (chosenFunctionName == "dereference") { - expression = generateDereferenceExpression(returnType); - } else { - expression = generateExpressionFromConcreteSignatures( - returnType, chosenFunctionName); - if (!expression && - (options_.enableComplexTypes || options_.enableDecimalType)) { - expression = generateExpressionFromSignatureTemplate( + auto exprTransformer = options_.exprTransformers.find(chosenFunctionName); + if (exprTransformer != options_.exprTransformers.end()) { + state.remainingLevelOfNesting_ -= + exprTransformer->second->extraLevelOfNesting(); + } + + if (state.remainingLevelOfNesting_ >= 0) { + if (chosenFunctionName == "cast") { + expression = generateCastExpression(returnType); + } else if (chosenFunctionName == "row_constructor") { + // Avoid generating deeply nested types that is rarely used in practice. + if (levelOfNesting(returnType) < 3) { + expression = generateRowConstructorExpression(returnType); + } + } else if (chosenFunctionName == "dereference") { + expression = generateDereferenceExpression(returnType); + } else { + expression = generateExpressionFromConcreteSignatures( returnType, chosenFunctionName); + if (!expression && + (options_.enableComplexTypes || options_.enableDecimalType)) { + expression = generateExpressionFromSignatureTemplate( + returnType, chosenFunctionName); + } + } + } + + if (exprTransformer != options_.exprTransformers.end()) { + if (expression) { + expression = exprTransformer->second->transform(std::move(expression)); } + state.remainingLevelOfNesting_ += + exprTransformer->second->extraLevelOfNesting(); } } if (!expression) { diff --git a/velox/expression/fuzzer/ExpressionFuzzer.h b/velox/expression/fuzzer/ExpressionFuzzer.h index a2012de3e250..9f89abe48bb0 100644 --- a/velox/expression/fuzzer/ExpressionFuzzer.h +++ b/velox/expression/fuzzer/ExpressionFuzzer.h @@ -18,6 +18,7 @@ #include "velox/core/ITypedExpr.h" #include "velox/core/QueryCtx.h" +#include "velox/exec/fuzzer/ExprTransformer.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/expression/Expr.h" #include "velox/expression/fuzzer/ArgGenerator.h" @@ -30,6 +31,7 @@ namespace facebook::velox::fuzzer { using exec::test::ReferenceQueryRunner; +using facebook::velox::exec::test::ExprTransformer; // A tool that can be used to generate random expressions. class ExpressionFuzzer { @@ -101,6 +103,9 @@ class ExpressionFuzzer { // "array_sort(array(T),constant function(T,T,bigint)) -> array(T)"} std::unordered_set skipFunctions; + std::unordered_map> + exprTransformers; + // When set, when the input size of the generated expressions reaches // maxInputsThreshold, fuzzing input columns will reuse one of the existing // columns if any is already generated with the same type. diff --git a/velox/expression/fuzzer/ExpressionFuzzerTest.cpp b/velox/expression/fuzzer/ExpressionFuzzerTest.cpp index 3d218bf30ea3..fd0ec3236f30 100644 --- a/velox/expression/fuzzer/ExpressionFuzzerTest.cpp +++ b/velox/expression/fuzzer/ExpressionFuzzerTest.cpp @@ -25,6 +25,7 @@ #include "velox/functions/prestosql/fuzzer/ModulusArgGenerator.h" #include "velox/functions/prestosql/fuzzer/MultiplyArgGenerator.h" #include "velox/functions/prestosql/fuzzer/PlusMinusArgGenerator.h" +#include "velox/functions/prestosql/fuzzer/SortArrayTransformer.h" #include "velox/functions/prestosql/fuzzer/TruncateArgGenerator.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" @@ -113,6 +114,13 @@ int main(int argc, char** argv) { {"mod", std::make_shared()}, {"truncate", std::make_shared()}}; + std::unordered_map> + exprTransformers = { + {"array_intersect", std::make_shared()}, + {"array_except", std::make_shared()}, + {"map_keys", std::make_shared()}, + {"map_values", std::make_shared()}}; + std::shared_ptr rootPool{ facebook::velox::memory::memoryManager()->addRootPool()}; std::shared_ptr referenceQueryRunner{nullptr}; @@ -127,6 +135,7 @@ int main(int argc, char** argv) { FuzzerRunner::runFromGtest( initialSeed, skipFunctions, + exprTransformers, {{"session_timezone", "America/Los_Angeles"}, {"adjust_timestamp_to_session_timezone", "true"}}, argGenerators, diff --git a/velox/expression/fuzzer/FuzzerRunner.cpp b/velox/expression/fuzzer/FuzzerRunner.cpp index 52d93c762af2..22ca6c957a74 100644 --- a/velox/expression/fuzzer/FuzzerRunner.cpp +++ b/velox/expression/fuzzer/FuzzerRunner.cpp @@ -169,6 +169,8 @@ VectorFuzzer::Options getVectorFuzzerOptions() { ExpressionFuzzer::Options getExpressionFuzzerOptions( const std::unordered_set& skipFunctions, + const std::unordered_map>& + exprTransformers, std::shared_ptr referenceQueryRunner) { ExpressionFuzzer::Options opts; opts.maxLevelOfNesting = FLAGS_velox_fuzzer_max_level_of_nesting; @@ -185,11 +187,14 @@ ExpressionFuzzer::Options getExpressionFuzzerOptions( opts.useOnlyFunctions = FLAGS_only; opts.skipFunctions = skipFunctions; opts.referenceQueryRunner = referenceQueryRunner; + opts.exprTransformers = exprTransformers; return opts; } ExpressionFuzzerVerifier::Options getExpressionFuzzerVerifierOptions( const std::unordered_set& skipFunctions, + const std::unordered_map>& + exprTransformers, const std::unordered_map& queryConfigs, std::shared_ptr referenceQueryRunner) { ExpressionFuzzerVerifier::Options opts; @@ -204,8 +209,8 @@ ExpressionFuzzerVerifier::Options getExpressionFuzzerVerifierOptions( opts.lazyVectorGenerationRatio = FLAGS_lazy_vector_generation_ratio; opts.maxExpressionTreesPerStep = FLAGS_max_expression_trees_per_step; opts.vectorFuzzerOptions = getVectorFuzzerOptions(); - opts.expressionFuzzerOptions = - getExpressionFuzzerOptions(skipFunctions, referenceQueryRunner); + opts.expressionFuzzerOptions = getExpressionFuzzerOptions( + skipFunctions, exprTransformers, referenceQueryRunner); opts.queryConfigs = queryConfigs; return opts; } @@ -216,12 +221,19 @@ ExpressionFuzzerVerifier::Options getExpressionFuzzerVerifierOptions( int FuzzerRunner::run( size_t seed, const std::unordered_set& skipFunctions, + const std::unordered_map>& + exprTransformers, const std::unordered_map& queryConfigs, const std::unordered_map>& argGenerators, std::shared_ptr referenceQueryRunner) { runFromGtest( - seed, skipFunctions, queryConfigs, argGenerators, referenceQueryRunner); + seed, + skipFunctions, + exprTransformers, + queryConfigs, + argGenerators, + referenceQueryRunner); return RUN_ALL_TESTS(); } @@ -229,6 +241,8 @@ int FuzzerRunner::run( void FuzzerRunner::runFromGtest( size_t seed, const std::unordered_set& skipFunctions, + const std::unordered_map>& + exprTransformers, const std::unordered_map& queryConfigs, const std::unordered_map>& argGenerators, @@ -241,7 +255,7 @@ void FuzzerRunner::runFromGtest( signatures, seed, getExpressionFuzzerVerifierOptions( - skipFunctions, queryConfigs, referenceQueryRunner), + skipFunctions, exprTransformers, queryConfigs, referenceQueryRunner), argGenerators) .go(); } diff --git a/velox/expression/fuzzer/FuzzerRunner.h b/velox/expression/fuzzer/FuzzerRunner.h index 1a722bce9f36..c05477f270b8 100644 --- a/velox/expression/fuzzer/FuzzerRunner.h +++ b/velox/expression/fuzzer/FuzzerRunner.h @@ -22,18 +22,23 @@ #include #include +#include "velox/exec/fuzzer/ExprTransformer.h" #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/expression/fuzzer/ExpressionFuzzerVerifier.h" #include "velox/functions/FunctionRegistry.h" namespace facebook::velox::fuzzer { +using facebook::velox::exec::test::ExprTransformer; + /// FuzzerRunner leverages ExpressionFuzzerVerifier to create a gtest unit test. class FuzzerRunner { public: static int run( size_t seed, const std::unordered_set& skipFunctions, + const std::unordered_map>& + exprTransformers, const std::unordered_map& queryConfigs, const std::unordered_map>& argGenerators, @@ -42,6 +47,8 @@ class FuzzerRunner { static void runFromGtest( size_t seed, const std::unordered_set& skipFunctions, + const std::unordered_map>& + exprTransformers, const std::unordered_map& queryConfigs, const std::unordered_map>& argGenerators, diff --git a/velox/expression/fuzzer/SparkExpressionFuzzerTest.cpp b/velox/expression/fuzzer/SparkExpressionFuzzerTest.cpp index ff413337144d..bfe2d809f6ec 100644 --- a/velox/expression/fuzzer/SparkExpressionFuzzerTest.cpp +++ b/velox/expression/fuzzer/SparkExpressionFuzzerTest.cpp @@ -91,6 +91,7 @@ int main(int argc, char** argv) { return FuzzerRunner::run( FLAGS_seed, skipFunctions, + {{}}, queryConfigs, argGenerators, referenceQueryRunner); diff --git a/velox/functions/prestosql/fuzzer/SortArrayTransformer.h b/velox/functions/prestosql/fuzzer/SortArrayTransformer.h new file mode 100644 index 000000000000..905ea3419dcd --- /dev/null +++ b/velox/functions/prestosql/fuzzer/SortArrayTransformer.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/core/Expressions.h" +#include "velox/exec/fuzzer/ExprTransformer.h" +#include "velox/type/Type.h" + +namespace facebook::velox::exec::test { + +using facebook::velox::TypePtr; +using facebook::velox::core::TypedExprPtr; +using facebook::velox::exec::test::ExprTransformer; + +class SortArrayTransformer : public ExprTransformer { + public: + ~SortArrayTransformer() override = default; + + /// Wraps 'expr' in a call to array_sort. If the type of 'expr' contains a + /// map, array_sort doesn't support this type, so we return a constant null + /// instead. + TypedExprPtr transform(TypedExprPtr expr) const override { + facebook::velox::TypePtr type = expr->type(); + if (containsMap(type)) { + // TODO: support map type by using array_sort with a lambda that casts + // array elements to JSON before comparison. + return std::make_shared( + type, facebook::velox::variant::null(type->kind())); + } else { + return std::make_shared( + type, std::vector{std::move(expr)}, "array_sort"); + } + } + + int32_t extraLevelOfNesting() const override { + return 1; + } + + private: + bool containsMap(const TypePtr& type) const { + if (type->isMap()) { + return true; + } else if (type->isArray()) { + return containsMap(type->asArray().elementType()); + } else if (type->isRow()) { + for (const auto& child : type->asRow().children()) { + if (containsMap(child)) { + return true; + } + } + } + return false; + } +}; + +} // namespace facebook::velox::exec::test