Skip to content

Commit

Permalink
Add Spark mask function (#10264)
Browse files Browse the repository at this point in the history
Summary:
A function returns a masked version of the input string.

Spark documentation: https://spark.apache.org/docs/latest/api/sql/#mask
Spark implementation: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala#L103
Spark tests: https://github.com/apache/spark/blob/0db5bdecfa6cbfff1be7690bb783a858989776b9/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala#L5677

Fixes #10263

Pull Request resolved: #10264

Reviewed By: kagamiori

Differential Revision: D60386594

Pulled By: Yuhta

fbshipit-source-id: 6af6c2c89ae281897effa9ec184ed5c51e14e286
  • Loading branch information
gaoyangxiaozhu authored and facebook-github-bot committed Jul 29, 2024
1 parent a89e4f1 commit 73ca922
Show file tree
Hide file tree
Showing 5 changed files with 568 additions and 1 deletion.
26 changes: 25 additions & 1 deletion velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,30 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

SELECT ltrim('ps', 'spark'); -- "ark"

.. spark:function:: mask(string[, upperChar, lowerChar, digitChar, otherChar]) -> string
Returns a masked version of the input ``string``.
``string``: string value to mask.
``upperChar``: A single character string used to substitute upper case characters. The default is 'X'. If NULL, upper case characters remain unmasked.
``lowerChar``: A single character string used to substitute lower case characters. The default is 'x'. If NULL, lower case characters remain unmasked.
``digitChar``: A single character string used to substitute digits. The default is 'n'. If NULL, digits remain unmasked.
``otherChar``: A single character string used to substitute any other character. The default is NULL, which leaves these characters unmasked.
Any invalid UTF-8 characters present in the input string will be treated as a single other character. ::

SELECT mask('abcd-EFGH-8765-4321'); -- "xxxx-XXXX-nnnn-nnnn"
SELECT mask('abcd-EFGH-8765-4321', 'Q'); -- "xxxx-QQQQ-nnnn-nnnn"
SELECT mask('AbCD123-@$#'); -- "XxXXnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q'); -- "QxQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q'); -- "QqQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd'); -- "QqQQddd-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd', 'o'); -- "QqQQdddoooo"
SELECT mask('AbCD123-@$#', NULL, 'q', 'd', 'o'); -- "AqCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, 'd', 'o'); -- "AbCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, 'o'); -- "AbCD123oooo"
SELECT mask(NULL, NULL, NULL, NULL, 'o'); -- NULL
SELECT mask(NULL); -- NULL
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, NULL); -- "AbCD123-@$#"

.. spark:function:: overlay(input, replace, pos, len) -> same as input
Replace a substring of ``input`` starting at ``pos`` character with ``replace`` and
Expand Down Expand Up @@ -334,4 +358,4 @@ Unless specified otherwise, all functions return NULL if at least one of the arg
Returns string with all characters changed to uppercase. ::

SELECT upper('SparkSql'); -- SPARKSQL
SELECT upper('SparkSql'); -- SPARKSQL
228 changes: 228 additions & 0 deletions velox/functions/sparksql/MaskFunction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/functions/prestosql/Utf8Utils.h"

namespace facebook::velox::functions::sparksql {

// mask(string) -> string
// mask(string, upperChar) -> string
// mask(string, upperChar, lowerChar) -> string
// mask(string, upperChar, lowerChar, digitChar) -> string
// mask(string, upperChar, lowerChar, digitChar, otherChar) -> string
//
// Masks the characters of the given string value with the provided specific
// characters respectively. Upper-case characters are replaced with the second
// argument. Default value is 'X'. Lower-case characters are replaced with the
// third argument. Default value is 'x'. Digit characters are replaced with the
// fourth argument. Default value is 'n'. Other characters are replaced with the
// last argument. Default value is NULL and the original character is retained.
// If the provided nth argument is NULL, the related original character is
// retained.
template <typename T>
struct MaskFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& input) {
doCall(
result,
std::string_view(input),
kMaskedUpperCase_,
kMaskedLowerCase_,
kMaskedDigit_,
std::nullopt);
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
std::string_view(*inputPtr),
getMaskedChar(upperCharPtr),
kMaskedLowerCase_,
kMaskedDigit_,
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr,
const arg_type<Varchar>* lowerCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
std::string_view(*inputPtr),
getMaskedChar(upperCharPtr),
getMaskedChar(lowerCharPtr),
kMaskedDigit_,
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr,
const arg_type<Varchar>* lowerCharPtr,
const arg_type<Varchar>* digitCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
std::string_view(*inputPtr),
getMaskedChar(upperCharPtr),
getMaskedChar(lowerCharPtr),
getMaskedChar(digitCharPtr),
std::nullopt);
return true;
}

FOLLY_ALWAYS_INLINE bool callNullable(
out_type<Varchar>& result,
const arg_type<Varchar>* inputPtr,
const arg_type<Varchar>* upperCharPtr,
const arg_type<Varchar>* lowerCharPtr,
const arg_type<Varchar>* digitCharPtr,
const arg_type<Varchar>* otherCharPtr) {
if (inputPtr == nullptr) {
return false;
}

doCall(
result,
std::string_view(*inputPtr),
getMaskedChar(upperCharPtr),
getMaskedChar(lowerCharPtr),
getMaskedChar(digitCharPtr),
getMaskedChar(otherCharPtr));
return true;
}

private:
void doCall(
out_type<Varchar>& result,
std::string_view input,
const std::optional<std::string_view> upperChar,
const std::optional<std::string_view> lowerChar,
const std::optional<std::string_view> digitChar,
const std::optional<std::string_view> otherChar) const {
auto inputBuffer = input.data();
const size_t inputSize = input.size();
result.reserve(inputSize * 4);
auto outputBuffer = result.data();
size_t inputIdx = 0;
size_t outputIdx = 0;
while (inputIdx < inputSize) {
int charByteSize;
auto curCodePoint = utf8proc_codepoint(
&inputBuffer[inputIdx], inputBuffer + inputSize, charByteSize);
if (curCodePoint == -1) {
// That means it is a invalid UTF-8 character for example '\xED',
// treat it as char with size 1.
charByteSize = 1;
}
auto maskedChar = &inputBuffer[inputIdx];
auto maskedCharByteSize = charByteSize;
// Treat invalid UTF-8 character as other char.
utf8proc_propval_t category = utf8proc_category(curCodePoint);
if (isUpperChar(category) && upperChar.has_value()) {
maskedChar = upperChar.value().data();
maskedCharByteSize = upperChar.value().size();
} else if (isLowerChar(category) && lowerChar.has_value()) {
maskedChar = lowerChar.value().data();
maskedCharByteSize = lowerChar.value().size();
} else if (isDigitChar(category) && digitChar.has_value()) {
maskedChar = digitChar.value().data();
maskedCharByteSize = digitChar.value().size();
} else if (
!isUpperChar(category) && !isLowerChar(category) &&
!isDigitChar(category) && otherChar.has_value()) {
maskedChar = otherChar.value().data();
maskedCharByteSize = otherChar.value().size();
}

for (auto i = 0; i < maskedCharByteSize; i++) {
outputBuffer[outputIdx++] = maskedChar[i];
}

inputIdx += charByteSize;
}
result.resize(outputIdx);
}

bool isUpperChar(utf8proc_propval_t& category) const {
return category == UTF8PROC_CATEGORY_LU;
}

bool isLowerChar(utf8proc_propval_t& category) const {
return category == UTF8PROC_CATEGORY_LL;
}

bool isDigitChar(utf8proc_propval_t& category) const {
return category == UTF8PROC_CATEGORY_ND;
}

std::optional<std::string_view> getMaskedChar(
const arg_type<Varchar>* maskChar) {
if (maskChar) {
auto maskCharData = maskChar->data();
auto maskCharSize = maskChar->size();
if (maskCharSize == 1) {
return std::string_view(maskCharData);
}

VELOX_USER_CHECK_NE(
maskCharSize,
0,
"Replacement string must contain a single character and cannot be empty.");

// Calculates the byte length of the first unicode character, and compares
// it with the length of replacing character. Inequality indicates the
// replacing character includes more than one unicode characters.
int size;
auto codePoint = utf8proc_codepoint(
&maskCharData[0], maskCharData + maskCharSize, size);
VELOX_USER_CHECK_EQ(
maskCharSize,
size,
"Replacement string must contain a single character and cannot be empty.");

return std::string_view(maskCharData, maskCharSize);
}
return std::nullopt;
}

static constexpr std::string_view kMaskedUpperCase_{"X"};
static constexpr std::string_view kMaskedLowerCase_{"x"};
static constexpr std::string_view kMaskedDigit_{"n"};
};
} // namespace facebook::velox::functions::sparksql
16 changes: 16 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "velox/functions/sparksql/In.h"
#include "velox/functions/sparksql/JsonObjectKeys.h"
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/MaskFunction.h"
#include "velox/functions/sparksql/MightContain.h"
#include "velox/functions/sparksql/MonotonicallyIncreasingId.h"
#include "velox/functions/sparksql/RaiseError.h"
Expand Down Expand Up @@ -481,6 +482,21 @@ void registerFunctions(const std::string& prefix) {
int32_t>({prefix + "levenshtein"});
registerFunction<LevenshteinDistanceFunction, int32_t, Varchar, Varchar>(
{prefix + "levenshtein"});

registerFunction<MaskFunction, Varchar, Varchar>({prefix + "mask"});
registerFunction<MaskFunction, Varchar, Varchar, Varchar>({prefix + "mask"});
registerFunction<MaskFunction, Varchar, Varchar, Varchar, Varchar>(
{prefix + "mask"});
registerFunction<MaskFunction, Varchar, Varchar, Varchar, Varchar, Varchar>(
{prefix + "mask"});
registerFunction<
MaskFunction,
Varchar,
Varchar,
Varchar,
Varchar,
Varchar,
Varchar>({prefix + "mask"});
}

} // namespace sparksql
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ add_executable(
MakeDecimalTest.cpp
MakeTimestampTest.cpp
MapTest.cpp
MaskTest.cpp
MightContainTest.cpp
MonotonicallyIncreasingIdTest.cpp
RaiseErrorTest.cpp
Expand Down
Loading

0 comments on commit 73ca922

Please sign in to comment.