From af8dbedc6fb010678f074bc1b1f7c38029b736af Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Mon, 19 Aug 2024 13:51:25 +0530 Subject: [PATCH] Convert Base64 as a non-throwing API --- velox/common/encode/Base64.cpp | 79 +++++++----- velox/common/encode/Base64.h | 12 +- velox/common/encode/tests/Base64Test.cpp | 158 +++++++++++++++++------ velox/common/encode/tests/CMakeLists.txt | 7 +- 4 files changed, 168 insertions(+), 88 deletions(-) diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index ee65b24861f4c..dfb6bc88f0e90 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -20,8 +20,6 @@ #include #include -#include "velox/common/base/Exceptions.h" - namespace facebook::velox::encoding { // Constants defining the size in bytes of binary and encoded blocks for Base64 @@ -326,13 +324,16 @@ void Base64::decode(std::string_view input, size_t size, char* output) { // static uint8_t Base64::base64ReverseLookup( - char character, - const Base64::ReverseIndex& reverseIndex) { - auto lookupValue = reverseIndex[(uint8_t)character]; - if (lookupValue >= 0x40) { - VELOX_USER_FAIL("decode() - invalid input string: invalid characters"); + char p, + const Base64::ReverseIndex& reverseIndex, + Status& status) { + auto curr = reverseIndex[(uint8_t)p]; + if (curr >= 0x40) { + status = Status::UserError( + "Base64::decode() - invalid input string: contains invalid characters."); + return 0; // Return 0 or any other error code indicating failure } - return lookupValue; + return curr; } // static @@ -399,7 +400,7 @@ Status Base64::decodeImpl( char* output, size_t outputSize, const Base64::ReverseIndex& reverseIndex) { - if (!inputSize) { + if (inputSize == 0) { return Status::OK(); } @@ -415,36 +416,44 @@ Status Base64::decodeImpl( "Base64::decode() - invalid output string: output string is too small."); } - // Handle full groups of 4 characters - for (; inputSize > 4; inputSize -= 4, input.remove_prefix(4), output += 3) { - // Each character of the 4 encodes 6 bits of the original, grab each with - // the appropriate shifts to rebuild the original and then split that back - // into the original 8-bit bytes. - uint32_t currentBlock = - (base64ReverseLookup(input[0], reverseIndex) << 18) | - (base64ReverseLookup(input[1], reverseIndex) << 12) | - (base64ReverseLookup(input[2], reverseIndex) << 6) | - base64ReverseLookup(input[3], reverseIndex); - output[0] = (currentBlock >> 16) & 0xff; - output[1] = (currentBlock >> 8) & 0xff; - output[2] = currentBlock & 0xff; + const char* inputPtr = input.data(); + char* outputPtr = output; + Status lookupStatus; + + // Process full blocks of 4 characters + size_t fullBlockCount = inputSize / 4; + for (size_t i = 0; i < fullBlockCount; ++i) { + uint8_t val0 = base64ReverseLookup(inputPtr[0], reverseIndex, lookupStatus); + uint8_t val1 = base64ReverseLookup(inputPtr[1], reverseIndex, lookupStatus); + uint8_t val2 = base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus); + uint8_t val3 = base64ReverseLookup(inputPtr[3], reverseIndex, lookupStatus); + + uint32_t currentBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3; + outputPtr[0] = static_cast((currentBlock >> 16) & 0xFF); + outputPtr[1] = static_cast((currentBlock >> 8) & 0xFF); + outputPtr[2] = static_cast(currentBlock & 0xFF); + + inputPtr += 4; + outputPtr += 3; } - // Handle the last 2-4 characters. This is similar to the above, but the - // last 2 characters may or may not exist. - DCHECK(inputSize >= 2); - uint32_t currentBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) | - (base64ReverseLookup(input[1], reverseIndex) << 12); - output[0] = (currentBlock >> 16) & 0xff; - if (inputSize > 2) { - currentBlock |= base64ReverseLookup(input[2], reverseIndex) << 6; - output[1] = (currentBlock >> 8) & 0xff; - if (inputSize > 3) { - currentBlock |= base64ReverseLookup(input[3], reverseIndex); - output[2] = currentBlock & 0xff; + // Handle the last block (2-3 characters) + size_t remaining = inputSize % 4; + if (remaining > 1) { + uint8_t val0 = base64ReverseLookup(inputPtr[0], reverseIndex, lookupStatus); + uint8_t val1 = base64ReverseLookup(inputPtr[1], reverseIndex, lookupStatus); + uint32_t currentBlock = (val0 << 18) | (val1 << 12); + outputPtr[0] = static_cast((currentBlock >> 16) & 0xFF); + + if (remaining == 3) { + uint8_t val2 = + base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus); + currentBlock |= (val2 << 6); + outputPtr[1] = static_cast((currentBlock >> 8) & 0xFF); } } - + if (!lookupStatus.ok()) + return lookupStatus; return Status::OK(); } diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 023abd031a88c..efeb42054ee4a 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -121,19 +121,18 @@ class Base64 { // Counts the number of padding characters in encoded input. static inline size_t numPadding(std::string_view input, size_t inputSize) { - size_t padding = 0; + size_t numPadding{0}; while (inputSize > 0 && input[inputSize - 1] == kPadding) { - padding++; + numPadding++; inputSize--; } - return padding; + return numPadding; } // Performs a reverse lookup in the reverse index to retrieve the original // index of a character in the base. - static uint8_t base64ReverseLookup( - char character, - const ReverseIndex& reverseIndex); + static uint8_t + base64ReverseLookup(char p, const ReverseIndex& reverseIndex, Status& status); // Encodes the specified input using the provided charset. template @@ -158,6 +157,7 @@ class Base64 { VELOX_FRIEND_TEST(Base64Test, isPadded); VELOX_FRIEND_TEST(Base64Test, numPadding); + VELOX_FRIEND_TEST(Base64Test, testDecodeImpl); }; } // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 7eaaccfa4ab53..e83e0aeaa6b82 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -17,13 +17,26 @@ #include "velox/common/encode/Base64.h" #include -#include "velox/common/base/Exceptions.h" #include "velox/common/base/Status.h" #include "velox/common/base/tests/GTestUtils.h" namespace facebook::velox::encoding { -class Base64Test : public ::testing::Test {}; +class Base64Test : public ::testing::Test { + protected: + void checkDecodedSize( + const std::string& encodedString, + size_t expectedEncodedSize, + size_t expectedDecodedSize) { + size_t encodedSize = expectedEncodedSize; + size_t decodedSize = 0; + EXPECT_EQ( + Status::OK(), + Base64::calculateDecodedSize(encodedString, encodedSize, decodedSize)); + EXPECT_EQ(expectedEncodedSize, encodedSize); + EXPECT_EQ(expectedDecodedSize, decodedSize); + } +}; TEST_F(Base64Test, fromBase64) { EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ==")); @@ -43,49 +56,23 @@ TEST_F(Base64Test, fromBase64) { } TEST_F(Base64Test, calculateDecodedSizeProperSize) { - size_t encoded_size{0}; - size_t decoded_size{0}; - - encoded_size = 20; - Base64::calculateDecodedSize( - "SGVsbG8sIFdvcmxkIQ==", encoded_size, decoded_size); - EXPECT_EQ(18, encoded_size); - EXPECT_EQ(13, decoded_size); - - encoded_size = 18; - Base64::calculateDecodedSize( - "SGVsbG8sIFdvcmxkIQ", encoded_size, decoded_size); - EXPECT_EQ(18, encoded_size); - EXPECT_EQ(13, decoded_size); - - encoded_size = 21; + checkDecodedSize("SGVsbG8sIFdvcmxkIQ==", 18, 13); + checkDecodedSize("SGVsbG8sIFdvcmxkIQ", 18, 13); + checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 31, 23); + checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 23); + checkDecodedSize("MTIzNDU2Nzg5MA==", 14, 10); + checkDecodedSize("MTIzNDU2Nzg5MA", 14, 10); +} + +TEST_F(Base64Test, calculateDecodedSizeImproperSize) { + size_t encodedSize{21}; + size_t decodedSize; + EXPECT_EQ( Status::UserError( "Base64::decode() - invalid input string: string length is not a multiple of 4."), Base64::calculateDecodedSize( - "SGVsbG8sIFdvcmxkIQ===", encoded_size, decoded_size)); - - encoded_size = 32; - Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size, decoded_size); - EXPECT_EQ(31, encoded_size); - EXPECT_EQ(23, decoded_size); - - encoded_size = 31; - Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size, decoded_size); - EXPECT_EQ(31, encoded_size); - EXPECT_EQ(23, decoded_size); - - encoded_size = 16; - Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size, decoded_size); - EXPECT_EQ(14, encoded_size); - EXPECT_EQ(10, decoded_size); - - encoded_size = 14; - Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size, decoded_size); - EXPECT_EQ(14, encoded_size); - EXPECT_EQ(10, decoded_size); + "SGVsbG8sIFdvcmxkIQ===", encodedSize, decodedSize)); } TEST_F(Base64Test, isPadded) { @@ -98,4 +85,93 @@ TEST_F(Base64Test, numPadding) { EXPECT_EQ(1, Base64::numPadding("ABC=", 4)); EXPECT_EQ(2, Base64::numPadding("AB==", 4)); } + +TEST_F(Base64Test, testDecodeImpl) { + constexpr const Base64::ReverseIndex reverseTable = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, + 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, + 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255}; + + auto testDecode = [&](const std::string_view input, + char* output1, + size_t outputSize, + Status expectedStatus) { + EXPECT_EQ( + Base64::decodeImpl( + input, input.size(), output1, outputSize, reverseTable), + expectedStatus); + }; + + // Predefine buffer sizes and reuse. + char output1[20] = {}; + char output2[1] = {}; + char output3[1] = {}; + + // Invalid characters in the input string + testDecode( + "SGVsbG8gd29ybGQ$", + output1, + sizeof(output1), + Status::UserError( + "Base64::decode() - invalid input string: contains invalid characters.")); + + // All characters are padding characters + testDecode("====", output1, sizeof(output1), Status::OK()); + + // Invalid input size + testDecode( + "S", + output1, + sizeof(output1), + Status::UserError( + "Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4.")); + + // Valid input without padding characters + testDecode("SGVsbG8gd29ybGQ", output1, sizeof(output1), Status::OK()); + EXPECT_STREQ(output1, "Hello world"); + + // Empty input string + testDecode("", output2, sizeof(output2), Status::OK()); + EXPECT_STREQ(output2, ""); + + // Invalid input size + testDecode( + "SGVsbG8gd29ybGQ===", + output1, + sizeof(output1), + Status::UserError( + "Base64::decode() - invalid input string: string length is not a multiple of 4.")); + + // whiltespaces in the input string + testDecode( + " SGVsb G8gd2 9ybGQ= ", + output1, + sizeof(output1), + Status::UserError( + "Base64::decode() - invalid input string: contains invalid characters.")); + + // insufficient buffer size + testDecode( + " SGVsb G8gd2 9ybGQ= ", + output3, + sizeof(output3), + Status::UserError( + "Base64::decode() - invalid output string: output string is too small.")); +} + } // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 7920e80b2c2cf..266e091d02914 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -16,9 +16,4 @@ add_executable(velox_common_encode_test Base64Test.cpp) add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test - PRIVATE - velox_encode - velox_status - velox_exception - GTest::gtest - GTest::gtest_main) + PRIVATE velox_encode velox_status GTest::gtest GTest::gtest_main)