From 861b80d9335908ed44c6248fff6cb9ea5857ccf4 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Thu, 3 Oct 2024 00:22:57 +0530 Subject: [PATCH] Refactor Base64 APIs as non-throwing APIs --- velox/common/encode/Base64.cpp | 163 ++++++++++++-------- velox/common/encode/Base64.h | 25 +-- velox/common/encode/tests/Base64Test.cpp | 39 +++-- velox/common/encode/tests/CMakeLists.txt | 2 +- velox/docs/functions/presto/binary.rst | 11 +- velox/functions/prestosql/BinaryFunctions.h | 44 +++--- 6 files changed, 167 insertions(+), 117 deletions(-) diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index 6ee390de80a5..f3c1f3605f39 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -18,9 +18,7 @@ #include #include #include -#include - -#include "velox/common/base/Exceptions.h" +#include namespace facebook::velox::encoding { @@ -163,22 +161,22 @@ std::string Base64::encodeImpl( const T& input, const Charset& charset, bool includePadding) { - size_t encodedSize = calculateEncodedSize(input.size(), includePadding); + const size_t encodedSize{calculateEncodedSize(input.size(), includePadding)}; std::string encodedResult; encodedResult.resize(encodedSize); - encodeImpl(input, charset, includePadding, encodedResult.data()); + (void)encodeImpl(input, charset, includePadding, encodedResult.data()); return encodedResult; } // static -size_t Base64::calculateEncodedSize(size_t inputSize, bool withPadding) { +size_t Base64::calculateEncodedSize(size_t inputSize, bool includePadding) { if (inputSize == 0) { return 0; } // Calculate the output size assuming that we are including padding. size_t encodedSize = ((inputSize + 2) / 3) * 4; - if (!withPadding) { + if (!includePadding) { // If the padding was not requested, subtract the padding bytes. encodedSize -= (3 - (inputSize % 3)) % 3; } @@ -186,27 +184,31 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool withPadding) { } // static -void Base64::encode(const char* input, size_t inputSize, char* output) { - encodeImpl( +Status Base64::encode(const char* input, size_t inputSize, char* output) { + return encodeImpl( folly::StringPiece(input, inputSize), kBase64Charset, true, output); } // static -void Base64::encodeUrl(const char* input, size_t inputSize, char* output) { - encodeImpl( - folly::StringPiece(input, inputSize), kBase64UrlCharset, true, output); +Status +Base64::encodeUrl(const char* input, size_t inputSize, char* outputBuffer) { + return encodeImpl( + folly::StringPiece(input, inputSize), + kBase64UrlCharset, + true, + outputBuffer); } // static template -void Base64::encodeImpl( +Status Base64::encodeImpl( const T& input, - const Charset& charset, + const Base64::Charset& charset, bool includePadding, char* outputBuffer) { auto inputSize = input.size(); if (inputSize == 0) { - return; + return Status::OK(); } auto outputPointer = outputBuffer; @@ -215,9 +217,9 @@ void Base64::encodeImpl( // For each group of 3 bytes (24 bits) in the input, split that into // 4 groups of 6 bits and encode that using the supplied charset lookup for (; inputSize > 2; inputSize -= 3) { - uint32_t inputBlock = uint8_t(*inputIterator++) << 16; - inputBlock |= uint8_t(*inputIterator++) << 8; - inputBlock |= uint8_t(*inputIterator++); + uint32_t inputBlock = static_cast(*inputIterator++) << 16; + inputBlock |= static_cast(*inputIterator++) << 8; + inputBlock |= static_cast(*inputIterator++); *outputPointer++ = charset[(inputBlock >> 18) & 0x3f]; *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; @@ -229,10 +231,10 @@ void Base64::encodeImpl( // We have either 1 or 2 input bytes left. Encode this similar to the // above (assuming 0 for all other bytes). Optionally append the '=' // character if it is requested. - uint32_t inputBlock = uint8_t(*inputIterator++) << 16; + uint32_t inputBlock = static_cast(*inputIterator++) << 16; *outputPointer++ = charset[(inputBlock >> 18) & 0x3f]; if (inputSize > 1) { - inputBlock |= uint8_t(*inputIterator) << 8; + inputBlock |= static_cast(*inputIterator) << 8; *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; *outputPointer++ = charset[(inputBlock >> 6) & 0x3f]; if (includePadding) { @@ -246,6 +248,7 @@ void Base64::encodeImpl( } } } + return Status::OK(); } // static @@ -322,29 +325,35 @@ void Base64::decode( const std::pair& payload, std::string& decodedOutput) { size_t inputSize = payload.second; - decodedOutput.resize(calculateDecodedSize(payload.first, inputSize)); - decode(payload.first, inputSize, decodedOutput.data(), decodedOutput.size()); + size_t decodedSize; + (void)calculateDecodedSize(payload.first, inputSize, decodedSize); + decodedOutput.resize(decodedSize); + (void)decode( + payload.first, inputSize, decodedOutput.data(), decodedOutput.size()); } // static -void Base64::decode(const char* input, size_t size, char* output) { - size_t expectedOutputSize = size / 4 * 3; - Base64::decode(input, size, output, expectedOutputSize); +void Base64::decode(const char* input, size_t inputSize, char* outputBuffer) { + size_t outputSize; + (void)calculateDecodedSize(input, inputSize, outputSize); + (void)Base64::decode(input, inputSize, outputBuffer, outputSize); } // static uint8_t Base64::base64ReverseLookup( char encodedChar, - const Base64::ReverseIndex& reverseIndex) { - auto reverseLookupValue = reverseIndex[(uint8_t)encodedChar]; + const Base64::ReverseIndex& reverseIndex, + Status& status) { + auto reverseLookupValue = reverseIndex[static_cast(encodedChar)]; if (reverseLookupValue >= 0x40) { - VELOX_USER_FAIL("decode() - invalid input string: invalid characters"); + status = Status::UserError( + "decode() - invalid input string: invalid characters"); } return reverseLookupValue; } // static -size_t Base64::decode( +Status Base64::decode( const char* input, size_t inputSize, char* output, @@ -354,9 +363,13 @@ size_t Base64::decode( } // static -size_t Base64::calculateDecodedSize(const char* input, size_t& inputSize) { +Status Base64::calculateDecodedSize( + const char* input, + size_t& inputSize, + size_t& decodedSize) { if (inputSize == 0) { - return 0; + decodedSize = 0; + return Status::OK(); } // Check if the input string is padded @@ -364,88 +377,107 @@ size_t Base64::calculateDecodedSize(const char* input, size_t& inputSize) { // If padded, ensure that the string length is a multiple of the encoded // block size if (inputSize % kEncodedBlockByteSize != 0) { - VELOX_USER_FAIL( + return Status::UserError( "Base64::decode() - invalid input string: " "string length is not a multiple of 4."); } - auto decodedSize = - (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize; + decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize; auto paddingCount = numPadding(input, inputSize); inputSize -= paddingCount; // Adjust the needed size by deducting the bytes corresponding to the // padding from the calculated size. - return decodedSize - + decodedSize -= ((paddingCount * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) / kEncodedBlockByteSize; + return Status::OK(); } // If not padded, Calculate extra bytes, if any auto extraBytes = inputSize % kEncodedBlockByteSize; - auto decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize; + decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize; // Adjust the needed size for extra bytes, if present if (extraBytes) { if (extraBytes == 1) { - VELOX_USER_FAIL( + return Status::UserError( "Base64::decode() - invalid input string: " "string length cannot be 1 more than a multiple of 4."); } decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize; } - return decodedSize; + return Status::OK(); } // static -size_t Base64::decodeImpl( +Status Base64::decodeImpl( const char* input, size_t inputSize, char* outputBuffer, size_t outputSize, const ReverseIndex& reverseIndex) { - if (!inputSize) { - return 0; + if (inputSize == 0) { + return Status::OK(); } - auto decodedSize = calculateDecodedSize(input, inputSize); + size_t decodedSize; + auto status = calculateDecodedSize(input, inputSize, decodedSize); + if (!status.ok()) { + return status; + } if (outputSize < decodedSize) { - VELOX_USER_FAIL( + return Status::UserError( "Base64::decode() - invalid output string: " "output string is too small."); } + Status lookupStatus; // Handle full groups of 4 characters for (; inputSize > 4; inputSize -= 4, input += 4, outputBuffer += 3) { // Each character of the 4 encodes 6 bits of the original, grab each with // the appropriate shifts to rebuild the original and then split that back // into the original 8-bit bytes. uint32_t decodedBlock = - (base64ReverseLookup(input[0], reverseIndex) << 18) | - (base64ReverseLookup(input[1], reverseIndex) << 12) | - (base64ReverseLookup(input[2], reverseIndex) << 6) | - base64ReverseLookup(input[3], reverseIndex); - outputBuffer[0] = (decodedBlock >> 16) & 0xff; - outputBuffer[1] = (decodedBlock >> 8) & 0xff; - outputBuffer[2] = decodedBlock & 0xff; + (base64ReverseLookup(input[0], reverseIndex, lookupStatus) << 18) | + (base64ReverseLookup(input[1], reverseIndex, lookupStatus) << 12) | + (base64ReverseLookup(input[2], reverseIndex, lookupStatus) << 6) | + base64ReverseLookup(input[3], reverseIndex, lookupStatus); + if (!lookupStatus.ok()) { + return lookupStatus; + } + outputBuffer[0] = static_cast((decodedBlock >> 16) & 0xff); + outputBuffer[1] = static_cast((decodedBlock >> 8) & 0xff); + outputBuffer[2] = static_cast(decodedBlock & 0xff); } // Handle the last 2-4 characters. This is similar to the above, but the // last 2 characters may or may not exist. DCHECK(inputSize >= 2); - uint32_t decodedBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) | - (base64ReverseLookup(input[1], reverseIndex) << 12); - outputBuffer[0] = (decodedBlock >> 16) & 0xff; + uint32_t decodedBlock = + (base64ReverseLookup(input[0], reverseIndex, lookupStatus) << 18) | + (base64ReverseLookup(input[1], reverseIndex, lookupStatus) << 12); + if (!lookupStatus.ok()) { + return lookupStatus; + } + outputBuffer[0] = static_cast((decodedBlock >> 16) & 0xff); if (inputSize > 2) { - decodedBlock |= base64ReverseLookup(input[2], reverseIndex) << 6; - outputBuffer[1] = (decodedBlock >> 8) & 0xff; + decodedBlock |= base64ReverseLookup(input[2], reverseIndex, lookupStatus) + << 6; + if (!lookupStatus.ok()) { + return lookupStatus; + } + outputBuffer[1] = static_cast((decodedBlock >> 8) & 0xff); if (inputSize > 3) { - decodedBlock |= base64ReverseLookup(input[3], reverseIndex); - outputBuffer[2] = decodedBlock & 0xff; + decodedBlock |= base64ReverseLookup(input[3], reverseIndex, lookupStatus); + if (!lookupStatus.ok()) { + return lookupStatus; + } + outputBuffer[2] = static_cast(decodedBlock & 0xff); } } - return decodedSize; + return Status::OK(); } // static @@ -464,12 +496,12 @@ std::string Base64::encodeUrl(const folly::IOBuf* inputBuffer) { } // static -void Base64::decodeUrl( +Status Base64::decodeUrl( const char* input, size_t inputSize, char* outputBuffer, size_t outputSize) { - decodeImpl( + return decodeImpl( input, inputSize, outputBuffer, outputSize, kBase64UrlReverseIndexTable); } @@ -485,15 +517,16 @@ std::string Base64::decodeUrl(folly::StringPiece encodedText) { void Base64::decodeUrl( const std::pair& payload, std::string& decodedOutput) { - size_t decodedSize = (payload.second + 3) / 4 * 3; - decodedOutput.resize(decodedSize, '\0'); - decodedSize = Base64::decodeImpl( + size_t inputSize = payload.second; + size_t decodedSize; + (void)calculateDecodedSize(payload.first, inputSize, decodedSize); + decodedOutput.resize(decodedSize); + (void)Base64::decodeImpl( payload.first, payload.second, &decodedOutput[0], - decodedSize, + decodedOutput.size(), kBase64UrlReverseIndexTable); - decodedOutput.resize(decodedSize); } } // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 9eac6899c0a0..c9a9df9220ab 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -16,11 +16,12 @@ #pragma once -#include -#include #include #include +#include +#include #include "velox/common/base/GTestMacros.h" +#include "velox/common/base/Status.h" namespace facebook::velox::encoding { @@ -47,7 +48,7 @@ class Base64 { /// Encodes the specified number of characters from the 'input' and writes the /// result to the 'outputBuffer'. The output must have enough space as /// returned by the calculateEncodedSize(). - static void encode(const char* input, size_t inputSize, char* outputBuffer); + static Status encode(const char* input, size_t inputSize, char* outputBuffer); /// Encodes the specified number of characters from the 'input' using URL /// encoding. @@ -59,7 +60,7 @@ class Base64 { /// Encodes the specified number of characters from the 'input' and writes the /// result to the 'outputBuffer' using URL encoding. The output must have /// enough space as returned by the calculateEncodedSize(). - static void + static Status encodeUrl(const char* input, size_t inputSize, char* outputBuffer); /// Decodes the input Base64 encoded string. @@ -75,7 +76,7 @@ class Base64 { static void decode(const char* input, size_t inputSize, char* outputBuffer); /// Decodes the specified number of characters from the 'input' and writes the /// result to the 'outputBuffer'. - static size_t decode( + static Status decode( const char* input, size_t inputSize, char* outputBuffer, @@ -90,7 +91,7 @@ class Base64 { std::string& output); /// Decodes the specified number of characters from the 'input' using URL /// encoding and writes the result to the 'outputBuffer' - static void decodeUrl( + static Status decodeUrl( const char* input, size_t inputSize, char* outputBuffer, @@ -101,7 +102,10 @@ class Base64 { /// Calculates the decoded size based on encoded input and adjusts the input /// size for padding. - static size_t calculateDecodedSize(const char* input, size_t& inputSize); + static Status calculateDecodedSize( + const char* input, + size_t& inputSize, + size_t& decodedSize); private: // Padding character used in encoding. @@ -126,7 +130,8 @@ class Base64 { // character. static uint8_t base64ReverseLookup( char encodedChar, - const ReverseIndex& reverseIndex); + const ReverseIndex& reverseIndex, + Status& status); // Encodes the specified data using the provided charset. template @@ -135,14 +140,14 @@ class Base64 { // Encodes the specified data using the provided charset. template - static void encodeImpl( + static Status encodeImpl( const T& input, const Charset& charset, bool includePadding, char* outputBuffer); // Decodes the specified data using the provided reverse lookup table. - static size_t decodeImpl( + static Status decodeImpl( const char* input, size_t inputSize, char* outputBuffer, diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 9cbbbad47124..ecfbf20a09f2 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -50,43 +50,48 @@ TEST_F(Base64Test, fromBase64) { TEST_F(Base64Test, calculateDecodedSizeProperSize) { size_t encoded_size{0}; + size_t decoded_size{0}; encoded_size = 20; - EXPECT_EQ( - 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size)); + Base64::calculateDecodedSize( + "SGVsbG8sIFdvcmxkIQ==", encoded_size, decoded_size); EXPECT_EQ(18, encoded_size); + EXPECT_EQ(13, decoded_size); encoded_size = 18; - EXPECT_EQ( - 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ", encoded_size)); + Base64::calculateDecodedSize( + "SGVsbG8sIFdvcmxkIQ", encoded_size, decoded_size); EXPECT_EQ(18, encoded_size); + EXPECT_EQ(13, decoded_size); encoded_size = 21; - VELOX_ASSERT_THROW( - Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size), - "Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4."); - - encoded_size = 32; EXPECT_EQ( - 23, + Status::UserError( + "Base64::decode() - invalid input string: string length is not a multiple of 4."), Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size)); + "SGVsbG8sIFdvcmxkIQ===", encoded_size, decoded_size)); + + encoded_size = 32; + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size, decoded_size); EXPECT_EQ(31, encoded_size); + EXPECT_EQ(23, decoded_size); encoded_size = 31; - EXPECT_EQ( - 23, - Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size)); + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size, decoded_size); EXPECT_EQ(31, encoded_size); + EXPECT_EQ(23, decoded_size); encoded_size = 16; - EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size)); + Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size, decoded_size); EXPECT_EQ(14, encoded_size); + EXPECT_EQ(10, decoded_size); encoded_size = 14; - EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size)); + Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size, decoded_size); EXPECT_EQ(14, encoded_size); + EXPECT_EQ(10, decoded_size); } TEST_F(Base64Test, checksPadding) { diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 90c9733ecf22..63f718c24745 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -17,4 +17,4 @@ add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test PUBLIC Folly::folly - PRIVATE velox_encode velox_exception GTest::gtest GTest::gtest_main) + PRIVATE velox_encode velox_status GTest::gtest GTest::gtest_main) diff --git a/velox/docs/functions/presto/binary.rst b/velox/docs/functions/presto/binary.rst index 8b4ddc26832e..07deb3e4b0e9 100644 --- a/velox/docs/functions/presto/binary.rst +++ b/velox/docs/functions/presto/binary.rst @@ -8,26 +8,25 @@ Binary Functions .. function:: from_base64(string) -> varbinary - Decodes a Base64-encoded ``string`` back into its original binary form. - This function is capable of handling both fully padded and non-padded Base64 encoded strings. - Partially padded Base64 strings are not supported and will result in an error. + Decodes a Base64-encoded ``string`` back into its original binary form. + This function is capable of handling both fully padded and non-padded Base64 encoded strings. + Partially padded Base64 strings are not supported and will result in a "UserError" status being returned. Examples -------- Query with padded Base64 string: :: SELECT from_base64('SGVsbG8gV29ybGQ='); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] - Query with non-padded Base64 string: :: SELECT from_base64('SGVsbG8gV29ybGQ'); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] Query with partial-padded Base64 string: :: - SELECT from_base64('SGVsbG8gV29ybGQgZm9yIHZlbG94IQ='); -- Error : Base64::decode() - invalid input string: string length is not a multiple of 4. + SELECT from_base64('SGVsbG8gV29ybGQgZm9yIHZlbG94IQ='); -- UserError: Base64::decode() - invalid input string: string length is not a multiple of 4. In the above examples, both the fully padded and non-padded Base64 strings ('SGVsbG8gV29ybGQ=' and 'SGVsbG8gV29ybGQ') decode to the binary representation of the text 'Hello World'. - While, partial-padded Base64 string 'SGVsbG8gV29ybGQgZm9yIHZlbG94IQ=' will lead to an velox error. + A partial-padded Base64 string 'SGVsbG8gV29ybGQgZm9yIHZlbG94IQ=' will result in a "UserError" status indicating the Base64 string is invalid. .. function:: from_base64url(string) -> varbinary diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index ce153ee349fc..ea151078a26f 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -278,11 +278,10 @@ template struct ToBase64Function { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { result.resize(encoding::Base64::calculateEncodedSize(input.size())); - encoding::Base64::encode(input.data(), input.size(), result.data()); + return encoding::Base64::encode(input.data(), input.size(), result.data()); } }; @@ -293,11 +292,16 @@ struct FromBase64Function { // T can be either arg_type or arg_type. These are the // same, but hard-coding one of them might be confusing. template - FOLLY_ALWAYS_INLINE void call(out_type& result, const T& input) { + FOLLY_ALWAYS_INLINE Status call(out_type& result, const T& input) { auto inputSize = input.size(); - result.resize( - encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decode( + size_t decodedSize; + auto status = encoding::Base64::calculateDecodedSize( + input.data(), inputSize, decodedSize); + if (!status.ok()) { + return status; + } + result.resize(decodedSize); + return encoding::Base64::decode( input.data(), inputSize, result.data(), result.size()); } }; @@ -305,13 +309,17 @@ struct FromBase64Function { template struct FromBase64UrlFunction { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { auto inputSize = input.size(); - result.resize( - encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decodeUrl( + size_t decodedSize; + auto status = encoding::Base64::calculateDecodedSize( + input.data(), inputSize, decodedSize); + if (!status.ok()) { + return status; + } + result.resize(decodedSize); + return encoding::Base64::decodeUrl( input.data(), inputSize, result.data(), result.size()); } }; @@ -320,11 +328,11 @@ template struct ToBase64UrlFunction { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { result.resize(encoding::Base64::calculateEncodedSize(input.size())); - encoding::Base64::encodeUrl(input.data(), input.size(), result.data()); + return encoding::Base64::encodeUrl( + input.data(), input.size(), result.data()); } };