diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index da4e9cdbfcfdd..ee65b24861f4c 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -117,22 +117,22 @@ static_assert( // Searches for a character within a charset up to a certain index. constexpr bool findCharacterInCharset( const Base64::Charset& charset, - uint8_t idx, - const char c) { - return idx < charset.size() && - ((charset[idx] == c) || findCharacterInCharset(charset, idx + 1, c)); + uint8_t index, + const char character) { + return index < charset.size() && + ((charset[index] == character) || + findCharacterInCharset(charset, index + 1, character)); } -// Checks the consistency of a reverse index mapping for a given character -// set. +// Checks the consistency of a reverse index mapping for a given character set. constexpr bool checkReverseIndex( - uint8_t idx, + uint8_t index, const Base64::Charset& charset, const Base64::ReverseIndex& reverseIndex) { - return (reverseIndex[idx] == 255 - ? !findCharacterInCharset(charset, 0, static_cast(idx)) - : (charset[reverseIndex[idx]] == idx)) && - (idx > 0 ? checkReverseIndex(idx - 1, charset, reverseIndex) : true); + return (reverseIndex[index] == 255 + ? !findCharacterInCharset(charset, 0, static_cast(index)) + : (charset[reverseIndex[index]] == index)) && + (index > 0 ? checkReverseIndex(index - 1, charset, reverseIndex) : true); } // Verify that for every entry in kBase64ReverseIndexTable, the corresponding @@ -156,101 +156,98 @@ static_assert( // "kBase64UrlReverseIndexTable has incorrect entries."); // Implementation of Base64 encoding and decoding functions. +// static template -/* static */ std::string Base64::encodeImpl( - const T& data, +std::string Base64::encodeImpl( + const T& input, const Base64::Charset& charset, - bool include_pad) { - size_t outlen = calculateEncodedSize(data.size(), include_pad); - std::string out; - out.resize(outlen); - encodeImpl(data, charset, include_pad, out.data()); - return out; + bool includePadding) { + size_t outputLength = calculateEncodedSize(input.size(), includePadding); + std::string output; + output.resize(outputLength); + encodeImpl(input, charset, includePadding, output.data()); + return output; } // static -size_t Base64::calculateEncodedSize(size_t size, bool withPadding) { - if (size == 0) { +size_t Base64::calculateEncodedSize(size_t inputSize, bool includePadding) { + if (inputSize == 0) { return 0; } // Calculate the output size assuming that we are including padding. - size_t encodedSize = ((size + 2) / 3) * 4; - if (!withPadding) { - // If the padding was not requested, subtract the padding bytes. - encodedSize -= (3 - (size % 3)) % 3; + size_t encodedSize = ((inputSize + 2) / 3) * 4; + if (!includePadding) { + encodedSize -= (3 - (inputSize % 3)) % 3; } return encodedSize; } // static -void Base64::encode(const char* data, size_t len, char* output) { - encodeImpl(folly::StringPiece(data, len), kBase64Charset, true, output); +Status Base64::encode(std::string_view input, char* output) { + return encodeImpl(input, kBase64Charset, true, output); } // static -void Base64::encodeUrl(const char* data, size_t len, char* output) { - encodeImpl(folly::StringPiece(data, len), kBase64UrlCharset, true, output); +Status Base64::encodeUrl(std::string_view input, char* output) { + return encodeImpl(input, kBase64UrlCharset, true, output); } +// static template -/* static */ void Base64::encodeImpl( - const T& data, +Status Base64::encodeImpl( + const T& input, const Base64::Charset& charset, - bool include_pad, - char* out) { - auto len = data.size(); - if (len == 0) { - return; + bool includePadding, + char* output) { + auto inputSize = input.size(); + if (inputSize == 0) { + return Status::OK(); } - auto wp = out; - auto it = data.begin(); + auto outputPtr = output; + auto dataIterator = input.begin(); - // For each group of 3 bytes (24 bits) in the input, split that into - // 4 groups of 6 bits and encode that using the supplied charset lookup - for (; len > 2; len -= 3) { - uint32_t curr = uint8_t(*it++) << 16; - curr |= uint8_t(*it++) << 8; - curr |= uint8_t(*it++); + for (; inputSize > 2; inputSize -= 3) { + uint32_t currentBlock = uint8_t(*dataIterator++) << 16; + currentBlock |= uint8_t(*dataIterator++) << 8; + currentBlock |= uint8_t(*dataIterator++); - *wp++ = charset[(curr >> 18) & 0x3f]; - *wp++ = charset[(curr >> 12) & 0x3f]; - *wp++ = charset[(curr >> 6) & 0x3f]; - *wp++ = charset[curr & 0x3f]; + *outputPtr++ = charset[(currentBlock >> 18) & 0x3f]; + *outputPtr++ = charset[(currentBlock >> 12) & 0x3f]; + *outputPtr++ = charset[(currentBlock >> 6) & 0x3f]; + *outputPtr++ = charset[currentBlock & 0x3f]; } - if (len > 0) { - // We have either 1 or 2 input bytes left. Encode this similar to the - // above (assuming 0 for all other bytes). Optionally append the '=' - // character if it is requested. - uint32_t curr = uint8_t(*it++) << 16; - *wp++ = charset[(curr >> 18) & 0x3f]; - if (len > 1) { - curr |= uint8_t(*it) << 8; - *wp++ = charset[(curr >> 12) & 0x3f]; - *wp++ = charset[(curr >> 6) & 0x3f]; - if (include_pad) { - *wp = kPadding; + if (inputSize > 0) { + uint32_t currentBlock = uint8_t(*dataIterator++) << 16; + *outputPtr++ = charset[(currentBlock >> 18) & 0x3f]; + if (inputSize > 1) { + currentBlock |= uint8_t(*dataIterator) << 8; + *outputPtr++ = charset[(currentBlock >> 12) & 0x3f]; + *outputPtr++ = charset[(currentBlock >> 6) & 0x3f]; + if (includePadding) { + *outputPtr = kPadding; } } else { - *wp++ = charset[(curr >> 12) & 0x3f]; - if (include_pad) { - *wp++ = kPadding; - *wp = kPadding; + *outputPtr++ = charset[(currentBlock >> 12) & 0x3f]; + if (includePadding) { + *outputPtr++ = kPadding; + *outputPtr = kPadding; } } } + return Status::OK(); } // static -std::string Base64::encode(folly::StringPiece text) { +std::string Base64::encode(std::string_view text) { return encodeImpl(text, kBase64Charset, true); } // static -std::string Base64::encode(const char* data, size_t len) { - return encode(folly::StringPiece(data, len)); +std::string Base64::encode(std::string_view input, size_t /*len*/) { + return encodeImpl(input, kBase64Charset, true); } namespace { @@ -265,7 +262,7 @@ class IOBufWrapper { private: class Iterator { public: - explicit Iterator(const folly::IOBuf* data) : cs_(data) {} + explicit Iterator(const folly::IOBuf* data) : cursor_(data) {} Iterator& operator++(int32_t) { // This is a noop since reading from the Cursor has already moved the @@ -275,11 +272,11 @@ class IOBufWrapper { uint8_t operator*() { // This will read _and_ increment - return cs_.read(); + return cursor_.read(); } private: - folly::io::Cursor cs_; + folly::io::Cursor cursor_; }; public: @@ -300,182 +297,191 @@ class IOBufWrapper { } // namespace // static -std::string Base64::encode(const folly::IOBuf* data) { - return encodeImpl(IOBufWrapper(data), kBase64Charset, true); +std::string Base64::encode(const folly::IOBuf* input) { + return encodeImpl(IOBufWrapper(input), kBase64Charset, true); } // static -std::string Base64::decode(folly::StringPiece encoded) { +std::string Base64::decode(std::string_view encoded) { std::string output; - Base64::decode(std::make_pair(encoded.data(), encoded.size()), output); + Base64::decode(encoded, output); return output; } // static -void Base64::decode( - const std::pair& payload, - std::string& output) { - size_t inputSize = payload.second; - output.resize(calculateDecodedSize(payload.first, inputSize)); - decode(payload.first, inputSize, output.data(), output.size()); +void Base64::decode(std::string_view input, std::string& output) { + size_t inputSize{input.size()}; + size_t decodedSize; + + calculateDecodedSize(input, inputSize, decodedSize); + output.resize(decodedSize); + decode(input.data(), inputSize, output.data(), output.size()); } // static -void Base64::decode(const char* data, size_t size, char* output) { - size_t out_len = size / 4 * 3; - Base64::decode(data, size, output, out_len); +void Base64::decode(std::string_view input, size_t size, char* output) { + size_t outputLength = size / 4 * 3; + Base64::decode(input, size, output, outputLength); } // static uint8_t Base64::base64ReverseLookup( - char p, + char character, const Base64::ReverseIndex& reverseIndex) { - auto curr = reverseIndex[(uint8_t)p]; - if (curr >= 0x40) { + auto lookupValue = reverseIndex[(uint8_t)character]; + if (lookupValue >= 0x40) { VELOX_USER_FAIL("decode() - invalid input string: invalid characters"); } - return curr; + return lookupValue; } // static -size_t -Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { - return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable); +Status Base64::decode( + std::string_view input, + size_t inputSize, + char* output, + size_t outputSize) { + return decodeImpl( + input, inputSize, output, outputSize, kBase64ReverseIndexTable); } // static -size_t Base64::calculateDecodedSize(const char* data, size_t& size) { - if (size == 0) { - return 0; +Status Base64::calculateDecodedSize( + std::string_view input, + size_t& inputSize, + size_t& decodedSize) { + if (inputSize == 0) { + decodedSize = 0; + return Status::OK(); } // Check if the input data is padded - if (isPadded(data, size)) { + if (isPadded(input, inputSize)) { // If padded, ensure that the string length is a multiple of the encoded // block size - if (size % kEncodedBlockByteSize != 0) { - VELOX_USER_FAIL( - "Base64::decode() - invalid input string: " - "string length is not a multiple of 4."); + if (inputSize % kEncodedBlockByteSize != 0) { + return Status::UserError( + "Base64::decode() - invalid input string: string length is not a multiple of 4."); } - auto needed = (size * kBinaryBlockByteSize) / kEncodedBlockByteSize; - auto padding = numPadding(data, size); - size -= padding; + decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize; + auto padding = numPadding(input, inputSize); + inputSize -= padding; // Adjust the needed size by deducting the bytes corresponding to the // padding from the calculated size. - return needed - + decodedSize -= ((padding * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) / kEncodedBlockByteSize; + return Status::OK(); } - // If not padded, Calculate extra bytes, if any - auto extra = size % kEncodedBlockByteSize; - auto needed = (size / kEncodedBlockByteSize) * kBinaryBlockByteSize; + + // If not padded, calculate extra bytes, if any + auto extraBytes = inputSize % kEncodedBlockByteSize; + decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize; // Adjust the needed size for extra bytes, if present - if (extra) { - if (extra == 1) { - VELOX_USER_FAIL( - "Base64::decode() - invalid input string: " - "string length cannot be 1 more than a multiple of 4."); + if (extraBytes) { + if (extraBytes == 1) { + return Status::UserError( + "Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4."); } - needed += (extra * kBinaryBlockByteSize) / kEncodedBlockByteSize; + decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize; } - return needed; + return Status::OK(); } // static -size_t Base64::decodeImpl( - const char* src, - size_t src_len, - char* dst, - size_t dst_len, +Status Base64::decodeImpl( + std::string_view input, + size_t inputSize, + char* output, + size_t outputSize, const Base64::ReverseIndex& reverseIndex) { - if (!src_len) { - return 0; + if (!inputSize) { + return Status::OK(); + } + + size_t decodedSize; + // Calculate decoded size and check for status + auto status = calculateDecodedSize(input, inputSize, decodedSize); + if (!status.ok()) { + return status; } - auto needed = calculateDecodedSize(src, src_len); - if (dst_len < needed) { - VELOX_USER_FAIL( - "Base64::decode() - invalid output string: " - "output string is too small."); + if (outputSize < decodedSize) { + return Status::UserError( + "Base64::decode() - invalid output string: output string is too small."); } // Handle full groups of 4 characters - for (; src_len > 4; src_len -= 4, src += 4, dst += 3) { - // Each character of the 4 encode 6 bits of the original, grab each with + for (; inputSize > 4; inputSize -= 4, input.remove_prefix(4), output += 3) { + // Each character of the 4 encodes 6 bits of the original, grab each with // the appropriate shifts to rebuild the original and then split that back - // into the original 8 bit bytes. - uint32_t last = (base64ReverseLookup(src[0], reverseIndex) << 18) | - (base64ReverseLookup(src[1], reverseIndex) << 12) | - (base64ReverseLookup(src[2], reverseIndex) << 6) | - base64ReverseLookup(src[3], reverseIndex); - dst[0] = (last >> 16) & 0xff; - dst[1] = (last >> 8) & 0xff; - dst[2] = last & 0xff; + // into the original 8-bit bytes. + uint32_t currentBlock = + (base64ReverseLookup(input[0], reverseIndex) << 18) | + (base64ReverseLookup(input[1], reverseIndex) << 12) | + (base64ReverseLookup(input[2], reverseIndex) << 6) | + base64ReverseLookup(input[3], reverseIndex); + output[0] = (currentBlock >> 16) & 0xff; + output[1] = (currentBlock >> 8) & 0xff; + output[2] = currentBlock & 0xff; } - // Handle the last 2-4 characters. This is similar to the above, but the + // Handle the last 2-4 characters. This is similar to the above, but the // last 2 characters may or may not exist. - DCHECK(src_len >= 2); - uint32_t last = (base64ReverseLookup(src[0], reverseIndex) << 18) | - (base64ReverseLookup(src[1], reverseIndex) << 12); - dst[0] = (last >> 16) & 0xff; - if (src_len > 2) { - last |= base64ReverseLookup(src[2], reverseIndex) << 6; - dst[1] = (last >> 8) & 0xff; - if (src_len > 3) { - last |= base64ReverseLookup(src[3], reverseIndex); - dst[2] = last & 0xff; + DCHECK(inputSize >= 2); + uint32_t currentBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) | + (base64ReverseLookup(input[1], reverseIndex) << 12); + output[0] = (currentBlock >> 16) & 0xff; + if (inputSize > 2) { + currentBlock |= base64ReverseLookup(input[2], reverseIndex) << 6; + output[1] = (currentBlock >> 8) & 0xff; + if (inputSize > 3) { + currentBlock |= base64ReverseLookup(input[3], reverseIndex); + output[2] = currentBlock & 0xff; } } - return needed; -} - -// static -std::string Base64::encodeUrl(folly::StringPiece text) { - return encodeImpl(text, kBase64UrlCharset, false); + return Status::OK(); } // static -std::string Base64::encodeUrl(const char* data, size_t len) { - return encodeUrl(folly::StringPiece(data, len)); +std::string Base64::encodeUrl(std::string_view input) { + return encodeImpl(input, kBase64UrlCharset, false); } // static -std::string Base64::encodeUrl(const folly::IOBuf* data) { - return encodeImpl(IOBufWrapper(data), kBase64UrlCharset, false); +std::string Base64::encodeUrl(const folly::IOBuf* input) { + return encodeImpl(IOBufWrapper(input), kBase64UrlCharset, false); } // static -void Base64::decodeUrl( - const char* src, - size_t src_len, - char* dst, - size_t dst_len) { - decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable); +Status Base64::decodeUrl( + std::string_view input, + size_t inputSize, + char* output, + size_t outputSize) { + return decodeImpl( + input, inputSize, output, outputSize, kBase64UrlReverseIndexTable); } // static -std::string Base64::decodeUrl(folly::StringPiece encoded) { +std::string Base64::decodeUrl(std::string_view input) { std::string output; - Base64::decodeUrl(std::make_pair(encoded.data(), encoded.size()), output); + Base64::decodeUrl(input, output); return output; } // static -void Base64::decodeUrl( - const std::pair& payload, - std::string& output) { - size_t out_len = (payload.second + 3) / 4 * 3; +void Base64::decodeUrl(std::string_view input, std::string& output) { + size_t out_len = (input.size() + 3) / 4 * 3; output.resize(out_len, '\0'); - out_len = Base64::decodeImpl( - payload.first, - payload.second, + Base64::decodeImpl( + input.data(), + input.size(), &output[0], out_len, kBase64UrlReverseIndexTable); diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 13004175379a6..023abd031a88c 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -16,14 +16,13 @@ #pragma once #include -#include -#include #include #include #include #include "velox/common/base/GTestMacros.h" +#include "velox/common/base/Status.h" namespace facebook::velox::encoding { @@ -41,119 +40,124 @@ class Base64 { /// within the encoding base. using ReverseIndex = std::array; - /// Padding character used in encoding. - static const char kPadding = '='; - - /// Encodes the specified number of characters from the 'data'. - static std::string encode(const char* data, size_t len); + /// Encodes the specified number of characters from the 'input'. + static std::string encode(std::string_view input, size_t len); /// Encodes the specified text. - static std::string encode(folly::StringPiece text); + static std::string encode(std::string_view text); - /// Encodes the specified IOBuf data. - static std::string encode(const folly::IOBuf* text); + /// Encodes the specified IOBuf input. + static std::string encode(const folly::IOBuf* input); /// Returns encoded size for the input of the specified size. - static size_t calculateEncodedSize(size_t size, bool withPadding = true); + static size_t calculateEncodedSize( + size_t inputSize, + bool includePadding = true); - /// Encodes the specified number of characters from the 'data' and writes the - /// result to the 'output'. The output must have enough space, e.g. as - /// returned by the calculateEncodedSize(). - static void encode(const char* data, size_t size, char* output); + /// Encodes the specified number of characters from the 'input' and writes the + /// result to the 'output'. The output must have enough space, e.g., as + /// returned by calculateEncodedSize(). + static Status encode(std::string_view input, char* output); /// Decodes the specified encoded text. - static std::string decode(folly::StringPiece encoded); + static std::string decode(std::string_view encoded); /// Returns the actual size of the decoded data. Will also remove the padding - /// length from the input data 'size'. - static size_t calculateDecodedSize(const char* data, size_t& size); + /// length from the 'inputSize'. + static Status calculateDecodedSize( + std::string_view input, + size_t& inputSize, + size_t& decodedSize); - /// Decodes the specified number of characters from the 'data' and writes the - /// result to the 'output'. The output must have enough space, e.g. as - /// returned by the calculateDecodedSize(). - static void decode(const char* data, size_t size, char* output); + /// Decodes the specified number of characters from the 'input' and writes the + /// result to the 'output'. The output must have enough space, e.g., as + /// returned by calculateDecodedSize(). + static void decode(std::string_view input, size_t inputSize, char* output); - static void decode( - const std::pair& payload, - std::string& output); + static void decode(std::string_view input, std::string& output); - /// Encodes the specified number of characters from the 'data' and writes the + /// Encodes the specified number of characters from the 'input' and writes the /// result to the 'output' using URL encoding. The output must have enough - /// space as returned by the calculateEncodedSize(). - static void encodeUrl(const char* data, size_t size, char* output); + /// space as returned by calculateEncodedSize(). + static Status encodeUrl(std::string_view input, char* output); - /// Encodes the specified number of characters from the 'data' using URL - /// encoding. - static std::string encodeUrl(const char* data, size_t len); - - /// Encodes the specified IOBuf data using URL encoding. - static std::string encodeUrl(const folly::IOBuf* data); + /// Encodes the specified IOBuf input using URL encoding. + static std::string encodeUrl(const folly::IOBuf* input); /// Encodes the specified text using URL encoding. - static std::string encodeUrl(folly::StringPiece text); + static std::string encodeUrl(std::string_view input); - /// Decodes the specified URL encoded payload and writes the result to the + /// Decodes the specified URL encoded input and writes the result to the /// 'output'. - static void decodeUrl( - const std::pair& payload, - std::string& output); + static void decodeUrl(std::string_view input, std::string& output); /// Decodes the specified URL encoded text. - static std::string decodeUrl(folly::StringPiece text); - - /// Decodes the specified number of characters from the 'src' and writes the - /// result to the 'dst'. - static size_t - decode(const char* src, size_t src_len, char* dst, size_t dst_len); - - /// Decodes the specified number of characters from the 'src' using URL - /// encoding and writes the result to the 'dst'. - static void - decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len); + static std::string decodeUrl(std::string_view input); + + /// Decodes the specified number of characters from the 'input' and writes the + /// result to the 'output'. + static Status decode( + std::string_view input, + size_t inputSize, + char* output, + size_t outputSize); + + /// Decodes the specified number of characters from the 'input' using URL + /// encoding and writes the result to the 'output'. + static Status decodeUrl( + std::string_view input, + size_t inputSize, + char* output, + size_t outputSize); private: - /// Checks if there is padding in encoded data. - static inline bool isPadded(const char* data, size_t len) { - return (len > 0 && data[len - 1] == kPadding); + // Padding character used in encoding. + static const char kPadding = '='; + + // Checks if there is padding in encoded input. + static inline bool isPadded(std::string_view input, size_t inputSize) { + return (inputSize > 0 && input[inputSize - 1] == kPadding); } - /// Counts the number of padding characters in encoded data. - static inline size_t numPadding(const char* src, size_t len) { - size_t numPadding{0}; - while (len > 0 && src[len - 1] == kPadding) { - numPadding++; - len--; + // Counts the number of padding characters in encoded input. + static inline size_t numPadding(std::string_view input, size_t inputSize) { + size_t padding = 0; + while (inputSize > 0 && input[inputSize - 1] == kPadding) { + padding++; + inputSize--; } - return numPadding; + return padding; } - /// Performs a reverse lookup in the reverse index to retrieve the original - /// index of a character in the base. - static uint8_t base64ReverseLookup(char p, const ReverseIndex& reverseIndex); + // Performs a reverse lookup in the reverse index to retrieve the original + // index of a character in the base. + static uint8_t base64ReverseLookup( + char character, + const ReverseIndex& reverseIndex); - /// Encodes the specified data using the provided charset. + // Encodes the specified input using the provided charset. template static std::string - encodeImpl(const T& data, const Charset& charset, bool include_pad); + encodeImpl(const T& input, const Charset& charset, bool includePadding); - /// Encodes the specified data using the provided charset. + // Encodes the specified input using the provided charset. template - static void encodeImpl( - const T& data, + static Status encodeImpl( + const T& input, const Charset& charset, - bool include_pad, - char* out); - - /// Decodes the specified data using the provided reverse lookup table. - static size_t decodeImpl( - const char* src, - size_t src_len, - char* dst, - size_t dst_len, - const ReverseIndex& table); - - VELOX_FRIEND_TEST(Base64Test, checksPadding); - VELOX_FRIEND_TEST(Base64Test, countsPaddingCorrectly); + bool includePadding, + char* output); + + // Decodes the specified input using the provided reverse lookup table. + static Status decodeImpl( + std::string_view input, + size_t inputSize, + char* output, + size_t outputSize, + const ReverseIndex& reverseIndex); + + VELOX_FRIEND_TEST(Base64Test, isPadded); + VELOX_FRIEND_TEST(Base64Test, numPadding); }; } // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 9cbbbad471245..7eaaccfa4ab53 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -18,6 +18,7 @@ #include #include "velox/common/base/Exceptions.h" +#include "velox/common/base/Status.h" #include "velox/common/base/tests/GTestUtils.h" namespace facebook::velox::encoding { @@ -25,76 +26,74 @@ namespace facebook::velox::encoding { class Base64Test : public ::testing::Test {}; TEST_F(Base64Test, fromBase64) { - EXPECT_EQ( - "Hello, World!", - Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ=="))); + EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ==")); EXPECT_EQ( "Base64 encoding is fun.", - Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4="))); - EXPECT_EQ( - "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ="))); - EXPECT_EQ( - "1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA=="))); + Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=")); + EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ=")); + EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA==")); // Check encoded strings without padding - EXPECT_EQ( - "Hello, World!", - Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ"))); + EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ")); EXPECT_EQ( "Base64 encoding is fun.", - Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4"))); - EXPECT_EQ( - "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ"))); - EXPECT_EQ("1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA"))); + Base64::decode("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4")); + EXPECT_EQ("Simple text", Base64::decode("U2ltcGxlIHRleHQ")); + EXPECT_EQ("1234567890", Base64::decode("MTIzNDU2Nzg5MA")); } TEST_F(Base64Test, calculateDecodedSizeProperSize) { size_t encoded_size{0}; + size_t decoded_size{0}; encoded_size = 20; - EXPECT_EQ( - 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size)); + Base64::calculateDecodedSize( + "SGVsbG8sIFdvcmxkIQ==", encoded_size, decoded_size); EXPECT_EQ(18, encoded_size); + EXPECT_EQ(13, decoded_size); encoded_size = 18; - EXPECT_EQ( - 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ", encoded_size)); + Base64::calculateDecodedSize( + "SGVsbG8sIFdvcmxkIQ", encoded_size, decoded_size); EXPECT_EQ(18, encoded_size); + EXPECT_EQ(13, decoded_size); encoded_size = 21; - VELOX_ASSERT_THROW( - Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size), - "Base64::decode() - invalid input string: string length cannot be 1 more than a multiple of 4."); - - encoded_size = 32; EXPECT_EQ( - 23, + Status::UserError( + "Base64::decode() - invalid input string: string length is not a multiple of 4."), Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size)); + "SGVsbG8sIFdvcmxkIQ===", encoded_size, decoded_size)); + + encoded_size = 32; + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size, decoded_size); EXPECT_EQ(31, encoded_size); + EXPECT_EQ(23, decoded_size); encoded_size = 31; - EXPECT_EQ( - 23, - Base64::calculateDecodedSize( - "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size)); + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size, decoded_size); EXPECT_EQ(31, encoded_size); + EXPECT_EQ(23, decoded_size); encoded_size = 16; - EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size)); + Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size, decoded_size); EXPECT_EQ(14, encoded_size); + EXPECT_EQ(10, decoded_size); encoded_size = 14; - EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size)); + Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size, decoded_size); EXPECT_EQ(14, encoded_size); + EXPECT_EQ(10, decoded_size); } -TEST_F(Base64Test, checksPadding) { +TEST_F(Base64Test, isPadded) { EXPECT_TRUE(Base64::isPadded("ABC=", 4)); EXPECT_FALSE(Base64::isPadded("ABC", 3)); } -TEST_F(Base64Test, countsPaddingCorrectly) { +TEST_F(Base64Test, numPadding) { EXPECT_EQ(0, Base64::numPadding("ABC", 3)); EXPECT_EQ(1, Base64::numPadding("ABC=", 4)); EXPECT_EQ(2, Base64::numPadding("AB==", 4)); diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 90c9733ecf22e..7920e80b2c2cf 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -16,5 +16,9 @@ add_executable(velox_common_encode_test Base64Test.cpp) add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test - PUBLIC Folly::folly - PRIVATE velox_encode velox_exception GTest::gtest GTest::gtest_main) + PRIVATE + velox_encode + velox_status + velox_exception + GTest::gtest + GTest::gtest_main) diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index ce153ee349fc4..46f44fc151192 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -278,11 +278,12 @@ template struct ToBase64Function { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { + std::string_view inputString( + reinterpret_cast(input.data()), input.size()); result.resize(encoding::Base64::calculateEncodedSize(input.size())); - encoding::Base64::encode(input.data(), input.size(), result.data()); + return encoding::Base64::encode(inputString, result.data()); } }; @@ -293,11 +294,16 @@ struct FromBase64Function { // T can be either arg_type or arg_type. These are the // same, but hard-coding one of them might be confusing. template - FOLLY_ALWAYS_INLINE void call(out_type& result, const T& input) { + FOLLY_ALWAYS_INLINE Status call(out_type& result, const T& input) { auto inputSize = input.size(); - result.resize( - encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decode( + size_t decodedSize; + auto status = encoding::Base64::calculateDecodedSize( + input.data(), inputSize, decodedSize); + if (!status.ok()) { + return status; + } + result.resize(decodedSize); + return encoding::Base64::decode( input.data(), inputSize, result.data(), result.size()); } }; @@ -305,13 +311,17 @@ struct FromBase64Function { template struct FromBase64UrlFunction { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { auto inputSize = input.size(); - result.resize( - encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decodeUrl( + size_t decodedSize; + auto status = encoding::Base64::calculateDecodedSize( + input.data(), inputSize, decodedSize); + if (!status.ok()) { + return status; + } + result.resize(decodedSize); + return encoding::Base64::decodeUrl( input.data(), inputSize, result.data(), result.size()); } }; @@ -320,11 +330,12 @@ template struct ToBase64UrlFunction { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { + FOLLY_ALWAYS_INLINE Status + call(out_type& result, const arg_type& input) { + std::string_view inputString( + reinterpret_cast(input.data()), input.size()); result.resize(encoding::Base64::calculateEncodedSize(input.size())); - encoding::Base64::encodeUrl(input.data(), input.size(), result.data()); + return encoding::Base64::encodeUrl(inputString, result.data()); } };