From 0f36e3c959672a3b4fb7bad7e6a611533d05ce4d Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Tue, 27 Aug 2024 20:18:39 +0800 Subject: [PATCH] wip --- .../Functions/SparkFunctionGetJsonObject.h | 293 +++++++++++++++++- 1 file changed, 277 insertions(+), 16 deletions(-) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h b/cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h index 5d73c52af4993..d6d8b7dcc917a 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h +++ b/cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h @@ -16,8 +16,8 @@ */ #pragma once #include -#include #include +#include #include #include #include @@ -33,12 +33,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include namespace DB @@ -214,7 +216,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction private: DB::ContextPtr context; - size_t normalizeJson(std::string_view & json, char * dst) const + size_t normalizeJson(const std::string_view & json, char * dst) const { const char * json_chars = json.data(); const size_t json_size = json.size(); @@ -242,6 +244,277 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction return new_json_size; } + template + bool safeParseJson(std::string_view str, JSONParser & parser, JSONParser::Element & doc) const + { +#if 0 + // '{"a":"b"}acc' is invalid in simdjson, but it OK for spark. + if(!parser.parse(str, doc)) [[unlikely]] + { + // This is should not be a normal case, so we only try to normalize a json str when it fails. + char dst[str.size()]; + size_t len = normalizeJson(str, dst); + auto res = parser.parse(std::string_view(dst, len), doc); + LOG_ERROR(getLogger("GetJsonObject"), "xxx failed to parse: {}, len: {}. adjust to:\n{}, res: {}", str, str.size(), std::string_view(dst, len), res); + return res; + } + LOG_ERROR(getLogger("GetJsonObject"), "xxx parse ok by simdjson"); + return true; +#else + std::vector buf; + buf.resize(str.size(), 0); + char * buf_pos = buf.data(); + const char * pos = normalizeJsonStr(str.data(), str.data() + str.size(), buf_pos); + if (!pos) + { + LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalize json string failed"); + return false; + } + std::string n_str(buf.data(), buf_pos - buf.data()); + LOG_ERROR(getLogger("GetJsonObject"), "xxx new str: {}/ {}", str, n_str); + return parser.parse(n_str, doc); +#endif + } + + inline static void copyToDst(char * & p, char c) + { + *p = c; + p++; + } + + inline static void copyToDst(char * & p, const char * src, size_t len) + { + memcpy(p, src, len); + p += len; + } + + inline static bool isExpectedChar(char c, const char * pos, const char * end) + { + return pos && pos < end && *pos == c; + } + + static const char * normalizeWhitespace(const char * pos, const char * end, char * & dst) + { + const auto * start_pos = pos; + while(pos && pos < end) + { + if (isWhitespaceASCII(*pos)) + pos++; + else + break; + } + if (pos != start_pos) + copyToDst(dst, start_pos, pos - start_pos); + return pos; + } + + static const char * normalizeObject(const char * pos, const char * end, char * & dst) + { + if (!isExpectedChar('{', pos, end)) [[unlikely]] + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not object start"); + return nullptr; + } + pos += 1; + copyToDst(dst, '{'); + + bool has_more = false; + while(pos && pos < end && *pos != '}') + { + has_more = false; + pos = normalizeWhitespace(pos, end, dst); + + pos = normalizeString(pos, end, dst); + + pos = normalizeColon(pos, end, dst); + if (!pos) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not :"); + break; + } + + switch(*pos) + { + case '{': { + pos = normalizeObject(pos, end, dst); + break; + } + case '"': { + pos = normalizeString(pos, end, dst); + break; + } + case '[': { + pos = normalizeArray(pos, end, dst); + break; + } + default: { + pos = normalizeField(pos, end, dst); + break; + } + } + + if (!isExpectedChar(',', pos, end)) + break; + pos = normalizeComma(pos, end, dst); + has_more = true; + } + + if (!isExpectedChar('}', pos, end) || has_more) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not object end"); + return nullptr; + } + pos += 1; + copyToDst(dst, '}'); + return normalizeWhitespace(pos, end, dst); + } + + static const char * normalizeArray(const char * pos, const char * end, char * & dst) + { + if (!isExpectedChar('[', pos, end)) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeArray. not ["); + return nullptr; + } + pos += 1; + copyToDst(dst, '['); + + pos = normalizeWhitespace(pos, end, dst); + + bool has_more = false; + while (pos && pos < end && *pos != ']') + { + has_more = false; + switch(*pos) + { + case '{': { + pos = normalizeObject(pos, end, dst); + break; + } + case '"': { + pos = normalizeString(pos, end, dst); + break; + } + case '[': { + pos = normalizeArray(pos, end, dst); + break; + } + default: { + pos = normalizeField(pos, end, dst); + break; + } + } + if (!isExpectedChar(',', pos, end)) + break; + pos = normalizeComma(pos, end, dst); + has_more = true; + } + + if (!isExpectedChar(']', pos, end) || has_more) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeArray. not ]"); + return nullptr; + } + pos += 1; + copyToDst(dst, ']'); + return normalizeWhitespace(pos, end, dst); + } + + inline static const char * normalizeString(const char * pos, const char * end, char * & dst) + { + const auto * start_pos = pos; + if (!isExpectedChar('"', pos, end)) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeString. not \""); + return nullptr; + } + pos += 1; + + pos = find_first_symbols<'"'>(pos, end); + while (*(pos - 1) == '\\' && *pos == '"' && pos < end) + { + pos += 1; + pos = find_first_symbols<'"'>(pos, end); + } + if (!isExpectedChar('"', pos, end)) + return nullptr; + pos += 1; + // LOG_DEBUG(getLogger("1"), "xxx string value: {}", std::string_view(start_pos, pos - start_pos)); + + size_t n = 0; + for (; start_pos != pos; ++start_pos) + { + if ((*start_pos >= 0x00 && *start_pos <= 0x1f) || *start_pos == 0x7f) + { + if (n) + { + copyToDst(dst, start_pos - n, n); + n = 0; + } + continue; + } + else + { + n += 1; + } + } + if (n) + copyToDst(dst, start_pos - n, n); + + return normalizeWhitespace(pos, end, dst); + } + + inline static const char * normalizeField(const char * pos, const char * end, char * & dst) + { + const auto * start_pos = pos; + pos = find_first_symbols<',', '}', ']'>(pos, end); + if (pos >= end) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeField. not field"); + return nullptr; + } + copyToDst(dst, start_pos, pos - start_pos); + return pos;; + } + + inline static const char * normalizeColon(const char * pos, const char * end, char * & dst) + { + pos = normalizeWhitespace(pos, end, dst); + if (!isExpectedChar(':', pos, end)) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeColon. not :"); + return nullptr; + } + pos += 1; + copyToDst(dst, ':'); + return normalizeWhitespace(pos, end, dst); + } + + inline static const char * normalizeComma(const char * pos, const char * end, char * & dst) + { + pos = normalizeWhitespace(pos, end, dst); + if (!isExpectedChar(',', pos, end)) + { + // LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeComma. not ,"); + return nullptr; + } + pos += 1; + copyToDst(dst, ','); + return normalizeWhitespace(pos, end, dst); + } + + inline static const char * normalizeJsonStr(const char * pos, const char * end, char * & dst) + { + pos = normalizeWhitespace(pos, end, dst); + if (!pos || pos >= end) + return nullptr; + if (*pos == '[') + return normalizeArray(pos, end, dst); + else if (*pos == '{') + return normalizeObject(pos, end, dst); + return nullptr; + } + template DB::ColumnPtr innerExecuteImpl(const DB::ColumnsWithTypeAndName & arguments) const { @@ -318,13 +591,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction if (col_json_const) { std::string_view json{reinterpret_cast(chars.data()), offsets[0] - 1}; - document_ok = parser.parse(json, document); - if (!document_ok) - { - char dst[json.size()]; - size_t size = normalizeJson(json, dst); - document_ok = parser.parse(std::string_view(dst, size), document); - } + document_ok = safeParseJson(json, parser, document); } size_t tuple_size = tuple_columns.size(); @@ -340,13 +607,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction if (!col_json_const) { std::string_view json{reinterpret_cast(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1}; - document_ok = parser.parse(json, document); - if (!document_ok) - { - char dst[json.size()]; - size_t size = normalizeJson(json, dst); - document_ok = parser.parse(std::string_view(dst, size), document); - } + document_ok = safeParseJson(json, parser, document); } if (document_ok) {