diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index 41a9cf62e8d4..d733553ce4a9 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -282,21 +282,19 @@ struct ToBase64Function { } }; -template +template struct FromBase64Function { - VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( - out_type& result, - const arg_type& input) { - try { - auto inputSize = input.size(); - result.resize( - encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decode( - input.data(), inputSize, result.data(), result.size()); - } catch (const encoding::Base64Exception& e) { - VELOX_USER_FAIL(e.what()); - } + VELOX_DEFINE_FUNCTION_TYPES(TExec); + + // T can be either arg_type or arg_type. These are the + // same, but hard-coding one of them might be confusing. + template + FOLLY_ALWAYS_INLINE void call(out_type& result, const T& input) { + auto inputSize = input.size(); + result.resize( + encoding::Base64::calculateDecodedSize(input.data(), inputSize)); + encoding::Base64::decode( + input.data(), inputSize, result.data(), result.size()); } }; diff --git a/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp b/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp index 5aa8fc682e50..6f098ebadc51 100644 --- a/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp @@ -45,8 +45,12 @@ void registerSimpleFunctions(const std::string& prefix) { registerFunction({prefix + "from_hex"}); registerFunction( {prefix + "to_base64"}); + registerFunction( {prefix + "from_base64"}); + registerFunction( + {prefix + "from_base64"}); + registerFunction( {prefix + "to_base64url"}); registerFunction( diff --git a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp index fa1e841ef985..72ef47e22b10 100644 --- a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp @@ -411,7 +411,20 @@ TEST_F(BinaryFunctionsTest, toBase64Url) { TEST_F(BinaryFunctionsTest, fromBase64) { const auto fromBase64 = [&](std::optional value) { - return evaluateOnce("from_base64(c0)", value); + // from_base64 allows VARCHAR and VARBINARY inputs. + auto result = + evaluateOnce("from_base64(c0)", VARCHAR(), value); + auto otherResult = + evaluateOnce("from_base64(c0)", VARBINARY(), value); + + VELOX_CHECK_EQ(result.has_value(), otherResult.has_value()); + + if (!result.has_value()) { + return result; + } + + VELOX_CHECK_EQ(result.value(), otherResult.value()); + return result; }; EXPECT_EQ(std::nullopt, fromBase64(std::nullopt)); @@ -424,8 +437,12 @@ TEST_F(BinaryFunctionsTest, fromBase64) { "Hello World from Velox!", fromBase64("SGVsbG8gV29ybGQgZnJvbSBWZWxveCE=")); - EXPECT_THROW(fromBase64("YQ="), VeloxUserError); - EXPECT_THROW(fromBase64("YQ==="), VeloxUserError); + VELOX_ASSERT_USER_THROW( + fromBase64("YQ="), + "Base64::decode() - invalid input string: string length is not a multiple of 4."); + VELOX_ASSERT_USER_THROW( + fromBase64("YQ==="), + "Base64::decode() - invalid input string: string length is not a multiple of 4."); // Check encoded strings without padding EXPECT_EQ("a", fromBase64("YQ"));