From 006efc589d233b4566dc63218a1730056e882c1a Mon Sep 17 00:00:00 2001 From: zhli1142015 Date: Mon, 2 Dec 2024 17:48:12 +0800 Subject: [PATCH] feat: Add from_json Spark function --- velox/docs/functions/spark/json.rst | 15 + velox/functions/sparksql/Register.cpp | 4 + .../sparksql/specialforms/CMakeLists.txt | 3 +- .../sparksql/specialforms/FromJson.cpp | 602 ++++++++++++++++++ .../sparksql/specialforms/FromJson.h | 36 ++ velox/functions/sparksql/tests/CMakeLists.txt | 1 + .../functions/sparksql/tests/FromJsonTest.cpp | 250 ++++++++ 7 files changed, 910 insertions(+), 1 deletion(-) create mode 100644 velox/functions/sparksql/specialforms/FromJson.cpp create mode 100644 velox/functions/sparksql/specialforms/FromJson.h create mode 100644 velox/functions/sparksql/tests/FromJsonTest.cpp diff --git a/velox/docs/functions/spark/json.rst b/velox/docs/functions/spark/json.rst index 8004873986880..aba634cb5f94f 100644 --- a/velox/docs/functions/spark/json.rst +++ b/velox/docs/functions/spark/json.rst @@ -11,3 +11,18 @@ JSON Functions SELECT json_object_keys(''); -- NULL SELECT json_object_keys(1); -- NULL SELECT json_object_keys('"hello"'); -- NULL + +.. spark:function:: from_json(jsonString) -> [json object] + + Casting a JSON text to a supported type returns the value represented by this + JSON text. The JSON text must represent a valid value of the type it is casted + to, or null will be returned. Casting to ARRAY and MAP is supported when the + element type of the array is one of the supported types, or when the key type of + the map is VARCHAR and value type of the map is one of the supported types. When + casting from JSON to ROW, only JSON object are supported. Cast from JSON object + to ROW uses case sensitive match for the JSON keys. + Behaviors of the casts are shown with the examples below::: + + SELECT from_json('{"a": 1}', 'ROW(a INT)'); -- {a=1} + SELECT from_json('["name", "age", "id"]', 'array'); -- ['name', 'age', 'id'] + SELECT from_json('{"a": 1, "b": 2}', 'map'); -- {a=1, b=2} diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index c753eca82baad..a3bf680866e60 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -55,6 +55,7 @@ #include "velox/functions/sparksql/Uuid.h" #include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h" #include "velox/functions/sparksql/specialforms/DecimalRound.h" +#include "velox/functions/sparksql/specialforms/FromJson.h" #include "velox/functions/sparksql/specialforms/MakeDecimal.h" #include "velox/functions/sparksql/specialforms/SparkCastExpr.h" @@ -152,6 +153,9 @@ void registerAllSpecialFormGeneralFunctions() { exec::registerFunctionCallToSpecialForm( AtLeastNNonNullsCallToSpecialForm::kAtLeastNNonNulls, std::make_unique()); + exec::registerFunctionCallToSpecialForm( + FromJsonCallToSpecialForm::kFromJson, + std::make_unique()); } namespace { diff --git a/velox/functions/sparksql/specialforms/CMakeLists.txt b/velox/functions/sparksql/specialforms/CMakeLists.txt index e141e0074bc84..8360c79b5dd0e 100644 --- a/velox/functions/sparksql/specialforms/CMakeLists.txt +++ b/velox/functions/sparksql/specialforms/CMakeLists.txt @@ -15,10 +15,11 @@ velox_add_library( velox_functions_spark_specialforms AtLeastNNonNulls.cpp + FromJson.cpp DecimalRound.cpp MakeDecimal.cpp SparkCastExpr.cpp SparkCastHooks.cpp) velox_link_libraries(velox_functions_spark_specialforms fmt::fmt - velox_expression) + velox_functions_json velox_expression) diff --git a/velox/functions/sparksql/specialforms/FromJson.cpp b/velox/functions/sparksql/specialforms/FromJson.cpp new file mode 100644 index 0000000000000..ad7e6628909d6 --- /dev/null +++ b/velox/functions/sparksql/specialforms/FromJson.cpp @@ -0,0 +1,602 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/specialforms/FromJson.h" + +#include +#include + +#include "velox/expression/CastExpr.h" +#include "velox/expression/EvalCtx.h" +#include "velox/expression/PeeledEncoding.h" +#include "velox/expression/ScopedVarSetter.h" +#include "velox/expression/SpecialForm.h" +#include "velox/expression/VectorWriters.h" +#include "velox/functions/prestosql/json/SIMDJsonUtil.h" +#include "velox/type/Type.h" + +using namespace facebook::velox::exec; + +namespace facebook::velox::functions::sparksql { +namespace { + +template +struct ParseJsonTypedImpl { + template + static simdjson::error_code + apply(Input input, exec::GenericWriter& writer, bool isRoot) { + return KindDispatcher::apply(input, writer, isRoot); + } + + private: + // Dummy is needed because full/explicit specialization is not allowed inside + // class. + template + struct KindDispatcher { + static simdjson::error_code apply(Input, exec::GenericWriter&, bool) { + VELOX_NYI( + "Casting from JSON to {} is not supported.", TypeTraits::name); + return simdjson::error_code::UNEXPECTED_ERROR; // Make compiler happy. + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type()); + std::string_view s; + switch (type) { + case simdjson::ondemand::json_type::string: { + SIMDJSON_ASSIGN_OR_RAISE(s, value.get_string()); + break; + } + case simdjson::ondemand::json_type::number: + case simdjson::ondemand::json_type::boolean: + s = value.raw_json_token(); + break; + default: + return simdjson::INCORRECT_TYPE; + } + writer.castTo().append(s); + return simdjson::SUCCESS; + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type()); + auto& w = writer.castTo(); + switch (type) { + case simdjson::ondemand::json_type::boolean: { + SIMDJSON_ASSIGN_OR_RAISE(w, value.get_bool()); + break; + } + default: + return simdjson::INCORRECT_TYPE; + } + return simdjson::SUCCESS; + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + return castJsonToInt(value, writer); + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + return castJsonToInt(value, writer); + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + return castJsonToInt(value, writer); + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + return castJsonToInt(value, writer); + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + return castJsonToFloatingPoint(value, writer); + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + return castJsonToFloatingPoint(value, writer); + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + auto& writerTyped = writer.castTo>(); + auto& elementType = writer.type()->childAt(0); + SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type()); + if (type == simdjson::ondemand::json_type::array) { + SIMDJSON_ASSIGN_OR_RAISE(auto array, value.get_array()); + for (auto elementResult : array) { + SIMDJSON_ASSIGN_OR_RAISE(auto element, elementResult); + // If casting to array of JSON, nulls in array elements should become + // the JSON text "null". + if (element.is_null()) { + writerTyped.add_null(); + } else { + SIMDJSON_TRY(VELOX_DYNAMIC_TYPE_DISPATCH( + ParseJsonTypedImpl::apply, + elementType->kind(), + element, + writerTyped.add_item(), + false)); + } + } + } else if (elementType->kind() == TypeKind::ROW && isRoot) { + SIMDJSON_TRY(VELOX_DYNAMIC_TYPE_DISPATCH( + ParseJsonTypedImpl::apply, + elementType->kind(), + value, + writerTyped.add_item(), + false)); + } else { + return simdjson::INCORRECT_TYPE; + } + return simdjson::SUCCESS; + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + auto& writerTyped = writer.castTo>(); + auto& keyType = writer.type()->childAt(0); + auto& valueType = writer.type()->childAt(1); + SIMDJSON_ASSIGN_OR_RAISE(auto object, value.get_object()); + for (auto fieldResult : object) { + SIMDJSON_ASSIGN_OR_RAISE(auto field, fieldResult); + SIMDJSON_ASSIGN_OR_RAISE(auto key, field.unescaped_key(true)); + // If casting to map of JSON values, nulls in map values should become + // the JSON text "null". + if (field.value().is_null()) { + writerTyped.add_null().castTo().append(key); + } else { + auto writers = writerTyped.add_item(); + std::get<0>(writers).castTo().append(key); + SIMDJSON_TRY(VELOX_DYNAMIC_TYPE_DISPATCH( + ParseJsonTypedImpl::apply, + valueType->kind(), + field.value(), + std::get<1>(writers), + false)); + } + } + return simdjson::SUCCESS; + } + }; + + template + struct KindDispatcher { + static simdjson::error_code + apply(Input value, exec::GenericWriter& writer, bool isRoot) { + auto& rowType = writer.type()->asRow(); + auto& writerTyped = writer.castTo(); + if (value.type().error() != ::simdjson::SUCCESS) { + writerTyped.set_null_at(0); + return simdjson::SUCCESS; + } + auto type = value.type().value_unsafe(); + if (type == simdjson::ondemand::json_type::object) { + SIMDJSON_ASSIGN_OR_RAISE(auto object, value.get_object()); + + folly::F14FastMap fieldIndices; + const auto size = rowType.size(); + for (auto i = 0; i < size; ++i) { + std::string key = rowType.nameOf(i); + fieldIndices[key] = i; + } + + std::string key; + for (auto fieldResult : object) { + if (fieldResult.error() != ::simdjson::SUCCESS) { + continue; + } + auto field = fieldResult.value_unsafe(); + if (!field.value().is_null()) { + SIMDJSON_ASSIGN_OR_RAISE(key, field.unescaped_key(true)); + + auto it = fieldIndices.find(key); + if (it != fieldIndices.end()) { + const auto index = it->second; + it->second = -1; + + auto res = VELOX_DYNAMIC_TYPE_DISPATCH( + ParseJsonTypedImpl::apply, + rowType.childAt(index)->kind(), + field.value(), + writerTyped.get_writer_at(index), + false); + if (res != simdjson::SUCCESS) { + writerTyped.set_null_at(index); + } + } + } + } + + for (const auto& [key, index] : fieldIndices) { + if (index >= 0) { + writerTyped.set_null_at(index); + } + } + } else { + if (isRoot) { + writerTyped.set_null_at(0); + return simdjson::SUCCESS; + } else { + return simdjson::INCORRECT_TYPE; + } + } + return simdjson::SUCCESS; + } + }; + + static simdjson::simdjson_result rawJson( + Input value, + simdjson::ondemand::json_type type) { + switch (type) { + case simdjson::ondemand::json_type::array: { + SIMDJSON_ASSIGN_OR_RAISE(auto array, value.get_array()); + return array.raw_json(); + } + case simdjson::ondemand::json_type::object: { + SIMDJSON_ASSIGN_OR_RAISE(auto object, value.get_object()); + return object.raw_json(); + } + default: + return value.raw_json_token(); + } + } + + template + static simdjson::error_code castJsonToInt( + Input value, + exec::GenericWriter& writer) { + SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type()); + switch (type) { + case simdjson::ondemand::json_type::number: { + SIMDJSON_ASSIGN_OR_RAISE(auto num, value.get_number()); + switch (num.get_number_type()) { + case simdjson::ondemand::number_type::floating_point_number: + return simdjson::INCORRECT_TYPE; + case simdjson::ondemand::number_type::signed_integer: + return convertIfInRange(num.get_int64(), writer); + case simdjson::ondemand::number_type::unsigned_integer: + return simdjson::NUMBER_OUT_OF_RANGE; + case simdjson::ondemand::number_type::big_integer: + VELOX_UNREACHABLE(); // value.get_number() would have failed + // already. + } + break; + } + default: + return simdjson::INCORRECT_TYPE; + } + return simdjson::SUCCESS; + } + + template + static simdjson::error_code castJsonToFloatingPoint( + Input value, + exec::GenericWriter& writer) { + SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type()); + switch (type) { + case simdjson::ondemand::json_type::number: { + SIMDJSON_ASSIGN_OR_RAISE(auto num, value.get_double()); + return convertIfInRange(num, writer); + } + case simdjson::ondemand::json_type::string: { + SIMDJSON_ASSIGN_OR_RAISE(auto s, value.get_string()); + constexpr T kNaN = std::numeric_limits::quiet_NaN(); + constexpr T kInf = std::numeric_limits::infinity(); + if (s == "NaN") { + writer.castTo() = kNaN; + } else if (s == "+INF" || s == "+Infinity" || s == "Infinity") { + writer.castTo() = kInf; + } else if (s == "-INF" || s == "-Infinity") { + writer.castTo() = -kInf; + } else { + return simdjson::INCORRECT_TYPE; + } + break; + } + default: + return simdjson::INCORRECT_TYPE; + } + return simdjson::SUCCESS; + } + + template + static simdjson::error_code convertIfInRange( + From x, + exec::GenericWriter& writer) { + static_assert(std::is_signed_v && std::is_signed_v); + if constexpr (sizeof(To) < sizeof(From)) { + constexpr From kMin = std::numeric_limits::lowest(); + constexpr From kMax = std::numeric_limits::max(); + if (!(kMin <= x && x <= kMax)) { + return simdjson::NUMBER_OUT_OF_RANGE; + } + } + writer.castTo() = x; + return simdjson::SUCCESS; + } +}; + +template +simdjson::error_code parseJsonOneRow( + simdjson::padded_string_view input, + exec::VectorWriter& writer) { + SIMDJSON_ASSIGN_OR_RAISE(auto doc, simdjsonParse(input)); + if (doc.is_null()) { + writer.commitNull(); + } else { + SIMDJSON_TRY(ParseJsonTypedImpl::apply( + doc, writer.current(), true)); + writer.commit(true); + } + return simdjson::SUCCESS; +} + +class FromJsonExpr : public SpecialForm { + public: + /// @param type The target type of the cast expression + /// @param expr The expression to gerenate input + /// @param trackCpuUsage Whether to track CPU usage + FromJsonExpr(TypePtr type, ExprPtr&& expr, bool trackCpuUsage) + : SpecialForm( + type, + std::vector({expr}), + FromJsonCallToSpecialForm::kFromJson, + false /* supportsFlatNoNullsFastPath */, + trackCpuUsage) { + if (!isSupportedType(type)) { + VELOX_UNSUPPORTED("Unsupported type {}.", type->toString()); + } + simdjsonErrorsToExceptions(errors_); + } + + void evalSpecialForm( + const SelectivityVector& rows, + EvalCtx& context, + VectorPtr& result) override { + VectorPtr input; + inputs_[0]->eval(rows, context, input); + auto toType = std::const_pointer_cast(type_); + apply(rows, input, context, toType, result); + // Return 'input' back to the vector pool in 'context' so it can be reused. + context.releaseVector(input); + } + + private: + void computePropagatesNulls() override { + propagatesNulls_ = false; + } + + // Peal data. + void apply( + const SelectivityVector& rows, + const VectorPtr& input, + exec::EvalCtx& context, + const TypePtr& toType, + VectorPtr& result) { + LocalSelectivityVector remainingRows(context, rows); + + context.deselectErrors(*remainingRows); + + LocalDecodedVector decoded(context, *input, *remainingRows); + auto* rawNulls = decoded->nulls(remainingRows.get()); + + if (rawNulls) { + remainingRows->deselectNulls( + rawNulls, remainingRows->begin(), remainingRows->end()); + } + + VectorPtr localResult; + if (!remainingRows->hasSelections()) { + localResult = + BaseVector::createNullConstant(toType, rows.end(), context.pool()); + } else if (decoded->isIdentityMapping()) { + applyPeeled( + *remainingRows, *decoded->base(), context, toType, localResult); + } else { + withContextSaver([&](ContextSaver& saver) { + LocalSelectivityVector newRowsHolder(*context.execCtx()); + + LocalDecodedVector localDecoded(context); + std::vector peeledVectors; + auto peeledEncoding = PeeledEncoding::peel( + {input}, *remainingRows, localDecoded, true, peeledVectors); + VELOX_CHECK_EQ(peeledVectors.size(), 1); + if (peeledVectors[0]->isLazy()) { + peeledVectors[0] = + peeledVectors[0]->as()->loadedVectorShared(); + } + auto newRows = + peeledEncoding->translateToInnerRows(*remainingRows, newRowsHolder); + // Save context and set the peel. + context.saveAndReset(saver, *remainingRows); + context.setPeeledEncoding(peeledEncoding); + applyPeeled(*newRows, *peeledVectors[0], context, toType, localResult); + + localResult = context.getPeeledEncoding()->wrap( + toType, context.pool(), localResult, *remainingRows); + }); + } + context.moveOrCopyResult(localResult, *remainingRows, result); + context.releaseVector(localResult); + + // If there are nulls or rows that encountered errors in the input, add + // nulls to the result at the same rows. + VELOX_CHECK_NOT_NULL(result); + if (rawNulls || context.errors()) { + EvalCtx::addNulls( + rows, remainingRows->asRange().bits(), context, toType, result); + } + } + + void applyPeeled( + const SelectivityVector& rows, + const BaseVector& input, + exec::EvalCtx& context, + const TypePtr& toType, + VectorPtr& result) { + context.ensureWritable(rows, toType, result); + switch (result->typeKind()) { + case TypeKind::ARRAY: { + parseJson(input, context, rows, *result); + break; + } + case TypeKind::MAP: { + parseJson(input, context, rows, *result); + break; + } + case TypeKind::ROW: { + parseJson(input, context, rows, *result); + break; + } + default: + VELOX_UNSUPPORTED("INVALID_JSON_SCHEMA"); + } + } + + template + void parseJson( + const BaseVector& input, + exec::EvalCtx& context, + const SelectivityVector& rows, + BaseVector& result) const { + // Result is guaranteed to be a flat writable vector. + auto* flatResult = result.as::type>(); + exec::VectorWriter writer; + writer.init(*flatResult); + // Input is guaranteed to be in flat or constant encodings when passed in. + auto* inputVector = input.as>(); + size_t maxSize = 0; + rows.applyToSelected([&](auto row) { + if (inputVector->isNullAt(row)) { + return; + } + auto& input = inputVector->valueAt(row); + maxSize = std::max(maxSize, input.size()); + }); + paddedInput_.resize(maxSize + simdjson::SIMDJSON_PADDING); + context.applyToSelectedNoThrow(rows, [&](auto row) { + writer.setOffset(row); + if (inputVector->isNullAt(row)) { + writer.commitNull(); + return; + } + auto& input = inputVector->valueAt(row); + memcpy(paddedInput_.data(), input.data(), input.size()); + simdjson::padded_string_view paddedInput( + paddedInput_.data(), input.size(), paddedInput_.size()); + if (auto error = parseJsonOneRow(paddedInput, writer)) { + writer.commitNull(); + } + }); + writer.finish(); + } + + bool isSupportedType(const TypePtr& other, bool isRootType = true) const { + switch (other->kind()) { + case TypeKind::ARRAY: + return isSupportedType(other->childAt(0), false); + case TypeKind::ROW: + for (const auto& child : other->as().children()) { + if (!isSupportedType(child, false)) { + return false; + } + } + return true; + case TypeKind::MAP: + return ( + other->childAt(0)->kind() == TypeKind::VARCHAR && + isSupportedType(other->childAt(1), false)); + case TypeKind::BOOLEAN: + case TypeKind::BIGINT: + case TypeKind::INTEGER: + case TypeKind::SMALLINT: + case TypeKind::TINYINT: + case TypeKind::DOUBLE: + case TypeKind::REAL: + case TypeKind::VARCHAR: { + if (other->isDate() || other->isDecimal()) { + return false; + } + return !isRootType; + } + default: + return false; + } + } + + std::exception_ptr errors_[simdjson::NUM_ERROR_CODES]; + mutable std::string paddedInput_; +}; +} // namespace + +TypePtr FromJsonCallToSpecialForm::resolveType( + const std::vector& /*argTypes*/) { + VELOX_FAIL("from_json function does not support type resolution."); +} + +exec::ExprPtr FromJsonCallToSpecialForm::constructSpecialForm( + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& /*config*/) { + VELOX_USER_CHECK_EQ(args.size(), 1, "from_json expects one argument."); + VELOX_USER_CHECK_EQ( + args[0]->type()->kind(), + TypeKind::VARCHAR, + "The first argument of from_json should be of varchar type."); + + return std::make_shared( + type, std::move(args[0]), trackCpuUsage); +} +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/specialforms/FromJson.h b/velox/functions/sparksql/specialforms/FromJson.h new file mode 100644 index 0000000000000..073db7c0898a1 --- /dev/null +++ b/velox/functions/sparksql/specialforms/FromJson.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "velox/expression/FunctionCallToSpecialForm.h" + +namespace facebook::velox::functions::sparksql { +class FromJsonCallToSpecialForm : public exec::FunctionCallToSpecialForm { + public: + // Throws not supported exception. + TypePtr resolveType(const std::vector& argTypes) override; + + /// @brief Returns an expression for from_json special form. The expression + /// is a regular expression based on a custom VectorFunction implementation. + exec::ExprPtr constructSpecialForm( + const TypePtr& type, + std::vector&& args, + bool trackCpuUsage, + const core::QueryConfig& config) override; + + static constexpr const char* kFromJson = "from_json"; +}; +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index d17a9e0ed8550..a5b72238194b3 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -32,6 +32,7 @@ add_executable( DecimalRoundTest.cpp DecimalUtilTest.cpp ElementAtTest.cpp + FromJsonTest.cpp HashTest.cpp InTest.cpp JsonObjectKeysTest.cpp diff --git a/velox/functions/sparksql/tests/FromJsonTest.cpp b/velox/functions/sparksql/tests/FromJsonTest.cpp new file mode 100644 index 0000000000000..5e72ce0399465 --- /dev/null +++ b/velox/functions/sparksql/tests/FromJsonTest.cpp @@ -0,0 +1,250 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +using namespace facebook::velox::test; + +namespace facebook::velox::functions::sparksql::test { +namespace { +constexpr float kNaNFloat = std::numeric_limits::quiet_NaN(); +constexpr float kInfFloat = std::numeric_limits::infinity(); +constexpr double kNaNDouble = std::numeric_limits::quiet_NaN(); +constexpr double kInfDouble = std::numeric_limits::infinity(); + +class FromJsonTest : public SparkFunctionBaseTest { + protected: + core::CallTypedExprPtr createFromJson(const TypePtr& outputType) { + std::vector inputs = { + std::make_shared(VARCHAR(), "c0")}; + + return std::make_shared( + outputType, std::move(inputs), "from_json"); + } + + void testFromJson(const VectorPtr& input, const VectorPtr& expected) { + auto expr = createFromJson(expected->type()); + testEncodings(expr, {input}, expected); + } +}; + +TEST_F(FromJsonTest, basicStruct) { + auto expected = makeFlatVector({1, 2, 3}); + auto input = makeFlatVector( + {R"({"a": 1})", R"({"a": 2})", R"({"a": 3})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, basicArray) { + auto expected = makeArrayVector({{1}, {2}, {}}); + auto input = makeFlatVector({R"([1])", R"([2])", R"([])"}); + testFromJson(input, expected); + + auto rowVector = makeRowVector({"a"}, {makeFlatVector({1, 2, 2})}); + std::vector offsets; + offsets.push_back(0); + offsets.push_back(1); + offsets.push_back(2); + auto arrayVector = makeArrayVector(offsets, rowVector); + input = makeFlatVector( + {R"({"a": 1})", R"([{"a": 2}])", R"([{"a": 2}])"}); + testFromJson(input, arrayVector); +} + +TEST_F(FromJsonTest, basicMap) { + auto expected = makeMapVector( + {{{"a", 1}}, {{"b", 2}}, {{"c", 3}}, {{"3", 3}}}); + auto input = makeFlatVector( + {R"({"a": 1})", R"({"b": 2})", R"({"c": 3})", R"({"3": 3})"}); + testFromJson(input, expected); +} + +TEST_F(FromJsonTest, basicBool) { + auto expected = makeNullableFlatVector( + {true, false, std::nullopt, std::nullopt, std::nullopt}); + auto input = makeFlatVector( + {R"({"a": true})", + R"({"a": false})", + R"({"a": 1})", + R"({"a": 0.0})", + R"({"a": "true"})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, basicTinyInt) { + auto expected = makeNullableFlatVector( + {1, std::nullopt, std::nullopt, std::nullopt, std::nullopt}); + auto input = makeFlatVector( + {R"({"a": 1})", + R"({"a": -129})", + R"({"a": 128})", + R"({"a": 1.0})", + R"({"a": "1"})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, basicSmallInt) { + auto expected = makeNullableFlatVector( + {1, std::nullopt, std::nullopt, std::nullopt, std::nullopt}); + auto input = makeFlatVector( + {R"({"a": 1})", + R"({"a": -32769})", + R"({"a": 32768})", + R"({"a": 1.0})", + R"({"a": "1"})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, basicInt) { + auto expected = makeNullableFlatVector( + {1, std::nullopt, std::nullopt, std::nullopt, std::nullopt}); + auto input = makeFlatVector( + {R"({"a": 1})", + R"({"a": -2147483649})", + R"({"a": 2147483648})", + R"({"a": 2.0})", + R"({"a": "3"})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, basicBigInt) { + auto expected = + makeNullableFlatVector({1, std::nullopt, std::nullopt}); + auto input = makeFlatVector( + {R"({"a": 1})", R"({"a": 2.0})", R"({"a": "3"})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, basicFloat) { + auto expected = makeNullableFlatVector( + {1.0, + 2.0, + std::nullopt, + kNaNFloat, + -kInfFloat, + -kInfFloat, + kInfFloat, + kInfFloat, + kInfFloat}); + auto input = makeFlatVector( + {R"({"a": 1})", + R"({"a": 2.0})", + R"({"a": "3"})", + R"({"a": "NaN"})", + R"({"a": "-Infinity"})", + R"({"a": "-INF"})", + R"({"a": "+Infinity"})", + R"({"a": "Infinity"})", + R"({"a": "+INF"})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, basicDouble) { + auto expected = makeNullableFlatVector( + {1.0, + 2.0, + std::nullopt, + kNaNDouble, + -kInfDouble, + -kInfDouble, + kInfDouble, + kInfDouble, + kInfDouble}); + auto input = makeFlatVector( + {R"({"a": 1})", + R"({"a": 2.0})", + R"({"a": "3"})", + R"({"a": "NaN"})", + R"({"a": "-Infinity"})", + R"({"a": "-INF"})", + R"({"a": "+Infinity"})", + R"({"a": "Infinity"})", + R"({"a": "+INF"})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, basicString) { + auto expected = makeNullableFlatVector({"1", "2.0", "true"}); + auto input = makeFlatVector( + {R"({"a": 1})", R"({"a": 2.0})", R"({"a": "true"})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, keyCaseSensitive) { + auto expected1 = makeNullableFlatVector({1, 2, 4}); + auto expected2 = makeNullableFlatVector({3, 4, 5}); + auto input = makeFlatVector( + {R"({"a": 1, "A": 3})", R"({"a": 2, "A": 4})", R"({"a": 4, "A": 5})"}); + testFromJson(input, makeRowVector({"a", "A"}, {expected1, expected2})); +} + +TEST_F(FromJsonTest, nullOnFailure) { + auto expected = makeNullableFlatVector({1, std::nullopt, 3}); + auto input = + makeFlatVector({R"({"a": 1})", R"({"a" 2})", R"({"a": 3})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, structEmptyArray) { + auto expected = makeNullableFlatVector({std::nullopt, 2, 3}); + auto input = + makeFlatVector({R"([])", R"({"a": 2})", R"({"a": 3})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, structEmptyStruct) { + auto expected = makeNullableFlatVector({std::nullopt, 2, 3}); + auto input = + makeFlatVector({R"({ })", R"({"a": 2})", R"({"a": 3})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, structWrongSchema) { + auto expected = makeNullableFlatVector({std::nullopt, 2, 3}); + auto input = makeFlatVector( + {R"({"b": 2})", R"({"a": 2})", R"({"a": 3})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, structWrongData) { + auto expected = makeNullableFlatVector({std::nullopt, 2, 3}); + auto input = makeFlatVector( + {R"({"a": 2.1})", R"({"a": 2})", R"({"a": 3})"}); + testFromJson(input, makeRowVector({"a"}, {expected})); +} + +TEST_F(FromJsonTest, invalidType) { + auto primitiveTypeOutput = makeFlatVector({2, 2, 3}); + auto dateTypeOutput = makeFlatVector({2, 2, 3}, DATE()); + auto decimalOutput = makeFlatVector({2, 2, 3}, DECIMAL(16, 7)); + auto mapOutput = + makeMapVector({{{1, 1}}, {{2, 2}}, {{3, 3}}}); + auto input = makeFlatVector({R"(2)", R"({2)", R"({3)"}); + VELOX_ASSERT_USER_THROW( + testFromJson(input, primitiveTypeOutput), "Unsupported type BIGINT."); + VELOX_ASSERT_USER_THROW( + testFromJson(input, makeRowVector({"a"}, {dateTypeOutput})), + "Unsupported type ROW."); + VELOX_ASSERT_USER_THROW( + testFromJson(input, makeRowVector({"a"}, {decimalOutput})), + "Unsupported type ROW"); + VELOX_ASSERT_USER_THROW( + testFromJson(input, mapOutput), "Unsupported type MAP."); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test