Skip to content

Commit

Permalink
Update word_stem impl to address comments
Browse files Browse the repository at this point in the history
- doc update
- separate the impl to a new header file
- separate the test to a new cpp file
- apply code convensions

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang committed Apr 10, 2024
1 parent e927a8d commit 4546f1e
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 169 deletions.
33 changes: 31 additions & 2 deletions velox/docs/functions/presto/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,40 @@ String Functions

.. function:: word_stem(word) -> varchar

Returns the stem of ``word`` in the English language.
Returns the stem of ``word`` in the English language. If the ``word`` is not an English word,
the ``word`` in lowercase is returned.

.. function:: word_stem(word, lang) -> varchar

Returns the stem of ``word`` in the ``lang`` language.
Returns the stem of ``word`` in the ``lang`` language. This function supports the following languages:

=========== ================
lang Language
=========== ================
``ca`` ``Catalan``
``da`` ``Danish``
``de`` ``German``
``en`` ``English``
``es`` ``Spanish``
``eu`` ``Basque``
``fi`` ``Finnish``
``fr`` ``French``
``hu`` ``Hungarian``
``hy`` ``Armenian``
``ir`` ``Irish``
``it`` ``Italian``
``lt`` ``Lithuanian``
``nl`` ``Dutch``
``no`` ``Norwegian``
``pt`` ``Portuguese``
``ro`` ``Romanian``
``ru`` ``Russian``
``sv`` ``Swedish``
``tr`` ``Turkish``
=========== ================

If the specified ``lang`` is not supported, this function throws a user error.


Unicode Functions
-----------------
Expand Down
109 changes: 0 additions & 109 deletions velox/functions/prestosql/StringFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
#pragma once

#include <libstemmer.h>

#include "velox/functions/Udf.h"
#include "velox/functions/lib/string/StringCore.h"
#include "velox/functions/lib/string/StringImpl.h"
Expand Down Expand Up @@ -496,111 +494,4 @@ struct LevenshteinDistanceFunction {
}
};

/// Wrap the sbstemmer library and use its sb_stemmer_stem
/// to get word stem
class Stemmer {
private:
sb_stemmer* sbstemmer;
Stemmer(sb_stemmer* stemmer) : sbstemmer(stemmer) {}

public:
~Stemmer() {
sb_stemmer_delete(sbstemmer);
}

/// Get a Stemmer from the the map stored in thread local storage
/// or create a new one if it doesn't exist. Return NULL if the
/// specified lang is not supported.
static Stemmer* getStemmer(const char* lang) {
thread_local std::map<std::string, Stemmer*> stemmers;
if (auto found = stemmers.find(lang); found != stemmers.end()) {
return found->second;
}
// Only support ASCII and UTF-8
if (auto sbstemmer = sb_stemmer_new(lang, "UTF_8")) {
auto rev = new Stemmer(sbstemmer);
stemmers[lang] = rev;
return rev;
}
return NULL;
}

/// Get the word stem or NULL if out of memory
const char* stem(const std::string& input) {
return (const char*)(sb_stemmer_stem(
sbstemmer,
reinterpret_cast<unsigned char const*>(input.c_str()),
input.length()));
}
};

/// word_stem function
/// word_stem(word) -> varchar
/// return the stem of the word in the English language
/// word_stem(word, lang) -> varchar
/// return the stem of the word in the specificed language
///
/// It uses the snowball stemmer library to calculate the stem.
/// https://snowballstem.org
/// It provides Java implementation which is used in Presto as well
/// as C implementation. Therefore, both Presto and Prestimissio
/// would have the same word stem results.
template <typename T>
struct WordStemFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Results refer to strings in the first argument.
static constexpr int32_t reuse_strings_from_arg = 0;

// ASCII input always produces ASCII result.
static constexpr bool is_default_ascii_behavior = true;

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input) {
return doCall<false>(result, input);
}

FOLLY_ALWAYS_INLINE void callAscii(
out_type<Varchar>& result,
const arg_type<Varchar>& input) {
return doCall<true>(result, input);
}

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const arg_type<Varchar>& lang) {
return doCall<false>(result, input, lang.data());
}

FOLLY_ALWAYS_INLINE void callAscii(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const arg_type<Varchar>& lang) {
return doCall<true>(result, input, lang.data());
}

template <bool isAscii>
FOLLY_ALWAYS_INLINE void doCall(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const char* lang = "en") {
auto stemmer = Stemmer::getStemmer(lang);
if (!stemmer) {
// language is not supported
throw std::invalid_argument(
"Unknown stemmer language: \"" + std::string(lang) + "\"");
}

std::string lowerOutput;
stringImpl::lower<isAscii>(lowerOutput, input);
auto rev = stemmer->stem(lowerOutput);
if (rev == NULL) {
throw std::runtime_error("out of memory");
}
result = rev;
}
};

} // namespace facebook::velox::functions
131 changes: 131 additions & 0 deletions velox/functions/prestosql/WordStem.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <libstemmer.h>

#include "velox/functions/Udf.h"
#include "velox/functions/lib/string/StringImpl.h"

namespace facebook::velox::functions {

namespace {
/// Wrap the sbstemmer library and use its sb_stemmer_stem
/// to get word stem
class Stemmer {
private:
sb_stemmer* sbStemmer_;
Stemmer(sb_stemmer* stemmer) : sbStemmer_(stemmer) {}

public:
~Stemmer() {
sb_stemmer_delete(sbStemmer_);
}

/// Get a Stemmer from the the map stored in thread local storage
/// or create a new one if it doesn't exist. Return nullptr if the
/// specified lang is not supported.
static Stemmer* getStemmer(const char* lang) {
thread_local std::map<std::string, std::unique_ptr<Stemmer>> stemmers;
if (auto found = stemmers.find(lang); found != stemmers.end()) {
return found->second.get();
}
Stemmer* stemmer = nullptr;
// Only support ASCII and UTF-8
if (auto sbStemmer = sb_stemmer_new(lang, "UTF_8")) {
stemmer = new Stemmer(sbStemmer);
stemmers[lang] = std::unique_ptr<Stemmer>(stemmer);
}
return stemmer;
}

/// Get the word stem or NULL if out of memory
const char* stem(const std::string& input) {
return (const char*)(sb_stemmer_stem(
sbStemmer_,
reinterpret_cast<unsigned char const*>(input.c_str()),
input.length()));
}
};
} // namespace

/// word_stem function
/// word_stem(word) -> varchar
/// return the stem of the word in the English language
/// word_stem(word, lang) -> varchar
/// return the stem of the word in the specificed language
///
/// It uses the snowball stemmer library to calculate the stem.
/// https://snowballstem.org
/// It provides Java implementation which is used in Presto as well
/// as C implementation. Therefore, both Presto and Prestimissio
/// would have the same word stem results.
template <typename TExec>
struct WordStemFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

// Results refer to strings in the first argument.
static constexpr int32_t reuse_strings_from_arg = 0;

// ASCII input always produces ASCII result.
static constexpr bool is_default_ascii_behavior = true;

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input) {
return doCall<false>(result, input);
}

FOLLY_ALWAYS_INLINE void callAscii(
out_type<Varchar>& result,
const arg_type<Varchar>& input) {
return doCall<true>(result, input);
}

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const arg_type<Varchar>& lang) {
return doCall<false>(result, input, lang.data());
}

FOLLY_ALWAYS_INLINE void callAscii(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const arg_type<Varchar>& lang) {
return doCall<true>(result, input, lang.data());
}

template <bool isAscii>
FOLLY_ALWAYS_INLINE void doCall(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const char* lang = "en") {
auto stemmer = Stemmer::getStemmer(lang);
if (!stemmer) {
// language is not supported
VELOX_USER_FAIL("Unknown stemmer language: \"{}\"", lang);
}

std::string lowerOutput;
stringImpl::lower<isAscii>(lowerOutput, input);
auto stem = stemmer->stem(lowerOutput);
VELOX_CHECK_NOT_NULL(
stem, "Stemmer library returned a NULL (out-of-memory)")
result = stem;
}
};
} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "velox/functions/prestosql/SplitPart.h"
#include "velox/functions/prestosql/SplitToMap.h"
#include "velox/functions/prestosql/StringFunctions.h"
#include "velox/functions/prestosql/WordStem.h"

namespace facebook::velox::functions {

Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ add_executable(
URLFunctionsTest.cpp
Utf8Test.cpp
WidthBucketArrayTest.cpp
WordStemTest.cpp
ZipTest.cpp
ZipWithTest.cpp)

Expand Down
58 changes: 0 additions & 58 deletions velox/functions/prestosql/tests/StringFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,6 @@ class StringFunctionsTest : public FunctionBaseTest {
const std::string& left,
const std::string& right);

std::string wordStemWithLang(
const std::string& word,
const std::string& lang);

std::string wordStem(const std::string& word);

using replace_input_test_t = std::vector<std::pair<
std::tuple<std::string, std::string, std::string>,
std::string>>;
Expand Down Expand Up @@ -1210,58 +1204,6 @@ TEST_F(StringFunctionsTest, invalidLevenshteinDistance) {
"The combined inputs size exceeded max Levenshtein distance combined input size");
}

std::string StringFunctionsTest::wordStemWithLang(
const std::string& word,
const std::string& lang) {
return evaluateOnce<std::string>(
"word_stem(c0, c1)", std::optional(word), std::optional(lang))
.value();
}

std::string StringFunctionsTest::wordStem(const std::string& word) {
return evaluateOnce<std::string>("word_stem(c0)", std::optional(word))
.value();
}

/// Borrow test cases from Presto Java
/// https://github.com/prestodb/presto/blob/master/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestWordStemFunction.java
TEST_F(StringFunctionsTest, asciiWordStem) {
EXPECT_EQ(wordStem(""), "");
EXPECT_EQ(wordStem("x"), "x");
EXPECT_EQ(wordStem("abc"), "abc");
EXPECT_EQ(wordStem("generally"), "general");
EXPECT_EQ(wordStem("useful"), "use");
EXPECT_EQ(wordStem("runs"), "run");
EXPECT_EQ(wordStem("run"), "run");
EXPECT_EQ(wordStemWithLang("authorized", "en"), "author");
EXPECT_EQ(wordStemWithLang("accessories", "en"), "accessori");
EXPECT_EQ(wordStemWithLang("intensifying", "en"), "intensifi");
EXPECT_EQ(wordStemWithLang("resentment", "en"), "resent");
EXPECT_EQ(wordStemWithLang("faithfulness", "en"), "faith");
EXPECT_EQ(wordStemWithLang("continuerait", "fr"), "continu");
EXPECT_EQ(wordStemWithLang("torpedearon", "es"), "torped");
EXPECT_EQ(wordStemWithLang("quilomtricos", "pt"), "quilomtr");
EXPECT_EQ(wordStemWithLang("pronunziare", "it"), "pronunz");
EXPECT_EQ(wordStemWithLang("auferstnde", "de"), "auferstnd");
}

TEST_F(StringFunctionsTest, invalidWordStemLang) {
VELOX_ASSERT_THROW(
wordStemWithLang("hello", "xx"), "Unknown stemmer language: \"xx\"");
}

TEST_F(StringFunctionsTest, unicodeWordStem) {
EXPECT_EQ(
wordStemWithLang(
"\u004b\u0069\u0074\u0061\u0062\u0131\u006d\u0131\u007a\u0064\u0131",
"tr"),
"kitap");
EXPECT_EQ(
wordStemWithLang(
"\u0432\u0435\u0441\u0435\u043d\u043d\u0438\u0439", "ru"),
"\u0432\u0435\u0441\u0435\u043d");
}

void StringFunctionsTest::testReplaceInPlace(
const std::vector<std::pair<std::string, std::string>>& tests,
const std::string& search,
Expand Down
Loading

0 comments on commit 4546f1e

Please sign in to comment.