From 2e130cf2284b07b9d1afca5d2e219e31ae5cd8c0 Mon Sep 17 00:00:00 2001 From: arnavb Date: Thu, 12 Dec 2024 15:47:42 +0000 Subject: [PATCH] update --- velox/expression/fuzzer/ExpressionFuzzer.cpp | 9 ++ velox/functions/sparksql/CMakeLists.txt | 1 + velox/functions/sparksql/Factorial.cpp | 124 ++++++++++++++++++ velox/functions/sparksql/Factorial.h | 35 +++++ .../sparksql/registration/RegisterMath.cpp | 3 + velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../sparksql/tests/FactorialTest.cpp | 64 +++++++++ 7 files changed, 237 insertions(+) create mode 100644 velox/functions/sparksql/Factorial.cpp create mode 100644 velox/functions/sparksql/Factorial.h create mode 100644 velox/functions/sparksql/tests/FactorialTest.cpp diff --git a/velox/expression/fuzzer/ExpressionFuzzer.cpp b/velox/expression/fuzzer/ExpressionFuzzer.cpp index 876501414501..46f515bc5064 100644 --- a/velox/expression/fuzzer/ExpressionFuzzer.cpp +++ b/velox/expression/fuzzer/ExpressionFuzzer.cpp @@ -359,6 +359,15 @@ static void appendSpecialForms( /// them to fuzzer instead of hard-coding signatures here. getSignaturesForCast(), }, + { + "factorial", + std::vector{ + // Signature: factorial (integer) -> integer + facebook::velox::exec::FunctionSignatureBuilder() + .returnType("T") + .argumentType("T") + .build()}, + }, }; auto specialFormNames = splitNames(specialForms); diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index 5e2f5ad58271..507501d40fc9 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -20,6 +20,7 @@ velox_add_library( Comparisons.cpp DecimalArithmetic.cpp DecimalCompare.cpp + Factorial.cpp Hash.cpp In.cpp LeastGreatest.cpp diff --git a/velox/functions/sparksql/Factorial.cpp b/velox/functions/sparksql/Factorial.cpp new file mode 100644 index 000000000000..df8d79520028 --- /dev/null +++ b/velox/functions/sparksql/Factorial.cpp @@ -0,0 +1,124 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/sparksql/Factorial.h" +#include "velox/expression/ConstantExpr.h" +#include "velox/expression/VectorFunction.h" +#include +#include + +namespace facebook::velox::functions::sparksql { + +namespace { + +/** + * Computes the factorial of integers in the range [0...20] + * + * Returns NULL for inputs which are outside the range [0...20]. + * Leverages a lookup table for O(1) computation, similar to Spark JVM. + */ +class Factorial : public exec::VectorFunction { + public: + Factorial() = default; + + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + + context.ensureWritable(rows, BIGINT(), result); + auto* flatResult = result->asFlatVector(); + + exec::DecodedArgs decodedArgs(rows, args, context); + auto* inputVector = decodedArgs.at(0); + + rows.applyToSelected([&](vector_size_t row) { + if (inputVector->isNullAt(row)) { + flatResult->setNull(row, true); + } else { + int32_t value = inputVector->valueAt(row); + if (value < LOWER_BOUND || value > UPPER_BOUND) { + flatResult->setNull(row, true); + } else { + flatResult->set(row, kFactorials[value]); + } + } + }); + } + + private: + static constexpr int64_t LOWER_BOUND = 0; + static constexpr int64_t UPPER_BOUND = 20; + static constexpr int64_t MAX_INT64 = std::numeric_limits::max(); + + static constexpr int64_t kFactorials[21] = { + 1, + 1, + 2, + 6, + 24, + 120, + 720, + 5040, + 40320, + 362880, + 3628800, + 39916800, + 479001600, + 6227020800L, + 87178291200L, + 1307674368000L, + 20922789888000L, + 355687428096000L, + 6402373705728000L, + 121645100408832000L, + 2432902008176640000L + }; +}; +} // namespace + +TypePtr FactorialCallToSpecialForm::resolveType( + const std::vector&) { + return BIGINT(); +} + +exec::ExprPtr FactorialCallToSpecialForm::constructSpecialForm( + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& config) { + auto numArgs = args.size(); + + VELOX_USER_CHECK_EQ( + numArgs, + 1, + "factorial requires exactly 1 argument, but got {}.", + numArgs); + VELOX_USER_CHECK( + args[0]->type()->isInteger(), + "The argument of factorial must be an integer."); + + auto factorial = std::make_shared(); + return std::make_shared( + type, + std::move(args), + std::move(factorial), + exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(), + "factorial", + trackCpuUsage); +} +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Factorial.h b/velox/functions/sparksql/Factorial.h new file mode 100644 index 000000000000..ae7dc18774c7 --- /dev/null +++ b/velox/functions/sparksql/Factorial.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "velox/expression/FunctionCallToSpecialForm.h" + +namespace facebook::velox::functions::sparksql { + +class FactorialCallToSpecialForm : public exec::FunctionCallToSpecialForm { + public: + TypePtr resolveType(const std::vector& argTypes) override; + + exec::ExprPtr constructSpecialForm( + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& config) override; + + static constexpr const char* factorial = "factorial"; +}; +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/registration/RegisterMath.cpp b/velox/functions/sparksql/registration/RegisterMath.cpp index 0532c5a82055..0da70cfb7612 100644 --- a/velox/functions/sparksql/registration/RegisterMath.cpp +++ b/velox/functions/sparksql/registration/RegisterMath.cpp @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "velox/expression/SpecialFormRegistry.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/Arithmetic.h" #include "velox/functions/sparksql/Arithmetic.h" #include "velox/functions/sparksql/DecimalArithmetic.h" #include "velox/functions/sparksql/Rand.h" +#include "velox/functions/sparksql/Factorial.h" namespace facebook::velox::functions::sparksql { @@ -120,6 +122,7 @@ void registerMathFunctions(const std::string& prefix) { registerBinaryNumeric({prefix + "checked_subtract"}); registerBinaryNumeric({prefix + "checked_multiply"}); registerBinaryNumeric({prefix + "checked_divide"}); + registerFunctionCallToSpecialForm("factorial",std::make_unique()); } } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index d17a9e0ed855..7d61b01733fc 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -32,6 +32,7 @@ add_executable( DecimalRoundTest.cpp DecimalUtilTest.cpp ElementAtTest.cpp + FactorialTest.cpp HashTest.cpp InTest.cpp JsonObjectKeysTest.cpp diff --git a/velox/functions/sparksql/tests/FactorialTest.cpp b/velox/functions/sparksql/tests/FactorialTest.cpp new file mode 100644 index 000000000000..e2d0f5630ace --- /dev/null +++ b/velox/functions/sparksql/tests/FactorialTest.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/Factorial.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class FactorialTest : public SparkFunctionBaseTest { + protected: + void testFactorial( + const VectorPtr& input, + const VectorPtr& expected) { + auto result = evaluate>( + "factorial(c0)", makeRowVector({input})); + velox::test::assertEqualVectors(expected, result); + } +}; + +TEST_F(FactorialTest, basic) { + auto input = makeFlatVector({0, 1, 2, 5, 10, 15, 20}); + auto expected = makeFlatVector( + {1, 1, 2, 120, 3628800, 1307674368000L, 2432902008176640000L}); + testFactorial(input, expected); +} + +TEST_F(FactorialTest, nullInput) { + auto input = makeNullableFlatVector( + {0, std::nullopt, 5, 20, std::nullopt}); + auto expected = makeNullableFlatVector( + {1, std::nullopt, 120, 2432902008176640000L, std::nullopt}); + testFactorial(input, expected); +} + +TEST_F(FactorialTest, outOfRangeInput) { + auto input = makeFlatVector({-1, 21, -5, 25}); + auto expected = makeNullConstant(TypeKind::BIGINT, input->size()); + testFactorial(input, expected); +} + +TEST_F(FactorialTest, mixedInputs) { + auto input = makeNullableFlatVector( + {3, 5, std::nullopt, 25, -3, 10, 15}); + auto expected = makeNullableFlatVector( + {6, 120, std::nullopt, std::nullopt, std::nullopt, 3628800, 1307674368000L}); + testFactorial(input, expected); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test