Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spark mask function #10264

Closed
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
95b679a
spark mask function support init commit
gaoyangxiaozhu Jun 20, 2024
176ce92
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jun 20, 2024
69a5717
refactor mask and added ut
gaoyangxiaozhu Jun 25, 2024
b60ad76
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jun 27, 2024
805d6e1
address comments
gaoyangxiaozhu Jun 27, 2024
bbdf3f1
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jun 27, 2024
f75d197
address comments
gaoyangxiaozhu Jun 27, 2024
4138d3c
address comment
gaoyangxiaozhu Jun 28, 2024
0ed5f15
address comments
gaoyangxiaozhu Jul 2, 2024
ab9c523
small change
gaoyangxiaozhu Jul 9, 2024
cf27370
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 9, 2024
930a339
address comment
gaoyangxiaozhu Jul 10, 2024
209229f
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jul 10, 2024
deb156a
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 16, 2024
d1ce4df
refactor spark mask function to simple function
gaoyangxiaozhu Jul 17, 2024
72acae8
address comments
gaoyangxiaozhu Jul 18, 2024
fe09441
small change
gaoyangxiaozhu Jul 19, 2024
4e4597e
address comment
gaoyangxiaozhu Jul 21, 2024
ab3d1b7
address comment
gaoyangxiaozhu Jul 23, 2024
5a58710
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 23, 2024
fd0a41c
address comments
gaoyangxiaozhu Jul 24, 2024
aad59e2
format fix
gaoyangxiaozhu Jul 24, 2024
ecec9f0
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 25, 2024
8d20919
address comment
gaoyangxiaozhu Jul 29, 2024
d37cb88
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,29 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

SELECT ltrim('ps', 'spark'); -- "ark"

.. spark:function:: mask(string[, upperChar, lowerChar, digitChar, otherChar]) -> string
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

Returns a masked version of the input ``string``.
``string``: string value to mask.
``upperChar``: A single character STRING used to substitute upper case characters. The default is 'X'. If NULL, upper case characters remain unmasked.
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
``lowerChar``: A single character STRING used to substitute lower case characters. The default is 'x'. If NULL, lower case characters remain unmasked.
``digitChar``: A single character STRING used to substitute digits. The default is 'n'. If NULL, digits remain unmasked.
``otherChar``: A single character STRING used to substitute any other character. The default is NULL, which leaves these characters unmasked. ::

SELECT mask('abcd-EFGH-8765-4321'); -- "xxxx-XXXX-nnnn-nnnn"
SELECT mask('abcd-EFGH-8765-4321', 'Q'); -- "xxxx-QQQQ-nnnn-nnnn"
SELECT mask('AbCD123-@$#'); -- "XxXXnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q'); -- "QxQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q'); -- "QqQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd'); -- "QqQQddd-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd', 'o'); -- "QqQQdddoooo"
SELECT mask('AbCD123-@$#', NULL, 'q', 'd', 'o'); -- "AqCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, 'd', 'o'); -- "AbCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, 'o'); -- "AbCD123oooo"
SELECT mask(NULL, NULL, NULL, NULL, 'o'); -- NULL
SELECT mask(NULL); -- NULL
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, NULL); -- "AbCD123-@$#"

.. spark:function:: overlay(input, replace, pos, len) -> same as input

Replace a substring of ``input`` starting at ``pos`` character with ``replace`` and
Expand Down Expand Up @@ -326,4 +349,4 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

Returns string with all characters changed to uppercase. ::

SELECT upper('SparkSql'); -- SPARKSQL
SELECT upper('SparkSql'); -- SPARKSQL
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_library(
LeastGreatest.cpp
MakeTimestamp.cpp
Map.cpp
Mask.cpp
RegexFunctions.cpp
Register.cpp
RegisterArithmetic.cpp
Expand Down
226 changes: 226 additions & 0 deletions velox/functions/sparksql/Mask.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <utility>

#include "velox/expression/DecodedArgs.h"
#include "velox/expression/VectorFunction.h"
#include "velox/expression/VectorWriters.h"
#include "velox/functions/lib/Re2Functions.h"

namespace facebook::velox::functions::sparksql {
namespace {
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

// mask(string) -> string
// mask(string, upperChar) -> string
// mask(string, upperChar, lowerChar) -> string
// mask(string, upperChar, lowerChar, digitChar) -> string
// mask(string, upperChar, lowerChar, digitChar, otherChar) -> string
//
// Masks the characters of the given string value with the provided specific
// characters respectively. Upper-case characters are replaced with the second
// argument. Default value is 'X'. Lower-case characters are replaced with the
// third argument. Default value is 'x'. Digit characters are replaced with the
// fourth argument. Default value is 'n'. Other characters are replaced with the
// last argument. Default value is NULL and the original character is retained.
// If the provided nth argument is NULL, the related original character is
// retained.
class MaskFunction final : public exec::VectorFunction {
public:
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
// Get the decoded vectors out of arguments.
exec::DecodedArgs decodedArgs(rows, args, context);
DecodedVector* strings = decodedArgs.at(0);
DecodedVector* upperChars = args.size() >= 2 ? decodedArgs.at(1) : nullptr;
DecodedVector* lowerChars = args.size() >= 3 ? decodedArgs.at(2) : nullptr;
DecodedVector* digitChars = args.size() >= 4 ? decodedArgs.at(3) : nullptr;
DecodedVector* otherChars = args.size() >= 5 ? decodedArgs.at(4) : nullptr;
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

BaseVector::ensureWritable(rows, VARCHAR(), context.pool(), result);
auto* flatResult = result->as<FlatVector<StringView>>();
auto getMaskedChar = [&](DecodedVector* inputChars,
const VectorPtr& arg,
int index,
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
std::optional<char> maskedChar,
const char* charType) -> std::optional<char> {
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
if (inputChars == nullptr) {
return maskedChar;
}
if (arg->isNullAt(index)) {
return std::nullopt;
}
std::optional<StringView> inputCharStr =
inputChars->valueAt<StringView>(index);
if (inputCharStr.has_value()) {
VELOX_USER_CHECK(
inputCharStr.value().size() == 1,
std::string("Length of ") + charType + " should be 1");
}
return inputCharStr->data()[0];
};

// Fast path for the (flat, const, const, const, const) case.
if (strings->isIdentityMapping() and
(upperChars == nullptr || upperChars->isConstantMapping()) and
(lowerChars == nullptr || lowerChars->isConstantMapping()) and
(digitChars == nullptr || digitChars->isConstantMapping()) and
(otherChars == nullptr || otherChars->isConstantMapping())) {
const auto* rawStrings = strings->data<StringView>();
const auto upperChar =
getMaskedChar(upperChars, args[1], 0, maskedUpperCase_, "upperChar");
const auto lowerChar =
getMaskedChar(lowerChars, args[2], 0, maskedLowerCase_, "lowerChar");
const auto digitChar =
getMaskedChar(digitChars, args[3], 0, maskedDigit_, "digitChar");
const auto otherChar =
getMaskedChar(otherChars, args[4], 0, std::nullopt, "otherChar");
rows.applyToSelected([&](vector_size_t row) {
if (args[0]->isNullAt(row)) {
flatResult->setNull(row, true);
return;
}
auto proxy = exec::StringWriter<>(flatResult, row);
applyInner(
rawStrings[row],
upperChar,
lowerChar,
digitChar,
otherChar,
row,
proxy);
proxy.finalize();
});
} else {
rows.applyToSelected([&](vector_size_t row) {
const auto upperChar = getMaskedChar(
upperChars, args[1], row, maskedUpperCase_, "upperChar");
const auto lowerChar = getMaskedChar(
lowerChars, args[2], row, maskedLowerCase_, "lowerChar");
const auto digitChar =
getMaskedChar(digitChars, args[3], row, maskedDigit_, "digitChar");
const auto otherChar =
getMaskedChar(otherChars, args[4], row, std::nullopt, "otherChar");
if (args[0]->isNullAt(row)) {
flatResult->setNull(row, true);
return;
}
auto proxy = exec::StringWriter<>(flatResult, row);
applyInner(
strings->valueAt<StringView>(row),
upperChar,
lowerChar,
digitChar,
otherChar,
row,
proxy);
proxy.finalize();
});
}
}

void applyInner(
StringView input,
const std::optional<char> upperChar,
const std::optional<char> lowerChar,
const std::optional<char> digitChar,
const std::optional<char> otherChar,
vector_size_t row,
exec::StringWriter<false>& result) const {
const auto inputSize = input.size();
auto inputBuffer = input.data();
result.reserve(inputSize);
auto outputBuffer = result.data();

for (auto i = 0; i < inputSize; i++) {
unsigned char input = inputBuffer[i];
unsigned char masked = input;
if (isupper(input) && upperChar.has_value()) {
masked = upperChar.value();
} else if (islower(input) && lowerChar.has_value()) {
masked = lowerChar.value();
} else if (isdigit(input) && digitChar.has_value()) {
masked = digitChar.value();
} else if (
!isupper(input) && !islower(input) && !isdigit(input) &&
otherChar.has_value()) {
masked = otherChar.value();
}
outputBuffer[i] = masked;
}
result.resize(inputSize);
}

private:
static constexpr char maskedUpperCase_ = 'X';
static constexpr char maskedLowerCase_ = 'x';
static constexpr char maskedDigit_ = 'n';
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
};

std::shared_ptr<exec::VectorFunction> createMask(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
VELOX_USER_CHECK_GE(inputArgs.size(), 1);
return std::make_shared<MaskFunction>();
}

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures;
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.build());
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build());
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build());
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build());
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build());
return signatures;
}
} // namespace

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION_WITH_METADATA(
mask,
signatures(),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
createMask);
} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ void registerFunctions(const std::string& prefix) {
exec::registerStatefulVectorFunction(
prefix + "like", likeSignatures(), makeLike);
VELOX_REGISTER_VECTOR_FUNCTION(udf_regexp_split, prefix + "split");
VELOX_REGISTER_VECTOR_FUNCTION(mask, prefix + "mask");

exec::registerStatefulVectorFunction(
prefix + "least",
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ add_executable(
MakeDecimalTest.cpp
MakeTimestampTest.cpp
MapTest.cpp
MaskTest.cpp
MightContainTest.cpp
MonotonicallyIncreasingIdTest.cpp
RaiseErrorTest.cpp
Expand Down
Loading
Loading