diff --git a/CMake/resolve_dependency_modules/stemmer.cmake b/CMake/resolve_dependency_modules/stemmer.cmake new file mode 100644 index 0000000000000..076e83bcaf3b5 --- /dev/null +++ b/CMake/resolve_dependency_modules/stemmer.cmake @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +include_guard(GLOBAL) + +set(VELOX_STEMMER_VERSION 2.2.0) +set(VELOX_STEMMER_BUILD_SHA256_CHECKSUM + b941d9fe9cf36b4e2f8d3873cd4d8b8775bd94867a1df8d8c001bb8b688377c3) +set(VELOX_STEMMER_SOURCE_URL + "https://snowballstem.org/dist/libstemmer_c-${VELOX_STEMMER_VERSION}.tar.gz" +) + +resolve_dependency_url(STEMMER) + +message(STATUS "Building stemmer from source") +find_program(MAKE_PROGRAM make REQUIRED) + +set(STEMMER_PREFIX "${CMAKE_BINARY_DIR}/_deps/libstemmer") + +# We can not use FetchContent as libstemmer does not use cmake +ExternalProject_Add( + libstemmer + PREFIX ${STEMMER_PREFIX} + SOURCE_DIR ${STEMMER_PREFIX}/src/libstemmer + URL ${VELOX_STEMMER_SOURCE_URL} + URL_HASH ${VELOX_STEMMER_BUILD_SHA256_CHECKSUM} + BUILD_IN_SOURCE TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND ${MAKE_PROGRAM} + INSTALL_COMMAND "" + BUILD_BYPRODUCTS ${STEMMER_PREFIX}/src/libstemmer/${CMAKE_STATIC_LIBRARY_PREFIX}stemmer${CMAKE_STATIC_LIBRARY_SUFFIX}) + +ExternalProject_Get_Property(libstemmer BINARY_DIR) + +add_library(stemmer STATIC IMPORTED) +set_target_properties(stemmer PROPERTIES IMPORTED_LOCATION ${BINARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}stemmer${CMAKE_STATIC_LIBRARY_SUFFIX}) + +include_directories(${BINARY_DIR}/include) +add_dependencies(stemmer libstemmer) \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c84456ef0119..0ae32faba2277 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -542,6 +542,10 @@ endif() set_source(xsimd) resolve_dependency(xsimd 10.0.0) +set(stemmer_SOURCE BUNDLED) +set_source(stemmer) +resolve_dependency(stemmer) + if(VELOX_BUILD_TESTING) set(BUILD_TESTING ON) include(CTest) # include after project() but before add_subdirectory() diff --git a/velox/functions/prestosql/CMakeLists.txt b/velox/functions/prestosql/CMakeLists.txt index 3a8008be601b9..38cfaa8655f4a 100644 --- a/velox/functions/prestosql/CMakeLists.txt +++ b/velox/functions/prestosql/CMakeLists.txt @@ -74,7 +74,8 @@ target_link_libraries( velox_type_tz velox_presto_types velox_functions_util - Folly::folly) + Folly::folly + stemmer) set_property(TARGET velox_functions_prestosql_impl PROPERTY JOB_POOL_COMPILE high_memory_pool) diff --git a/velox/functions/prestosql/StringFunctions.h b/velox/functions/prestosql/StringFunctions.h index 7afaea4e43072..bc08767f91fac 100644 --- a/velox/functions/prestosql/StringFunctions.h +++ b/velox/functions/prestosql/StringFunctions.h @@ -15,6 +15,8 @@ */ #pragma once +#include + #include "velox/functions/Udf.h" #include "velox/functions/lib/string/StringCore.h" #include "velox/functions/lib/string/StringImpl.h" @@ -494,4 +496,111 @@ struct LevenshteinDistanceFunction { } }; +/// Wrap the sbstemmer library and use its sb_stemmer_stem +/// to get word stem +class Stemmer { + private: + sb_stemmer* sbstemmer; + Stemmer(sb_stemmer* stemmer) : sbstemmer(stemmer) {} + + public: + ~Stemmer() { + sb_stemmer_delete(sbstemmer); + } + + /// Get a Stemmer from the the map stored in thread local storage + /// or create a new one if it doesn't exist. Return NULL if the + /// specified lang is not supported. + static Stemmer* getStemmer(const char* lang) { + thread_local std::map stemmers; + if (auto found = stemmers.find(lang); found != stemmers.end()) { + return found->second; + } + // Only support ASCII and UTF-8 + if (auto sbstemmer = sb_stemmer_new(lang, "UTF_8")) { + auto rev = new Stemmer(sbstemmer); + stemmers[lang] = rev; + return rev; + } + return NULL; + } + + /// Get the word stem or NULL if out of memory + const char* stem(const std::string& input) { + return (const char*)(sb_stemmer_stem( + sbstemmer, + reinterpret_cast(input.c_str()), + input.length())); + } +}; + +/// word_stem function +/// word_stem(word) -> varchar +/// return the stem of the word in the English language +/// word_stem(word, lang) -> varchar +/// return the stem of the word in the specificed language +/// +/// It uses the snowball stemmer library to calculate the stem. +/// https://snowballstem.org +/// It provides Java implementation which is used in Presto as well +/// as C implementation. Therefore, both Presto and Prestimissio +/// would have the same word stem results. +template +struct WordStemFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + // Results refer to strings in the first argument. + static constexpr int32_t reuse_strings_from_arg = 0; + + // ASCII input always produces ASCII result. + static constexpr bool is_default_ascii_behavior = true; + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& input) { + return doCall(result, input); + } + + FOLLY_ALWAYS_INLINE void callAscii( + out_type& result, + const arg_type& input) { + return doCall(result, input); + } + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& input, + const arg_type& lang) { + return doCall(result, input, lang.data()); + } + + FOLLY_ALWAYS_INLINE void callAscii( + out_type& result, + const arg_type& input, + const arg_type& lang) { + return doCall(result, input, lang.data()); + } + + template + FOLLY_ALWAYS_INLINE void doCall( + out_type& result, + const arg_type& input, + const char* lang = "en") { + auto stemmer = Stemmer::getStemmer(lang); + if (!stemmer) { + // language is not supported + throw std::invalid_argument( + "Unknown stemmer language: \"" + std::string(lang) + "\""); + } + + std::string lowerOutput; + stringImpl::lower(lowerOutput, input); + auto rev = stemmer->stem(lowerOutput); + if (rev == NULL) { + throw std::runtime_error("out of memory"); + } + result = rev; + } +}; + } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp index f1e36049e92d1..85f7a4c3d62da 100644 --- a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp @@ -127,5 +127,10 @@ void registerStringFunctions(const std::string& prefix) { {prefix + "strrpos"}); registerFunction( {prefix + "strrpos"}); + + // word_stem function + registerFunction({prefix + "word_stem"}); + registerFunction( + {prefix + "word_stem"}); } } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/tests/StringFunctionsTest.cpp b/velox/functions/prestosql/tests/StringFunctionsTest.cpp index 4c05965fb56a3..67908a6f68039 100644 --- a/velox/functions/prestosql/tests/StringFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/StringFunctionsTest.cpp @@ -305,6 +305,12 @@ class StringFunctionsTest : public FunctionBaseTest { const std::string& left, const std::string& right); + std::string wordStemWithLang( + const std::string& word, + const std::string& lang); + + std::string wordStem(const std::string& word); + using replace_input_test_t = std::vector, std::string>>; @@ -1204,6 +1210,58 @@ TEST_F(StringFunctionsTest, invalidLevenshteinDistance) { "The combined inputs size exceeded max Levenshtein distance combined input size"); } +std::string StringFunctionsTest::wordStemWithLang( + const std::string& word, + const std::string& lang) { + return evaluateOnce( + "word_stem(c0, c1)", std::optional(word), std::optional(lang)) + .value(); +} + +std::string StringFunctionsTest::wordStem(const std::string& word) { + return evaluateOnce("word_stem(c0)", std::optional(word)) + .value(); +} + +/// Borrow test cases from Presto Java +/// https://github.com/prestodb/presto/blob/master/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestWordStemFunction.java +TEST_F(StringFunctionsTest, asciiWordStem) { + EXPECT_EQ(wordStem(""), ""); + EXPECT_EQ(wordStem("x"), "x"); + EXPECT_EQ(wordStem("abc"), "abc"); + EXPECT_EQ(wordStem("generally"), "general"); + EXPECT_EQ(wordStem("useful"), "use"); + EXPECT_EQ(wordStem("runs"), "run"); + EXPECT_EQ(wordStem("run"), "run"); + EXPECT_EQ(wordStemWithLang("authorized", "en"), "author"); + EXPECT_EQ(wordStemWithLang("accessories", "en"), "accessori"); + EXPECT_EQ(wordStemWithLang("intensifying", "en"), "intensifi"); + EXPECT_EQ(wordStemWithLang("resentment", "en"), "resent"); + EXPECT_EQ(wordStemWithLang("faithfulness", "en"), "faith"); + EXPECT_EQ(wordStemWithLang("continuerait", "fr"), "continu"); + EXPECT_EQ(wordStemWithLang("torpedearon", "es"), "torped"); + EXPECT_EQ(wordStemWithLang("quilomtricos", "pt"), "quilomtr"); + EXPECT_EQ(wordStemWithLang("pronunziare", "it"), "pronunz"); + EXPECT_EQ(wordStemWithLang("auferstnde", "de"), "auferstnd"); +} + +TEST_F(StringFunctionsTest, invalidWordStemLang) { + VELOX_ASSERT_THROW( + wordStemWithLang("hello", "xx"), "Unknown stemmer language: \"xx\""); +} + +TEST_F(StringFunctionsTest, unicodeWordStem) { + EXPECT_EQ( + wordStemWithLang( + "\u004b\u0069\u0074\u0061\u0062\u0131\u006d\u0131\u007a\u0064\u0131", + "tr"), + "kitap"); + EXPECT_EQ( + wordStemWithLang( + "\u0432\u0435\u0441\u0435\u043d\u043d\u0438\u0439", "ru"), + "\u0432\u0435\u0441\u0435\u043d"); +} + void StringFunctionsTest::testReplaceInPlace( const std::vector>& tests, const std::string& search,