Skip to content

Commit

Permalink
velox split function support pattern of string type
Browse files Browse the repository at this point in the history
  • Loading branch information
unigof authored and jackylee-ch committed Apr 12, 2024
1 parent b1252f2 commit 01ecc1e
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 259 deletions.
11 changes: 7 additions & 4 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,18 +184,21 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

SELECT rtrim('kr', 'spark'); -- "spa"

.. spark:function:: split(string, delimiter) -> array(string)
.. spark:function:: split(string, regex) -> array(string)
Splits ``string`` on ``delimiter`` and returns an array. ::
Returns an array by splitting ``string`` as many times as possible.
The delimiter is any string matching regex, supported by re2.
This is equivalent to split(string, regex, -1), -1 is used for limit. ::

SELECT split('oneAtwoBthreeC', '[ABC]'); -- ["one","two","three",""]
SELECT split('one', ''); -- ["o", "n", "e", ""]
SELECT split('one', '1'); -- ["one"]

.. spark:function:: split(string, delimiter, limit) -> array(string)
.. spark:function:: split(string, regex, limit) -> array(string)
:noindex:

Splits ``string`` on ``delimiter`` and returns an array of size at most ``limit``. ::
Splits ``string`` on ``regex`` and returns an array of size at most ``limit``.
If limit is negative, ``string`` will be split as many times as possible. ::

SELECT split('oneAtwoBthreeC', '[ABC]', -1); -- ["one","two","three",""]
SELECT split('oneAtwoBthreeC', '[ABC]', 0); -- ["one", "two", "three", ""]
Expand Down
125 changes: 125 additions & 0 deletions velox/functions/lib/Re2Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,83 @@ class Re2ExtractAll final : public exec::VectorFunction {
mutable ReCache cache_;
};

void re2SplitAll(
exec::VectorWriter<Array<Varchar>>& resultWriter,
const RE2& re,
const exec::LocalDecodedVector& inputStrs,
const int row,
std::vector<re2::StringPiece>& groups) {
resultWriter.setOffset(row);

auto& arrayWriter = resultWriter.current();

const StringView str = inputStrs->valueAt<StringView>(row);
const re2::StringPiece input = toStringPiece(str);
size_t pos = 0;

while (
re.Match(input, pos, input.size(), RE2::UNANCHORED, groups.data(), 1)) {
const re2::StringPiece fullMatch = groups[0];
const re2::StringPiece subMatch =
input.substr(pos, fullMatch.data() - input.data() - pos);

arrayWriter.add_item().setNoCopy(
StringView(subMatch.data(), subMatch.size()));
pos = fullMatch.data() + fullMatch.size() - input.data();
if (UNLIKELY(fullMatch.size() == 0)) {
++pos;
}
}

if (pos < input.size()) {
const re2::StringPiece remaining = input.substr(pos);
arrayWriter.add_item().setNoCopy(
StringView(remaining.data(), remaining.size()));
} else if (pos == input.size()) {
arrayWriter.add_item().setNoCopy(StringView(nullptr, 0));
}

resultWriter.commit();
}

class Re2SplitAllConstantPattern final : public exec::VectorFunction {
public:
Re2SplitAllConstantPattern(StringView pattern)
: re_(toStringPiece(pattern), RE2::Quiet) {
checkForBadPattern(re_);
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& resultRef) const final {
BaseVector::ensureWritable(
rows, ARRAY(VARCHAR()), context.pool(), resultRef);
exec::VectorWriter<Array<Varchar>> resultWriter;
resultWriter.init(*resultRef->as<ArrayVector>());

exec::LocalDecodedVector inputStrs(context, *args[0], rows);
FOLLY_DECLARE_REUSED(groups, std::vector<re2::StringPiece>);
groups.resize(1);

context.applyToSelectedNoThrow(rows, [&](vector_size_t row) {
re2SplitAll(resultWriter, re_, inputStrs, row, groups);
});

resultWriter.finish();

resultRef->as<ArrayVector>()
->elements()
->asFlatVector<StringView>()
->acquireSharedStringBuffers(inputStrs->base());
}

private:
RE2 re_;
};

template <bool (*Fn)(StringView, const RE2&)>
std::shared_ptr<exec::VectorFunction> makeRe2MatchImpl(
const std::string& name,
Expand Down Expand Up @@ -1935,4 +2012,52 @@ re2ExtractAllSignatures() {
};
}

std::shared_ptr<exec::VectorFunction> makeRe2SplitAll(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
auto numArgs = inputArgs.size();
VELOX_USER_CHECK_EQ(
numArgs, 2, "{} requires 2 arguments, but got {}", name, numArgs);

VELOX_USER_CHECK(
inputArgs[0].type->isVarchar(),
"{} requires first argument of type VARCHAR, but got {}",
name,
inputArgs[0].type->toString());

VELOX_USER_CHECK(
inputArgs[1].type->isVarchar(),
"{} requires second argument of type VARCHAR, but got {}",
name,
inputArgs[1].type->toString());

BaseVector* constantPattern = inputArgs[1].constantValue.get();
VELOX_USER_CHECK(
constantPattern != nullptr && !constantPattern->isNullAt(0),
"{} requires second argument of constant, but got {}",
name,
inputArgs[1].type->toString());

auto pattern = constantPattern->as<ConstantVector<StringView>>()->valueAt(0);

try {
return std::make_shared<Re2SplitAllConstantPattern>(pattern);
} catch (...) {
return std::make_shared<exec::AlwaysFailingVectorFunction>(
std::current_exception());
}
}

std::vector<std::shared_ptr<exec::FunctionSignature>> re2SplitAllSignatures() {
// varchar, varchar -> array<varchar>
return {
exec::FunctionSignatureBuilder()
.returnType("array(varchar)")
.argumentType("varchar")
.constantArgumentType("varchar")
.build(),
};
}

} // namespace facebook::velox::functions
14 changes: 14 additions & 0 deletions velox/functions/lib/Re2Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,20 @@ std::shared_ptr<exec::VectorFunction> makeRe2ExtractAll(

std::vector<std::shared_ptr<exec::FunctionSignature>> re2ExtractAllSignatures();

/// re2SplitAll(string, pattern) → array<string>
///
/// Returns an array by splitting string as many times as possible.
/// The pattern is any string matching regex.
///
/// If the pattern is invalid or not constant, throws an exception.
/// If the pattern does not match, returns original string as array.
std::shared_ptr<exec::VectorFunction> makeRe2SplitAll(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& config);

std::vector<std::shared_ptr<exec::FunctionSignature>> re2SplitAllSignatures();

/// regexp_replace(string, pattern, replacement) -> string
/// regexp_replace(string, pattern) -> string
///
Expand Down
126 changes: 126 additions & 0 deletions velox/functions/lib/tests/Re2FunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class Re2FunctionsTest : public test::FunctionBaseTest {
exec::registerStatefulVectorFunction(
"re2_extract_all", re2ExtractAllSignatures(), makeRe2ExtractAll);
exec::registerStatefulVectorFunction("like", likeSignatures(), makeLike);
exec::registerStatefulVectorFunction(
"re2_split_all", re2SplitAllSignatures(), makeRe2SplitAll);
}

protected:
Expand Down Expand Up @@ -85,6 +87,11 @@ class Re2FunctionsTest : public test::FunctionBaseTest {
return output;
}

void testRe2SplitAll(
const std::vector<std::optional<std::string>>& inputs,
const std::string& pattern,
const std::vector<std::optional<std::vector<std::string>>>& output);

void testLike(
const std::string& input,
const std::string& pattern,
Expand Down Expand Up @@ -1471,5 +1478,124 @@ TEST_F(Re2FunctionsTest, limit) {
ASSERT_NO_THROW(evaluate("regexp_like(c0, c2)", data));
}

void Re2FunctionsTest::testRe2SplitAll(
const std::vector<std::optional<std::string>>& inputs,
const std::string& pattern,
const std::vector<std::optional<std::vector<std::string>>>& output) {
auto result = [&] {
auto input = makeFlatVector<StringView>(
inputs.size(),
[&inputs](vector_size_t row) {
return inputs[row] ? StringView(*inputs[row]) : StringView();
},
[&inputs](vector_size_t row) { return !inputs[row].has_value(); });

// Constant pattern.
std::string constantPattern = std::string(", '") + pattern + "'";
std::string expression =
std::string("re2_split_all(c0") + constantPattern + ")";
return evaluate<ArrayVector>(expression, makeRowVector({input}));
}();

// Creating vectors for output string vectors.
auto sizeAtOutput = [&output](vector_size_t row) {
return output[row] ? output[row]->size() : 0;
};
auto valueAtOutput = [&output](vector_size_t row, vector_size_t idx) {
return output[row] ? StringView(output[row]->at(idx)) : StringView("");
};
auto nullAtOutput = [&output](vector_size_t row) {
return !output[row].has_value();
};
auto expectedResult = makeArrayVector<StringView>(
output.size(), sizeAtOutput, valueAtOutput, nullAtOutput);

// Checking the results.
assertEqualVectors(expectedResult, result);
}

TEST_F(Re2FunctionsTest, regexSpiltAllSingleCharPattern) {
// _
testRe2SplitAll({"abc_ta"}, {"_"}, {{{"abc", "ta"}}});
testRe2SplitAll({"_abc_ta_"}, {"_"}, {{{"", "abc", "ta", ""}}});
testRe2SplitAll({"abc_ta "}, {"_"}, {{{"abc", "ta "}}});
testRe2SplitAll({" abc_ta "}, {"_"}, {{{" abc", "ta "}}});

// .
testRe2SplitAll({"abc"}, {"."}, {{{"", "", "", ""}}});
testRe2SplitAll({"abc "}, {"."}, {{{"", "", "", "", ""}}});
testRe2SplitAll({" abc "}, {"."}, {{{"", "", "", "", "", ""}}});

// \\.
testRe2SplitAll({"abc"}, {"\\."}, {{{"abc"}}});
testRe2SplitAll({"abc "}, {"\\."}, {{{"abc "}}});
testRe2SplitAll({" abc "}, {"\\."}, {{{" abc "}}});

// \\|
testRe2SplitAll({"abt|sc"}, {"\\|"}, {{{"abt", "sc"}}});
testRe2SplitAll({"|abc| "}, {"\\|"}, {{{"", "abc", " "}}});
testRe2SplitAll({" |ab|c | "}, {"\\|"}, {{{" ", "ab", "c ", " "}}});
}

TEST_F(Re2FunctionsTest, regexSpiltAllSequenceCharPattern) {
testRe2SplitAll({"dafefaatb"}, {"fe"}, {{{"da", "faatb"}}});
testRe2SplitAll({"abc_ta"}, {"abc_ta_t"}, {{{"abc_ta"}}});
testRe2SplitAll({"abc dt dat"}, {" dt"}, {{{"abc", " dat"}}});

testRe2SplitAll({"absdfghabiefjab"}, {"ab"}, {{{"", "sdfgh", "iefj", ""}}});
testRe2SplitAll(
{" absdfgha biefjab "}, {"ab"}, {{{" ", "sdfgha biefj", " "}}});
}

TEST_F(Re2FunctionsTest, regexSpiltAllRegexSequencePattern) {
const std::vector<std::optional<std::string>> inputs = {
" 123a 2b 14m ", "123a 2b 14m", "123a2b14m"};
const std::string constantPattern = "(\\d+)([a-z]+)";
const std::vector<std::optional<std::vector<std::string>>> expectedOutputs = {
{{" ", " ", " ", " "}},
{{"", " ", " ", ""}},
{{"", "", "", ""}}};

testRe2SplitAll(inputs, constantPattern, expectedOutputs);

testRe2SplitAll({"aa2bb3cc4"}, {"[1-9]+"}, {{{"aa", "bb", "cc", ""}}});
testRe2SplitAll({""}, {"[0-9]+"}, {{{""}}});
testRe2SplitAll({"abcde"}, {"[0-9]+"}, {{{"abcde"}}});
testRe2SplitAll({"abcde"}, {"\\d+"}, {{{"abcde"}}});
testRe2SplitAll({"23544"}, {"\\w+"}, {{{"", ""}}});
testRe2SplitAll({"(╯°□°)╯︵ ┻━┻"}, {"[0-9]+"}, {{{"(╯°□°)╯︵ ┻━┻"}}});
}

TEST_F(Re2FunctionsTest, regexSplitAllNonAscii) {
testRe2SplitAll(
{"\u82f9\u679c\u9999\u8549\u0076\u0065\u006c\u006f\u0078\u6a58\u5b50"},
{"\u9999\u8549"},
{{{"\u82f9\u679c", "\u0076\u0065\u006c\u006f\u0078\u6a58\u5b50"}}});

testRe2SplitAll(
{"\u82f9\u679c\u9999\u8549\u0076\u0065\u006c\u006f\u0078\u6a58\u5b50"},
{"\u0076\u0065\u006c\u006f\u0078"},
{{{"\u82f9\u679c\u9999\u8549", "\u6a58\u5b50"}}});

testRe2SplitAll(
{"\u6d4b\u8bd5\u0076\u0065\u006c\u006f\u0078"},
{"velox"},
{{{"\u6d4b\u8bd5", ""}}});

testRe2SplitAll(
{"\u6d4b\u8bd5\u0076\u0065\u006c\u006f\u0078\u0020"},
{"velox"},
{{{"\u6d4b\u8bd5", " "}}});

testRe2SplitAll(
{"\u0076\u0065\u006c\u006f\u0078\u6d4b\u8bd5"},
{"\u6d4b\u8bd5"},
{{{"velox", ""}}});

testRe2SplitAll({"苹果香蕉velox橘子 "}, {"velox"}, {{{"苹果香蕉", "橘子 "}}});

testRe2SplitAll({"苹果香蕉velox橘子 "}, {"橘子"}, {{{"苹果香蕉velox", " "}}});
}

} // namespace
} // namespace facebook::velox::functions
1 change: 0 additions & 1 deletion velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ add_library(
RegisterArithmetic.cpp
RegisterCompare.cpp
Size.cpp
SplitFunctions.cpp
String.cpp
UnscaledValueFunction.cpp)

Expand Down
3 changes: 2 additions & 1 deletion velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ void registerFunctions(const std::string& prefix) {
prefix + "regexp_extract", re2ExtractSignatures(), makeRegexExtract);
exec::registerStatefulVectorFunction(
prefix + "rlike", re2SearchSignatures(), makeRLike);
VELOX_REGISTER_VECTOR_FUNCTION(udf_regexp_split, prefix + "split");
exec::registerStatefulVectorFunction(
prefix + "split", re2SplitAllSignatures(), makeRe2SplitAll);

exec::registerStatefulVectorFunction(
prefix + "least",
Expand Down
Loading

0 comments on commit 01ecc1e

Please sign in to comment.