Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 27, 2024
1 parent d4d7241 commit 0f36e3c
Showing 1 changed file with 277 additions and 16 deletions.
293 changes: 277 additions & 16 deletions cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/
#pragma once
#include <memory>
#include <string_view>
#include <stack>
#include <string_view>
#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeString.h>
Expand All @@ -33,12 +33,14 @@
#include <Parsers/IParser.h>
#include <Parsers/Lexer.h>
#include <Parsers/TokenIterator.h>
#include <base/find_symbols.h>
#include <base/range.h>
#include <Poco/Logger.h>
#include <Poco/StringTokenizer.h>
#include <Common/Exception.h>
#include <Common/JSONParsers/DummyJSONParser.h>
#include <Common/JSONParsers/SimdJSONParser.h>
#include <Common/StringUtils.h>
#include <Common/logger_useful.h>

namespace DB
Expand Down Expand Up @@ -214,7 +216,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction
private:
DB::ContextPtr context;

size_t normalizeJson(std::string_view & json, char * dst) const
size_t normalizeJson(const std::string_view & json, char * dst) const
{
const char * json_chars = json.data();
const size_t json_size = json.size();
Expand Down Expand Up @@ -242,6 +244,277 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction
return new_json_size;
}

template<typename JSONParser>
bool safeParseJson(std::string_view str, JSONParser & parser, JSONParser::Element & doc) const
{
#if 0
// '{"a":"b"}acc' is invalid in simdjson, but it OK for spark.
if(!parser.parse(str, doc)) [[unlikely]]
{
// This is should not be a normal case, so we only try to normalize a json str when it fails.
char dst[str.size()];
size_t len = normalizeJson(str, dst);
auto res = parser.parse(std::string_view(dst, len), doc);
LOG_ERROR(getLogger("GetJsonObject"), "xxx failed to parse: {}, len: {}. adjust to:\n{}, res: {}", str, str.size(), std::string_view(dst, len), res);
return res;
}
LOG_ERROR(getLogger("GetJsonObject"), "xxx parse ok by simdjson");
return true;
#else
std::vector<char> buf;
buf.resize(str.size(), 0);
char * buf_pos = buf.data();
const char * pos = normalizeJsonStr(str.data(), str.data() + str.size(), buf_pos);
if (!pos)
{
LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalize json string failed");
return false;
}
std::string n_str(buf.data(), buf_pos - buf.data());
LOG_ERROR(getLogger("GetJsonObject"), "xxx new str: {}/ {}", str, n_str);
return parser.parse(n_str, doc);
#endif
}

inline static void copyToDst(char * & p, char c)
{
*p = c;
p++;
}

inline static void copyToDst(char * & p, const char * src, size_t len)
{
memcpy(p, src, len);
p += len;
}

inline static bool isExpectedChar(char c, const char * pos, const char * end)
{
return pos && pos < end && *pos == c;
}

static const char * normalizeWhitespace(const char * pos, const char * end, char * & dst)
{
const auto * start_pos = pos;
while(pos && pos < end)
{
if (isWhitespaceASCII(*pos))
pos++;
else
break;
}
if (pos != start_pos)
copyToDst(dst, start_pos, pos - start_pos);
return pos;
}

static const char * normalizeObject(const char * pos, const char * end, char * & dst)
{
if (!isExpectedChar('{', pos, end)) [[unlikely]]
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not object start");
return nullptr;
}
pos += 1;
copyToDst(dst, '{');

bool has_more = false;
while(pos && pos < end && *pos != '}')
{
has_more = false;
pos = normalizeWhitespace(pos, end, dst);

pos = normalizeString(pos, end, dst);

pos = normalizeColon(pos, end, dst);
if (!pos)
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not :");
break;
}

switch(*pos)
{
case '{': {
pos = normalizeObject(pos, end, dst);
break;
}
case '"': {
pos = normalizeString(pos, end, dst);
break;
}
case '[': {
pos = normalizeArray(pos, end, dst);
break;
}
default: {
pos = normalizeField(pos, end, dst);
break;
}
}

if (!isExpectedChar(',', pos, end))
break;
pos = normalizeComma(pos, end, dst);
has_more = true;
}

if (!isExpectedChar('}', pos, end) || has_more)
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not object end");
return nullptr;
}
pos += 1;
copyToDst(dst, '}');
return normalizeWhitespace(pos, end, dst);
}

static const char * normalizeArray(const char * pos, const char * end, char * & dst)
{
if (!isExpectedChar('[', pos, end))
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeArray. not [");
return nullptr;
}
pos += 1;
copyToDst(dst, '[');

pos = normalizeWhitespace(pos, end, dst);

bool has_more = false;
while (pos && pos < end && *pos != ']')
{
has_more = false;
switch(*pos)
{
case '{': {
pos = normalizeObject(pos, end, dst);
break;
}
case '"': {
pos = normalizeString(pos, end, dst);
break;
}
case '[': {
pos = normalizeArray(pos, end, dst);
break;
}
default: {
pos = normalizeField(pos, end, dst);
break;
}
}
if (!isExpectedChar(',', pos, end))
break;
pos = normalizeComma(pos, end, dst);
has_more = true;
}

if (!isExpectedChar(']', pos, end) || has_more)
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeArray. not ]");
return nullptr;
}
pos += 1;
copyToDst(dst, ']');
return normalizeWhitespace(pos, end, dst);
}

inline static const char * normalizeString(const char * pos, const char * end, char * & dst)
{
const auto * start_pos = pos;
if (!isExpectedChar('"', pos, end))
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeString. not \"");
return nullptr;
}
pos += 1;

pos = find_first_symbols<'"'>(pos, end);
while (*(pos - 1) == '\\' && *pos == '"' && pos < end)
{
pos += 1;
pos = find_first_symbols<'"'>(pos, end);
}
if (!isExpectedChar('"', pos, end))
return nullptr;
pos += 1;
// LOG_DEBUG(getLogger("1"), "xxx string value: {}", std::string_view(start_pos, pos - start_pos));

size_t n = 0;
for (; start_pos != pos; ++start_pos)
{
if ((*start_pos >= 0x00 && *start_pos <= 0x1f) || *start_pos == 0x7f)
{
if (n)
{
copyToDst(dst, start_pos - n, n);
n = 0;
}
continue;
}
else
{
n += 1;
}
}
if (n)
copyToDst(dst, start_pos - n, n);

return normalizeWhitespace(pos, end, dst);
}

inline static const char * normalizeField(const char * pos, const char * end, char * & dst)
{
const auto * start_pos = pos;
pos = find_first_symbols<',', '}', ']'>(pos, end);
if (pos >= end)
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeField. not field");
return nullptr;
}
copyToDst(dst, start_pos, pos - start_pos);
return pos;;
}

inline static const char * normalizeColon(const char * pos, const char * end, char * & dst)
{
pos = normalizeWhitespace(pos, end, dst);
if (!isExpectedChar(':', pos, end))
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeColon. not :");
return nullptr;
}
pos += 1;
copyToDst(dst, ':');
return normalizeWhitespace(pos, end, dst);
}

inline static const char * normalizeComma(const char * pos, const char * end, char * & dst)
{
pos = normalizeWhitespace(pos, end, dst);
if (!isExpectedChar(',', pos, end))
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeComma. not ,");
return nullptr;
}
pos += 1;
copyToDst(dst, ',');
return normalizeWhitespace(pos, end, dst);
}

inline static const char * normalizeJsonStr(const char * pos, const char * end, char * & dst)
{
pos = normalizeWhitespace(pos, end, dst);
if (!pos || pos >= end)
return nullptr;
if (*pos == '[')
return normalizeArray(pos, end, dst);
else if (*pos == '{')
return normalizeObject(pos, end, dst);
return nullptr;
}

template <typename JSONParser, typename Impl>
DB::ColumnPtr innerExecuteImpl(const DB::ColumnsWithTypeAndName & arguments) const
{
Expand Down Expand Up @@ -318,13 +591,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction
if (col_json_const)
{
std::string_view json{reinterpret_cast<const char *>(chars.data()), offsets[0] - 1};
document_ok = parser.parse(json, document);
if (!document_ok)
{
char dst[json.size()];
size_t size = normalizeJson(json, dst);
document_ok = parser.parse(std::string_view(dst, size), document);
}
document_ok = safeParseJson(json, parser, document);
}

size_t tuple_size = tuple_columns.size();
Expand All @@ -340,13 +607,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction
if (!col_json_const)
{
std::string_view json{reinterpret_cast<const char *>(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1};
document_ok = parser.parse(json, document);
if (!document_ok)
{
char dst[json.size()];
size_t size = normalizeJson(json, dst);
document_ok = parser.parse(std::string_view(dst, size), document);
}
document_ok = safeParseJson(json, parser, document);
}
if (document_ok)
{
Expand Down

0 comments on commit 0f36e3c

Please sign in to comment.