From d79dfe444219340eec075bdb4d5fe5d4f815c7c6 Mon Sep 17 00:00:00 2001 From: Yihong Wang Date: Wed, 10 Apr 2024 10:59:06 -0700 Subject: [PATCH] Update word_stem impl to address comments - doc update - separate the impl to a new header file - separate the test to a new cpp file - apply code convensions Signed-off-by: Yihong Wang --- velox/docs/functions/presto/string.rst | 33 ++++- velox/functions/prestosql/StringFunctions.h | 109 --------------- velox/functions/prestosql/WordStem.h | 128 ++++++++++++++++++ .../StringFunctionsRegistration.cpp | 1 + .../functions/prestosql/tests/CMakeLists.txt | 1 + .../prestosql/tests/StringFunctionsTest.cpp | 58 -------- .../prestosql/tests/WordStemTest.cpp | 78 +++++++++++ 7 files changed, 239 insertions(+), 169 deletions(-) create mode 100644 velox/functions/prestosql/WordStem.h create mode 100644 velox/functions/prestosql/tests/WordStemTest.cpp diff --git a/velox/docs/functions/presto/string.rst b/velox/docs/functions/presto/string.rst index 59568471fb618..44fb82cef4fab 100644 --- a/velox/docs/functions/presto/string.rst +++ b/velox/docs/functions/presto/string.rst @@ -257,11 +257,40 @@ String Functions .. function:: word_stem(word) -> varchar - Returns the stem of ``word`` in the English language. + Returns the stem of ``word`` in the English language. If the ``word`` is not an English word, + the ``word`` in lowercase is returned. .. function:: word_stem(word, lang) -> varchar - Returns the stem of ``word`` in the ``lang`` language. + Returns the stem of ``word`` in the ``lang`` language. This function supports the following languages: + + =========== ================ + lang Language + =========== ================ + ``ca`` ``Catalan`` + ``da`` ``Danish`` + ``de`` ``German`` + ``en`` ``English`` + ``es`` ``Spanish`` + ``eu`` ``Basque`` + ``fi`` ``Finnish`` + ``fr`` ``French`` + ``hu`` ``Hungarian`` + ``hy`` ``Armenian`` + ``ir`` ``Irish`` + ``it`` ``Italian`` + ``lt`` ``Lithuanian`` + ``nl`` ``Dutch`` + ``no`` ``Norwegian`` + ``pt`` ``Portuguese`` + ``ro`` ``Romanian`` + ``ru`` ``Russian`` + ``sv`` ``Swedish`` + ``tr`` ``Turkish`` + =========== ================ + + If the specified ``lang`` is not supported, this function throws a user error. + Unicode Functions ----------------- diff --git a/velox/functions/prestosql/StringFunctions.h b/velox/functions/prestosql/StringFunctions.h index bc08767f91fac..7afaea4e43072 100644 --- a/velox/functions/prestosql/StringFunctions.h +++ b/velox/functions/prestosql/StringFunctions.h @@ -15,8 +15,6 @@ */ #pragma once -#include - #include "velox/functions/Udf.h" #include "velox/functions/lib/string/StringCore.h" #include "velox/functions/lib/string/StringImpl.h" @@ -496,111 +494,4 @@ struct LevenshteinDistanceFunction { } }; -/// Wrap the sbstemmer library and use its sb_stemmer_stem -/// to get word stem -class Stemmer { - private: - sb_stemmer* sbstemmer; - Stemmer(sb_stemmer* stemmer) : sbstemmer(stemmer) {} - - public: - ~Stemmer() { - sb_stemmer_delete(sbstemmer); - } - - /// Get a Stemmer from the the map stored in thread local storage - /// or create a new one if it doesn't exist. Return NULL if the - /// specified lang is not supported. - static Stemmer* getStemmer(const char* lang) { - thread_local std::map stemmers; - if (auto found = stemmers.find(lang); found != stemmers.end()) { - return found->second; - } - // Only support ASCII and UTF-8 - if (auto sbstemmer = sb_stemmer_new(lang, "UTF_8")) { - auto rev = new Stemmer(sbstemmer); - stemmers[lang] = rev; - return rev; - } - return NULL; - } - - /// Get the word stem or NULL if out of memory - const char* stem(const std::string& input) { - return (const char*)(sb_stemmer_stem( - sbstemmer, - reinterpret_cast(input.c_str()), - input.length())); - } -}; - -/// word_stem function -/// word_stem(word) -> varchar -/// return the stem of the word in the English language -/// word_stem(word, lang) -> varchar -/// return the stem of the word in the specificed language -/// -/// It uses the snowball stemmer library to calculate the stem. -/// https://snowballstem.org -/// It provides Java implementation which is used in Presto as well -/// as C implementation. Therefore, both Presto and Prestimissio -/// would have the same word stem results. -template -struct WordStemFunction { - VELOX_DEFINE_FUNCTION_TYPES(T); - - // Results refer to strings in the first argument. - static constexpr int32_t reuse_strings_from_arg = 0; - - // ASCII input always produces ASCII result. - static constexpr bool is_default_ascii_behavior = true; - - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { - return doCall(result, input); - } - - FOLLY_ALWAYS_INLINE void callAscii( - out_type& result, - const arg_type& input) { - return doCall(result, input); - } - - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input, - const arg_type& lang) { - return doCall(result, input, lang.data()); - } - - FOLLY_ALWAYS_INLINE void callAscii( - out_type& result, - const arg_type& input, - const arg_type& lang) { - return doCall(result, input, lang.data()); - } - - template - FOLLY_ALWAYS_INLINE void doCall( - out_type& result, - const arg_type& input, - const char* lang = "en") { - auto stemmer = Stemmer::getStemmer(lang); - if (!stemmer) { - // language is not supported - throw std::invalid_argument( - "Unknown stemmer language: \"" + std::string(lang) + "\""); - } - - std::string lowerOutput; - stringImpl::lower(lowerOutput, input); - auto rev = stemmer->stem(lowerOutput); - if (rev == NULL) { - throw std::runtime_error("out of memory"); - } - result = rev; - } -}; - } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/WordStem.h b/velox/functions/prestosql/WordStem.h new file mode 100644 index 0000000000000..aac83da79ef8b --- /dev/null +++ b/velox/functions/prestosql/WordStem.h @@ -0,0 +1,128 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "velox/functions/Udf.h" +#include "velox/functions/lib/string/StringImpl.h" + +namespace facebook::velox::functions { + +namespace { +/// Wrap the sbstemmer library and use its sb_stemmer_stem +/// to get word stem +class Stemmer { + private: + sb_stemmer* sbStemmer_; + Stemmer(sb_stemmer* stemmer) : sbStemmer_(stemmer) {} + + public: + ~Stemmer() { + sb_stemmer_delete(sbStemmer_); + } + + /// Get a Stemmer from the the map stored in thread local storage + /// or create a new one if it doesn't exist. Return nullptr if the + /// specified lang is not supported. + static Stemmer* getStemmer(const char* lang) { + thread_local std::map> stemmers; + if (auto found = stemmers.find(lang); found != stemmers.end()) { + return found->second.get(); + } + Stemmer* stemmer = nullptr; + // Only support ASCII and UTF-8 + if (auto sbStemmer = sb_stemmer_new(lang, "UTF_8")) { + stemmer = new Stemmer(sbStemmer); + stemmers[lang] = std::unique_ptr(stemmer); + } + return stemmer; + } + + /// Get the word stem or NULL if out of memory + const char* stem(const std::string& input) { + return (const char*)(sb_stemmer_stem( + sbStemmer_, + reinterpret_cast(input.c_str()), + input.length())); + } +}; +} // namespace + +/// word_stem function +/// word_stem(word) -> varchar +/// return the stem of the word in the English language +/// word_stem(word, lang) -> varchar +/// return the stem of the word in the specificed language +/// +/// It uses the snowball stemmer library to calculate the stem. +/// https://snowballstem.org +/// It provides Java implementation which is used in Presto as well +/// as C implementation. Therefore, both Presto and Prestimissio +/// would have the same word stem results. +template +struct WordStemFunction { + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + // ASCII input always produces ASCII result. + static constexpr bool is_default_ascii_behavior = true; + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& input) { + return doCall(result, input); + } + + FOLLY_ALWAYS_INLINE void callAscii( + out_type& result, + const arg_type& input) { + return doCall(result, input); + } + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& input, + const arg_type& lang) { + return doCall(result, input, lang.data()); + } + + FOLLY_ALWAYS_INLINE void callAscii( + out_type& result, + const arg_type& input, + const arg_type& lang) { + return doCall(result, input, lang.data()); + } + + template + FOLLY_ALWAYS_INLINE void doCall( + out_type& result, + const arg_type& input, + const char* lang = "en") { + auto stemmer = Stemmer::getStemmer(lang); + if (!stemmer) { + // language is not supported + VELOX_USER_FAIL("Unknown stemmer language: \"{}\"", lang); + } + + std::string lowerOutput; + stringImpl::lower(lowerOutput, input); + auto stem = stemmer->stem(lowerOutput); + VELOX_CHECK_NOT_NULL( + stem, "Stemmer library returned a NULL (out-of-memory)") + result = stem; + } +}; +} // namespace facebook::velox::functions \ No newline at end of file diff --git a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp index 85f7a4c3d62da..4ceaf4edab946 100644 --- a/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/StringFunctionsRegistration.cpp @@ -19,6 +19,7 @@ #include "velox/functions/prestosql/SplitPart.h" #include "velox/functions/prestosql/SplitToMap.h" #include "velox/functions/prestosql/StringFunctions.h" +#include "velox/functions/prestosql/WordStem.h" namespace facebook::velox::functions { diff --git a/velox/functions/prestosql/tests/CMakeLists.txt b/velox/functions/prestosql/tests/CMakeLists.txt index d10283a85c00b..7f22424d641dc 100644 --- a/velox/functions/prestosql/tests/CMakeLists.txt +++ b/velox/functions/prestosql/tests/CMakeLists.txt @@ -97,6 +97,7 @@ add_executable( URLFunctionsTest.cpp Utf8Test.cpp WidthBucketArrayTest.cpp + WordStemTest.cpp ZipTest.cpp ZipWithTest.cpp) diff --git a/velox/functions/prestosql/tests/StringFunctionsTest.cpp b/velox/functions/prestosql/tests/StringFunctionsTest.cpp index 67908a6f68039..4c05965fb56a3 100644 --- a/velox/functions/prestosql/tests/StringFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/StringFunctionsTest.cpp @@ -305,12 +305,6 @@ class StringFunctionsTest : public FunctionBaseTest { const std::string& left, const std::string& right); - std::string wordStemWithLang( - const std::string& word, - const std::string& lang); - - std::string wordStem(const std::string& word); - using replace_input_test_t = std::vector, std::string>>; @@ -1210,58 +1204,6 @@ TEST_F(StringFunctionsTest, invalidLevenshteinDistance) { "The combined inputs size exceeded max Levenshtein distance combined input size"); } -std::string StringFunctionsTest::wordStemWithLang( - const std::string& word, - const std::string& lang) { - return evaluateOnce( - "word_stem(c0, c1)", std::optional(word), std::optional(lang)) - .value(); -} - -std::string StringFunctionsTest::wordStem(const std::string& word) { - return evaluateOnce("word_stem(c0)", std::optional(word)) - .value(); -} - -/// Borrow test cases from Presto Java -/// https://github.com/prestodb/presto/blob/master/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestWordStemFunction.java -TEST_F(StringFunctionsTest, asciiWordStem) { - EXPECT_EQ(wordStem(""), ""); - EXPECT_EQ(wordStem("x"), "x"); - EXPECT_EQ(wordStem("abc"), "abc"); - EXPECT_EQ(wordStem("generally"), "general"); - EXPECT_EQ(wordStem("useful"), "use"); - EXPECT_EQ(wordStem("runs"), "run"); - EXPECT_EQ(wordStem("run"), "run"); - EXPECT_EQ(wordStemWithLang("authorized", "en"), "author"); - EXPECT_EQ(wordStemWithLang("accessories", "en"), "accessori"); - EXPECT_EQ(wordStemWithLang("intensifying", "en"), "intensifi"); - EXPECT_EQ(wordStemWithLang("resentment", "en"), "resent"); - EXPECT_EQ(wordStemWithLang("faithfulness", "en"), "faith"); - EXPECT_EQ(wordStemWithLang("continuerait", "fr"), "continu"); - EXPECT_EQ(wordStemWithLang("torpedearon", "es"), "torped"); - EXPECT_EQ(wordStemWithLang("quilomtricos", "pt"), "quilomtr"); - EXPECT_EQ(wordStemWithLang("pronunziare", "it"), "pronunz"); - EXPECT_EQ(wordStemWithLang("auferstnde", "de"), "auferstnd"); -} - -TEST_F(StringFunctionsTest, invalidWordStemLang) { - VELOX_ASSERT_THROW( - wordStemWithLang("hello", "xx"), "Unknown stemmer language: \"xx\""); -} - -TEST_F(StringFunctionsTest, unicodeWordStem) { - EXPECT_EQ( - wordStemWithLang( - "\u004b\u0069\u0074\u0061\u0062\u0131\u006d\u0131\u007a\u0064\u0131", - "tr"), - "kitap"); - EXPECT_EQ( - wordStemWithLang( - "\u0432\u0435\u0441\u0435\u043d\u043d\u0438\u0439", "ru"), - "\u0432\u0435\u0441\u0435\u043d"); -} - void StringFunctionsTest::testReplaceInPlace( const std::vector>& tests, const std::string& search, diff --git a/velox/functions/prestosql/tests/WordStemTest.cpp b/velox/functions/prestosql/tests/WordStemTest.cpp new file mode 100644 index 0000000000000..348cba579b36f --- /dev/null +++ b/velox/functions/prestosql/tests/WordStemTest.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" + +using namespace facebook::velox::functions::test; + +class WordStemTest : public FunctionBaseTest { + protected: + std::string wordStemWithLang( + const std::string& word, + const std::string& lang) { + return evaluateOnce( + "word_stem(c0, c1)", std::optional(word), std::optional(lang)) + .value(); + } + + std::string wordStem(const std::string& word) { + return evaluateOnce("word_stem(c0)", std::optional(word)) + .value(); + } +}; + +/// Borrow test cases from Presto Java +/// https://github.com/prestodb/presto/blob/master/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestWordStemFunction.java +TEST_F(WordStemTest, asciiWord) { + EXPECT_EQ(wordStem(""), ""); + EXPECT_EQ(wordStem("x"), "x"); + EXPECT_EQ(wordStem("abc"), "abc"); + EXPECT_EQ(wordStem("generally"), "general"); + EXPECT_EQ(wordStem("useful"), "use"); + EXPECT_EQ(wordStem("runs"), "run"); + EXPECT_EQ(wordStem("run"), "run"); + EXPECT_EQ(wordStemWithLang("authorized", "en"), "author"); + EXPECT_EQ(wordStemWithLang("accessories", "en"), "accessori"); + EXPECT_EQ(wordStemWithLang("intensifying", "en"), "intensifi"); + EXPECT_EQ(wordStemWithLang("resentment", "en"), "resent"); + EXPECT_EQ(wordStemWithLang("faithfulness", "en"), "faith"); + EXPECT_EQ(wordStemWithLang("continuerait", "fr"), "continu"); + EXPECT_EQ(wordStemWithLang("torpedearon", "es"), "torped"); + EXPECT_EQ(wordStemWithLang("quilomtricos", "pt"), "quilomtr"); + EXPECT_EQ(wordStemWithLang("pronunziare", "it"), "pronunz"); + EXPECT_EQ(wordStemWithLang("auferstnde", "de"), "auferstnd"); +} + +TEST_F(WordStemTest, invalidLang) { + VELOX_ASSERT_THROW( + wordStemWithLang("hello", "xx"), "Unknown stemmer language: \"xx\""); +} + +TEST_F(WordStemTest, unicodeWord) { + EXPECT_EQ( + wordStemWithLang( + "\u004b\u0069\u0074\u0061\u0062\u0131\u006d\u0131\u007a\u0064\u0131", + "tr"), + "kitap"); + EXPECT_EQ( + wordStemWithLang( + "\u0432\u0435\u0441\u0435\u043d\u043d\u0438\u0439", "ru"), + "\u0432\u0435\u0441\u0435\u043d"); +} \ No newline at end of file