From a0946a1e8844d0d04d224062ed760e64fddc63e6 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Mon, 27 May 2024 22:39:37 +0530 Subject: [PATCH] x --- velox/functions/prestosql/BinaryFunctions.h | 24 +++++++++++++++---- .../prestosql/tests/BinaryFunctionsTest.cpp | 12 ++++------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index 3d48371a2fdb..b12b6ebf1548 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -24,7 +24,6 @@ #include "folly/ssl/OpenSSLHash.h" #include "velox/common/base/BitUtil.h" -#include "velox/common/encode/Base64.h" #include "velox/external/md5/md5.h" #include "velox/functions/Udf.h" #include "velox/functions/lib/ToHex.h" @@ -294,8 +293,15 @@ struct FromBase64Function { out_type& result, const arg_type& input) { try { - auto decoded = cppcodec::base64_rfc4648::decode>( - std::string(input.data(), input.size())); + std::string inputStr = std::string(input.data(), input.size()); + + // Calculate the number of padding characters needed + size_t padding = (4 - (inputStr.size() % 4)) % 4; + inputStr.append(padding, '='); + + // Decode using cppcodec with padding + std::vector decoded = cppcodec::base64_rfc4648::decode>(inputStr); + result.resize(decoded.size()); std::copy(decoded.begin(), decoded.end(), result.data()); } catch (const cppcodec::parse_error& e) { @@ -312,8 +318,15 @@ struct FromBase64UrlFunction { out_type& result, const arg_type& input) { try { - auto decoded = cppcodec::base64_url::decode>( - std::string(input.data(), input.size())); + std::string inputStr = std::string(input.data(), input.size()); + + // Calculate the number of padding characters needed + size_t padding = (4 - (inputStr.size() % 4)) % 4; + inputStr.append(padding, '='); + + // Decode using cppcodec with padding + std::vector decoded = cppcodec::base64_url::decode>(inputStr); + result.resize(decoded.size()); std::copy(decoded.begin(), decoded.end(), result.data()); } catch (const cppcodec::parse_error& e) { @@ -322,6 +335,7 @@ struct FromBase64UrlFunction { } }; + template struct ToBase64UrlFunction { VELOX_DEFINE_FUNCTION_TYPES(T); diff --git a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp index da357b5eeb8f..e2cbd38b7ca0 100644 --- a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp @@ -424,13 +424,13 @@ TEST_F(BinaryFunctionsTest, fromBase64) { "Hello World from Velox!", fromBase64("SGVsbG8gV29ybGQgZnJvbSBWZWxveCE=")); - // EXPECT_THROW(fromBase64("YQ="), VeloxUserError); - // EXPECT_THROW(fromBase64("YQ==="), VeloxUserError); + EXPECT_THROW(fromBase64("YQ=+"), VeloxUserError); + EXPECT_THROW(fromBase64("YQ===/"), VeloxUserError); // Check encoded strings without padding - // EXPECT_EQ("a", fromBase64("YQ")); - // EXPECT_EQ("ab", fromBase64("YWI")); - // EXPECT_EQ("abcd", fromBase64("YWJjZA")); + EXPECT_EQ("a", fromBase64("YQ")); + EXPECT_EQ("ab", fromBase64("YWI")); + EXPECT_EQ("abcd", fromBase64("YWJjZA")); } TEST_F(BinaryFunctionsTest, fromBase64Url) { @@ -454,8 +454,6 @@ TEST_F(BinaryFunctionsTest, fromBase64Url) { EXPECT_EQ(fromHex("FF4FBF50"), fromBase64Url("_0-_UA==")); // the encoded string input from base 64 url should be multiple of 4 and must // not contain invalid char like '+' and '/' - EXPECT_THROW(fromBase64Url("YQ="), VeloxUserError); - EXPECT_THROW(fromBase64Url("YQ==="), VeloxUserError); EXPECT_THROW(fromBase64Url("YQ=+"), VeloxUserError); EXPECT_THROW(fromBase64Url("YQ=/"), VeloxUserError); }