Skip to content

Commit

Permalink
Add overwrite flag to simple function registration API (#9158)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #9158

The `SimpleFunctionRegistry::registerFunction(aliases, constraints)`
API used to always overwrite the function registry when the function
name and signature already exists. This diff adds an overwrite flag to
this API to control the behavior of overwriting. The overwrite flag is
true by default. The
`SimpleFunctionRegistry::registerFunction(aliases, constraints, overwrite)`
API returns a bool that is true only when all aliases are successfully
registered.

Reviewed By: bikramSingh91

Differential Revision: D55041377

fbshipit-source-id: 76cea41f98de717dd8cea83df7771226d85d8b53
  • Loading branch information
kagamiori authored and facebook-github-bot committed May 13, 2024
1 parent a54929b commit b29d933
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 27 deletions.
11 changes: 8 additions & 3 deletions velox/expression/SimpleFunctionRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,23 @@ SimpleFunctionRegistry& mutableSimpleFunctions() {
return simpleFunctionsInternal();
}

void SimpleFunctionRegistry::registerFunctionInternal(
bool SimpleFunctionRegistry::registerFunctionInternal(
const std::string& name,
const std::shared_ptr<const Metadata>& metadata,
const FunctionFactory& factory) {
const FunctionFactory& factory,
bool overwrite) {
const auto sanitizedName = sanitizeName(name);
registeredFunctions_.withWLock([&](auto& map) {
return registeredFunctions_.withWLock([&](auto& map) {
SignatureMap& signatureMap = map[sanitizedName];
auto& functions = signatureMap[*metadata->signature()];

for (auto it = functions.begin(); it != functions.end(); ++it) {
const auto& otherMetadata = (*it)->getMetadata();

if (metadata->physicalSignatureEquals(otherMetadata)) {
if (!overwrite) {
return false;
}
functions.erase(it);
break;
}
Expand All @@ -60,6 +64,7 @@ void SimpleFunctionRegistry::registerFunctionInternal(

functions.emplace_back(
std::make_unique<const FunctionEntry>(metadata, factory));
return true;
});
}

Expand Down
42 changes: 32 additions & 10 deletions velox/expression/SimpleFunctionRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,32 @@ using FunctionMap = std::unordered_map<std::string, SignatureMap>;

class SimpleFunctionRegistry {
public:
/// Register a UDF with the given aliases and constraints. If an alias already
/// exists and 'overwrite' is true, the existing entry in the function
/// registry is overwritten by the current UDF. If an alias already exists and
/// 'overwrite' is false, the current UDF is not registered with this alias.
/// This method returns true if all 'aliases' are registered successfully. It
/// returns false if any alias in 'aliases' already exists in the registry and
/// is not overwritten.
template <typename UDF>
void registerFunction(
bool registerFunction(
const std::vector<std::string>& aliases,
const std::vector<exec::SignatureVariable>& constraints) {
const std::vector<exec::SignatureVariable>& constraints,
bool overwrite) {
const auto& metadata = singletonUdfMetadata<typename UDF::Metadata>(
UDF::is_default_null_behavior, constraints);
const auto factory = []() { return CreateUdf<UDF>(); };

if (aliases.empty()) {
registerFunctionInternal(metadata->getName(), metadata, factory);
return registerFunctionInternal(
metadata->getName(), metadata, factory, overwrite);
} else {
bool registered = true;
for (const auto& name : aliases) {
registerFunctionInternal(name, metadata, factory);
registered &=
registerFunctionInternal(name, metadata, factory, overwrite);
}
return registered;
}
}

Expand Down Expand Up @@ -142,10 +154,17 @@ class SimpleFunctionRegistry {
return std::make_unique<T>();
}

void registerFunctionInternal(
/// Registers a function with the given name and metadata. If an entry with
/// the name already exists and 'overwrite' is true, the existing entry is
/// overwritten. If an entry with the name already exists and 'overwrite' is
/// false, the function reigstry remain unchanged. This method returns true if
/// the function is successfully registered. It returns false if the function
/// registry remains unchanged.
bool registerFunctionInternal(
const std::string& name,
const std::shared_ptr<const Metadata>& metadata,
const FunctionFactory& factory);
const FunctionFactory& factory,
bool overwrite);

folly::Synchronized<FunctionMap> registeredFunctions_;
};
Expand All @@ -159,13 +178,16 @@ SimpleFunctionRegistry& mutableSimpleFunctions();
/// @param constraints Additional constraints for variables used in function
/// signature. Primarily used to specify rules for calculating precision and
/// scale for decimal result types.
/// @param overwrite If true, overwrites existing entries in the function
/// registry with the same names.
template <typename UDFHolder>
void registerSimpleFunction(
bool registerSimpleFunction(
const std::vector<std::string>& names,
const std::vector<exec::SignatureVariable>& constraints) {
mutableSimpleFunctions()
const std::vector<exec::SignatureVariable>& constraints,
bool overwrite) {
return mutableSimpleFunctions()
.registerFunction<SimpleFunctionAdapterFactoryImpl<UDFHolder>>(
names, constraints);
names, constraints, overwrite);
}

} // namespace facebook::velox::exec
9 changes: 9 additions & 0 deletions velox/expression/tests/SimpleFunctionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,15 @@ TEST_F(SimpleFunctionTest, decimals) {
// Verify overwrite behavior. Register a different function using the same
// name and physical signature as decimal_plus_one. Expect the new function to
// be used for (short) -> short signature.
// Not overwrite function registry.
registerFunction<
DecimalPlusTwoFunction,
ShortDecimal<P1, S1>,
ShortDecimal<P1, S1>>({"decimal_plus_one"}, {}, false);
result = evaluate("decimal_plus_one(c1)", data);
assertEqualVectors(expected, result);

// Overwrite function registry.
registerFunction<
DecimalPlusTwoFunction,
ShortDecimal<P1, S1>,
Expand Down
21 changes: 16 additions & 5 deletions velox/functions/Registerer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,45 @@ struct TempWrapper {
template <template <class...> typename T, typename... TArgs>
using ParameterBinder = TempWrapper<T<exec::VectorExec, TArgs...>>;

// Register a UDF with the given aliases. If an alias already
// exists and 'overwrite' is true, the existing entry in the function
// registry is overwritten by the current UDF. If an alias already exists and
// 'overwrite' is false, the current UDF is not registered with this alias.
// This method returns true if all 'aliases' are registered successfully. It
// returns false if any alias in 'aliases' already exists in the registry and
// is not overwritten.
template <typename Func, typename TReturn, typename... TArgs>
void registerFunction(const std::vector<std::string>& aliases = {}) {
bool registerFunction(
const std::vector<std::string>& aliases = {},
bool overwrite = true) {
using funcClass = typename Func::template udf<exec::VectorExec>;
using holderClass = core::UDFHolder<
funcClass,
exec::VectorExec,
TReturn,
ConstantChecker<TArgs...>,
typename UnwrapConstantType<TArgs>::type...>;
exec::registerSimpleFunction<holderClass>(aliases, {});
return exec::registerSimpleFunction<holderClass>(aliases, {}, overwrite);
}

// New registration function; mostly a copy from the function above, but taking
// the inner "udf" struct directly, instead of the wrapper. We can keep both for
// a while to maintain backwards compatibility, but the idea is to remove the
// one above eventually.
template <template <class> typename Func, typename TReturn, typename... TArgs>
void registerFunction(
bool registerFunction(
const std::vector<std::string>& aliases = {},
const std::vector<exec::SignatureVariable>& constraints = {}) {
const std::vector<exec::SignatureVariable>& constraints = {},
bool overwrite = true) {
using funcClass = Func<exec::VectorExec>;
using holderClass = core::UDFHolder<
funcClass,
exec::VectorExec,
TReturn,
ConstantChecker<TArgs...>,
typename UnwrapConstantType<TArgs>::type...>;
exec::registerSimpleFunction<holderClass>(aliases, constraints);
return exec::registerSimpleFunction<holderClass>(
aliases, constraints, overwrite);
}

} // namespace facebook::velox
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
namespace facebook::velox::aggregate::prestosql {

/// Entery point to register aggregate functions.
/// \param prefix : Prefix for the aggregate functions.
/// \param withCompanionFunctions : Also register companion functions, defaults
/// to true. \param onlyPrestoSignatures : Register only function signatures
/// @param prefix Prefix for the aggregate functions.
/// @param withCompanionFunctions Also register companion functions, defaults
/// to true.
/// @param onlyPrestoSignatures Register only function signatures
/// that are compatible with Presto.
void registerAllAggregateFunctions(
const std::string& prefix = "",
Expand All @@ -31,7 +32,7 @@ void registerAllAggregateFunctions(
bool overwrite = true);

/// Register internal aggregation functions only for testing.
/// \param prefix : Prefix for the aggregate functions.
/// @param prefix Prefix for the aggregate functions.
void registerInternalAggregateFunctions(const std::string& prefix);

} // namespace facebook::velox::aggregate::prestosql
2 changes: 1 addition & 1 deletion velox/functions/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ add_executable(velox_function_registry_test FunctionRegistryTest.cpp)
add_test(NAME velox_function_registry_test COMMAND velox_function_registry_test)

target_link_libraries(velox_function_registry_test velox_function_registry
gmock gtest gtest_main)
velox_functions_test_lib gmock gtest gtest_main)
39 changes: 35 additions & 4 deletions velox/functions/tests/FunctionRegistryTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "velox/functions/FunctionRegistry.h"
#include "velox/functions/Macros.h"
#include "velox/functions/Registerer.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
#include "velox/type/Type.h"

namespace facebook::velox {
Expand Down Expand Up @@ -79,9 +80,18 @@ struct FuncFour {

template <typename T>
struct FuncFive {
FOLLY_ALWAYS_INLINE bool call(
int64_t& /* result */,
const int64_t& /* arg1 */) {
FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) {
result = 5;
return true;
}
};

// FuncSix has the same signature as FuncFive. It's used to test overwrite
// during registration.
template <typename T>
struct FuncSix {
FOLLY_ALWAYS_INLINE bool call(int64_t& result, const int64_t& /* arg1 */) {
result = 6;
return true;
}
};
Expand Down Expand Up @@ -223,7 +233,7 @@ inline void registerTestFunctions() {
}
} // namespace

class FunctionRegistryTest : public ::testing::Test {
class FunctionRegistryTest : public testing::Test {
public:
FunctionRegistryTest() {
registerTestFunctions();
Expand Down Expand Up @@ -588,4 +598,25 @@ TEST_F(FunctionRegistryTest, resolveWithMetadata) {
result = resolveFunctionWithMetadata("non-existent-function", {VARCHAR()});
EXPECT_FALSE(result.has_value());
}

class FunctionRegistryOverwriteTest : public functions::test::FunctionBaseTest {
public:
FunctionRegistryOverwriteTest() {
registerTestFunctions();
}
};

TEST_F(FunctionRegistryOverwriteTest, overwrite) {
ASSERT_TRUE((registerFunction<FuncFive, int64_t, int64_t>({"foo"})));
ASSERT_FALSE(
(registerFunction<FuncSix, int64_t, int64_t>({"foo"}, {}, false)));
ASSERT_TRUE((evaluateOnce<int64_t, int64_t>("foo(c0)", 0) == 5));
ASSERT_TRUE((registerFunction<FuncSix, int64_t, int64_t>({"foo"})));
ASSERT_TRUE((evaluateOnce<int64_t, int64_t>("foo(c0)", 0) == 6));

auto& simpleFunctions = exec::simpleFunctions();
auto signatures = simpleFunctions.getFunctionSignatures("foo");
ASSERT_EQ(signatures.size(), 1);
}

} // namespace facebook::velox

0 comments on commit b29d933

Please sign in to comment.