From 96eaab658fd567a3ddc8e9b904c7d803cd676a23 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Thu, 3 Oct 2024 00:22:57 +0530 Subject: [PATCH] refactor: Base64 APIs as non-throwing APIs --- velox/common/encode/Base64.cpp | 175 +++++++++++------- velox/common/encode/Base64.h | 27 +-- velox/common/encode/tests/Base64Test.cpp | 39 ++-- velox/common/encode/tests/CMakeLists.txt | 2 +- velox/docs/functions/presto/binary.rst | 11 +- velox/functions/prestosql/BinaryFunctions.h | 44 +++-- .../prestosql/tests/BinaryFunctionsTest.cpp | 6 + 7 files changed, 185 insertions(+), 119 deletions(-) diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index 045a01982bcb..b36cb55e8558 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -18,9 +18,7 @@ #include #include #include -#include - -#include "velox/common/base/Exceptions.h" +#include namespace facebook::velox::encoding { @@ -163,10 +161,10 @@ std::string Base64::encodeImpl( const T& input, const Charset& charset, bool includePadding) { - size_t encodedSize = calculateEncodedSize(input.size(), includePadding); + size_t encodedSize{calculateEncodedSize(input.size(), includePadding)}; std::string encodedResult; encodedResult.resize(encodedSize); - encodeImpl(input, charset, includePadding, encodedResult.data()); + (void)encodeImpl(input, charset, includePadding, encodedResult.data()); return encodedResult; } @@ -186,27 +184,31 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool withPadding) { } // static -void Base64::encode(const char* input, size_t inputSize, char* output) { - encodeImpl( +Status Base64::encode(const char* input, size_t inputSize, char* output) { + return encodeImpl( folly::StringPiece(input, inputSize), kBase64Charset, true, output); } // static -void Base64::encodeUrl(const char* input, size_t inputSize, char* output) { - encodeImpl( - folly::StringPiece(input, inputSize), kBase64UrlCharset, true, output); +Status +Base64::encodeUrl(const char* input, size_t inputSize, char* outputBuffer) { + return encodeImpl( + folly::StringPiece(input, inputSize), + kBase64UrlCharset, + true, + outputBuffer); } // static template -void Base64::encodeImpl( +Status Base64::encodeImpl( const T& input, const Charset& charset, bool includePadding, char* outputBuffer) { auto inputSize = input.size(); if (inputSize == 0) { - return; + return Status::OK(); } auto outputPointer = outputBuffer; @@ -246,6 +248,7 @@ void Base64::encodeImpl( } } } + return Status::OK(); } // static @@ -310,8 +313,7 @@ std::string Base64::encode(const folly::IOBuf* inputBuffer) { // static std::string Base64::decode(folly::StringPiece encodedText) { std::string decodedResult; - Base64::decode( - std::make_pair(encodedText.data(), encodedText.size()), decodedResult); + decode(std::make_pair(encodedText.data(), encodedText.size()), decodedResult); return decodedResult; } @@ -320,29 +322,36 @@ void Base64::decode( const std::pair& payload, std::string& decodedOutput) { size_t inputSize = payload.second; - decodedOutput.resize(calculateDecodedSize(payload.first, inputSize)); - decode(payload.first, inputSize, decodedOutput.data(), decodedOutput.size()); + size_t decodedSize; + (void)calculateDecodedSize(payload.first, inputSize, decodedSize); + decodedOutput.resize(decodedSize); + (void)decode( + payload.first, inputSize, decodedOutput.data(), decodedOutput.size()); } // static -void Base64::decode(const char* input, size_t size, char* output) { - size_t expectedOutputSize = size / 4 * 3; - Base64::decode(input, size, output, expectedOutputSize); +void Base64::decode(const char* input, size_t inputSize, char* outputBuffer) { + size_t outputSize; + (void)calculateDecodedSize(input, inputSize, outputSize); + (void)decode(input, inputSize, outputBuffer, outputSize); } // static -uint8_t Base64::base64ReverseLookup( +Status Base64::base64ReverseLookup( char encodedChar, - const Base64::ReverseIndex& reverseIndex) { - auto reverseLookupValue = reverseIndex[static_cast(encodedChar)]; + const ReverseIndex& reverseIndex, + uint8_t& reverseLookupValue) { + reverseLookupValue = reverseIndex[static_cast(encodedChar)]; if (reverseLookupValue >= 0x40) { - VELOX_USER_FAIL("decode() - invalid input string: invalid characters"); + return Status::UserError(fmt::format( + "decode() - invalid input string: invalid character '{}'", + encodedChar)); } - return reverseLookupValue; + return Status::OK(); } // static -size_t Base64::decode( +Status Base64::decode( const char* input, size_t inputSize, char* output, @@ -352,9 +361,13 @@ size_t Base64::decode( } // static -size_t Base64::calculateDecodedSize(const char* input, size_t& inputSize) { +Status Base64::calculateDecodedSize( + const char* input, + size_t& inputSize, + size_t& decodedSize) { if (inputSize == 0) { - return 0; + decodedSize = 0; + return Status::OK(); } // Check if the input string is padded @@ -362,53 +375,57 @@ size_t Base64::calculateDecodedSize(const char* input, size_t& inputSize) { // If padded, ensure that the string length is a multiple of the encoded // block size if (inputSize % kEncodedBlockByteSize != 0) { - VELOX_USER_FAIL( + return Status::UserError( "Base64::decode() - invalid input string: " "string length is not a multiple of 4."); } - auto decodedSize = - (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize; + decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize; auto paddingCount = numPadding(input, inputSize); inputSize -= paddingCount; // Adjust the needed size by deducting the bytes corresponding to the // padding from the calculated size. - return decodedSize - + decodedSize -= ((paddingCount * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) / kEncodedBlockByteSize; + return Status::OK(); } // If not padded, Calculate extra bytes, if any auto extraBytes = inputSize % kEncodedBlockByteSize; - auto decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize; + decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize; // Adjust the needed size for extra bytes, if present if (extraBytes) { if (extraBytes == 1) { - VELOX_USER_FAIL( + return Status::UserError( "Base64::decode() - invalid input string: " "string length cannot be 1 more than a multiple of 4."); } decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize; } - return decodedSize; + return Status::OK(); } // static -size_t Base64::decodeImpl( +Status Base64::decodeImpl( const char* input, size_t inputSize, char* outputBuffer, size_t outputSize, const ReverseIndex& reverseIndex) { - if (!inputSize) { - return 0; + if (inputSize == 0) { + return Status::OK(); } - auto decodedSize = calculateDecodedSize(input, inputSize); + size_t decodedSize; + auto status = calculateDecodedSize(input, inputSize, decodedSize); + if (!status.ok()) { + return status; + } if (outputSize < decodedSize) { - VELOX_USER_FAIL( + return Status::UserError( "Base64::decode() - invalid output string: " "output string is too small."); } @@ -418,32 +435,57 @@ size_t Base64::decodeImpl( // Each character of the 4 encodes 6 bits of the original, grab each with // the appropriate shifts to rebuild the original and then split that back // into the original 8-bit bytes. - uint32_t decodedBlock = - (base64ReverseLookup(input[0], reverseIndex) << 18) | - (base64ReverseLookup(input[1], reverseIndex) << 12) | - (base64ReverseLookup(input[2], reverseIndex) << 6) | - base64ReverseLookup(input[3], reverseIndex); - outputBuffer[0] = (decodedBlock >> 16) & 0xff; - outputBuffer[1] = (decodedBlock >> 8) & 0xff; - outputBuffer[2] = decodedBlock & 0xff; + uint32_t decodedBlock = 0; + uint8_t reverseLookupValue; + for (int i = 0; i < 4; ++i) { + status = base64ReverseLookup(input[i], reverseIndex, reverseLookupValue); + if (!status.ok()) { + return status; + } + decodedBlock |= reverseLookupValue << (18 - 6 * i); + } + outputBuffer[0] = static_cast((decodedBlock >> 16) & 0xff); + outputBuffer[1] = static_cast((decodedBlock >> 8) & 0xff); + outputBuffer[2] = static_cast(decodedBlock & 0xff); } // Handle the last 2-4 characters. This is similar to the above, but the // last 2 characters may or may not exist. - DCHECK(inputSize >= 2); - uint32_t decodedBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) | - (base64ReverseLookup(input[1], reverseIndex) << 12); - outputBuffer[0] = (decodedBlock >> 16) & 0xff; - if (inputSize > 2) { - decodedBlock |= base64ReverseLookup(input[2], reverseIndex) << 6; - outputBuffer[1] = (decodedBlock >> 8) & 0xff; - if (inputSize > 3) { - decodedBlock |= base64ReverseLookup(input[3], reverseIndex); - outputBuffer[2] = decodedBlock & 0xff; + if (inputSize >= 2) { + uint32_t decodedBlock = 0; + uint8_t reverseLookupValue; + + // Process the first two characters + for (int i = 0; i < 2; ++i) { + status = base64ReverseLookup(input[i], reverseIndex, reverseLookupValue); + if (!status.ok()) { + return status; + } + decodedBlock |= reverseLookupValue << (18 - 6 * i); + } + outputBuffer[0] = static_cast((decodedBlock >> 16) & 0xff); + + if (inputSize > 2) { + status = base64ReverseLookup(input[2], reverseIndex, reverseLookupValue); + if (!status.ok()) { + return status; + } + decodedBlock |= reverseLookupValue << 6; + outputBuffer[1] = static_cast((decodedBlock >> 8) & 0xff); + + if (inputSize > 3) { + status = + base64ReverseLookup(input[3], reverseIndex, reverseLookupValue); + if (!status.ok()) { + return status; + } + decodedBlock |= reverseLookupValue; + outputBuffer[2] = static_cast(decodedBlock & 0xff); + } } } - return decodedSize; + return Status::OK(); } // static @@ -462,19 +504,19 @@ std::string Base64::encodeUrl(const folly::IOBuf* inputBuffer) { } // static -void Base64::decodeUrl( +Status Base64::decodeUrl( const char* input, size_t inputSize, char* outputBuffer, size_t outputSize) { - decodeImpl( + return decodeImpl( input, inputSize, outputBuffer, outputSize, kBase64UrlReverseIndexTable); } // static std::string Base64::decodeUrl(folly::StringPiece encodedText) { std::string decodedOutput; - Base64::decodeUrl( + decodeUrl( std::make_pair(encodedText.data(), encodedText.size()), decodedOutput); return decodedOutput; } @@ -483,15 +525,16 @@ std::string Base64::decodeUrl(folly::StringPiece encodedText) { void Base64::decodeUrl( const std::pair& payload, std::string& decodedOutput) { - size_t decodedSize = (payload.second + 3) / 4 * 3; - decodedOutput.resize(decodedSize, '\0'); - decodedSize = Base64::decodeImpl( + size_t inputSize = payload.second; + size_t decodedSize; + (void)calculateDecodedSize(payload.first, inputSize, decodedSize); + decodedOutput.resize(decodedSize); + (void)decodeImpl( payload.first, payload.second, &decodedOutput[0], - decodedSize, + decodedOutput.size(), kBase64UrlReverseIndexTable); - decodedOutput.resize(decodedSize); } } // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 20c53beae91f..f351d6e340fb 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -23,6 +23,7 @@ #include #include "velox/common/base/GTestMacros.h" +#include "velox/common/base/Status.h" namespace facebook::velox::encoding { @@ -52,7 +53,7 @@ class Base64 { /// Encodes the specified number of characters from the 'input' and writes the /// result to the 'outputBuffer'. The output must have enough space as /// returned by the calculateEncodedSize(). - static void encode(const char* input, size_t inputSize, char* outputBuffer); + static Status encode(const char* input, size_t inputSize, char* outputBuffer); /// Encodes the specified number of characters from the 'input' using URL /// encoding. @@ -67,7 +68,7 @@ class Base64 { /// Encodes the specified number of characters from the 'input' and writes the /// result to the 'outputBuffer' using URL encoding. The output must have /// enough space as returned by the calculateEncodedSize(). - static void + static Status encodeUrl(const char* input, size_t inputSize, char* outputBuffer); /// Decodes the input Base64 encoded string. @@ -86,7 +87,7 @@ class Base64 { /// Decodes the specified number of characters from the 'input' and writes the /// result to the 'outputBuffer'. - static size_t decode( + static Status decode( const char* input, size_t inputSize, char* outputBuffer, @@ -103,7 +104,7 @@ class Base64 { /// Decodes the specified number of characters from the 'input' using URL /// encoding and writes the result to the 'outputBuffer' - static void decodeUrl( + static Status decodeUrl( const char* input, size_t inputSize, char* outputBuffer, @@ -112,9 +113,12 @@ class Base64 { /// Calculates the encoded size based on input 'inputSize'. static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true); - /// Returns the actual size of the decoded data. Removes the padding - /// length from the input data 'inputSize'. - static size_t calculateDecodedSize(const char* input, size_t& inputSize); + /// Calculates the decoded size based on encoded input and adjusts the input + /// size for padding. + static Status calculateDecodedSize( + const char* input, + size_t& inputSize, + size_t& decodedSize); private: // Padding character used in encoding. @@ -137,9 +141,10 @@ class Base64 { // Reverse lookup helper function to get the original index of a Base64 // character. - static uint8_t base64ReverseLookup( + static Status base64ReverseLookup( char encodedChar, - const ReverseIndex& reverseIndex); + const ReverseIndex& reverseIndex, + uint8_t& reverseLookupValue); // Encodes the specified data using the provided charset. template @@ -148,14 +153,14 @@ class Base64 { // Encodes the specified data using the provided charset. template - static void encodeImpl( + static Status encodeImpl( const T& input, const Charset& charset, bool includePadding, char* outputBuffer); // Decodes the specified data using the provided reverse lookup table. - static size_t decodeImpl( + static Status decodeImpl( const char* input, size_t inputSize, char* outputBuffer, diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 9cbbbad47124..ecfbf20a09f2 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -50,43 +50,48 @@ TEST_F(Base64Test, fromBase64) { TEST_F(Base64Test, calculateDecodedSizeProperSize) { size_t encoded_size{0}; + size_t decoded_size{0}; encoded_size = 20; - EXPECT_EQ( - 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size)); + Base64::calculateDecodedSize( + "SGVsbG8sIFdvcmxkIQ==", encoded_size, decoded_size); EXPECT_EQ(18, encoded_size); + EXPECT_EQ(13, decoded_size); encoded_size = 18; - EXPECT_EQ( - 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ", encoded_size)); + Base64::calculateDecodedSize( + "SGVsbG8sIFdvcmxkIQ", encoded_size, decoded_size); EXPECT_EQ(18, encoded_size); + EXPECT_EQ(13, decoded_size); encoded_size = 21; - VELOX_ASSERT_THROW( - Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size), - "Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4."); - - encoded_size = 32; EXPECT_EQ( - 23, + Status::UserError( + "Base64::decode() - invalid input string: string length is not a multiple of 4."), Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size)); + "SGVsbG8sIFdvcmxkIQ===", encoded_size, decoded_size)); + + encoded_size = 32; + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size, decoded_size); EXPECT_EQ(31, encoded_size); + EXPECT_EQ(23, decoded_size); encoded_size = 31; - EXPECT_EQ( - 23, - Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size)); + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size, decoded_size); EXPECT_EQ(31, encoded_size); + EXPECT_EQ(23, decoded_size); encoded_size = 16; - EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size)); + Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size, decoded_size); EXPECT_EQ(14, encoded_size); + EXPECT_EQ(10, decoded_size); encoded_size = 14; - EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size)); + Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size, decoded_size); EXPECT_EQ(14, encoded_size); + EXPECT_EQ(10, decoded_size); } TEST_F(Base64Test, checksPadding) { diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 90c9733ecf22..63f718c24745 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -17,4 +17,4 @@ add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test PUBLIC Folly::folly - PRIVATE velox_encode velox_exception GTest::gtest GTest::gtest_main) + PRIVATE velox_encode velox_status GTest::gtest GTest::gtest_main) diff --git a/velox/docs/functions/presto/binary.rst b/velox/docs/functions/presto/binary.rst index 8b4ddc26832e..07deb3e4b0e9 100644 --- a/velox/docs/functions/presto/binary.rst +++ b/velox/docs/functions/presto/binary.rst @@ -8,26 +8,25 @@ Binary Functions .. function:: from_base64(string) -> varbinary - Decodes a Base64-encoded ``string`` back into its original binary form. - This function is capable of handling both fully padded and non-padded Base64 encoded strings. - Partially padded Base64 strings are not supported and will result in an error. + Decodes a Base64-encoded ``string`` back into its original binary form. + This function is capable of handling both fully padded and non-padded Base64 encoded strings. + Partially padded Base64 strings are not supported and will result in a "UserError" status being returned. Examples -------- Query with padded Base64 string: :: SELECT from_base64('SGVsbG8gV29ybGQ='); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] - Query with non-padded Base64 string: :: SELECT from_base64('SGVsbG8gV29ybGQ'); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100] Query with partial-padded Base64 string: :: - SELECT from_base64('SGVsbG8gV29ybGQgZm9yIHZlbG94IQ='); -- Error : Base64::decode() - invalid input string: string length is not a multiple of 4. + SELECT from_base64('SGVsbG8gV29ybGQgZm9yIHZlbG94IQ='); -- UserError: Base64::decode() - invalid input string: string length is not a multiple of 4. In the above examples, both the fully padded and non-padded Base64 strings ('SGVsbG8gV29ybGQ=' and 'SGVsbG8gV29ybGQ') decode to the binary representation of the text 'Hello World'. - While, partial-padded Base64 string 'SGVsbG8gV29ybGQgZm9yIHZlbG94IQ=' will lead to an velox error. + A partial-padded Base64 string 'SGVsbG8gV29ybGQgZm9yIHZlbG94IQ=' will result in a "UserError" status indicating the Base64 string is invalid. .. function:: from_base64url(string) -> varbinary diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index c9495daff1bb..d35fccefe48b 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -278,11 +278,10 @@ template struct ToBase64Function { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { result.resize(encoding::Base64::calculateEncodedSize(input.size())); - encoding::Base64::encode(input.data(), input.size(), result.data()); + return encoding::Base64::encode(input.data(), input.size(), result.data()); } }; @@ -293,11 +292,16 @@ struct FromBase64Function { // T can be either arg_type or arg_type. These are the // same, but hard-coding one of them might be confusing. template - FOLLY_ALWAYS_INLINE void call(out_type& result, const T& input) { + FOLLY_ALWAYS_INLINE Status call(out_type& result, const T& input) { auto inputSize = input.size(); - result.resize( - encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decode( + size_t decodedSize; + auto status = encoding::Base64::calculateDecodedSize( + input.data(), inputSize, decodedSize); + if (!status.ok()) { + return status; + } + result.resize(decodedSize); + return encoding::Base64::decode( input.data(), inputSize, result.data(), result.size()); } }; @@ -305,13 +309,17 @@ struct FromBase64Function { template struct FromBase64UrlFunction { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { auto inputSize = input.size(); - result.resize( - encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decodeUrl( + size_t decodedSize; + auto status = encoding::Base64::calculateDecodedSize( + input.data(), inputSize, decodedSize); + if (!status.ok()) { + return status; + } + result.resize(decodedSize); + return encoding::Base64::decodeUrl( input.data(), inputSize, result.data(), result.size()); } }; @@ -320,11 +328,11 @@ template struct ToBase64UrlFunction { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { result.resize(encoding::Base64::calculateEncodedSize(input.size())); - encoding::Base64::encodeUrl(input.data(), input.size(), result.data()); + return encoding::Base64::encodeUrl( + input.data(), input.size(), result.data()); } }; diff --git a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp index a1582e9f5eb0..da559d214015 100644 --- a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp @@ -440,6 +440,12 @@ TEST_F(BinaryFunctionsTest, fromBase64) { VELOX_ASSERT_USER_THROW( fromBase64("YQ==="), "Base64::decode() - invalid input string: string length is not a multiple of 4."); + VELOX_ASSERT_USER_THROW( + fromBase64("aG;"), + "decode() - invalid input string: invalid character ';'"); + VELOX_ASSERT_USER_THROW( + fromBase64("YQ?="), + "decode() - invalid input string: invalid character '?'"); // Check encoded strings without padding EXPECT_EQ("a", fromBase64("YQ"));