Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spark mask function #10264

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
95b679a
spark mask function support init commit
gaoyangxiaozhu Jun 20, 2024
176ce92
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jun 20, 2024
69a5717
refactor mask and added ut
gaoyangxiaozhu Jun 25, 2024
b60ad76
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jun 27, 2024
805d6e1
address comments
gaoyangxiaozhu Jun 27, 2024
bbdf3f1
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jun 27, 2024
f75d197
address comments
gaoyangxiaozhu Jun 27, 2024
4138d3c
address comment
gaoyangxiaozhu Jun 28, 2024
0ed5f15
address comments
gaoyangxiaozhu Jul 2, 2024
ab9c523
small change
gaoyangxiaozhu Jul 9, 2024
cf27370
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 9, 2024
930a339
address comment
gaoyangxiaozhu Jul 10, 2024
209229f
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jul 10, 2024
deb156a
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 16, 2024
d1ce4df
refactor spark mask function to simple function
gaoyangxiaozhu Jul 17, 2024
72acae8
address comments
gaoyangxiaozhu Jul 18, 2024
fe09441
small change
gaoyangxiaozhu Jul 19, 2024
4e4597e
address comment
gaoyangxiaozhu Jul 21, 2024
ab3d1b7
address comment
gaoyangxiaozhu Jul 23, 2024
5a58710
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 23, 2024
fd0a41c
address comments
gaoyangxiaozhu Jul 24, 2024
aad59e2
format fix
gaoyangxiaozhu Jul 24, 2024
ecec9f0
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 25, 2024
8d20919
address comment
gaoyangxiaozhu Jul 29, 2024
d37cb88
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,30 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

SELECT ltrim('ps', 'spark'); -- "ark"

.. spark:function:: mask(string[, upperChar, lowerChar, digitChar, otherChar]) -> string
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

Returns a masked version of the input ``string``.
``string``: string value to mask.
``upperChar``: A single character string used to substitute upper case characters. The default is 'X'. If NULL, upper case characters remain unmasked.
``lowerChar``: A single character string used to substitute lower case characters. The default is 'x'. If NULL, lower case characters remain unmasked.
``digitChar``: A single character string used to substitute digits. The default is 'n'. If NULL, digits remain unmasked.
``otherChar``: A single character string used to substitute any other character. The default is NULL, which leaves these characters unmasked.
Any invalid UTF-8 characters present in the input string will be treated as a single other character. ::

SELECT mask('abcd-EFGH-8765-4321'); -- "xxxx-XXXX-nnnn-nnnn"
SELECT mask('abcd-EFGH-8765-4321', 'Q'); -- "xxxx-QQQQ-nnnn-nnnn"
SELECT mask('AbCD123-@$#'); -- "XxXXnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q'); -- "QxQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q'); -- "QqQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd'); -- "QqQQddd-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd', 'o'); -- "QqQQdddoooo"
SELECT mask('AbCD123-@$#', NULL, 'q', 'd', 'o'); -- "AqCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, 'd', 'o'); -- "AbCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, 'o'); -- "AbCD123oooo"
SELECT mask(NULL, NULL, NULL, NULL, 'o'); -- NULL
SELECT mask(NULL); -- NULL
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, NULL); -- "AbCD123-@$#"

.. spark:function:: overlay(input, replace, pos, len) -> same as input

Replace a substring of ``input`` starting at ``pos`` character with ``replace`` and
Expand Down Expand Up @@ -334,4 +358,4 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

Returns string with all characters changed to uppercase. ::

SELECT upper('SparkSql'); -- SPARKSQL
SELECT upper('SparkSql'); -- SPARKSQL
228 changes: 228 additions & 0 deletions velox/functions/sparksql/MaskFunction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/functions/prestosql/Utf8Utils.h"

namespace facebook::velox::functions::sparksql {

// mask(string) -> string
// mask(string, upperChar) -> string
// mask(string, upperChar, lowerChar) -> string
// mask(string, upperChar, lowerChar, digitChar) -> string
// mask(string, upperChar, lowerChar, digitChar, otherChar) -> string
//
// Masks the characters of the given string value with the provided specific
// characters respectively. Upper-case characters are replaced with the second
// argument. Default value is 'X'. Lower-case characters are replaced with the
// third argument. Default value is 'x'. Digit characters are replaced with the
// fourth argument. Default value is 'n'. Other characters are replaced with the
// last argument. Default value is NULL and the original character is retained.
// If the provided nth argument is NULL, the related original character is
// retained.
template <typename T>
struct MaskFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input) {
doCall(
result,
std::string_view(input),
kMaskedUpperCase_,
kMaskedLowerCase_,
kMaskedDigit_,
std::nullopt);
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
std::string_view(*inputPtr),
getMaskedChar(upperCharPtr),
kMaskedLowerCase_,
kMaskedDigit_,
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr,
const arg_type<Varchar>* lowerCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
std::string_view(*inputPtr),
getMaskedChar(upperCharPtr),
getMaskedChar(lowerCharPtr),
kMaskedDigit_,
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr,
const arg_type<Varchar>* lowerCharPtr,
const arg_type<Varchar>* digitCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
std::string_view(*inputPtr),
getMaskedChar(upperCharPtr),
getMaskedChar(lowerCharPtr),
getMaskedChar(digitCharPtr),
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr,
const arg_type<Varchar>* lowerCharPtr,
const arg_type<Varchar>* digitCharPtr,
const arg_type<Varchar>* otherCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
std::string_view(*inputPtr),
getMaskedChar(upperCharPtr),
getMaskedChar(lowerCharPtr),
getMaskedChar(digitCharPtr),
getMaskedChar(otherCharPtr));
return true;
}

private:
void doCall(
out_type<Varchar>& result,
std::string_view input,
const std::optional<std::string_view> upperChar,
const std::optional<std::string_view> lowerChar,
const std::optional<std::string_view> digitChar,
const std::optional<std::string_view> otherChar) const {
auto inputBuffer = input.data();
const size_t inputSize = input.size();
result.reserve(inputSize * 4);
auto outputBuffer = result.data();
size_t inputIdx = 0;
size_t outputIdx = 0;
while (inputIdx < inputSize) {
int charByteSize;
auto curCodePoint = utf8proc_codepoint(
&inputBuffer[inputIdx], inputBuffer + inputSize, charByteSize);
if (curCodePoint == -1) {
// That means it is a invalid UTF-8 character for example '\xED',
// treat it as char with size 1.
charByteSize = 1;
}
auto maskedChar = &inputBuffer[inputIdx];
auto maskedCharByteSize = charByteSize;
// Treat invalid UTF-8 character as other char.
utf8proc_propval_t category = utf8proc_category(curCodePoint);
if (isUpperChar(category) && upperChar.has_value()) {
maskedChar = upperChar.value().data();
maskedCharByteSize = upperChar.value().size();
} else if (isLowerChar(category) && lowerChar.has_value()) {
maskedChar = lowerChar.value().data();
maskedCharByteSize = lowerChar.value().size();
} else if (isDigitChar(category) && digitChar.has_value()) {
maskedChar = digitChar.value().data();
maskedCharByteSize = digitChar.value().size();
} else if (
!isUpperChar(category) && !isLowerChar(category) &&
!isDigitChar(category) && otherChar.has_value()) {
maskedChar = otherChar.value().data();
maskedCharByteSize = otherChar.value().size();
}

for (auto i = 0; i < maskedCharByteSize; i++) {
outputBuffer[outputIdx++] = maskedChar[i];
}

inputIdx += charByteSize;
}
result.resize(outputIdx);
}

bool isUpperChar(utf8proc_propval_t& category) const {
return category == UTF8PROC_CATEGORY_LU;
}

bool isLowerChar(utf8proc_propval_t& category) const {
return category == UTF8PROC_CATEGORY_LL;
}

bool isDigitChar(utf8proc_propval_t& category) const {
return category == UTF8PROC_CATEGORY_ND;
}

std::optional<std::string_view> getMaskedChar(
const arg_type<Varchar>* maskChar) {
if (maskChar) {
auto maskCharData = maskChar->data();
auto maskCharSize = maskChar->size();
if (maskCharSize == 1) {
return std::string_view(maskCharData);
}

VELOX_USER_CHECK_NE(
maskCharSize,
0,
"Replacement string must contain a single character and cannot be empty.");

// Calculates the byte length of the first unicode character, and compares
// it with the length of replacing character. Inequality indicates the
// replacing character includes more than one unicode characters.
int size;
auto codePoint = utf8proc_codepoint(
&maskCharData[0], maskCharData + maskCharSize, size);
VELOX_USER_CHECK_EQ(
maskCharSize,
size,
"Replacement string must contain a single character and cannot be empty.");

return std::string_view(maskCharData, maskCharSize);
}
return std::nullopt;
}

static constexpr std::string_view kMaskedUpperCase_{"X"};
static constexpr std::string_view kMaskedLowerCase_{"x"};
static constexpr std::string_view kMaskedDigit_{"n"};
};
} // namespace facebook::velox::functions::sparksql
16 changes: 16 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "velox/functions/sparksql/In.h"
#include "velox/functions/sparksql/JsonObjectKeys.h"
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/MaskFunction.h"
#include "velox/functions/sparksql/MightContain.h"
#include "velox/functions/sparksql/MonotonicallyIncreasingId.h"
#include "velox/functions/sparksql/RaiseError.h"
Expand Down Expand Up @@ -481,6 +482,21 @@ void registerFunctions(const std::string& prefix) {
int32_t>({prefix + "levenshtein"});
registerFunction<LevenshteinDistanceFunction, int32_t, Varchar, Varchar>(
{prefix + "levenshtein"});

registerFunction<MaskFunction, Varchar, Varchar>({prefix + "mask"});
registerFunction<MaskFunction, Varchar, Varchar, Varchar>({prefix + "mask"});
registerFunction<MaskFunction, Varchar, Varchar, Varchar, Varchar>(
{prefix + "mask"});
registerFunction<MaskFunction, Varchar, Varchar, Varchar, Varchar, Varchar>(
{prefix + "mask"});
registerFunction<
MaskFunction,
Varchar,
Varchar,
Varchar,
Varchar,
Varchar,
Varchar>({prefix + "mask"});
}

} // namespace sparksql
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ add_executable(
MakeDecimalTest.cpp
MakeTimestampTest.cpp
MapTest.cpp
MaskTest.cpp
MightContainTest.cpp
MonotonicallyIncreasingIdTest.cpp
RaiseErrorTest.cpp
Expand Down
Loading
Loading