diff --git a/velox/docs/functions/spark/array.rst b/velox/docs/functions/spark/array.rst index f80e13923dc0..168bf0fc06c1 100644 --- a/velox/docs/functions/spark/array.rst +++ b/velox/docs/functions/spark/array.rst @@ -51,6 +51,18 @@ Array Functions SELECT array_min(ARRAY [4.0, float('nan')]); -- 4.0 SELECT array_min(ARRAY [NULL, float('nan')]); -- NaN +.. function:: array_remove(x, element) -> array + + Remove all elements that equal ``element`` from array ``x``. Returns NULL as result if ``element`` is NULL. + If array ``x`` is empty array, returns empty array. If all elements in array ``x`` are NULL but ``element`` is not NULL, + returns array ``x``. :: + + SELECT array_remove(ARRAY [1, 2, 3], 3); -- [1, 2] + SELECT array_remove(ARRAY [2, 1, NULL], 1); -- [2, NULL] + SELECT array_remove(ARRAY [1, 2, NULL], NULL); -- NULL + SELECT array_remove(ARRAY [], 1); -- [] + SELECT array_remove(ARRAY [NULL, NULL], -1); -- [NULL, NULL] + .. spark:function:: array_repeat(element, count) -> array(E) Returns an array containing ``element`` ``count`` times. If ``count`` is negative or zero, diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 01049c0ab59b..72a8ca527568 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -22,6 +22,7 @@ #include "velox/functions/lib/Re2Functions.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/lib/Repeat.h" +#include "velox/functions/prestosql/ArrayFunctions.h" #include "velox/functions/prestosql/DateTimeFunctions.h" #include "velox/functions/prestosql/JsonFunctions.h" #include "velox/functions/prestosql/StringFunctions.h" @@ -52,6 +53,32 @@ extern void registerElementAtFunction( const std::string& name, bool enableCaching); +template +inline void registerArrayRemoveFunctions(const std::string& prefix) { + registerFunction, Array, T>( + {prefix + "array_remove"}); +} + +inline void registerArrayRemoveFunctions(const std::string& prefix) { + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions>(prefix); + registerFunction< + ArrayRemoveFunctionString, + Array, + Array, + Varchar>({prefix + "array_remove"}); +} + static void workAroundRegistrationMacro(const std::string& prefix) { // VELOX_REGISTER_VECTOR_FUNCTION must be invoked in the same namespace as the // vector function definition. @@ -82,6 +109,7 @@ static void workAroundRegistrationMacro(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION(udf_not, prefix + "not"); registerIsNullFunction(prefix + "isnull"); registerIsNotNullFunction(prefix + "isnotnull"); + registerArrayRemoveFunctions(prefix); } namespace sparksql {