From eb4883a3f21fea0f456955b322f15eb08f289654 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 27 Aug 2024 20:18:39 +0800 Subject: [PATCH] different results from get_json_object with comparison to vanilla --- .../text-data/abnormal-json/data.txt | 2 + .../hive/GlutenClickHouseHiveTableSuite.scala | 3 + .../Functions/SparkFunctionGetJsonObject.h | 304 +++++++++++++++--- 3 files changed, 273 insertions(+), 36 deletions(-) diff --git a/backends-clickhouse/src/test/resources/text-data/abnormal-json/data.txt b/backends-clickhouse/src/test/resources/text-data/abnormal-json/data.txt index 7f6edd8bf8e4e..906dd83e05552 100644 --- a/backends-clickhouse/src/test/resources/text-data/abnormal-json/data.txt +++ b/backends-clickhouse/src/test/resources/text-data/abnormal-json/data.txt @@ -1 +1,3 @@ 1{"data": {"id": "Qu001cڜu00cƼ","v": 5}} +2{"data": {"id": "Qu001cڜu00c}Ƽ","v": 5}} +3{"data": {"id": "Qu001cڜu00c\\\"Ƽ","v": 5}} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala index cc91556133433..f165d7aef69c2 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala @@ -877,12 +877,15 @@ class GlutenClickHouseHiveTableSuite val select_sql_5 = "select id, get_json_object(data, 'v112') from test_tbl_3337" val select_sql_6 = "select id, get_json_object(data, '$.id') from test_tbl_3337 where id = 123" + val select_sql_7 = + "select id, get_json_object(data, '$.id') from test_tbl_3337" compareResultsAgainstVanillaSpark(select_sql_1, compareResult = true, _ => {}) compareResultsAgainstVanillaSpark(select_sql_2, compareResult = true, _ => {}) compareResultsAgainstVanillaSpark(select_sql_3, compareResult = true, _ => {}) compareResultsAgainstVanillaSpark(select_sql_4, compareResult = true, _ => {}) compareResultsAgainstVanillaSpark(select_sql_5, compareResult = true, _ => {}) compareResultsAgainstVanillaSpark(select_sql_6, compareResult = true, _ => {}) + compareResultsAgainstVanillaSpark(select_sql_7, compareResult = true, _ => {}) spark.sql("DROP TABLE test_tbl_3337") } diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h b/cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h index 5d73c52af4993..cad5ba7d4fb87 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h @@ -16,8 +16,8 @@ */ #pragma once #include -#include #include +#include #include #include #include @@ -33,12 +33,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include namespace DB @@ -66,6 +68,255 @@ struct GetJsonObject static constexpr auto name{"get_json_object"}; }; +class JSONTextNormalizer +{ +public: + // simd json will fail to parse the json text on some cases, see #7014, #3750, #3337, #5303 + // To keep the result same with vanilla, we normalize the json string when simd json fails. + // It returns null when normalize the json text fail, otherwise returns a position among `pos` + // and `end` which points to the whole json object end. + // `dst` refer to a memory buffer that is used to store the normalization result. + static const char * normalize(const char * pos, const char * end, char * & dst) + { + pos = normalizeWhitespace(pos, end, dst); + if (!pos || pos >= end) + return nullptr; + if (*pos == '[') + return normalizeArray(pos, end, dst); + else if (*pos == '{') + return normalizeObject(pos, end, dst); + return nullptr; + } +private: + inline static void copyToDst(char * & p, char c) + { + *p = c; + p++; + } + + inline static void copyToDst(char * & p, const char * src, size_t len) + { + memcpy(p, src, len); + p += len; + } + + inline static bool isExpectedChar(char c, const char * pos, const char * end) + { + return pos && pos < end && *pos == c; + } + + inline static const char * normalizeWhitespace(const char * pos, const char * end, char * & dst) + { + const auto * start_pos = pos; + while(pos && pos < end) + { + if (isWhitespaceASCII(*pos)) + pos++; + else + break; + } + if (pos != start_pos) + copyToDst(dst, start_pos, pos - start_pos); + return pos; + } + + inline static const char * normalizeComma(const char * pos, const char * end, char * & dst) + { + pos = normalizeWhitespace(pos, end, dst); + if (!isExpectedChar(',', pos, end)) [[unlikely]] + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeComma. not ,"); + return nullptr; + } + pos += 1; + copyToDst(dst, ','); + return normalizeWhitespace(pos, end, dst); + } + + inline static const char * normalizeColon(const char * pos, const char * end, char * & dst) + { + pos = normalizeWhitespace(pos, end, dst); + if (!isExpectedChar(':', pos, end)) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeColon. not :"); + return nullptr; + } + pos += 1; + copyToDst(dst, ':'); + return normalizeWhitespace(pos, end, dst); + } + + inline static const char * normalizeField(const char * pos, const char * end, char * & dst) + { + const auto * start_pos = pos; + pos = find_first_symbols<',', '}', ']'>(pos, end); + if (pos >= end) [[unlikely]] + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeField. not field"); + return nullptr; + } + copyToDst(dst, start_pos, pos - start_pos); + return pos;; + } + + inline static const char * normalizeString(const char * pos, const char * end, char * & dst) + { + const auto * start_pos = pos; + if (!isExpectedChar('"', pos, end)) [[unlikely]] + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeString. not \""); + return nullptr; + } + pos += 1; + + pos = find_first_symbols<'"'>(pos, end); + while (*(pos - 1) == '\\' && *pos == '"' && pos < end) + { + pos += 1; + pos = find_first_symbols<'"'>(pos, end); + } + if (!isExpectedChar('"', pos, end)) + return nullptr; + pos += 1; + // LOG_DEBUG(getLogger("1"), "xxx string value: {}", std::string_view(start_pos, pos - start_pos)); + + size_t n = 0; + for (; start_pos != pos; ++start_pos) + { + if ((*start_pos >= 0x00 && *start_pos <= 0x1f) || *start_pos == 0x7f) + { + if (n) + { + copyToDst(dst, start_pos - n, n); + n = 0; + } + continue; + } + else + { + n += 1; + } + } + if (n) + copyToDst(dst, start_pos - n, n); + + return normalizeWhitespace(pos, end, dst); + } + + inline static const char * normalizeArray(const char * pos, const char * end, char * & dst) + { + if (!isExpectedChar('[', pos, end)) [[unlikely]] + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeArray. not ["); + return nullptr; + } + pos += 1; + copyToDst(dst, '['); + + pos = normalizeWhitespace(pos, end, dst); + + bool has_more = false; + while (pos && pos < end && *pos != ']') + { + has_more = false; + switch(*pos) + { + case '{': { + pos = normalizeObject(pos, end, dst); + break; + } + case '"': { + pos = normalizeString(pos, end, dst); + break; + } + case '[': { + pos = normalizeArray(pos, end, dst); + break; + } + default: { + pos = normalizeField(pos, end, dst); + break; + } + } + if (!isExpectedChar(',', pos, end)) + break; + pos = normalizeComma(pos, end, dst); + has_more = true; + } + + if (!isExpectedChar(']', pos, end) || has_more) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeArray. not ]"); + return nullptr; + } + pos += 1; + copyToDst(dst, ']'); + return normalizeWhitespace(pos, end, dst); + } + + inline static const char * normalizeObject(const char * pos, const char * end, char * & dst) + { + if (!isExpectedChar('{', pos, end)) [[unlikely]] + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not object start"); + return nullptr; + } + pos += 1; + copyToDst(dst, '{'); + + bool has_more = false; + while(pos && pos < end && *pos != '}') + { + has_more = false; + pos = normalizeWhitespace(pos, end, dst); + + pos = normalizeString(pos, end, dst); + + pos = normalizeColon(pos, end, dst); + if (!pos) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not :"); + break; + } + + switch(*pos) + { + case '{': { + pos = normalizeObject(pos, end, dst); + break; + } + case '"': { + pos = normalizeString(pos, end, dst); + break; + } + case '[': { + pos = normalizeArray(pos, end, dst); + break; + } + default: { + pos = normalizeField(pos, end, dst); + break; + } + } + + if (!isExpectedChar(',', pos, end)) + break; + pos = normalizeComma(pos, end, dst); + has_more = true; + } + + if (!isExpectedChar('}', pos, end) || has_more) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not object end"); + return nullptr; + } + pos += 1; + copyToDst(dst, '}'); + return normalizeWhitespace(pos, end, dst); + } + +}; + template class GetJsonObjectImpl { @@ -116,6 +367,7 @@ class GetJsonObjectImpl if (elements[0].isNull()) return false; nullable_col_str.getNullMapData().push_back(0); + if (elements[0].isString()) { auto str = elements[0].getString(); @@ -200,6 +452,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction #if USE_SIMDJSON if (context->getSettingsRef().allow_simdjson) { + LOG_ERROR(getLogger("GetJsonObject"), "xxxx use simd json"); return innerExecuteImpl< DB::SimdJSONParser, GetJsonObjectImpl>>( @@ -214,32 +467,23 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction private: DB::ContextPtr context; - size_t normalizeJson(std::string_view & json, char * dst) const + template + bool safeParseJson(std::string_view str, JSONParser & parser, JSONParser::Element & doc) const { - const char * json_chars = json.data(); - const size_t json_size = json.size(); - std::stack tmp; - size_t new_json_size = 0; - for (size_t i = 0; i <= json_size; ++i) + if (!parser.parse(str, doc)) [[unlikely]] { - if ((*(json_chars + i) >= 0x00 && *(json_chars + i) <= 0x1F) || *(json_chars + i) == 0x7F) - continue; - else + std::vector buf; + buf.resize(str.size(), 0); + char * buf_pos = buf.data(); + const char * pos = JSONTextNormalizer::normalize(str.data(), str.data() + str.size(), buf_pos); + if (!pos) { - char ch = *(json_chars + i); - dst[new_json_size++] = ch; - if (ch == '{') - tmp.push('{'); - else if (ch == '}') - { - if (!tmp.empty() && tmp.top() == '{') - tmp.pop(); - } - if (tmp.empty()) - break; + return false; } + std::string n_str(buf.data(), buf_pos - buf.data()); + return parser.parse(n_str, doc); } - return new_json_size; + return true; } template @@ -318,13 +562,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction if (col_json_const) { std::string_view json{reinterpret_cast(chars.data()), offsets[0] - 1}; - document_ok = parser.parse(json, document); - if (!document_ok) - { - char dst[json.size()]; - size_t size = normalizeJson(json, dst); - document_ok = parser.parse(std::string_view(dst, size), document); - } + document_ok = safeParseJson(json, parser, document); } size_t tuple_size = tuple_columns.size(); @@ -340,13 +578,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction if (!col_json_const) { std::string_view json{reinterpret_cast(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1}; - document_ok = parser.parse(json, document); - if (!document_ok) - { - char dst[json.size()]; - size_t size = normalizeJson(json, dst); - document_ok = parser.parse(std::string_view(dst, size), document); - } + document_ok = safeParseJson(json, parser, document); } if (document_ok) {