Skip to content

Commit

Permalink
Add support for canonicalization of JSON.
Browse files Browse the repository at this point in the history
  • Loading branch information
kgpai committed Oct 30, 2024
1 parent 92779f9 commit 4b8671b
Show file tree
Hide file tree
Showing 4 changed files with 319 additions and 24 deletions.
228 changes: 213 additions & 15 deletions velox/functions/prestosql/JsonFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,154 @@
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"
#include "velox/functions/prestosql/json/JsonStringUtil.h"
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/type/Conversions.h"

namespace facebook::velox::functions {

namespace {
constexpr const char kArrayStart = '[';
constexpr const char kArrayEnd = ']';
constexpr const char kSeparator = ',';
constexpr const char kObjectStart = '{';
constexpr const char kObjectEnd = '}';
constexpr const char kObjectKeySeparator = ':';
constexpr const char kQuote = '\"';

/// Class to keep track of json strings being written
/// in to a buffer. The size of the backing buffer must be known during
/// construction time.
class BufferTracker {
public:
BufferTracker(size_t bufferSize, memory::MemoryPool* pool)
: curPos_(0), currentViewStart_(0) {
buffer_ = AlignedBuffer::allocate<char>(bufferSize, pool);
bufPtr_ = buffer_->asMutable<char>();
}

/// Utility function to trim and escape a string view input.
/// Trims whitespace and escapes utf characters before writing to buffer.
void trimEscapeWriteToBuffer(StringView input) {
auto trimmed = velox::util::trimWhiteSpace(input.data(), input.size());
auto curBufPtr = getCurrentBufferPtr();
auto bytesWritten =
escapeString(trimmed.data(), trimmed.size(), curBufPtr, true);
incrementCounter(bytesWritten);
}

/// Utility function to write char to buffer.
void writeChar(char value) {
auto curBufPtr = getCurrentBufferPtr();
*curBufPtr++ = value;
curPos_++;
}

/// Returns current string view against the buffer.
StringView getStringView() {
return StringView(bufPtr_ + currentViewStart_, curPos_);
}

/// Sets current view to the end of the previous string.
/// Should be called only after getStringView ,
/// as after this call the previous view is lost.
void startNewString() {
currentViewStart_ += curPos_;
curPos_ = 0;
}

/// Returns the underlying buffer where the json strings are saved.
BufferPtr getUnderlyingBuffer() {
return buffer_;
}

private:
inline char* getCurrentBufferPtr() {
return bufPtr_ + currentViewStart_ + curPos_;
}

void incrementCounter(size_t increment) {
VELOX_CHECK_LE(
curPos_ + currentViewStart_ + increment, buffer_->capacity());
curPos_ += increment;
}

BufferPtr buffer_;
size_t curPos_;
size_t currentViewStart_;
char* bufPtr_;
};

class JsonView {
public:
virtual void canonicalize(BufferTracker& buffer) = 0;
};

using JsonViewPtr = std::shared_ptr<JsonView>;

struct JsonLeafView : public JsonView {
JsonLeafView(const StringView view) : view_(view){};

void canonicalize(BufferTracker& buffer) override {
buffer.trimEscapeWriteToBuffer(view_);
}

private:
const StringView view_;
};

struct JsonArrayView : public JsonView {
JsonArrayView(const std::vector<JsonViewPtr> array) : array_(array){};

void canonicalize(BufferTracker& buffer) override {
buffer.writeChar(kArrayStart);
for (auto i = 0; i < array_.size(); i++) {
array_[i]->canonicalize(buffer);
if (i < array_.size() - 1) {
buffer.writeChar(kSeparator);
}
}
buffer.writeChar(kArrayEnd);
}

private:
const std::vector<JsonViewPtr> array_;
};

struct JsonObjView : public JsonView {
JsonObjView(std::vector<std::pair<StringView, JsonViewPtr>> objFields)
: objFields_(objFields){};

void canonicalize(BufferTracker& buffer) override {
std::sort(objFields_.begin(), objFields_.end(), [](auto& a, auto& b) {
return a.first < b.first;
});

buffer.writeChar(kObjectStart);

for (auto i = 0; i < objFields_.size(); i++) {
auto field = objFields_[i];
buffer.writeChar(kQuote);
buffer.trimEscapeWriteToBuffer(field.first);
buffer.writeChar(kQuote);
buffer.writeChar(kObjectKeySeparator);

field.second->canonicalize(buffer);
if (i < objFields_.size() - 1) {
buffer.writeChar(kSeparator);
}
}

buffer.writeChar(kObjectEnd);
}

private:
std::vector<std::pair<StringView, JsonViewPtr>> objFields_;
};

} // namespace

namespace {
class JsonFormatFunction : public exec::VectorFunction {
public:
Expand Down Expand Up @@ -84,38 +227,63 @@ class JsonParseFunction : public exec::VectorFunction {
auto value = arg->as<ConstantVector<StringView>>()->valueAt(0);
paddedInput_.resize(value.size() + simdjson::SIMDJSON_PADDING);
memcpy(paddedInput_.data(), value.data(), value.size());
if (auto error = parse(value.size())) {
auto escapeSize =
escapedStringSize(paddedInput_.data(), paddedInput_.length(), true);
BufferTracker bufferTracker{escapeSize, context.pool()};

JsonViewPtr jsonView;

if (auto error = parse(value.size(), jsonView)) {
context.setErrors(rows, errors_[error]);
return;
}
localResult = std::make_shared<ConstantVector<StringView>>(
context.pool(), rows.end(), false, JSON(), std::move(value));

jsonView->canonicalize(bufferTracker);
auto canonicalString = bufferTracker.getStringView();
localResult = BaseVector::createConstant(
JSON(), canonicalString, rows.end(), context.pool());

} else {
auto flatInput = arg->asFlatVector<StringView>();
BufferPtr stringViews = AlignedBuffer::allocate<StringView>(
rows.end(), context.pool(), StringView());
auto rawStringViews = stringViews->asMutable<StringView>();

auto stringBuffers = flatInput->stringBuffers();
VELOX_CHECK_LE(rows.end(), flatInput->size());

size_t maxSize = 0;
size_t totalOutputSize = 0;
rows.applyToSelected([&](auto row) {
auto value = flatInput->valueAt(row);
maxSize = std::max(maxSize, value.size());
totalOutputSize += escapedStringSize(value.data(), value.size(), true);
});

paddedInput_.resize(maxSize + simdjson::SIMDJSON_PADDING);
BufferTracker buffer{totalOutputSize, context.pool()};

JsonViewPtr jsonView;

rows.applyToSelected([&](auto row) {
auto value = flatInput->valueAt(row);
memcpy(paddedInput_.data(), value.data(), value.size());
if (auto error = parse(value.size())) {
if (auto error = parse(value.size(), jsonView)) {
context.setVeloxExceptionError(row, errors_[error]);
} else {
jsonView->canonicalize(buffer);
auto canonicalString = buffer.getStringView();
rawStringViews[row] = canonicalString;
buffer.startNewString();
}
});

localResult = std::make_shared<FlatVector<StringView>>(
context.pool(),
JSON(),
nullptr,
rows.end(),
flatInput->values(),
std::move(stringBuffers));
stringViews,
std::vector<BufferPtr>{buffer.getUnderlyingBuffer()});
}

context.moveOrCopyResult(localResult, rows, result);
Expand All @@ -130,45 +298,75 @@ class JsonParseFunction : public exec::VectorFunction {
}

private:
simdjson::error_code parse(size_t size) const {
simdjson::error_code parse(size_t size, JsonViewPtr& jsonView) const {
simdjson::padded_string_view paddedInput(
paddedInput_.data(), size, paddedInput_.size());
SIMDJSON_ASSIGN_OR_RAISE(auto doc, simdjsonParse(paddedInput));
SIMDJSON_TRY(validate<simdjson::ondemand::document&>(doc));
SIMDJSON_TRY(validate<simdjson::ondemand::document&>(doc, jsonView));
if (!doc.at_end()) {
return simdjson::TRAILING_CONTENT;
}
return simdjson::SUCCESS;
}

template <typename T>
static simdjson::error_code validate(T value) {
static simdjson::error_code validate(T value, JsonViewPtr& jsonView) {
SIMDJSON_ASSIGN_OR_RAISE(auto type, value.type());
switch (type) {
case simdjson::ondemand::json_type::array: {
SIMDJSON_ASSIGN_OR_RAISE(auto array, value.get_array());

std::vector<JsonViewPtr> arrayPtr;
for (auto elementOrError : array) {
SIMDJSON_ASSIGN_OR_RAISE(auto element, elementOrError);
SIMDJSON_TRY(validate(element));
JsonViewPtr elementPtr;
SIMDJSON_TRY(validate(element, elementPtr));
arrayPtr.push_back(elementPtr);
}

jsonView = std::make_shared<JsonArrayView>(std::move(arrayPtr));
return simdjson::SUCCESS;
}
case simdjson::ondemand::json_type::object: {
SIMDJSON_ASSIGN_OR_RAISE(auto object, value.get_object());

std::vector<std::pair<StringView, JsonViewPtr>> objFields;
for (auto fieldOrError : object) {
SIMDJSON_ASSIGN_OR_RAISE(auto field, fieldOrError);
SIMDJSON_TRY(validate(field.value()));
JsonViewPtr elementPtr;
auto key = StringView(field.escaped_key());
SIMDJSON_TRY(validate(field.value(), elementPtr));
objFields.push_back({key, elementPtr});
}

jsonView = std::make_shared<JsonObjView>(objFields);
return simdjson::SUCCESS;
}
case simdjson::ondemand::json_type::number:
case simdjson::ondemand::json_type::number: {
std::string_view rawJsonv = value.raw_json_token();

jsonView = std::make_shared<JsonLeafView>(StringView(rawJsonv));
return value.get_double().error();
case simdjson::ondemand::json_type::string:
}
case simdjson::ondemand::json_type::string: {
auto rawJsonv = StringView(value.raw_json_token());

jsonView = std::make_shared<JsonLeafView>(rawJsonv);
return value.get_string().error();
case simdjson::ondemand::json_type::boolean:
}

case simdjson::ondemand::json_type::boolean: {
auto rawJsonv = StringView(value.raw_json_token());

jsonView = std::make_shared<JsonLeafView>(rawJsonv);
return value.get_bool().error();
}

case simdjson::ondemand::json_type::null: {
SIMDJSON_ASSIGN_OR_RAISE(auto isNull, value.is_null());
auto rawJsonv = StringView(value.raw_json_token());

jsonView = std::make_shared<JsonLeafView>(rawJsonv);
return isNull ? simdjson::SUCCESS : simdjson::N_ATOM_ERROR;
}
}
Expand Down
21 changes: 17 additions & 4 deletions velox/functions/prestosql/json/JsonStringUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ void testingEncodeUtf16Hex(char32_t codePoint, char*& out) {
encodeUtf16Hex(codePoint, out);
}

void escapeString(const char* input, size_t length, char* output) {
size_t
escapeString(const char* input, size_t length, char* output, bool skipAscii) {
char* pos = output;

auto* start = reinterpret_cast<const unsigned char*>(input);
Expand All @@ -117,7 +118,12 @@ void escapeString(const char* input, size_t length, char* output) {
int count = validateAndGetNextUtf8Length(start, end);
switch (count) {
case 1: {
encodeAscii(int8_t(*start), pos);
if (!skipAscii) {
encodeAscii(int8_t(*start), pos);
} else {
*pos++ = *start;
}

start++;
continue;
}
Expand Down Expand Up @@ -148,9 +154,11 @@ void escapeString(const char* input, size_t length, char* output) {
}
}
}

return (pos - output);
}

size_t escapedStringSize(const char* input, size_t length) {
size_t escapedStringSize(const char* input, size_t length, bool skipAscii) {
// 6 chars that is returned by `writeHex`.
constexpr size_t kEncodedHexSize = 6;

Expand All @@ -162,7 +170,12 @@ size_t escapedStringSize(const char* input, size_t length) {
int count = validateAndGetNextUtf8Length(start, end);
switch (count) {
case 1:
outSize += encodedAsciiSizes[int8_t(*start)];
if (!skipAscii) {
outSize += encodedAsciiSizes[int8_t(*start)];
} else {
outSize++;
}

break;
case 2:
case 3:
Expand Down
14 changes: 12 additions & 2 deletions velox/functions/prestosql/json/JsonStringUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,24 @@ namespace facebook::velox {
/// @param length: Length of the input string.
/// @param output: Output string to write the escaped input to. The caller is
/// responsible to allocate enough space for output.
void escapeString(const char* input, size_t length, char* output);
/// @param skipAscii: Do not consider ascii characters for encoding (used in
/// json_parse for example).
/// @return The number of bytes written to the output.
size_t escapeString(
const char* input,
size_t length,
char* output,
bool skipAscii = false);

/// Return the size of string after the unicode characters of `input` are
/// escaped using the method as in`escapeString`. The function will iterate
/// over `input` once.
/// @param input: Input string to escape that is UTF-8 encoded.
/// @param length: Length of the input string.
size_t escapedStringSize(const char* input, size_t length);
/// @param skipAscii: Do not consider ascii characters for encoding (used in
/// json_parse for example).
size_t
escapedStringSize(const char* input, size_t length, bool skipAscii = false);

/// For test only. Encode `codePoint` value by UTF-16 and write the one or two
/// prefixed hexadecimals to `out`. Move `out` forward by 6 or 12 chars
Expand Down
Loading

0 comments on commit 4b8671b

Please sign in to comment.