Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Spark mask function #10264

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
95b679a
spark mask function support init commit
gaoyangxiaozhu Jun 20, 2024
176ce92
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jun 20, 2024
69a5717
refactor mask and added ut
gaoyangxiaozhu Jun 25, 2024
b60ad76
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jun 27, 2024
805d6e1
address comments
gaoyangxiaozhu Jun 27, 2024
bbdf3f1
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jun 27, 2024
f75d197
address comments
gaoyangxiaozhu Jun 27, 2024
4138d3c
address comment
gaoyangxiaozhu Jun 28, 2024
0ed5f15
address comments
gaoyangxiaozhu Jul 2, 2024
ab9c523
small change
gaoyangxiaozhu Jul 9, 2024
cf27370
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 9, 2024
930a339
address comment
gaoyangxiaozhu Jul 10, 2024
209229f
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jul 10, 2024
deb156a
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 16, 2024
d1ce4df
refactor spark mask function to simple function
gaoyangxiaozhu Jul 17, 2024
72acae8
address comments
gaoyangxiaozhu Jul 18, 2024
fe09441
small change
gaoyangxiaozhu Jul 19, 2024
4e4597e
address comment
gaoyangxiaozhu Jul 21, 2024
ab3d1b7
address comment
gaoyangxiaozhu Jul 23, 2024
5a58710
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 23, 2024
fd0a41c
address comments
gaoyangxiaozhu Jul 24, 2024
aad59e2
format fix
gaoyangxiaozhu Jul 24, 2024
ecec9f0
Merge branch 'facebookincubator:main' into gayangya/spark_mask
gaoyangxiaozhu Jul 25, 2024
8d20919
address comment
gaoyangxiaozhu Jul 29, 2024
d37cb88
Merge branch 'gayangya/spark_mask' of https://github.com/gayangya/vel…
gaoyangxiaozhu Jul 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,4 +317,28 @@ Unless specified otherwise, all functions return NULL if at least one of the arg

Returns string with all characters changed to uppercase. ::

SELECT upper('SparkSql'); -- SPARKSQL
SELECT upper('SparkSql'); -- SPARKSQL

gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
.. spark:function:: mask(string, upperChar, lowerChar, digitChar, otherChar) -> string

gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
Returns a masked version of the input `string`.
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
This can be useful for creating copies of tables with sensitive information removed.
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
``string`` - string value to mask.
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
``upperChar``: A single character STRING literal used to substitute upper case characters. The default is 'X'. If upperChar is NULL, upper case characters remain unmasked.
``lowerChar``: A single character STRING literal used to substitute lower case characters. The default is 'x'. If lowerChar is NULL, lower case characters remain unmasked.
``digitChar``: A single character STRING literal used to substitute digits. The default is 'n'. If digitChar is NULL, digits remain unmasked.
``otherChar``: A single character STRING literal used to substitute any other character. The default is NULL, which leaves these characters unmasked. ::

SELECT mask('abcd-EFGH-8765-4321'); -- "xxxx-XXXX-nnnn-nnnn"
SELECT mask('abcd-EFGH-8765-4321', 'Q'); -- "xxxx-QQQQ-nnnn-nnnn"
SELECT mask('AbCD123-@$#'); -- "XxXXnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q'); -- "QxQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q'); -- "QqQQnnn-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd'); -- "QqQQddd-@$#"
SELECT mask('AbCD123-@$#', 'Q', 'q', 'd', 'o'); -- "QqQQdddoooo"
SELECT mask('AbCD123-@$#', NULL, 'q', 'd', 'o'); -- "AqCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, 'd', 'o'); -- "AbCDdddoooo"
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, 'o'); -- "AbCD123oooo"
SELECT mask(NULL, NULL, NULL, NULL, 'o'); -- NULL
SELECT mask(NULL); -- NULL
SELECT mask('AbCD123-@$#', NULL, NULL, NULL, NULL); -- "AbCD123-@$#"
3 changes: 2 additions & 1 deletion velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ add_library(
Size.cpp
SplitFunctions.cpp
String.cpp
UnscaledValueFunction.cpp)
UnscaledValueFunction.cpp
Mask.cpp)
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

# GCC 12 has a bug where it does not respect "pragma ignore" directives and ends
# up failing compilation in an openssl header included by a hash-related
Expand Down
250 changes: 250 additions & 0 deletions velox/functions/sparksql/Mask.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <utility>

#include "velox/expression/DecodedArgs.h"
#include "velox/expression/VectorFunction.h"
#include "velox/expression/VectorWriters.h"
#include "velox/functions/lib/Re2Functions.h"

namespace facebook::velox::functions::sparksql {
namespace {
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
class MaskFunction final : public exec::VectorFunction {
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
static constexpr std::string_view maskedUpperCase{"X"};
static constexpr std::string_view maskedLowerCase{"x"};
static constexpr std::string_view maskedDigit{"n"};
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

public:
MaskFunction() {}
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
// Get the decoded vectors out of arguments.
exec::DecodedArgs decodedArgs(rows, args, context);
DecodedVector* strings = decodedArgs.at(0);
DecodedVector* upperChars = args.size() >= 2 ? decodedArgs.at(1) : nullptr;
DecodedVector* lowerChars = args.size() >= 3 ? decodedArgs.at(2) : nullptr;
DecodedVector* digitChars = args.size() >= 4 ? decodedArgs.at(3) : nullptr;
DecodedVector* otherChars = args.size() >= 5 ? decodedArgs.at(4) : nullptr;
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
BaseVector::ensureWritable(rows, VARCHAR(), context.pool(), result);
auto* results = result->as<FlatVector<StringView>>();
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
// Optimization for the (flat, const, const, const, const) case.
if (strings->isIdentityMapping() and
(upperChars == nullptr || upperChars->isConstantMapping()) and
(lowerChars == nullptr || lowerChars->isConstantMapping()) and
(digitChars == nullptr || digitChars->isConstantMapping()) and
(otherChars == nullptr || otherChars->isConstantMapping())) {
// TODO: enable the inpalce if possible
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
const auto* rawStrings = strings->data<StringView>();
const auto upperChar = (upperChars == nullptr)
? std::optional<StringView>{StringView{maskedUpperCase}}
: (args[1]->containsNullAt(0)
? std::nullopt
: std::optional<StringView>{
upperChars->valueAt<StringView>(0)});
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
const auto lowerChar = (lowerChars == nullptr)
? std::optional<StringView>{StringView{maskedLowerCase}}
: (args[2]->containsNullAt(0)
? std::nullopt
: std::optional<StringView>{
lowerChars->valueAt<StringView>(0)});
const auto digitChar = (digitChars == nullptr)
? std::optional<StringView>{StringView{maskedDigit}}
: (args[3]->containsNullAt(0)
? std::nullopt
: std::optional<StringView>{
digitChars->valueAt<StringView>(0)});
const auto otherChar =
(otherChars == nullptr || args[4]->containsNullAt(0))
? std::nullopt
: std::optional(otherChars->valueAt<StringView>(0));
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

rows.applyToSelected([&](vector_size_t row) {
if (args[0]->isNullAt(row)) {
results->setNull(row, true);
return;
}
auto proxy = exec::StringWriter<>(results, row);
applyInner(
rawStrings[row],
upperChar,
lowerChar,
digitChar,
otherChar,
row,
proxy);
proxy.finalize();
});
} else {
// The rest of the cases are handled through this general path and no
// direct access.
rows.applyToSelected([&](vector_size_t row) {
if (args[0]->isNullAt(row)) {
results->setNull(row, true);
return;
}
auto proxy = exec::StringWriter<>(results, row);
const auto upperChar = (upperChars == nullptr)
? std::optional<StringView>{StringView{maskedUpperCase}}
: (args[1]->containsNullAt(row)
? std::nullopt
: std::optional<StringView>{
upperChars->valueAt<StringView>(row)});
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
const auto lowerChar = (lowerChars == nullptr)
? std::optional<StringView>{StringView{maskedLowerCase}}
: (args[2]->containsNullAt(row)
? std::nullopt
: std::optional<StringView>{
lowerChars->valueAt<StringView>(row)});
const auto digitChar = (digitChars == nullptr)
? std::optional<StringView>{StringView{maskedDigit}}
: (args[3]->containsNullAt(row)
? std::nullopt
: std::optional<StringView>{
digitChars->valueAt<StringView>(row)});
const auto otherChar =
(otherChars == nullptr || args[4]->containsNullAt(row))
? std::nullopt
: std::optional(otherChars->valueAt<StringView>(row));
applyInner(
strings->valueAt<StringView>(row),
upperChar,
lowerChar,
digitChar,
otherChar,
row,
proxy);
proxy.finalize();
});
}
}

inline void applyInner(
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
StringView input,
const std::optional<StringView> upperChar,
const std::optional<StringView> lowerChar,
const std::optional<StringView> digitChar,
const std::optional<StringView> otherChar,
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
vector_size_t row,
facebook::velox::exec::StringWriter<false>& result) const {
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
const auto inputSize = input.size();
auto inputBuffer = input.data();
result.reserve(inputSize);
auto outputBuffer = result.data();

auto hasMaskedUpperChar = false;
auto hasMaskedLowerChar = false;
auto hasMaskedDigitChar = false;
auto hasMaskedOtherChar = false;
auto maskedUpperChar = "";
auto maskedLowerChar = "";
auto maskedDigitChar = "";
auto maskedOtherChar = "";
if (upperChar.has_value()) {
VELOX_USER_CHECK(
upperChar.value().size() == 1, "Length of upperChar should be 1");
maskedUpperChar = upperChar.value().data();
hasMaskedUpperChar = true;
}
if (lowerChar.has_value()) {
VELOX_USER_CHECK(
lowerChar.value().size() == 1, "Length of lowerChar should be 1");
maskedLowerChar = lowerChar.value().data();
hasMaskedLowerChar = true;
}
if (digitChar.has_value()) {
VELOX_USER_CHECK(
digitChar.value().size() == 1, "Length of digitChar should be 1");
maskedDigitChar = digitChar.value().data();
hasMaskedDigitChar = true;
}
if (otherChar.has_value()) {
VELOX_USER_CHECK(
otherChar.value().size() == 1, "Length of otherChar should be 1");
maskedOtherChar = otherChar.value().data();
hasMaskedOtherChar = true;
}
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved

for (auto i = 0; i < inputSize; i++) {
unsigned char p = inputBuffer[i];
if (isupper(p)) {
outputBuffer[i] = hasMaskedUpperChar ? maskedUpperChar[0] : p;
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
} else if (islower(p)) {
outputBuffer[i] = hasMaskedLowerChar ? maskedLowerChar[0] : p;
} else if (isdigit(p)) {
outputBuffer[i] = hasMaskedDigitChar ? maskedDigitChar[0] : p;
} else {
outputBuffer[i] = hasMaskedOtherChar ? maskedOtherChar[0] : p;
}
}
result.resize(inputSize);
}
};

std::shared_ptr<exec::VectorFunction> createMask(
const std::string& /*name*/,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
gaoyangxiaozhu marked this conversation as resolved.
Show resolved Hide resolved
return std::make_shared<MaskFunction>();
}

std::vector<std::shared_ptr<exec::FunctionSignature>> signatures() {
std::vector<std::shared_ptr<exec::FunctionSignature>> signatures;
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.build());
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build());
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build());
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build());
signatures.emplace_back(exec::FunctionSignatureBuilder()
.returnType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.argumentType("varchar")
.build());
return signatures;
}
} // namespace

VELOX_DECLARE_STATEFUL_VECTOR_FUNCTION_WITH_METADATA(
mask,
signatures(),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
createMask);
} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ void registerFunctions(const std::string& prefix) {
exec::registerStatefulVectorFunction(
prefix + "like", likeSignatures(), makeLike);
VELOX_REGISTER_VECTOR_FUNCTION(udf_regexp_split, prefix + "split");
VELOX_REGISTER_VECTOR_FUNCTION(mask, prefix + "mask");

exec::registerStatefulVectorFunction(
prefix + "least",
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ add_executable(
MakeDecimalTest.cpp
MakeTimestampTest.cpp
MapTest.cpp
MaskTest.cpp
MightContainTest.cpp
MonotonicallyIncreasingIdTest.cpp
RandTest.cpp
Expand Down
Loading
Loading