diff --git a/velox/functions/lib/Utf8Utils.cpp b/velox/functions/lib/Utf8Utils.cpp index 2b027b6f2caf..dded518c6b61 100644 --- a/velox/functions/lib/Utf8Utils.cpp +++ b/velox/functions/lib/Utf8Utils.cpp @@ -173,4 +173,75 @@ tryGetUtf8CharLength(const char* input, int64_t size, int32_t& codePoint) { return -1; } +bool hasInvalidUTF8(const char* input, int32_t len) { + for (size_t inputIndex = 0; inputIndex < len;) { + if (IS_ASCII(input[inputIndex])) { + // Ascii + inputIndex++; + } else { + // Unicode + int32_t codePoint; + auto charLength = + tryGetUtf8CharLength(input + inputIndex, len - inputIndex, codePoint); + if (charLength < 0) { + return true; + } + inputIndex += charLength; + } + } + + return false; +} + +size_t replaceInvalidUTF8Characters( + char* outputBuffer, + const char* input, + int32_t len) { + size_t inputIndex = 0; + size_t outputIndex = 0; + + while (inputIndex < len) { + if (IS_ASCII(input[inputIndex])) { + outputBuffer[outputIndex++] = input[inputIndex++]; + } else { + // Unicode + int32_t codePoint; + const auto charLength = + tryGetUtf8CharLength(input + inputIndex, len - inputIndex, codePoint); + if (charLength > 0) { + std::memcpy(outputBuffer + outputIndex, input + inputIndex, charLength); + outputIndex += charLength; + inputIndex += charLength; + } else { + size_t replaceCharactersToWriteOut = inputIndex < len - 1 && + isMultipleInvalidSequences(input, inputIndex) + ? -charLength + : 1; + const auto& replacementCharacterString = + kReplacementCharacterStrings[replaceCharactersToWriteOut - 1]; + std::memcpy( + outputBuffer + outputIndex, + replacementCharacterString.data(), + replacementCharacterString.size()); + outputIndex += replacementCharacterString.size(); + inputIndex += -charLength; + } + } + } + + return outputIndex; +} + +template <> +void replaceInvalidUTF8Characters( + std::string& out, + const char* input, + int32_t len) { + auto maxLen = len * kReplacementCharacterStrings[0].size(); + out.resize(maxLen); + auto outputBuffer = out.data(); + auto outputIndex = replaceInvalidUTF8Characters(outputBuffer, input, len); + out.resize(outputIndex); +} + } // namespace facebook::velox::functions diff --git a/velox/functions/lib/Utf8Utils.h b/velox/functions/lib/Utf8Utils.h index 3c8950ee369a..781f3db6ccc7 100644 --- a/velox/functions/lib/Utf8Utils.h +++ b/velox/functions/lib/Utf8Utils.h @@ -23,6 +23,8 @@ namespace facebook::velox::functions { +#define IS_ASCII(x) !((x) & 0x80) + /// This function is not part of the original utf8proc. /// Tries to get the length of UTF-8 encoded code point. A /// positive return value means the UTF-8 sequence is valid, and @@ -86,4 +88,75 @@ FOLLY_ALWAYS_INLINE int validateAndGetNextUtf8Length( /// -1 for invalid UTF-8 first byte. int firstByteCharLength(const char* u_input); +/// Invalid character replacement matrix. +constexpr std::array kReplacementCharacterStrings{ + "\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", + "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd"}; + +/// Returns true if there are multiple UTF-8 invalid sequences. +template +FOLLY_ALWAYS_INLINE bool isMultipleInvalidSequences( + const T& inputBuffer, + size_t inputIndex) { + return + // 0xe0 followed by a value less than 0xe0 or 0xf0 followed by a + // value less than 0x90 is considered an overlong encoding. + (inputBuffer[inputIndex] == '\xe0' && + (inputBuffer[inputIndex + 1] & 0xe0) == 0x80) || + (inputBuffer[inputIndex] == '\xf0' && + (inputBuffer[inputIndex + 1] & 0xf0) == 0x80) || + // 0xf4 followed by a byte >= 0x90 looks valid to + // tryGetUtf8CharLength, but is actually outside the range of valid + // code points. + (inputBuffer[inputIndex] == '\xf4' && + (inputBuffer[inputIndex + 1] & 0xf0) != 0x80) || + // The bytes 0xf5-0xff, 0xc0, and 0xc1 look like the start of + // multi-byte code points to tryGetUtf8CharLength, but are not part of + // any valid code point. + (unsigned char)inputBuffer[inputIndex] > 0xf4 || + inputBuffer[inputIndex] == '\xc0' || inputBuffer[inputIndex] == '\xc1'; +} + +/// Returns true only if invalid UTF-8 is present in the input string. +bool hasInvalidUTF8(const char* input, int32_t len); + +/// Replaces invalid UTF-8 characters with replacement characters similar to +/// that produced by Presto java. The function requires that output have +/// sufficient capacity for the output string. +/// @param out Pointer to output string +/// @param input Pointer to input string +/// @param len Length of input string +/// @return number of bytes written +size_t +replaceInvalidUTF8Characters(char* output, const char* input, int32_t len); + +/// Replaces invalid UTF-8 characters with replacement characters similar to +/// that produced by Presto java. The function will allocate 1 byte for each +/// orininal character plus extra 2 bytes for each maximal subpart of an +/// ill-formed subsequence for an upper bound of 3x size of the input string. +/// @param out Reference to output string +/// @param input Pointer to input string +/// @param len Length of input string +template +void replaceInvalidUTF8Characters( + TOutString& out, + const char* input, + int32_t len) { + auto maxLen = len * kReplacementCharacterStrings[0].size(); + out.reserve(maxLen); + auto outputBuffer = out.data(); + auto outputIndex = replaceInvalidUTF8Characters(outputBuffer, input, len); + out.resize(outputIndex); +} + +template <> +void replaceInvalidUTF8Characters( + std::string& out, + const char* input, + int32_t len); + } // namespace facebook::velox::functions diff --git a/velox/functions/lib/tests/Utf8Test.cpp b/velox/functions/lib/tests/Utf8Test.cpp index 48a463a2a08e..24c0ddbbdcec 100644 --- a/velox/functions/lib/tests/Utf8Test.cpp +++ b/velox/functions/lib/tests/Utf8Test.cpp @@ -104,5 +104,48 @@ TEST(Utf8Test, tryCharLength) { ASSERT_EQ(-1, tryCharLength({0xBF})); } +TEST(UTF8Test, validUtf8) { + auto tryHasInvalidUTF8 = [](const std::vector& bytes) { + return hasInvalidUTF8( + reinterpret_cast(bytes.data()), bytes.size()); + }; + + ASSERT_FALSE(tryHasInvalidUTF8({0x5c, 0x19, 0x7A})); + ASSERT_TRUE(tryHasInvalidUTF8({0x5c, 0x19, 0x7A, 0xBF})); + ASSERT_TRUE(tryHasInvalidUTF8({0x64, 0x65, 0x1A, 0b11100000, 0x81, 0xBF})); +} + +TEST(UTF8Test, replaceInvalidUTF8Characters) { + auto testReplaceInvalidUTF8Chars = [](const std::string& input, + const std::string& expected) { + std::string output; + replaceInvalidUTF8Characters(output, input.data(), input.size()); + ASSERT_EQ(expected, output); + }; + + // Good case + testReplaceInvalidUTF8Chars("Hello World", "Hello World"); + // Bad encoding + testReplaceInvalidUTF8Chars("hello \xBF world", "hello � world"); + // Bad encoding with 3 byte char + testReplaceInvalidUTF8Chars("hello \xe0\x94\x83 world", "hello ��� world"); + // Bad encoding with 4 byte char + testReplaceInvalidUTF8Chars( + "hello \xf0\x80\x80\x80\x80 world", "hello ����� world"); + + // Overlong 4 byte utf8 character. + testReplaceInvalidUTF8Chars( + "hello \xef\xbf\xbd\xef\xbf\xbd world", "hello �� world"); + + // Test invalid byte 0xC0 + testReplaceInvalidUTF8Chars( + "hello \xef\xbf\xbd\xef\xbf\xbd world", "hello �� world"); + + // Test long 4 byte utf8 with continuation byte + testReplaceInvalidUTF8Chars( + "hello \xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd world", + "hello ���� world"); +} + } // namespace } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/JsonFunctions.cpp b/velox/functions/prestosql/JsonFunctions.cpp index 01ea6a627fe1..078fe35e3fc9 100644 --- a/velox/functions/prestosql/JsonFunctions.cpp +++ b/velox/functions/prestosql/JsonFunctions.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/expression/VectorFunction.h" +#include "velox/functions/lib/Utf8Utils.h" #include "velox/functions/prestosql/json/JsonStringUtil.h" #include "velox/functions/prestosql/json/SIMDJsonUtil.h" #include "velox/functions/prestosql/types/JsonType.h" @@ -166,15 +167,23 @@ class JsonParseFunction : public exec::VectorFunction { const auto& arg = args[0]; if (arg->isConstantEncoding()) { auto value = arg->as>()->valueAt(0); - paddedInput_.resize(value.size() + simdjson::SIMDJSON_PADDING); - memcpy(paddedInput_.data(), value.data(), value.size()); - auto escapeSize = escapedStringSize(value.data(), value.size()); + auto size = value.size(); + if (FOLLY_UNLIKELY(hasInvalidUTF8(value.data(), value.size()))) { + size = replaceInvalidUTF8Characters( + paddedInput_.data(), value.data(), value.size()); + paddedInput_.resize(size + simdjson::SIMDJSON_PADDING); + } else { + paddedInput_.resize(size + simdjson::SIMDJSON_PADDING); + memcpy(paddedInput_.data(), value.data(), size); + } + + auto escapeSize = escapedStringSize(value.data(), size); auto buffer = AlignedBuffer::allocate(escapeSize, context.pool()); BufferTracker bufferTracker{buffer}; JsonViews jsonViews; - if (auto error = parse(value.size(), jsonViews)) { + if (auto error = parse(size, jsonViews)) { context.setErrors(rows, errors_[error]); return; } @@ -219,8 +228,18 @@ class JsonParseFunction : public exec::VectorFunction { rows.applyToSelected([&](auto row) { JsonViews jsonViews; auto value = flatInput->valueAt(row); - memcpy(paddedInput_.data(), value.data(), value.size()); - if (auto error = parse(value.size(), jsonViews)) { + auto size = value.size(); + if (FOLLY_UNLIKELY(hasInvalidUTF8(value.data(), size))) { + size = replaceInvalidUTF8Characters( + paddedInput_.data(), value.data(), size); + if (maxSize < size) { + paddedInput_.resize(size + simdjson::SIMDJSON_PADDING); + } + } else { + memcpy(paddedInput_.data(), value.data(), size); + } + + if (auto error = parse(size, jsonViews)) { context.setVeloxExceptionError(row, errors_[error]); } else { auto canonicalString = bufferTracker.getCanonicalString(jsonViews); diff --git a/velox/functions/prestosql/URIParser.h b/velox/functions/prestosql/URIParser.h index 6a20c39b7550..1417d6f19b2d 100644 --- a/velox/functions/prestosql/URIParser.h +++ b/velox/functions/prestosql/URIParser.h @@ -16,6 +16,7 @@ #pragma once #include +#include "velox/functions/lib/Utf8Utils.h" #include "velox/type/StringView.h" namespace facebook::velox::functions { @@ -51,29 +52,6 @@ bool parseUri(const StringView& uriStr, URI& uri); /// false and pos is unchanged. bool tryConsumeIPV6Address(const char* str, const size_t len, int32_t& pos); -template -FOLLY_ALWAYS_INLINE bool isMultipleInvalidSequences( - const T& inputBuffer, - size_t inputIndex) { - return - // 0xe0 followed by a value less than 0xe0 or 0xf0 followed by a - // value less than 0x90 is considered an overlong encoding. - (inputBuffer[inputIndex] == '\xe0' && - (inputBuffer[inputIndex + 1] & 0xe0) == 0x80) || - (inputBuffer[inputIndex] == '\xf0' && - (inputBuffer[inputIndex + 1] & 0xf0) == 0x80) || - // 0xf4 followed by a byte >= 0x90 looks valid to - // tryGetUtf8CharLength, but is actually outside the range of valid - // code points. - (inputBuffer[inputIndex] == '\xf4' && - (inputBuffer[inputIndex + 1] & 0xf0) != 0x80) || - // The bytes 0xf5-0xff, 0xc0, and 0xc1 look like the start of - // multi-byte code points to tryGetUtf8CharLength, but are not part of - // any valid code point. - (unsigned char)inputBuffer[inputIndex] > 0xf4 || - inputBuffer[inputIndex] == '\xc0' || inputBuffer[inputIndex] == '\xc1'; -} - /// Find an extract the value for the parameter with key `param` from the query /// portion of a URI `query`. `query` should already be decoded if necessary. template diff --git a/velox/functions/prestosql/URLFunctions.h b/velox/functions/prestosql/URLFunctions.h index 74a5892ea874..e96be9223ec0 100644 --- a/velox/functions/prestosql/URLFunctions.h +++ b/velox/functions/prestosql/URLFunctions.h @@ -23,6 +23,8 @@ namespace facebook::velox::functions { namespace detail { + +/// Encoded replacement character strings. constexpr std::array kEncodedReplacementCharacterStrings = {"%EF%BF%BD", "%EF%BF%BD%EF%BF%BD", @@ -30,13 +32,6 @@ constexpr std::array kEncodedReplacementCharacterStrings = "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD", "%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD%EF%BF%BD"}; -constexpr std::array kDecodedReplacementCharacterStrings{ - "\xef\xbf\xbd", - "\xef\xbf\xbd\xef\xbf\xbd", - "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", - "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", - "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd", - "\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd\xef\xbf\xbd"}; FOLLY_ALWAYS_INLINE unsigned char toHex(unsigned char c) { return c < 10 ? (c + '0') : (c + 'A' - 10); @@ -178,7 +173,7 @@ FOLLY_ALWAYS_INLINE void urlUnescape( } else if (charLength < 0) { // This isn't the start of a valid UTF-8 character, write out the // replacement character. - const auto& replacementString = kDecodedReplacementCharacterStrings[0]; + const auto& replacementString = kReplacementCharacterStrings[0]; std::memcpy( outputBuffer, replacementString.data(), replacementString.length()); outputBuffer += replacementString.length(); @@ -216,8 +211,8 @@ FOLLY_ALWAYS_INLINE void urlUnescape( size_t charLength = outputBuffer - charStart; size_t replaceCharactersToWriteOut = isMultipleInvalidSequences(charStart, 0) ? charLength : 1; - const auto& replacementString = kDecodedReplacementCharacterStrings - [replaceCharactersToWriteOut - 1]; + const auto& replacementString = + kReplacementCharacterStrings[replaceCharactersToWriteOut - 1]; outputBuffer = charStart; std::memcpy( diff --git a/velox/functions/prestosql/tests/JsonFunctionsTest.cpp b/velox/functions/prestosql/tests/JsonFunctionsTest.cpp index e1ced085effe..9bdfa9581cf8 100644 --- a/velox/functions/prestosql/tests/JsonFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/JsonFunctionsTest.cpp @@ -235,6 +235,9 @@ TEST_F(JsonFunctionsTest, jsonParse) { R"("Items for D \ud835\udc52\ud835\udcc1 ")", R"("Items for D \uD835\uDC52\uD835\uDCC1 ")"); + // Test bad unicode characters + testJsonParse("\"Hello \xc0\xaf World\"", "\"Hello �� World\""); + VELOX_ASSERT_THROW( jsonParse(R"({"k1":})"), "The JSON document has an improper structure"); VELOX_ASSERT_THROW(