Skip to content

Commit

Permalink
Add word_stem() implementation
Browse files Browse the repository at this point in the history
Add snowball libstemmer as one of the dependencies.
And use it to implement the word_stem() as a scalar
UDF..

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang committed Apr 4, 2024
1 parent efb7e77 commit bedcb88
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 1 deletion.
49 changes: 49 additions & 0 deletions CMake/resolve_dependency_modules/stemmer.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include_guard(GLOBAL)

set(VELOX_STEMMER_VERSION 2.2.0)
set(VELOX_STEMMER_BUILD_SHA256_CHECKSUM
b941d9fe9cf36b4e2f8d3873cd4d8b8775bd94867a1df8d8c001bb8b688377c3)
set(VELOX_STEMMER_SOURCE_URL
"https://snowballstem.org/dist/libstemmer_c-${VELOX_STEMMER_VERSION}.tar.gz"
)

resolve_dependency_url(STEMMER)

message(STATUS "Building stemmer from source")
find_program(MAKE_PROGRAM make REQUIRED)

set(STEMMER_PREFIX "${CMAKE_BINARY_DIR}/_deps/libstemmer")

# We can not use FetchContent as libstemmer does not use cmake
ExternalProject_Add(
libstemmer
PREFIX ${STEMMER_PREFIX}
SOURCE_DIR ${STEMMER_PREFIX}/src/libstemmer
URL ${VELOX_STEMMER_SOURCE_URL}
URL_HASH ${VELOX_STEMMER_BUILD_SHA256_CHECKSUM}
BUILD_IN_SOURCE TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ${MAKE_PROGRAM}
INSTALL_COMMAND ""
BUILD_BYPRODUCTS ${STEMMER_PREFIX}/src/libstemmer/${CMAKE_STATIC_LIBRARY_PREFIX}stemmer${CMAKE_STATIC_LIBRARY_SUFFIX})

ExternalProject_Get_Property(libstemmer BINARY_DIR)

add_library(stemmer STATIC IMPORTED)
set_target_properties(stemmer PROPERTIES IMPORTED_LOCATION ${BINARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}stemmer${CMAKE_STATIC_LIBRARY_SUFFIX})

include_directories(${BINARY_DIR}/include)
add_dependencies(stemmer libstemmer)
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,10 @@ endif()
set_source(xsimd)
resolve_dependency(xsimd 10.0.0)

set(stemmer_SOURCE BUNDLED)
set_source(stemmer)
resolve_dependency(stemmer)

if(VELOX_BUILD_TESTING)
set(BUILD_TESTING ON)
include(CTest) # include after project() but before add_subdirectory()
Expand Down
3 changes: 2 additions & 1 deletion velox/functions/prestosql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ target_link_libraries(
velox_type_tz
velox_presto_types
velox_functions_util
Folly::folly)
Folly::folly
stemmer)

set_property(TARGET velox_functions_prestosql_impl PROPERTY JOB_POOL_COMPILE
high_memory_pool)
Expand Down
109 changes: 109 additions & 0 deletions velox/functions/prestosql/StringFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
#pragma once

#include <libstemmer.h>

#include "velox/functions/Udf.h"
#include "velox/functions/lib/string/StringCore.h"
#include "velox/functions/lib/string/StringImpl.h"
Expand Down Expand Up @@ -494,4 +496,111 @@ struct LevenshteinDistanceFunction {
}
};

/// Wrap the sbstemmer library and use its sb_stemmer_stem
/// to get word stem
class Stemmer {
private:
sb_stemmer* sbstemmer;
Stemmer(sb_stemmer* stemmer) : sbstemmer(stemmer) {}

public:
~Stemmer() {
sb_stemmer_delete(sbstemmer);
}

/// Get a Stemmer from the the map stored in thread local storage
/// or create a new one if it doesn't exist. Return NULL if the
/// specified lang is not supported.
static Stemmer* getStemmer(const char* lang) {
thread_local std::map<std::string, Stemmer*> stemmers;
if (auto found = stemmers.find(lang); found != stemmers.end()) {
return found->second;
}
// Only support ASCII and UTF-8
if (auto sbstemmer = sb_stemmer_new(lang, "UTF_8")) {
auto rev = new Stemmer(sbstemmer);
stemmers[lang] = rev;
return rev;
}
return NULL;
}

/// Get the word stem or NULL if out of memory
const char* stem(const std::string& input) {
return (const char*)(sb_stemmer_stem(
sbstemmer,
reinterpret_cast<unsigned char const*>(input.c_str()),
input.length()));
}
};

/// word_stem function
/// word_stem(word) -> varchar
/// return the stem of the word in the English language
/// word_stem(word, lang) -> varchar
/// return the stem of the word in the specificed language
///
/// It uses the snowball stemmer library to calculate the stem.
/// https://snowballstem.org
/// It provides Java implementation which is used in Presto as well
/// as C implementation. Therefore, both Presto and Prestimissio
/// would have the same word stem results.
template <typename T>
struct WordStemFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Results refer to strings in the first argument.
static constexpr int32_t reuse_strings_from_arg = 0;

// ASCII input always produces ASCII result.
static constexpr bool is_default_ascii_behavior = true;

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input) {
return doCall<false>(result, input);
}

FOLLY_ALWAYS_INLINE void callAscii(
out_type<Varchar>& result,
const arg_type<Varchar>& input) {
return doCall<true>(result, input);
}

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const arg_type<Varchar>& lang) {
return doCall<false>(result, input, lang.data());
}

FOLLY_ALWAYS_INLINE void callAscii(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const arg_type<Varchar>& lang) {
return doCall<true>(result, input, lang.data());
}

template <bool isAscii>
FOLLY_ALWAYS_INLINE void doCall(
out_type<Varchar>& result,
const arg_type<Varchar>& input,
const char* lang = "en") {
auto stemmer = Stemmer::getStemmer(lang);
if (!stemmer) {
// language is not supported
throw std::invalid_argument(
"Unknown stemmer language: \"" + std::string(lang) + "\"");
}

std::string lowerOutput;
stringImpl::lower<isAscii>(lowerOutput, input);
auto rev = stemmer->stem(lowerOutput);
if (rev == NULL) {
throw std::runtime_error("out of memory");
}
result = rev;
}
};

} // namespace facebook::velox::functions
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,10 @@ void registerStringFunctions(const std::string& prefix) {
{prefix + "strrpos"});
registerFunction<StrRPosFunction, int64_t, Varchar, Varchar, int64_t>(
{prefix + "strrpos"});

// word_stem function
registerFunction<WordStemFunction, Varchar, Varchar>({prefix + "word_stem"});
registerFunction<WordStemFunction, Varchar, Varchar, Varchar>(
{prefix + "word_stem"});
}
} // namespace facebook::velox::functions
58 changes: 58 additions & 0 deletions velox/functions/prestosql/tests/StringFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ class StringFunctionsTest : public FunctionBaseTest {
const std::string& left,
const std::string& right);

std::string wordStemWithLang(
const std::string& word,
const std::string& lang);

std::string wordStem(const std::string& word);

using replace_input_test_t = std::vector<std::pair<
std::tuple<std::string, std::string, std::string>,
std::string>>;
Expand Down Expand Up @@ -1204,6 +1210,58 @@ TEST_F(StringFunctionsTest, invalidLevenshteinDistance) {
"The combined inputs size exceeded max Levenshtein distance combined input size");
}

std::string StringFunctionsTest::wordStemWithLang(
const std::string& word,
const std::string& lang) {
return evaluateOnce<std::string>(
"word_stem(c0, c1)", std::optional(word), std::optional(lang))
.value();
}

std::string StringFunctionsTest::wordStem(const std::string& word) {
return evaluateOnce<std::string>("word_stem(c0)", std::optional(word))
.value();
}

/// Borrow test cases from Presto Java
/// https://github.com/prestodb/presto/blob/master/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestWordStemFunction.java
TEST_F(StringFunctionsTest, asciiWordStem) {
EXPECT_EQ(wordStem(""), "");
EXPECT_EQ(wordStem("x"), "x");
EXPECT_EQ(wordStem("abc"), "abc");
EXPECT_EQ(wordStem("generally"), "general");
EXPECT_EQ(wordStem("useful"), "use");
EXPECT_EQ(wordStem("runs"), "run");
EXPECT_EQ(wordStem("run"), "run");
EXPECT_EQ(wordStemWithLang("authorized", "en"), "author");
EXPECT_EQ(wordStemWithLang("accessories", "en"), "accessori");
EXPECT_EQ(wordStemWithLang("intensifying", "en"), "intensifi");
EXPECT_EQ(wordStemWithLang("resentment", "en"), "resent");
EXPECT_EQ(wordStemWithLang("faithfulness", "en"), "faith");
EXPECT_EQ(wordStemWithLang("continuerait", "fr"), "continu");
EXPECT_EQ(wordStemWithLang("torpedearon", "es"), "torped");
EXPECT_EQ(wordStemWithLang("quilomtricos", "pt"), "quilomtr");
EXPECT_EQ(wordStemWithLang("pronunziare", "it"), "pronunz");
EXPECT_EQ(wordStemWithLang("auferstnde", "de"), "auferstnd");
}

TEST_F(StringFunctionsTest, invalidWordStemLang) {
VELOX_ASSERT_THROW(
wordStemWithLang("hello", "xx"), "Unknown stemmer language: \"xx\"");
}

TEST_F(StringFunctionsTest, unicodeWordStem) {
EXPECT_EQ(
wordStemWithLang(
"\u004b\u0069\u0074\u0061\u0062\u0131\u006d\u0131\u007a\u0064\u0131",
"tr"),
"kitap");
EXPECT_EQ(
wordStemWithLang(
"\u0432\u0435\u0441\u0435\u043d\u043d\u0438\u0439", "ru"),
"\u0432\u0435\u0441\u0435\u043d");
}

void StringFunctionsTest::testReplaceInPlace(
const std::vector<std::pair<std::string, std::string>>& tests,
const std::string& search,
Expand Down

0 comments on commit bedcb88

Please sign in to comment.