Skip to content

Commit

Permalink
different results from get_json_object with comparison to vanilla
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 28, 2024
1 parent d4d7241 commit 5d23fb4
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
1{"data": {"id": "Qu001cڜu00cƼ","v": 5}}
2{"data": {"id": "Qu001cڜu00c}Ƽ","v": 5}}
3{"data": {"id": "Qu001cڜu00c\\\"Ƽ","v": 5}}1
4{"data": {"id": "12323\\","v": 5}}123
5{"data": {"id": "12323\"","v": 5}}123
6{"data": {"id": "12323\\\\","v": 5}}123
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,15 @@ class GlutenClickHouseHiveTableSuite
val select_sql_5 = "select id, get_json_object(data, 'v112') from test_tbl_3337"
val select_sql_6 =
"select id, get_json_object(data, '$.id') from test_tbl_3337 where id = 123"
val select_sql_7 =
"select id, get_json_object(data, '$.id') from test_tbl_3337"
compareResultsAgainstVanillaSpark(select_sql_1, compareResult = true, _ => {})
compareResultsAgainstVanillaSpark(select_sql_2, compareResult = true, _ => {})
compareResultsAgainstVanillaSpark(select_sql_3, compareResult = true, _ => {})
compareResultsAgainstVanillaSpark(select_sql_4, compareResult = true, _ => {})
compareResultsAgainstVanillaSpark(select_sql_5, compareResult = true, _ => {})
compareResultsAgainstVanillaSpark(select_sql_6, compareResult = true, _ => {})
compareResultsAgainstVanillaSpark(select_sql_7, compareResult = true, _ => {})

spark.sql("DROP TABLE test_tbl_3337")
}
Expand Down
309 changes: 273 additions & 36 deletions cpp-ch/local-engine/Functions/SparkFunctionGetJsonObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#pragma once
#include <memory>
#include <string_view>
#include <stack>
#include <Columns/ColumnNullable.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeString.h>
Expand All @@ -33,12 +32,14 @@
#include <Parsers/IParser.h>
#include <Parsers/Lexer.h>
#include <Parsers/TokenIterator.h>
#include <base/find_symbols.h>
#include <base/range.h>
#include <Poco/Logger.h>
#include <Poco/StringTokenizer.h>
#include <Common/Exception.h>
#include <Common/JSONParsers/DummyJSONParser.h>
#include <Common/JSONParsers/SimdJSONParser.h>
#include <Common/StringUtils.h>
#include <Common/logger_useful.h>

namespace DB
Expand Down Expand Up @@ -66,6 +67,260 @@ struct GetJsonObject
static constexpr auto name{"get_json_object"};
};

class JSONTextNormalizer
{
public:
// simd json will fail to parse the json text on some cases, see #7014, #3750, #3337, #5303
// To keep the result same with vanilla, we normalize the json string when simd json fails.
// It returns null when normalize the json text fail, otherwise returns a position among `pos`
// and `end` which points to the whole json object end.
// `dst` refer to a memory buffer that is used to store the normalization result.
static const char * normalize(const char * pos, const char * end, char *& dst)
{
pos = normalizeWhitespace(pos, end, dst);
if (!pos || pos >= end)
return nullptr;
if (*pos == '[')
return normalizeArray(pos, end, dst);
else if (*pos == '{')
return normalizeObject(pos, end, dst);
return nullptr;
}

private:
inline static void copyToDst(char *& p, char c)
{
*p = c;
p++;
}

inline static void copyToDst(char *& p, const char * src, size_t len)
{
memcpy(p, src, len);
p += len;
}

inline static bool isExpectedChar(char c, const char * pos, const char * end) { return pos && pos < end && *pos == c; }

inline static const char * normalizeWhitespace(const char * pos, const char * end, char *& dst)
{
const auto * start_pos = pos;
while (pos && pos < end)
{
if (isWhitespaceASCII(*pos))
pos++;
else
break;
}
if (pos != start_pos)
copyToDst(dst, start_pos, pos - start_pos);
return pos;
}

inline static const char * normalizeComma(const char * pos, const char * end, char *& dst)
{
pos = normalizeWhitespace(pos, end, dst);
if (!isExpectedChar(',', pos, end)) [[unlikely]]
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeComma. not ,");
return nullptr;
}
pos += 1;
copyToDst(dst, ',');
return normalizeWhitespace(pos, end, dst);
}

inline static const char * normalizeColon(const char * pos, const char * end, char *& dst)
{
pos = normalizeWhitespace(pos, end, dst);
if (!isExpectedChar(':', pos, end))
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeColon. not :");
return nullptr;
}
pos += 1;
copyToDst(dst, ':');
return normalizeWhitespace(pos, end, dst);
}

inline static const char * normalizeField(const char * pos, const char * end, char *& dst)
{
const auto * start_pos = pos;
pos = find_first_symbols<',', '}', ']'>(pos, end);
if (pos >= end) [[unlikely]]
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeField. not field");
return nullptr;
}
copyToDst(dst, start_pos, pos - start_pos);
return pos;
}

inline static const char * normalizeString(const char * pos, const char * end, char *& dst)
{
const auto * start_pos = pos;
if (!isExpectedChar('"', pos, end)) [[unlikely]]
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeString. not \"");
return nullptr;
}
pos += 1;

do
{
pos = find_first_symbols<'\\', '"'>(pos, end);
if (pos != end && *pos == '\\')
{
// escape charaters. e.g. '\"', '\\'
pos += 2;
if (pos >= end)
return nullptr;
}
else
break;
} while (pos != end);

pos = find_first_symbols<'"'>(pos, end);
if (!isExpectedChar('"', pos, end))
return nullptr;
pos += 1;

size_t n = 0;
for (; start_pos != pos; ++start_pos)
{
if ((*start_pos >= 0x00 && *start_pos <= 0x1f) || *start_pos == 0x7f)
{
if (n)
{
copyToDst(dst, start_pos - n, n);
n = 0;
}
continue;
}
else
{
n += 1;
}
}
if (n)
copyToDst(dst, start_pos - n, n);

return normalizeWhitespace(pos, end, dst);
}

static const char * normalizeArray(const char * pos, const char * end, char *& dst)
{
if (!isExpectedChar('[', pos, end)) [[unlikely]]
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeArray. not [");
return nullptr;
}
pos += 1;
copyToDst(dst, '[');

pos = normalizeWhitespace(pos, end, dst);

bool has_more = false;
while (pos && pos < end && *pos != ']')
{
has_more = false;
switch (*pos)
{
case '{': {
pos = normalizeObject(pos, end, dst);
break;
}
case '"': {
pos = normalizeString(pos, end, dst);
break;
}
case '[': {
pos = normalizeArray(pos, end, dst);
break;
}
default: {
pos = normalizeField(pos, end, dst);
break;
}
}
if (!isExpectedChar(',', pos, end))
break;
pos = normalizeComma(pos, end, dst);
has_more = true;
}

if (!isExpectedChar(']', pos, end) || has_more)
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeArray. not ]");
return nullptr;
}
pos += 1;
copyToDst(dst, ']');
return normalizeWhitespace(pos, end, dst);
}

static const char * normalizeObject(const char * pos, const char * end, char *& dst)
{
if (!isExpectedChar('{', pos, end)) [[unlikely]]
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not object start");
return nullptr;
}
pos += 1;
copyToDst(dst, '{');

bool has_more = false;
while (pos && pos < end && *pos != '}')
{
has_more = false;
pos = normalizeWhitespace(pos, end, dst);

pos = normalizeString(pos, end, dst);

pos = normalizeColon(pos, end, dst);
if (!pos)
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not :");
break;
}

switch (*pos)
{
case '{': {
pos = normalizeObject(pos, end, dst);
break;
}
case '"': {
pos = normalizeString(pos, end, dst);
break;
}
case '[': {
pos = normalizeArray(pos, end, dst);
break;
}
default: {
pos = normalizeField(pos, end, dst);
break;
}
}

if (!isExpectedChar(',', pos, end))
break;
pos = normalizeComma(pos, end, dst);
has_more = true;
}

if (!isExpectedChar('}', pos, end) || has_more)
{
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalizeObject. not object end");
return nullptr;
}
pos += 1;
copyToDst(dst, '}');
return normalizeWhitespace(pos, end, dst);
}
};

template <typename JSONParser, typename JSONStringSerializer>
class GetJsonObjectImpl
{
Expand Down Expand Up @@ -116,6 +371,7 @@ class GetJsonObjectImpl
if (elements[0].isNull())
return false;
nullable_col_str.getNullMapData().push_back(0);

if (elements[0].isString())
{
auto str = elements[0].getString();
Expand Down Expand Up @@ -214,32 +470,25 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction
private:
DB::ContextPtr context;

size_t normalizeJson(std::string_view & json, char * dst) const
template<typename JSONParser>
bool safeParseJson(std::string_view str, JSONParser & parser, JSONParser::Element & doc) const
{
const char * json_chars = json.data();
const size_t json_size = json.size();
std::stack<char> tmp;
size_t new_json_size = 0;
for (size_t i = 0; i <= json_size; ++i)
if (!parser.parse(str, doc)) [[unlikely]]
{
if ((*(json_chars + i) >= 0x00 && *(json_chars + i) <= 0x1F) || *(json_chars + i) == 0x7F)
continue;
else
std::vector<char> buf;
buf.resize(str.size(), 0);
char * buf_pos = buf.data();
const char * pos = JSONTextNormalizer::normalize(str.data(), str.data() + str.size(), buf_pos);
if (!pos)
{
char ch = *(json_chars + i);
dst[new_json_size++] = ch;
if (ch == '{')
tmp.push('{');
else if (ch == '}')
{
if (!tmp.empty() && tmp.top() == '{')
tmp.pop();
}
if (tmp.empty())
break;
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalize failed");
return false;
}
std::string n_str(buf.data(), buf_pos - buf.data());
// LOG_DEBUG(getLogger("GetJsonObject"), "xxx normalize {} to {}", str, n_str);
return parser.parse(n_str, doc);
}
return new_json_size;
return true;
}

template <typename JSONParser, typename Impl>
Expand Down Expand Up @@ -318,13 +567,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction
if (col_json_const)
{
std::string_view json{reinterpret_cast<const char *>(chars.data()), offsets[0] - 1};
document_ok = parser.parse(json, document);
if (!document_ok)
{
char dst[json.size()];
size_t size = normalizeJson(json, dst);
document_ok = parser.parse(std::string_view(dst, size), document);
}
document_ok = safeParseJson(json, parser, document);
}

size_t tuple_size = tuple_columns.size();
Expand All @@ -340,13 +583,7 @@ class FlattenJSONStringOnRequiredFunction : public DB::IFunction
if (!col_json_const)
{
std::string_view json{reinterpret_cast<const char *>(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1};
document_ok = parser.parse(json, document);
if (!document_ok)
{
char dst[json.size()];
size_t size = normalizeJson(json, dst);
document_ok = parser.parse(std::string_view(dst, size), document);
}
document_ok = safeParseJson(json, parser, document);
}
if (document_ok)
{
Expand Down
Loading

0 comments on commit 5d23fb4

Please sign in to comment.