From 80deccd511dbefcf9df195c955481341afc0fe1d Mon Sep 17 00:00:00 2001 From: rui-mo Date: Wed, 18 Oct 2023 14:13:47 +0800 Subject: [PATCH] Add Spark atan2 function (7113) --- velox/docs/functions/spark/math.rst | 4 ++++ velox/functions/sparksql/Arithmetic.h | 9 +++++++++ velox/functions/sparksql/RegisterArithmetic.cpp | 1 + velox/functions/sparksql/tests/ArithmeticTest.cpp | 9 ++++++++- 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/velox/docs/functions/spark/math.rst b/velox/docs/functions/spark/math.rst index 85870c33a992..7c583537d24e 100644 --- a/velox/docs/functions/spark/math.rst +++ b/velox/docs/functions/spark/math.rst @@ -18,6 +18,10 @@ Mathematical Functions Returns inverse hyperbolic sine of ``x``. +.. spark:function:: atan2(x, y) -> double + + Returns the angle in radians between the positive x-axis of a plane and the point given by the coordinates(x, y). + .. spark:function:: atanh(x) -> double Returns inverse hyperbolic tangent of ``x``. diff --git a/velox/functions/sparksql/Arithmetic.h b/velox/functions/sparksql/Arithmetic.h index fa9e101913d4..4523a89c22c2 100644 --- a/velox/functions/sparksql/Arithmetic.h +++ b/velox/functions/sparksql/Arithmetic.h @@ -280,4 +280,13 @@ struct Log10Function { return true; } }; + +template +struct Atan2Function { + template + FOLLY_ALWAYS_INLINE void call(TInput& result, TInput y, TInput x) { + result = std::atan2(y + 0.0, x + 0.0); + } +}; + } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/RegisterArithmetic.cpp index 08851f9e0d9f..cdddf87bd869 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/RegisterArithmetic.cpp @@ -95,6 +95,7 @@ void registerArithmeticFunctions(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_sub, prefix + "subtract"); VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_mul, prefix + "multiply"); VELOX_REGISTER_VECTOR_FUNCTION(udf_decimal_div, prefix + "divide"); + registerFunction({prefix + "atan2"}); } } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/ArithmeticTest.cpp b/velox/functions/sparksql/tests/ArithmeticTest.cpp index af43bc329999..2844cb876c4e 100644 --- a/velox/functions/sparksql/tests/ArithmeticTest.cpp +++ b/velox/functions/sparksql/tests/ArithmeticTest.cpp @@ -376,6 +376,14 @@ TEST_F(ArithmeticTest, cot) { EXPECT_EQ(cot(0), 1 / std::tan(0)); } +TEST_F(ArithmeticTest, atan2) { + const auto atan2 = [&](std::optional y, std::optional x) { + return evaluateOnce("atan2(c0, c1)", y, x); + }; + + EXPECT_EQ(atan2(0, 0), 0.0); +} + class LogNTest : public SparkFunctionBaseTest { protected: static constexpr float kInf = std::numeric_limits::infinity(); @@ -400,6 +408,5 @@ TEST_F(LogNTest, log10) { EXPECT_EQ(log10(-1.0), std::nullopt); EXPECT_EQ(log10(kInf), kInf); } - } // namespace } // namespace facebook::velox::functions::sparksql::test