Skip to content

Commit

Permalink
Convert Base64 as a non-throwing API
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Sep 23, 2024
1 parent 936300e commit af8dbed
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 88 deletions.
79 changes: 44 additions & 35 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
#include <folly/io/Cursor.h>
#include <stdint.h>

#include "velox/common/base/Exceptions.h"

namespace facebook::velox::encoding {

// Constants defining the size in bytes of binary and encoded blocks for Base64
Expand Down Expand Up @@ -326,13 +324,16 @@ void Base64::decode(std::string_view input, size_t size, char* output) {

// static
uint8_t Base64::base64ReverseLookup(
char character,
const Base64::ReverseIndex& reverseIndex) {
auto lookupValue = reverseIndex[(uint8_t)character];
if (lookupValue >= 0x40) {
VELOX_USER_FAIL("decode() - invalid input string: invalid characters");
char p,
const Base64::ReverseIndex& reverseIndex,
Status& status) {
auto curr = reverseIndex[(uint8_t)p];
if (curr >= 0x40) {
status = Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters.");
return 0; // Return 0 or any other error code indicating failure
}
return lookupValue;
return curr;
}

// static
Expand Down Expand Up @@ -399,7 +400,7 @@ Status Base64::decodeImpl(
char* output,
size_t outputSize,
const Base64::ReverseIndex& reverseIndex) {
if (!inputSize) {
if (inputSize == 0) {
return Status::OK();
}

Expand All @@ -415,36 +416,44 @@ Status Base64::decodeImpl(
"Base64::decode() - invalid output string: output string is too small.");
}

// Handle full groups of 4 characters
for (; inputSize > 4; inputSize -= 4, input.remove_prefix(4), output += 3) {
// Each character of the 4 encodes 6 bits of the original, grab each with
// the appropriate shifts to rebuild the original and then split that back
// into the original 8-bit bytes.
uint32_t currentBlock =
(base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12) |
(base64ReverseLookup(input[2], reverseIndex) << 6) |
base64ReverseLookup(input[3], reverseIndex);
output[0] = (currentBlock >> 16) & 0xff;
output[1] = (currentBlock >> 8) & 0xff;
output[2] = currentBlock & 0xff;
const char* inputPtr = input.data();
char* outputPtr = output;
Status lookupStatus;

// Process full blocks of 4 characters
size_t fullBlockCount = inputSize / 4;
for (size_t i = 0; i < fullBlockCount; ++i) {
uint8_t val0 = base64ReverseLookup(inputPtr[0], reverseIndex, lookupStatus);
uint8_t val1 = base64ReverseLookup(inputPtr[1], reverseIndex, lookupStatus);
uint8_t val2 = base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus);
uint8_t val3 = base64ReverseLookup(inputPtr[3], reverseIndex, lookupStatus);

uint32_t currentBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3;
outputPtr[0] = static_cast<char>((currentBlock >> 16) & 0xFF);
outputPtr[1] = static_cast<char>((currentBlock >> 8) & 0xFF);
outputPtr[2] = static_cast<char>(currentBlock & 0xFF);

inputPtr += 4;
outputPtr += 3;
}

// Handle the last 2-4 characters. This is similar to the above, but the
// last 2 characters may or may not exist.
DCHECK(inputSize >= 2);
uint32_t currentBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12);
output[0] = (currentBlock >> 16) & 0xff;
if (inputSize > 2) {
currentBlock |= base64ReverseLookup(input[2], reverseIndex) << 6;
output[1] = (currentBlock >> 8) & 0xff;
if (inputSize > 3) {
currentBlock |= base64ReverseLookup(input[3], reverseIndex);
output[2] = currentBlock & 0xff;
// Handle the last block (2-3 characters)
size_t remaining = inputSize % 4;
if (remaining > 1) {
uint8_t val0 = base64ReverseLookup(inputPtr[0], reverseIndex, lookupStatus);
uint8_t val1 = base64ReverseLookup(inputPtr[1], reverseIndex, lookupStatus);
uint32_t currentBlock = (val0 << 18) | (val1 << 12);
outputPtr[0] = static_cast<char>((currentBlock >> 16) & 0xFF);

if (remaining == 3) {
uint8_t val2 =
base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus);
currentBlock |= (val2 << 6);
outputPtr[1] = static_cast<char>((currentBlock >> 8) & 0xFF);
}
}

if (!lookupStatus.ok())
return lookupStatus;
return Status::OK();
}

Expand Down
12 changes: 6 additions & 6 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,18 @@ class Base64 {

// Counts the number of padding characters in encoded input.
static inline size_t numPadding(std::string_view input, size_t inputSize) {
size_t padding = 0;
size_t numPadding{0};
while (inputSize > 0 && input[inputSize - 1] == kPadding) {
padding++;
numPadding++;
inputSize--;
}
return padding;
return numPadding;
}

// Performs a reverse lookup in the reverse index to retrieve the original
// index of a character in the base.
static uint8_t base64ReverseLookup(
char character,
const ReverseIndex& reverseIndex);
static uint8_t
base64ReverseLookup(char p, const ReverseIndex& reverseIndex, Status& status);

// Encodes the specified input using the provided charset.
template <class T>
Expand All @@ -158,6 +157,7 @@ class Base64 {

VELOX_FRIEND_TEST(Base64Test, isPadded);
VELOX_FRIEND_TEST(Base64Test, numPadding);
VELOX_FRIEND_TEST(Base64Test, testDecodeImpl);
};

} // namespace facebook::velox::encoding
158 changes: 117 additions & 41 deletions velox/common/encode/tests/Base64Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,26 @@
#include "velox/common/encode/Base64.h"

#include <gtest/gtest.h>
#include "velox/common/base/Exceptions.h"
#include "velox/common/base/Status.h"
#include "velox/common/base/tests/GTestUtils.h"

namespace facebook::velox::encoding {

class Base64Test : public ::testing::Test {};
class Base64Test : public ::testing::Test {
protected:
void checkDecodedSize(
const std::string& encodedString,
size_t expectedEncodedSize,
size_t expectedDecodedSize) {
size_t encodedSize = expectedEncodedSize;
size_t decodedSize = 0;
EXPECT_EQ(
Status::OK(),
Base64::calculateDecodedSize(encodedString, encodedSize, decodedSize));
EXPECT_EQ(expectedEncodedSize, encodedSize);
EXPECT_EQ(expectedDecodedSize, decodedSize);
}
};

TEST_F(Base64Test, fromBase64) {
EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ=="));
Expand All @@ -43,49 +56,23 @@ TEST_F(Base64Test, fromBase64) {
}

TEST_F(Base64Test, calculateDecodedSizeProperSize) {
size_t encoded_size{0};
size_t decoded_size{0};

encoded_size = 20;
Base64::calculateDecodedSize(
"SGVsbG8sIFdvcmxkIQ==", encoded_size, decoded_size);
EXPECT_EQ(18, encoded_size);
EXPECT_EQ(13, decoded_size);

encoded_size = 18;
Base64::calculateDecodedSize(
"SGVsbG8sIFdvcmxkIQ", encoded_size, decoded_size);
EXPECT_EQ(18, encoded_size);
EXPECT_EQ(13, decoded_size);

encoded_size = 21;
checkDecodedSize("SGVsbG8sIFdvcmxkIQ==", 18, 13);
checkDecodedSize("SGVsbG8sIFdvcmxkIQ", 18, 13);
checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 31, 23);
checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 23);
checkDecodedSize("MTIzNDU2Nzg5MA==", 14, 10);
checkDecodedSize("MTIzNDU2Nzg5MA", 14, 10);
}

TEST_F(Base64Test, calculateDecodedSizeImproperSize) {
size_t encodedSize{21};
size_t decodedSize;

EXPECT_EQ(
Status::UserError(
"Base64::decode() - invalid input string: string length is not a multiple of 4."),
Base64::calculateDecodedSize(
"SGVsbG8sIFdvcmxkIQ===", encoded_size, decoded_size));

encoded_size = 32;
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size, decoded_size);
EXPECT_EQ(31, encoded_size);
EXPECT_EQ(23, decoded_size);

encoded_size = 31;
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size, decoded_size);
EXPECT_EQ(31, encoded_size);
EXPECT_EQ(23, decoded_size);

encoded_size = 16;
Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size, decoded_size);
EXPECT_EQ(14, encoded_size);
EXPECT_EQ(10, decoded_size);

encoded_size = 14;
Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size, decoded_size);
EXPECT_EQ(14, encoded_size);
EXPECT_EQ(10, decoded_size);
"SGVsbG8sIFdvcmxkIQ===", encodedSize, decodedSize));
}

TEST_F(Base64Test, isPadded) {
Expand All @@ -98,4 +85,93 @@ TEST_F(Base64Test, numPadding) {
EXPECT_EQ(1, Base64::numPadding("ABC=", 4));
EXPECT_EQ(2, Base64::numPadding("AB==", 4));
}

TEST_F(Base64Test, testDecodeImpl) {
constexpr const Base64::ReverseIndex reverseTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255,
255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255,
255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

auto testDecode = [&](const std::string_view input,
char* output1,
size_t outputSize,
Status expectedStatus) {
EXPECT_EQ(
Base64::decodeImpl(
input, input.size(), output1, outputSize, reverseTable),
expectedStatus);
};

// Predefine buffer sizes and reuse.
char output1[20] = {};
char output2[1] = {};
char output3[1] = {};

// Invalid characters in the input string
testDecode(
"SGVsbG8gd29ybGQ$",
output1,
sizeof(output1),
Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters."));

// All characters are padding characters
testDecode("====", output1, sizeof(output1), Status::OK());

// Invalid input size
testDecode(
"S",
output1,
sizeof(output1),
Status::UserError(
"Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4."));

// Valid input without padding characters
testDecode("SGVsbG8gd29ybGQ", output1, sizeof(output1), Status::OK());
EXPECT_STREQ(output1, "Hello world");

// Empty input string
testDecode("", output2, sizeof(output2), Status::OK());
EXPECT_STREQ(output2, "");

// Invalid input size
testDecode(
"SGVsbG8gd29ybGQ===",
output1,
sizeof(output1),
Status::UserError(
"Base64::decode() - invalid input string: string length is not a multiple of 4."));

// whiltespaces in the input string
testDecode(
" SGVsb G8gd2 9ybGQ= ",
output1,
sizeof(output1),
Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters."));

// insufficient buffer size
testDecode(
" SGVsb G8gd2 9ybGQ= ",
output3,
sizeof(output3),
Status::UserError(
"Base64::decode() - invalid output string: output string is too small."));
}

} // namespace facebook::velox::encoding
7 changes: 1 addition & 6 deletions velox/common/encode/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,4 @@ add_executable(velox_common_encode_test Base64Test.cpp)
add_test(velox_common_encode_test velox_common_encode_test)
target_link_libraries(
velox_common_encode_test
PRIVATE
velox_encode
velox_status
velox_exception
GTest::gtest
GTest::gtest_main)
PRIVATE velox_encode velox_status GTest::gtest GTest::gtest_main)

0 comments on commit af8dbed

Please sign in to comment.