Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spark mask function #10264

Closed
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
95b679a
spark mask function support init commit
gaoyangxiaozhu Jun 20, 2024
176ce92
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jun 20, 2024
69a5717
refactor mask and added ut
gaoyangxiaozhu Jun 25, 2024
b60ad76
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jun 27, 2024
805d6e1
address comments
gaoyangxiaozhu Jun 27, 2024
bbdf3f1
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jun 27, 2024
f75d197
address comments
gaoyangxiaozhu Jun 27, 2024
4138d3c
address comment
gaoyangxiaozhu Jun 28, 2024
0ed5f15
address comments
gaoyangxiaozhu Jul 2, 2024
ab9c523
small change
gaoyangxiaozhu Jul 9, 2024
cf27370
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 9, 2024
930a339
address comment
gaoyangxiaozhu Jul 10, 2024
209229f
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jul 10, 2024
deb156a
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 16, 2024
d1ce4df
refactor spark mask function to simple function
gaoyangxiaozhu Jul 17, 2024
72acae8
address comments
gaoyangxiaozhu Jul 18, 2024
fe09441
small change
gaoyangxiaozhu Jul 19, 2024
4e4597e
address comment
gaoyangxiaozhu Jul 21, 2024
ab3d1b7
address comment
gaoyangxiaozhu Jul 23, 2024
5a58710
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 23, 2024
fd0a41c
address comments
gaoyangxiaozhu Jul 24, 2024
aad59e2
format fix
gaoyangxiaozhu Jul 24, 2024
ecec9f0
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 25, 2024
8d20919
address comment
gaoyangxiaozhu Jul 29, 2024
d37cb88
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,29 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

SELECT ltrim('ps', 'spark'); -- "ark"

.. spark:function:: mask(string[, upperChar, lowerChar, digitChar, otherChar]) -> string
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

Returns a masked version of the input ``string``.
``string``: string value to mask.
``upperChar``: A single character STRING used to substitute upper case characters. The default is 'X'. If NULL, upper case characters remain unmasked.
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
``lowerChar``: A single character STRING used to substitute lower case characters. The default is 'x'. If NULL, lower case characters remain unmasked.
``digitChar``: A single character STRING used to substitute digits. The default is 'n'. If NULL, digits remain unmasked.
``otherChar``: A single character STRING used to substitute any other character. The default is NULL, which leaves these characters unmasked. ::

SELECT mask('abcd-EFGH-8765-4321'); -- "xxxx-XXXX-nnnn-nnnn"
SELECT mask('abcd-EFGH-8765-4321', 'Q'); -- "xxxx-QQQQ-nnnn-nnnn"
SELECT mask('AbCD123-@$#'); -- "XxXXnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q'); -- "QxQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q'); -- "QqQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd'); -- "QqQQddd-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd', 'o'); -- "QqQQdddoooo"
SELECT mask('AbCD123-@$#', NULL, 'q', 'd', 'o'); -- "AqCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, 'd', 'o'); -- "AbCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, 'o'); -- "AbCD123oooo"
SELECT mask(NULL, NULL, NULL, NULL, 'o'); -- NULL
SELECT mask(NULL); -- NULL
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, NULL); -- "AbCD123-@$#"

.. spark:function:: overlay(input, replace, pos, len) -> same as input

Replace a substring of ``input`` starting at ``pos`` character with ``replace`` and
Expand Down Expand Up @@ -334,4 +357,4 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

Returns string with all characters changed to uppercase. ::

SELECT upper('SparkSql'); -- SPARKSQL
SELECT upper('SparkSql'); -- SPARKSQL
15 changes: 15 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,21 @@ void registerFunctions(const std::string& prefix) {
int32_t>({prefix + "levenshtein"});
registerFunction<LevenshteinDistanceFunction, int32_t, Varchar, Varchar>(
{prefix + "levenshtein"});

registerFunction<MaskFunction, Varchar, Varchar>({prefix + "mask"});
registerFunction<MaskFunction, Varchar, Varchar, Varchar>({prefix + "mask"});
registerFunction<MaskFunction, Varchar, Varchar, Varchar, Varchar>(
{prefix + "mask"});
registerFunction<MaskFunction, Varchar, Varchar, Varchar, Varchar, Varchar>(
{prefix + "mask"});
registerFunction<
MaskFunction,
Varchar,
Varchar,
Varchar,
Varchar,
Varchar,
Varchar>({prefix + "mask"});
}

} // namespace sparksql
Expand Down
209 changes: 209 additions & 0 deletions velox/functions/sparksql/String.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "velox/functions/UDFOutputString.h"
#include "velox/functions/lib/string/StringCore.h"
#include "velox/functions/lib/string/StringImpl.h"
#include "velox/functions/prestosql/Utf8Utils.h"

namespace facebook::velox::functions::sparksql {

Expand Down Expand Up @@ -1431,4 +1432,212 @@ struct LevenshteinDistanceFunction {
}
};

// mask(string) -> string
// mask(string, upperChar) -> string
// mask(string, upperChar, lowerChar) -> string
// mask(string, upperChar, lowerChar, digitChar) -> string
// mask(string, upperChar, lowerChar, digitChar, otherChar) -> string
//
// Masks the characters of the given string value with the provided specific
// characters respectively. Upper-case characters are replaced with the second
// argument. Default value is 'X'. Lower-case characters are replaced with the
// third argument. Default value is 'x'. Digit characters are replaced with the
// fourth argument. Default value is 'n'. Other characters are replaced with the
// last argument. Default value is NULL and the original character is retained.
// If the provided nth argument is NULL, the related original character is
// retained.
template <typename T>
struct MaskFunction {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to provide fast path for ASCII inputs?

Copy link
Contributor Author

@gaoyangxiaozhu gaoyangxiaozhu Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emm.. looks leverage callASCII have benefit but not much since we still need handle replacement char args non ASCII cases.
Create a issue to do in seperate PR if needed. #10546

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbasmanova It seems only callAscii is provided, while a function like callAsciiNullable is needed here. Do you think we need to add that? Thanks.

VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE bool callNullable(
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
*inputPtr,
StringView{kMaskedUpperCase_},
StringView{kMaskedLowerCase_},
StringView{kMaskedDigit_},
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
*inputPtr,
getMaskedChar(upperCharPtr),
StringView{kMaskedLowerCase_},
StringView{kMaskedDigit_},
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr,
const arg_type<Varchar>* lowerCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
*inputPtr,
getMaskedChar(upperCharPtr),
getMaskedChar(lowerCharPtr),
StringView{kMaskedDigit_},
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr,
const arg_type<Varchar>* lowerCharPtr,
const arg_type<Varchar>* digitCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
*inputPtr,
getMaskedChar(upperCharPtr),
getMaskedChar(lowerCharPtr),
getMaskedChar(digitCharPtr),
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr,
const arg_type<Varchar>* lowerCharPtr,
const arg_type<Varchar>* digitCharPtr,
const arg_type<Varchar>* otherCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
*inputPtr,
getMaskedChar(upperCharPtr),
getMaskedChar(lowerCharPtr),
getMaskedChar(digitCharPtr),
getMaskedChar(otherCharPtr));
return true;
}

private:
void doCall(
out_type<Varchar>& result,
StringView input,
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
const std::optional<StringView> upperChar,
const std::optional<StringView> lowerChar,
const std::optional<StringView> digitChar,
const std::optional<StringView> otherChar) const {
auto inputBuffer = input.data();
const size_t inputSize = input.size();
result.reserve(inputSize);
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
auto outputBuffer = result.data();
size_t inputIdx = 0;
size_t outputIdx = 0;
while (inputIdx < inputSize) {
utf8proc_int32_t curCodePoint;
int charByteSize;
rui-mo marked this conversation as resolved.
Show resolved Hide resolved
curCodePoint = utf8proc_codepoint(
&inputBuffer[inputIdx], inputBuffer + inputSize, charByteSize);
if (curCodePoint == -1) {
// That means it is a invalid UTF-8 character for example '\xED',
// treat it as char with size 1.
charByteSize = 1;
}
auto maskedChar = &inputBuffer[inputIdx];
auto maskedCharByteSize = charByteSize;
// Treat invalid UTF-8 character as other char.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the same behavior with Spark? Perhaps document this in string.rst.

Copy link
Contributor Author

@gaoyangxiaozhu gaoyangxiaozhu Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not exactlly same, spark mask function actually has issue when handling invalid UTF-8 character or wide character (4 byte). For example for wide character "😀" which is 4 bytes \uD83D\uDE00, spark would wrongly treat it 2 characters, \uD83D and \uDE00. For invalided UTF-8 character case, for example \xED, spark would treat it as 4 characters - "", "x", "E", "D". That's due to the limitation of spark use java toString.map to iterater the char of the input string which only work well for character has at most 2 bytes (16 bits). Checking spark code - https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala#L322C1-L349C4

For do right thing, in our implement, for those wide character cases or invalid UTF-8 character cases we treat it correctlly as single character, instead of aligning with spark.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Spark contain tests for invalid UTF-8 character and wide character (4 byte) like you mentioned?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for providing above information. As you mentioned, the difference with Spark is on the handling of invalid UTF-8 character and wide character.

I tried in Spark and found its output was inconsistent indeed, e.g., for below case this PR gives yyyyyいいYyYYdddいいい while Spark gives yyyyyいいYyYYdddいいいい where one more is added.

Is it more like undefined behaviors in Spark because there is no test coverage for those cases? And the result from this PR looks more right from my view.

spark.sql("select mask('синяяい绿AbCD123世界🙂', 'Y', 'y', 'd', 'い')").show(false)
+-----------------------------------------+
|mask(синяяい绿AbCD123世界🙂, Y, y, d, い)|
+-----------------------------------------+
|yyyyyいいYyYYdddいいいい |
+-----------------------------------------+

@mbasmanova Do you have any suggestion? Thanks.

Copy link
Contributor Author

@gaoyangxiaozhu gaoyangxiaozhu Jul 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rui-mo , thanks and yes spark actually can't handle well when do the implement for invali UTF-8 character. And as i already mentioned in above comments : spark also can't handle wide-width character correctly.
spark mask function if you check the implement to levererage toString.map(charObject) iteration to replace the char, however, java toString default using UTF-16 if my unserstanding right. And the java char class object use 16 bits array to represents a character that means the spark mask function can only handle weel for BMP char (that is 16 bits), it can't handle well for both invalid UTF-8 character and also wide-character.

for 世界🙂, where 世界 is chinese character which each one has 2 bytes, so it would be replaced correctly with いい, but 🙂 is a wide character which represents with 4 bytes, so spark wrongly treat it as 2 character each one have 2 bytes, causing 🙂 replacing with いい. That's why it finally causes 4 .

Another is for replacing using a wide-character, for example, below case would cause wrong garbled code problem issue in spark

select mask("ABC", "🙂");

image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know your concern @rui-mo

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to open an issue in Spark to ask whether this is the case and if so to add documentation. Then, we can add the same clause to Velox's documentation for string Spark functions.

https://prestodb.io/docs/current/functions/string.html

@gaoyangxiaozhu I understand your point. Perhaps we can follow this suggestion to ask about it in Spark. How do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @rui-mo , I believe from the view of code implement part of spark mask function, they don't concern for invalid UTF-8 and wide character.

I just raise a issue/ question to stackoverflow https://stackoverflow.com/staging-ground/78781459 and also Spark JIRA issue - https://issues.apache.org/jira/browse/SPARK-48973

let's waitting the spark community's reply.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

utf8proc_propval_t category = utf8proc_category(curCodePoint);
if (isUpperChar(category) && upperChar.has_value()) {
maskedChar = upperChar.value().data();
maskedCharByteSize = upperChar.value().size();
} else if (isLowerChar(category) && lowerChar.has_value()) {
maskedChar = lowerChar.value().data();
maskedCharByteSize = lowerChar.value().size();
} else if (isDigitChar(category) && digitChar.has_value()) {
maskedChar = digitChar.value().data();
maskedCharByteSize = digitChar.value().size();
} else if (
!isUpperChar(category) && !isLowerChar(category) &&
!isDigitChar(category) && otherChar.has_value()) {
maskedChar = otherChar.value().data();
maskedCharByteSize = otherChar.value().size();
}

for (auto i = 0; i < maskedCharByteSize; i++) {
outputBuffer[outputIdx++] = maskedChar[i];
}

inputIdx += charByteSize;
}
result.resize(outputIdx);
}

bool isUpperChar(utf8proc_propval_t& category) const {
return category == UTF8PROC_CATEGORY_LU;
}

bool isLowerChar(utf8proc_propval_t& category) const {
return category == UTF8PROC_CATEGORY_LL;
}

bool isDigitChar(utf8proc_propval_t& category) const {
return category == UTF8PROC_CATEGORY_ND;
}

std::optional<StringView> getMaskedChar(const arg_type<Varchar>* maskChar) {
if (maskChar) {
auto maskCharData = maskChar->data();
auto maskCharSize = maskChar->size();
if (maskCharSize == 1) {
return StringView{maskCharData};
}

VELOX_USER_CHECK_NE(
maskCharSize, 0, "Length of replacing char should be 1");
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

// Calculates the byte length of the first unicode character, and compares
// it with the length of replacing character. Inequality indicates the
// replacing character includes more than one unicode characters.
int size;
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
auto codePoint = utf8proc_codepoint(
&maskCharData[0], maskCharData + maskCharSize, size);
VELOX_USER_CHECK_EQ(
maskCharSize, size, "Length of replacing char should be 1");
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

return StringView(maskCharData, maskCharSize);
}
return std::nullopt;
}

static constexpr std::string_view kMaskedUpperCase_{"X"};
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
static constexpr std::string_view kMaskedLowerCase_{"x"};
static constexpr std::string_view kMaskedDigit_{"n"};
};

} // namespace facebook::velox::functions::sparksql
Loading
Loading